feat: implement memtable package
This commit is contained in:
parent
b7fb76fd54
commit
b1f9ed6740
26
TODO.md
26
TODO.md
@ -35,21 +35,21 @@ This document outlines the implementation tasks for the Go Storage Engine, organ
|
|||||||
|
|
||||||
## Phase B: In-Memory Layer
|
## Phase B: In-Memory Layer
|
||||||
|
|
||||||
- [ ] Implement MemTable
|
- [✓] Implement MemTable
|
||||||
- [ ] Create skip list data structure aligned to 64-byte cache lines
|
- [✓] Create skip list data structure aligned to 64-byte cache lines
|
||||||
- [ ] Add key/value insertion and lookup operations
|
- [✓] Add key/value insertion and lookup operations
|
||||||
- [ ] Implement sorted key iteration
|
- [✓] Implement sorted key iteration
|
||||||
- [ ] Add size tracking for flush threshold detection
|
- [✓] Add size tracking for flush threshold detection
|
||||||
|
|
||||||
- [ ] Connect WAL replay to MemTable
|
- [✓] Connect WAL replay to MemTable
|
||||||
- [ ] Create recovery logic to rebuild MemTable from WAL
|
- [✓] Create recovery logic to rebuild MemTable from WAL
|
||||||
- [ ] Implement consistent snapshot reads during recovery
|
- [✓] Implement consistent snapshot reads during recovery
|
||||||
- [ ] Handle errors during replay with appropriate fallbacks
|
- [✓] Handle errors during replay with appropriate fallbacks
|
||||||
|
|
||||||
- [ ] Test concurrent read/write scenarios
|
- [✓] Test concurrent read/write scenarios
|
||||||
- [ ] Verify reader isolation during writes
|
- [✓] Verify reader isolation during writes
|
||||||
- [ ] Test snapshot consistency guarantees
|
- [✓] Test snapshot consistency guarantees
|
||||||
- [ ] Benchmark read/write performance under load
|
- [✓] Benchmark read/write performance under load
|
||||||
|
|
||||||
## Phase C: Persistent Storage
|
## Phase C: Persistent Storage
|
||||||
|
|
||||||
|
132
pkg/memtable/bench_test.go
Normal file
132
pkg/memtable/bench_test.go
Normal file
@ -0,0 +1,132 @@
|
|||||||
|
package memtable
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"math/rand"
|
||||||
|
"strconv"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func BenchmarkSkipListInsert(b *testing.B) {
|
||||||
|
sl := NewSkipList()
|
||||||
|
|
||||||
|
// Create random keys ahead of time
|
||||||
|
keys := make([][]byte, b.N)
|
||||||
|
values := make([][]byte, b.N)
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
keys[i] = []byte(fmt.Sprintf("key-%d", i))
|
||||||
|
values[i] = []byte(fmt.Sprintf("value-%d", i))
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
e := newEntry(keys[i], values[i], TypeValue, uint64(i))
|
||||||
|
sl.Insert(e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkSkipListFind(b *testing.B) {
|
||||||
|
sl := NewSkipList()
|
||||||
|
|
||||||
|
// Insert entries first
|
||||||
|
const numEntries = 100000
|
||||||
|
keys := make([][]byte, numEntries)
|
||||||
|
for i := 0; i < numEntries; i++ {
|
||||||
|
key := []byte(fmt.Sprintf("key-%d", i))
|
||||||
|
value := []byte(fmt.Sprintf("value-%d", i))
|
||||||
|
keys[i] = key
|
||||||
|
sl.Insert(newEntry(key, value, TypeValue, uint64(i)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create random keys for lookup
|
||||||
|
lookupKeys := make([][]byte, b.N)
|
||||||
|
r := rand.New(rand.NewSource(42)) // Use fixed seed for reproducibility
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
idx := r.Intn(numEntries)
|
||||||
|
lookupKeys[i] = keys[idx]
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
sl.Find(lookupKeys[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkMemTablePut(b *testing.B) {
|
||||||
|
mt := NewMemTable()
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
key := []byte("key-" + strconv.Itoa(i))
|
||||||
|
value := []byte("value-" + strconv.Itoa(i))
|
||||||
|
mt.Put(key, value, uint64(i))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkMemTableGet(b *testing.B) {
|
||||||
|
mt := NewMemTable()
|
||||||
|
|
||||||
|
// Insert entries first
|
||||||
|
const numEntries = 100000
|
||||||
|
keys := make([][]byte, numEntries)
|
||||||
|
for i := 0; i < numEntries; i++ {
|
||||||
|
key := []byte(fmt.Sprintf("key-%d", i))
|
||||||
|
value := []byte(fmt.Sprintf("value-%d", i))
|
||||||
|
keys[i] = key
|
||||||
|
mt.Put(key, value, uint64(i))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create random keys for lookup
|
||||||
|
lookupKeys := make([][]byte, b.N)
|
||||||
|
r := rand.New(rand.NewSource(42)) // Use fixed seed for reproducibility
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
idx := r.Intn(numEntries)
|
||||||
|
lookupKeys[i] = keys[idx]
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
mt.Get(lookupKeys[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkMemPoolGet(b *testing.B) {
|
||||||
|
cfg := createTestConfig()
|
||||||
|
cfg.MemTableSize = 1024 * 1024 * 32 // 32MB for benchmark
|
||||||
|
pool := NewMemTablePool(cfg)
|
||||||
|
|
||||||
|
// Create multiple memtables with entries
|
||||||
|
const entriesPerTable = 50000
|
||||||
|
const numTables = 3
|
||||||
|
keys := make([][]byte, entriesPerTable*numTables)
|
||||||
|
|
||||||
|
// Fill tables
|
||||||
|
for t := 0; t < numTables; t++ {
|
||||||
|
// Fill a table
|
||||||
|
for i := 0; i < entriesPerTable; i++ {
|
||||||
|
idx := t*entriesPerTable + i
|
||||||
|
key := []byte(fmt.Sprintf("key-%d", idx))
|
||||||
|
value := []byte(fmt.Sprintf("value-%d", idx))
|
||||||
|
keys[idx] = key
|
||||||
|
pool.Put(key, value, uint64(idx))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Switch to a new memtable (except for last one)
|
||||||
|
if t < numTables-1 {
|
||||||
|
pool.SwitchToNewMemTable()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create random keys for lookup
|
||||||
|
lookupKeys := make([][]byte, b.N)
|
||||||
|
r := rand.New(rand.NewSource(42)) // Use fixed seed for reproducibility
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
idx := r.Intn(entriesPerTable * numTables)
|
||||||
|
lookupKeys[i] = keys[idx]
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
pool.Get(lookupKeys[i])
|
||||||
|
}
|
||||||
|
}
|
166
pkg/memtable/mempool.go
Normal file
166
pkg/memtable/mempool.go
Normal file
@ -0,0 +1,166 @@
|
|||||||
|
package memtable
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.canoozie.net/jer/go-storage/pkg/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MemTablePool manages a pool of MemTables
|
||||||
|
// It maintains one active MemTable and a set of immutable MemTables
|
||||||
|
type MemTablePool struct {
|
||||||
|
cfg *config.Config
|
||||||
|
active *MemTable
|
||||||
|
immutables []*MemTable
|
||||||
|
maxAge time.Duration
|
||||||
|
maxSize int64
|
||||||
|
totalSize int64
|
||||||
|
flushPending atomic.Bool
|
||||||
|
mu sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMemTablePool creates a new MemTable pool
|
||||||
|
func NewMemTablePool(cfg *config.Config) *MemTablePool {
|
||||||
|
return &MemTablePool{
|
||||||
|
cfg: cfg,
|
||||||
|
active: NewMemTable(),
|
||||||
|
immutables: make([]*MemTable, 0, cfg.MaxMemTables-1),
|
||||||
|
maxAge: time.Duration(cfg.MaxMemTableAge) * time.Second,
|
||||||
|
maxSize: cfg.MemTableSize,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Put adds a key-value pair to the active MemTable
|
||||||
|
func (p *MemTablePool) Put(key, value []byte, seqNum uint64) {
|
||||||
|
p.mu.RLock()
|
||||||
|
p.active.Put(key, value, seqNum)
|
||||||
|
p.mu.RUnlock()
|
||||||
|
|
||||||
|
// Check if we need to flush after this write
|
||||||
|
p.checkFlushConditions()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete marks a key as deleted in the active MemTable
|
||||||
|
func (p *MemTablePool) Delete(key []byte, seqNum uint64) {
|
||||||
|
p.mu.RLock()
|
||||||
|
p.active.Delete(key, seqNum)
|
||||||
|
p.mu.RUnlock()
|
||||||
|
|
||||||
|
// Check if we need to flush after this write
|
||||||
|
p.checkFlushConditions()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get retrieves the value for a key from all MemTables
|
||||||
|
// Checks the active MemTable first, then the immutables in reverse order
|
||||||
|
func (p *MemTablePool) Get(key []byte) ([]byte, bool) {
|
||||||
|
p.mu.RLock()
|
||||||
|
defer p.mu.RUnlock()
|
||||||
|
|
||||||
|
// Check active table first
|
||||||
|
if value, found := p.active.Get(key); found {
|
||||||
|
return value, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check immutable tables in reverse order (newest first)
|
||||||
|
for i := len(p.immutables) - 1; i >= 0; i-- {
|
||||||
|
if value, found := p.immutables[i].Get(key); found {
|
||||||
|
return value, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// ImmutableCount returns the number of immutable MemTables
|
||||||
|
func (p *MemTablePool) ImmutableCount() int {
|
||||||
|
p.mu.RLock()
|
||||||
|
defer p.mu.RUnlock()
|
||||||
|
return len(p.immutables)
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkFlushConditions checks if we need to flush the active MemTable
|
||||||
|
func (p *MemTablePool) checkFlushConditions() {
|
||||||
|
needsFlush := false
|
||||||
|
|
||||||
|
p.mu.RLock()
|
||||||
|
defer p.mu.RUnlock()
|
||||||
|
|
||||||
|
// Skip if a flush is already pending
|
||||||
|
if p.flushPending.Load() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check size condition
|
||||||
|
if p.active.ApproximateSize() >= p.maxSize {
|
||||||
|
needsFlush = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check age condition
|
||||||
|
if p.maxAge > 0 && p.active.Age() > p.maxAge.Seconds() {
|
||||||
|
needsFlush = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mark as needing flush if conditions met
|
||||||
|
if needsFlush {
|
||||||
|
p.flushPending.Store(true)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SwitchToNewMemTable makes the active MemTable immutable and creates a new active one
|
||||||
|
// Returns the immutable MemTable that needs to be flushed
|
||||||
|
func (p *MemTablePool) SwitchToNewMemTable() *MemTable {
|
||||||
|
p.mu.Lock()
|
||||||
|
defer p.mu.Unlock()
|
||||||
|
|
||||||
|
// Reset the flush pending flag
|
||||||
|
p.flushPending.Store(false)
|
||||||
|
|
||||||
|
// Make the current active table immutable
|
||||||
|
oldActive := p.active
|
||||||
|
oldActive.SetImmutable()
|
||||||
|
|
||||||
|
// Create a new active table
|
||||||
|
p.active = NewMemTable()
|
||||||
|
|
||||||
|
// Add the old table to the immutables list
|
||||||
|
p.immutables = append(p.immutables, oldActive)
|
||||||
|
|
||||||
|
// Return the table that needs to be flushed
|
||||||
|
return oldActive
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetImmutablesForFlush returns a list of immutable MemTables ready for flushing
|
||||||
|
// and removes them from the pool
|
||||||
|
func (p *MemTablePool) GetImmutablesForFlush() []*MemTable {
|
||||||
|
p.mu.Lock()
|
||||||
|
defer p.mu.Unlock()
|
||||||
|
|
||||||
|
result := p.immutables
|
||||||
|
p.immutables = make([]*MemTable, 0, p.cfg.MaxMemTables-1)
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsFlushNeeded returns true if a flush is needed
|
||||||
|
func (p *MemTablePool) IsFlushNeeded() bool {
|
||||||
|
return p.flushPending.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetNextSequenceNumber returns the next sequence number to use
|
||||||
|
func (p *MemTablePool) GetNextSequenceNumber() uint64 {
|
||||||
|
p.mu.RLock()
|
||||||
|
defer p.mu.RUnlock()
|
||||||
|
return p.active.GetNextSequenceNumber()
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMemTables returns all MemTables (active and immutable)
|
||||||
|
func (p *MemTablePool) GetMemTables() []*MemTable {
|
||||||
|
p.mu.RLock()
|
||||||
|
defer p.mu.RUnlock()
|
||||||
|
|
||||||
|
result := make([]*MemTable, 0, len(p.immutables)+1)
|
||||||
|
result = append(result, p.active)
|
||||||
|
result = append(result, p.immutables...)
|
||||||
|
return result
|
||||||
|
}
|
225
pkg/memtable/mempool_test.go
Normal file
225
pkg/memtable/mempool_test.go
Normal file
@ -0,0 +1,225 @@
|
|||||||
|
package memtable
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.canoozie.net/jer/go-storage/pkg/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
func createTestConfig() *config.Config {
|
||||||
|
cfg := config.NewDefaultConfig("/tmp/db")
|
||||||
|
cfg.MemTableSize = 1024 // Small size for testing
|
||||||
|
cfg.MaxMemTableAge = 1 // 1 second
|
||||||
|
cfg.MaxMemTables = 4 // Allow up to 4 memtables
|
||||||
|
cfg.MemTablePoolCap = 4 // Pool capacity
|
||||||
|
return cfg
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMemPoolBasicOperations(t *testing.T) {
|
||||||
|
cfg := createTestConfig()
|
||||||
|
pool := NewMemTablePool(cfg)
|
||||||
|
|
||||||
|
// Test Put and Get
|
||||||
|
pool.Put([]byte("key1"), []byte("value1"), 1)
|
||||||
|
|
||||||
|
value, found := pool.Get([]byte("key1"))
|
||||||
|
if !found {
|
||||||
|
t.Fatalf("expected to find key1, but got not found")
|
||||||
|
}
|
||||||
|
if string(value) != "value1" {
|
||||||
|
t.Errorf("expected value1, got %s", string(value))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test Delete
|
||||||
|
pool.Delete([]byte("key1"), 2)
|
||||||
|
|
||||||
|
value, found = pool.Get([]byte("key1"))
|
||||||
|
if !found {
|
||||||
|
t.Fatalf("expected tombstone to be found for key1")
|
||||||
|
}
|
||||||
|
if value != nil {
|
||||||
|
t.Errorf("expected nil value for deleted key, got %v", value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMemPoolSwitchMemTable(t *testing.T) {
|
||||||
|
cfg := createTestConfig()
|
||||||
|
pool := NewMemTablePool(cfg)
|
||||||
|
|
||||||
|
// Add data to the active memtable
|
||||||
|
pool.Put([]byte("key1"), []byte("value1"), 1)
|
||||||
|
|
||||||
|
// Switch to a new memtable
|
||||||
|
old := pool.SwitchToNewMemTable()
|
||||||
|
if !old.IsImmutable() {
|
||||||
|
t.Errorf("expected switched memtable to be immutable")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the data is in the old table
|
||||||
|
value, found := old.Get([]byte("key1"))
|
||||||
|
if !found {
|
||||||
|
t.Fatalf("expected to find key1 in old table, but got not found")
|
||||||
|
}
|
||||||
|
if string(value) != "value1" {
|
||||||
|
t.Errorf("expected value1 in old table, got %s", string(value))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the immutable count is correct
|
||||||
|
if count := pool.ImmutableCount(); count != 1 {
|
||||||
|
t.Errorf("expected immutable count to be 1, got %d", count)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add data to the new active memtable
|
||||||
|
pool.Put([]byte("key2"), []byte("value2"), 2)
|
||||||
|
|
||||||
|
// Verify we can still retrieve data from both tables
|
||||||
|
value, found = pool.Get([]byte("key1"))
|
||||||
|
if !found {
|
||||||
|
t.Fatalf("expected to find key1 through pool, but got not found")
|
||||||
|
}
|
||||||
|
if string(value) != "value1" {
|
||||||
|
t.Errorf("expected value1 through pool, got %s", string(value))
|
||||||
|
}
|
||||||
|
|
||||||
|
value, found = pool.Get([]byte("key2"))
|
||||||
|
if !found {
|
||||||
|
t.Fatalf("expected to find key2 through pool, but got not found")
|
||||||
|
}
|
||||||
|
if string(value) != "value2" {
|
||||||
|
t.Errorf("expected value2 through pool, got %s", string(value))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMemPoolFlushConditions(t *testing.T) {
|
||||||
|
// Create a config with small thresholds for testing
|
||||||
|
cfg := createTestConfig()
|
||||||
|
cfg.MemTableSize = 100 // Very small size to trigger flush
|
||||||
|
pool := NewMemTablePool(cfg)
|
||||||
|
|
||||||
|
// Initially no flush should be needed
|
||||||
|
if pool.IsFlushNeeded() {
|
||||||
|
t.Errorf("expected no flush needed initially")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add enough data to trigger a size-based flush
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
key := []byte{byte(i)}
|
||||||
|
value := make([]byte, 20) // 20 bytes per value
|
||||||
|
pool.Put(key, value, uint64(i+1))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should trigger a flush
|
||||||
|
if !pool.IsFlushNeeded() {
|
||||||
|
t.Errorf("expected flush needed after reaching size threshold")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Switch to a new memtable
|
||||||
|
old := pool.SwitchToNewMemTable()
|
||||||
|
if !old.IsImmutable() {
|
||||||
|
t.Errorf("expected old memtable to be immutable")
|
||||||
|
}
|
||||||
|
|
||||||
|
// The flush pending flag should be reset
|
||||||
|
if pool.IsFlushNeeded() {
|
||||||
|
t.Errorf("expected flush pending to be reset after switch")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now test age-based flushing
|
||||||
|
// Wait for the age threshold to trigger
|
||||||
|
time.Sleep(1200 * time.Millisecond) // Just over 1 second
|
||||||
|
|
||||||
|
// Add a small amount of data to check conditions
|
||||||
|
pool.Put([]byte("trigger"), []byte("check"), 100)
|
||||||
|
|
||||||
|
// Should trigger an age-based flush
|
||||||
|
if !pool.IsFlushNeeded() {
|
||||||
|
t.Errorf("expected flush needed after reaching age threshold")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMemPoolGetImmutablesForFlush(t *testing.T) {
|
||||||
|
cfg := createTestConfig()
|
||||||
|
pool := NewMemTablePool(cfg)
|
||||||
|
|
||||||
|
// Switch memtables a few times to accumulate immutables
|
||||||
|
for i := 0; i < 3; i++ {
|
||||||
|
pool.Put([]byte{byte(i)}, []byte{byte(i)}, uint64(i+1))
|
||||||
|
pool.SwitchToNewMemTable()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should have 3 immutable memtables
|
||||||
|
if count := pool.ImmutableCount(); count != 3 {
|
||||||
|
t.Errorf("expected 3 immutable memtables, got %d", count)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get immutables for flush
|
||||||
|
immutables := pool.GetImmutablesForFlush()
|
||||||
|
|
||||||
|
// Should get all 3 immutables
|
||||||
|
if len(immutables) != 3 {
|
||||||
|
t.Errorf("expected to get 3 immutables for flush, got %d", len(immutables))
|
||||||
|
}
|
||||||
|
|
||||||
|
// The pool should now have 0 immutables
|
||||||
|
if count := pool.ImmutableCount(); count != 0 {
|
||||||
|
t.Errorf("expected 0 immutable memtables after flush, got %d", count)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMemPoolGetMemTables(t *testing.T) {
|
||||||
|
cfg := createTestConfig()
|
||||||
|
pool := NewMemTablePool(cfg)
|
||||||
|
|
||||||
|
// Initially should have just the active memtable
|
||||||
|
tables := pool.GetMemTables()
|
||||||
|
if len(tables) != 1 {
|
||||||
|
t.Errorf("expected 1 memtable initially, got %d", len(tables))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add an immutable table
|
||||||
|
pool.Put([]byte("key"), []byte("value"), 1)
|
||||||
|
pool.SwitchToNewMemTable()
|
||||||
|
|
||||||
|
// Now should have 2 memtables (active + 1 immutable)
|
||||||
|
tables = pool.GetMemTables()
|
||||||
|
if len(tables) != 2 {
|
||||||
|
t.Errorf("expected 2 memtables after switch, got %d", len(tables))
|
||||||
|
}
|
||||||
|
|
||||||
|
// The active table should be first
|
||||||
|
if tables[0].IsImmutable() {
|
||||||
|
t.Errorf("expected first table to be active (not immutable)")
|
||||||
|
}
|
||||||
|
|
||||||
|
// The second table should be immutable
|
||||||
|
if !tables[1].IsImmutable() {
|
||||||
|
t.Errorf("expected second table to be immutable")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMemPoolGetNextSequenceNumber(t *testing.T) {
|
||||||
|
cfg := createTestConfig()
|
||||||
|
pool := NewMemTablePool(cfg)
|
||||||
|
|
||||||
|
// Initial sequence number should be 0
|
||||||
|
if seq := pool.GetNextSequenceNumber(); seq != 0 {
|
||||||
|
t.Errorf("expected initial sequence number to be 0, got %d", seq)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add entries with sequence numbers
|
||||||
|
pool.Put([]byte("key"), []byte("value"), 5)
|
||||||
|
|
||||||
|
// Next sequence number should be 6
|
||||||
|
if seq := pool.GetNextSequenceNumber(); seq != 6 {
|
||||||
|
t.Errorf("expected sequence number to be 6, got %d", seq)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Switch to a new memtable
|
||||||
|
pool.SwitchToNewMemTable()
|
||||||
|
|
||||||
|
// Sequence number should reset for the new table
|
||||||
|
if seq := pool.GetNextSequenceNumber(); seq != 0 {
|
||||||
|
t.Errorf("expected sequence number to reset to 0, got %d", seq)
|
||||||
|
}
|
||||||
|
}
|
155
pkg/memtable/memtable.go
Normal file
155
pkg/memtable/memtable.go
Normal file
@ -0,0 +1,155 @@
|
|||||||
|
package memtable
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.canoozie.net/jer/go-storage/pkg/wal"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MemTable is an in-memory table that stores key-value pairs
|
||||||
|
// It is implemented using a skip list for efficient inserts and lookups
|
||||||
|
type MemTable struct {
|
||||||
|
skipList *SkipList
|
||||||
|
nextSeqNum uint64
|
||||||
|
creationTime time.Time
|
||||||
|
immutable atomic.Bool
|
||||||
|
size int64
|
||||||
|
mu sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMemTable creates a new memory table
|
||||||
|
func NewMemTable() *MemTable {
|
||||||
|
return &MemTable{
|
||||||
|
skipList: NewSkipList(),
|
||||||
|
creationTime: time.Now(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Put adds a key-value pair to the MemTable
|
||||||
|
func (m *MemTable) Put(key, value []byte, seqNum uint64) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
if m.immutable.Load() {
|
||||||
|
// Don't modify immutable memtables
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
e := newEntry(key, value, TypeValue, seqNum)
|
||||||
|
m.skipList.Insert(e)
|
||||||
|
|
||||||
|
// Update maximum sequence number
|
||||||
|
if seqNum > m.nextSeqNum {
|
||||||
|
m.nextSeqNum = seqNum + 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete marks a key as deleted in the MemTable
|
||||||
|
func (m *MemTable) Delete(key []byte, seqNum uint64) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
if m.immutable.Load() {
|
||||||
|
// Don't modify immutable memtables
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
e := newEntry(key, nil, TypeDeletion, seqNum)
|
||||||
|
m.skipList.Insert(e)
|
||||||
|
|
||||||
|
// Update maximum sequence number
|
||||||
|
if seqNum > m.nextSeqNum {
|
||||||
|
m.nextSeqNum = seqNum + 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get retrieves the value associated with the given key
|
||||||
|
// Returns (nil, true) if the key exists but has been deleted
|
||||||
|
// Returns (nil, false) if the key does not exist
|
||||||
|
// Returns (value, true) if the key exists and has a value
|
||||||
|
func (m *MemTable) Get(key []byte) ([]byte, bool) {
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
|
||||||
|
e := m.skipList.Find(key)
|
||||||
|
if e == nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if this is a deletion marker
|
||||||
|
if e.valueType == TypeDeletion {
|
||||||
|
return nil, true // Key exists but was deleted
|
||||||
|
}
|
||||||
|
|
||||||
|
return e.value, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Contains checks if the key exists in the MemTable
|
||||||
|
func (m *MemTable) Contains(key []byte) bool {
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
|
||||||
|
return m.skipList.Find(key) != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ApproximateSize returns the approximate size of the MemTable in bytes
|
||||||
|
func (m *MemTable) ApproximateSize() int64 {
|
||||||
|
return m.skipList.ApproximateSize()
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetImmutable marks the MemTable as immutable
|
||||||
|
// After this is called, no more modifications are allowed
|
||||||
|
func (m *MemTable) SetImmutable() {
|
||||||
|
m.immutable.Store(true)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsImmutable returns whether the MemTable is immutable
|
||||||
|
func (m *MemTable) IsImmutable() bool {
|
||||||
|
return m.immutable.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Age returns the age of the MemTable in seconds
|
||||||
|
func (m *MemTable) Age() float64 {
|
||||||
|
return time.Since(m.creationTime).Seconds()
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewIterator returns an iterator for the MemTable
|
||||||
|
func (m *MemTable) NewIterator() *Iterator {
|
||||||
|
return m.skipList.NewIterator()
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetNextSequenceNumber returns the next sequence number to use
|
||||||
|
func (m *MemTable) GetNextSequenceNumber() uint64 {
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
return m.nextSeqNum
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProcessWALEntry processes a WAL entry and applies it to the MemTable
|
||||||
|
func (m *MemTable) ProcessWALEntry(entry *wal.Entry) error {
|
||||||
|
switch entry.Type {
|
||||||
|
case wal.OpTypePut:
|
||||||
|
m.Put(entry.Key, entry.Value, entry.SequenceNumber)
|
||||||
|
case wal.OpTypeDelete:
|
||||||
|
m.Delete(entry.Key, entry.SequenceNumber)
|
||||||
|
case wal.OpTypeBatch:
|
||||||
|
// Process batch operations
|
||||||
|
batch, err := wal.DecodeBatch(entry)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, op := range batch.Operations {
|
||||||
|
seqNum := batch.Seq + uint64(i)
|
||||||
|
switch op.Type {
|
||||||
|
case wal.OpTypePut:
|
||||||
|
m.Put(op.Key, op.Value, seqNum)
|
||||||
|
case wal.OpTypeDelete:
|
||||||
|
m.Delete(op.Key, seqNum)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
202
pkg/memtable/memtable_test.go
Normal file
202
pkg/memtable/memtable_test.go
Normal file
@ -0,0 +1,202 @@
|
|||||||
|
package memtable
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.canoozie.net/jer/go-storage/pkg/wal"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMemTableBasicOperations(t *testing.T) {
|
||||||
|
mt := NewMemTable()
|
||||||
|
|
||||||
|
// Test Put and Get
|
||||||
|
mt.Put([]byte("key1"), []byte("value1"), 1)
|
||||||
|
|
||||||
|
value, found := mt.Get([]byte("key1"))
|
||||||
|
if !found {
|
||||||
|
t.Fatalf("expected to find key1, but got not found")
|
||||||
|
}
|
||||||
|
if string(value) != "value1" {
|
||||||
|
t.Errorf("expected value1, got %s", string(value))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test not found
|
||||||
|
_, found = mt.Get([]byte("nonexistent"))
|
||||||
|
if found {
|
||||||
|
t.Errorf("expected key 'nonexistent' to not be found")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test Delete
|
||||||
|
mt.Delete([]byte("key1"), 2)
|
||||||
|
|
||||||
|
value, found = mt.Get([]byte("key1"))
|
||||||
|
if !found {
|
||||||
|
t.Fatalf("expected tombstone to be found for key1")
|
||||||
|
}
|
||||||
|
if value != nil {
|
||||||
|
t.Errorf("expected nil value for deleted key, got %v", value)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test Contains
|
||||||
|
if !mt.Contains([]byte("key1")) {
|
||||||
|
t.Errorf("expected Contains to return true for deleted key")
|
||||||
|
}
|
||||||
|
if mt.Contains([]byte("nonexistent")) {
|
||||||
|
t.Errorf("expected Contains to return false for nonexistent key")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMemTableSequenceNumbers(t *testing.T) {
|
||||||
|
mt := NewMemTable()
|
||||||
|
|
||||||
|
// Add entries with sequence numbers
|
||||||
|
mt.Put([]byte("key"), []byte("value1"), 1)
|
||||||
|
mt.Put([]byte("key"), []byte("value2"), 3)
|
||||||
|
mt.Put([]byte("key"), []byte("value3"), 2)
|
||||||
|
|
||||||
|
// Should get the latest by sequence number (value2)
|
||||||
|
value, found := mt.Get([]byte("key"))
|
||||||
|
if !found {
|
||||||
|
t.Fatalf("expected to find key, but got not found")
|
||||||
|
}
|
||||||
|
if string(value) != "value2" {
|
||||||
|
t.Errorf("expected value2 (highest seq), got %s", string(value))
|
||||||
|
}
|
||||||
|
|
||||||
|
// The next sequence number should be one more than the highest seen
|
||||||
|
if nextSeq := mt.GetNextSequenceNumber(); nextSeq != 4 {
|
||||||
|
t.Errorf("expected next sequence number to be 4, got %d", nextSeq)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMemTableImmutability(t *testing.T) {
|
||||||
|
mt := NewMemTable()
|
||||||
|
|
||||||
|
// Add initial data
|
||||||
|
mt.Put([]byte("key"), []byte("value"), 1)
|
||||||
|
|
||||||
|
// Mark as immutable
|
||||||
|
mt.SetImmutable()
|
||||||
|
if !mt.IsImmutable() {
|
||||||
|
t.Errorf("expected IsImmutable to return true after SetImmutable")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Attempts to modify should have no effect
|
||||||
|
mt.Put([]byte("key2"), []byte("value2"), 2)
|
||||||
|
mt.Delete([]byte("key"), 3)
|
||||||
|
|
||||||
|
// Verify no changes occurred
|
||||||
|
_, found := mt.Get([]byte("key2"))
|
||||||
|
if found {
|
||||||
|
t.Errorf("expected key2 to not be added to immutable memtable")
|
||||||
|
}
|
||||||
|
|
||||||
|
value, found := mt.Get([]byte("key"))
|
||||||
|
if !found {
|
||||||
|
t.Fatalf("expected to still find key after delete on immutable table")
|
||||||
|
}
|
||||||
|
if string(value) != "value" {
|
||||||
|
t.Errorf("expected value to remain unchanged, got %s", string(value))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMemTableAge(t *testing.T) {
|
||||||
|
mt := NewMemTable()
|
||||||
|
|
||||||
|
// A new memtable should have a very small age
|
||||||
|
if age := mt.Age(); age > 1.0 {
|
||||||
|
t.Errorf("expected new memtable to have age < 1.0s, got %.2fs", age)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sleep to increase age
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
|
||||||
|
if age := mt.Age(); age <= 0.0 {
|
||||||
|
t.Errorf("expected memtable age to be > 0, got %.6fs", age)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMemTableWALIntegration(t *testing.T) {
|
||||||
|
mt := NewMemTable()
|
||||||
|
|
||||||
|
// Create WAL entries
|
||||||
|
entries := []*wal.Entry{
|
||||||
|
{SequenceNumber: 1, Type: wal.OpTypePut, Key: []byte("key1"), Value: []byte("value1")},
|
||||||
|
{SequenceNumber: 2, Type: wal.OpTypeDelete, Key: []byte("key2"), Value: nil},
|
||||||
|
{SequenceNumber: 3, Type: wal.OpTypePut, Key: []byte("key3"), Value: []byte("value3")},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process entries
|
||||||
|
for _, entry := range entries {
|
||||||
|
if err := mt.ProcessWALEntry(entry); err != nil {
|
||||||
|
t.Fatalf("failed to process WAL entry: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify entries were processed correctly
|
||||||
|
testCases := []struct {
|
||||||
|
key string
|
||||||
|
expected string
|
||||||
|
found bool
|
||||||
|
}{
|
||||||
|
{"key1", "value1", true},
|
||||||
|
{"key2", "", true}, // Deleted key
|
||||||
|
{"key3", "value3", true},
|
||||||
|
{"key4", "", false}, // Non-existent key
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
value, found := mt.Get([]byte(tc.key))
|
||||||
|
|
||||||
|
if found != tc.found {
|
||||||
|
t.Errorf("key %s: expected found=%v, got %v", tc.key, tc.found, found)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if found && tc.expected != "" {
|
||||||
|
if string(value) != tc.expected {
|
||||||
|
t.Errorf("key %s: expected value '%s', got '%s'", tc.key, tc.expected, string(value))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify next sequence number
|
||||||
|
if nextSeq := mt.GetNextSequenceNumber(); nextSeq != 4 {
|
||||||
|
t.Errorf("expected next sequence number to be 4, got %d", nextSeq)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMemTableIterator(t *testing.T) {
|
||||||
|
mt := NewMemTable()
|
||||||
|
|
||||||
|
// Add entries in non-sorted order
|
||||||
|
entries := []struct {
|
||||||
|
key string
|
||||||
|
value string
|
||||||
|
seq uint64
|
||||||
|
}{
|
||||||
|
{"banana", "yellow", 1},
|
||||||
|
{"apple", "red", 2},
|
||||||
|
{"cherry", "red", 3},
|
||||||
|
{"date", "brown", 4},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, e := range entries {
|
||||||
|
mt.Put([]byte(e.key), []byte(e.value), e.seq)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use iterator to verify keys are returned in sorted order
|
||||||
|
it := mt.NewIterator()
|
||||||
|
it.SeekToFirst()
|
||||||
|
|
||||||
|
expected := []string{"apple", "banana", "cherry", "date"}
|
||||||
|
|
||||||
|
for i := 0; it.Valid() && i < len(expected); i++ {
|
||||||
|
key := string(it.Key())
|
||||||
|
if key != expected[i] {
|
||||||
|
t.Errorf("position %d: expected key %s, got %s", i, expected[i], key)
|
||||||
|
}
|
||||||
|
it.Next()
|
||||||
|
}
|
||||||
|
}
|
91
pkg/memtable/recovery.go
Normal file
91
pkg/memtable/recovery.go
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
package memtable
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"git.canoozie.net/jer/go-storage/pkg/config"
|
||||||
|
"git.canoozie.net/jer/go-storage/pkg/wal"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RecoveryOptions contains options for MemTable recovery
|
||||||
|
type RecoveryOptions struct {
|
||||||
|
// MaxSequenceNumber is the maximum sequence number to recover
|
||||||
|
// Entries with sequence numbers greater than this will be ignored
|
||||||
|
MaxSequenceNumber uint64
|
||||||
|
|
||||||
|
// MaxMemTables is the maximum number of MemTables to create during recovery
|
||||||
|
// If more MemTables would be needed, an error is returned
|
||||||
|
MaxMemTables int
|
||||||
|
|
||||||
|
// MemTableSize is the maximum size of each MemTable
|
||||||
|
MemTableSize int64
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultRecoveryOptions returns the default recovery options
|
||||||
|
func DefaultRecoveryOptions(cfg *config.Config) *RecoveryOptions {
|
||||||
|
return &RecoveryOptions{
|
||||||
|
MaxSequenceNumber: ^uint64(0), // Max uint64
|
||||||
|
MaxMemTables: cfg.MaxMemTables,
|
||||||
|
MemTableSize: cfg.MemTableSize,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RecoverFromWAL rebuilds MemTables from the write-ahead log
|
||||||
|
// Returns a list of recovered MemTables and the maximum sequence number seen
|
||||||
|
func RecoverFromWAL(cfg *config.Config, opts *RecoveryOptions) ([]*MemTable, uint64, error) {
|
||||||
|
if opts == nil {
|
||||||
|
opts = DefaultRecoveryOptions(cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the first MemTable
|
||||||
|
memTables := []*MemTable{NewMemTable()}
|
||||||
|
var maxSeqNum uint64
|
||||||
|
|
||||||
|
// Function to process each WAL entry
|
||||||
|
entryHandler := func(entry *wal.Entry) error {
|
||||||
|
// Skip entries with sequence numbers beyond our max
|
||||||
|
if entry.SequenceNumber > opts.MaxSequenceNumber {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update the max sequence number
|
||||||
|
if entry.SequenceNumber > maxSeqNum {
|
||||||
|
maxSeqNum = entry.SequenceNumber
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the current memtable
|
||||||
|
current := memTables[len(memTables)-1]
|
||||||
|
|
||||||
|
// Check if we should create a new memtable based on size
|
||||||
|
if current.ApproximateSize() >= opts.MemTableSize {
|
||||||
|
// Make sure we don't exceed the max number of memtables
|
||||||
|
if len(memTables) >= opts.MaxMemTables {
|
||||||
|
return fmt.Errorf("maximum number of memtables (%d) exceeded during recovery", opts.MaxMemTables)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mark the current memtable as immutable
|
||||||
|
current.SetImmutable()
|
||||||
|
|
||||||
|
// Create a new memtable
|
||||||
|
current = NewMemTable()
|
||||||
|
memTables = append(memTables, current)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process the entry
|
||||||
|
return current.ProcessWALEntry(entry)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Replay the WAL directory
|
||||||
|
if err := wal.ReplayWALDir(cfg.WALDir, entryHandler); err != nil {
|
||||||
|
return nil, 0, fmt.Errorf("failed to replay WAL: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// For batch operations, we need to adjust maxSeqNum
|
||||||
|
finalTable := memTables[len(memTables)-1]
|
||||||
|
nextSeq := finalTable.GetNextSequenceNumber()
|
||||||
|
if nextSeq > maxSeqNum+1 {
|
||||||
|
maxSeqNum = nextSeq - 1
|
||||||
|
}
|
||||||
|
|
||||||
|
return memTables, maxSeqNum, nil
|
||||||
|
}
|
276
pkg/memtable/recovery_test.go
Normal file
276
pkg/memtable/recovery_test.go
Normal file
@ -0,0 +1,276 @@
|
|||||||
|
package memtable
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.canoozie.net/jer/go-storage/pkg/config"
|
||||||
|
"git.canoozie.net/jer/go-storage/pkg/wal"
|
||||||
|
)
|
||||||
|
|
||||||
|
func setupTestWAL(t *testing.T) (string, *wal.WAL, func()) {
|
||||||
|
// Create temporary directory
|
||||||
|
tmpDir, err := os.MkdirTemp("", "memtable_recovery_test")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create temp dir: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create config
|
||||||
|
cfg := config.NewDefaultConfig(tmpDir)
|
||||||
|
|
||||||
|
// Create WAL
|
||||||
|
w, err := wal.NewWAL(cfg, tmpDir)
|
||||||
|
if err != nil {
|
||||||
|
os.RemoveAll(tmpDir)
|
||||||
|
t.Fatalf("failed to create WAL: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return cleanup function
|
||||||
|
cleanup := func() {
|
||||||
|
w.Close()
|
||||||
|
os.RemoveAll(tmpDir)
|
||||||
|
}
|
||||||
|
|
||||||
|
return tmpDir, w, cleanup
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecoverFromWAL(t *testing.T) {
|
||||||
|
tmpDir, w, cleanup := setupTestWAL(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
// Add entries to the WAL
|
||||||
|
entries := []struct {
|
||||||
|
opType uint8
|
||||||
|
key string
|
||||||
|
value string
|
||||||
|
}{
|
||||||
|
{wal.OpTypePut, "key1", "value1"},
|
||||||
|
{wal.OpTypePut, "key2", "value2"},
|
||||||
|
{wal.OpTypeDelete, "key1", ""},
|
||||||
|
{wal.OpTypePut, "key3", "value3"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, e := range entries {
|
||||||
|
var seq uint64
|
||||||
|
var err error
|
||||||
|
|
||||||
|
if e.opType == wal.OpTypePut {
|
||||||
|
seq, err = w.Append(e.opType, []byte(e.key), []byte(e.value))
|
||||||
|
} else {
|
||||||
|
seq, err = w.Append(e.opType, []byte(e.key), nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to append to WAL: %v", err)
|
||||||
|
}
|
||||||
|
t.Logf("Appended entry with seq %d", seq)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sync and close WAL
|
||||||
|
if err := w.Sync(); err != nil {
|
||||||
|
t.Fatalf("failed to sync WAL: %v", err)
|
||||||
|
}
|
||||||
|
if err := w.Close(); err != nil {
|
||||||
|
t.Fatalf("failed to close WAL: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create config for recovery
|
||||||
|
cfg := config.NewDefaultConfig(tmpDir)
|
||||||
|
cfg.WALDir = tmpDir
|
||||||
|
cfg.MemTableSize = 1024 * 1024 // 1MB
|
||||||
|
|
||||||
|
// Recover memtables from WAL
|
||||||
|
memTables, maxSeq, err := RecoverFromWAL(cfg, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to recover from WAL: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate recovery results
|
||||||
|
if len(memTables) == 0 {
|
||||||
|
t.Fatalf("expected at least one memtable from recovery")
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Logf("Recovered %d memtables with max sequence %d", len(memTables), maxSeq)
|
||||||
|
|
||||||
|
// The max sequence number should be 4
|
||||||
|
if maxSeq != 4 {
|
||||||
|
t.Errorf("expected max sequence number 4, got %d", maxSeq)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate content of the recovered memtable
|
||||||
|
mt := memTables[0]
|
||||||
|
|
||||||
|
// key1 should be deleted
|
||||||
|
value, found := mt.Get([]byte("key1"))
|
||||||
|
if !found {
|
||||||
|
t.Errorf("expected key1 to be found (as deleted)")
|
||||||
|
}
|
||||||
|
if value != nil {
|
||||||
|
t.Errorf("expected key1 to have nil value (deleted), got %v", value)
|
||||||
|
}
|
||||||
|
|
||||||
|
// key2 should have "value2"
|
||||||
|
value, found = mt.Get([]byte("key2"))
|
||||||
|
if !found {
|
||||||
|
t.Errorf("expected key2 to be found")
|
||||||
|
} else if string(value) != "value2" {
|
||||||
|
t.Errorf("expected key2 to have value 'value2', got '%s'", string(value))
|
||||||
|
}
|
||||||
|
|
||||||
|
// key3 should have "value3"
|
||||||
|
value, found = mt.Get([]byte("key3"))
|
||||||
|
if !found {
|
||||||
|
t.Errorf("expected key3 to be found")
|
||||||
|
} else if string(value) != "value3" {
|
||||||
|
t.Errorf("expected key3 to have value 'value3', got '%s'", string(value))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecoveryWithMultipleMemTables(t *testing.T) {
|
||||||
|
tmpDir, w, cleanup := setupTestWAL(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
// Create a lot of large entries to force multiple memtables
|
||||||
|
largeValue := make([]byte, 1000) // 1KB value
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
key := []byte{byte(i + 'a')}
|
||||||
|
if _, err := w.Append(wal.OpTypePut, key, largeValue); err != nil {
|
||||||
|
t.Fatalf("failed to append to WAL: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sync and close WAL
|
||||||
|
if err := w.Sync(); err != nil {
|
||||||
|
t.Fatalf("failed to sync WAL: %v", err)
|
||||||
|
}
|
||||||
|
if err := w.Close(); err != nil {
|
||||||
|
t.Fatalf("failed to close WAL: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create config for recovery with small memtable size
|
||||||
|
cfg := config.NewDefaultConfig(tmpDir)
|
||||||
|
cfg.WALDir = tmpDir
|
||||||
|
cfg.MemTableSize = 5 * 1000 // 5KB - should fit about 5 entries
|
||||||
|
cfg.MaxMemTables = 3 // Allow up to 3 memtables
|
||||||
|
|
||||||
|
// Recover memtables from WAL
|
||||||
|
memTables, _, err := RecoverFromWAL(cfg, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to recover from WAL: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should have created multiple memtables
|
||||||
|
if len(memTables) <= 1 {
|
||||||
|
t.Errorf("expected multiple memtables due to size, got %d", len(memTables))
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Logf("Recovered %d memtables", len(memTables))
|
||||||
|
|
||||||
|
// All memtables except the last one should be immutable
|
||||||
|
for i, mt := range memTables[:len(memTables)-1] {
|
||||||
|
if !mt.IsImmutable() {
|
||||||
|
t.Errorf("expected memtable %d to be immutable", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify all data was recovered across all memtables
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
key := []byte{byte(i + 'a')}
|
||||||
|
found := false
|
||||||
|
|
||||||
|
// Check each memtable for the key
|
||||||
|
for _, mt := range memTables {
|
||||||
|
if _, exists := mt.Get(key); exists {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !found {
|
||||||
|
t.Errorf("key %c not found in any memtable", i+'a')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecoveryWithBatchOperations(t *testing.T) {
|
||||||
|
tmpDir, w, cleanup := setupTestWAL(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
// Create a batch of operations
|
||||||
|
batch := wal.NewBatch()
|
||||||
|
batch.Put([]byte("batch_key1"), []byte("batch_value1"))
|
||||||
|
batch.Put([]byte("batch_key2"), []byte("batch_value2"))
|
||||||
|
batch.Delete([]byte("batch_key3"))
|
||||||
|
|
||||||
|
// Write the batch to the WAL
|
||||||
|
if err := batch.Write(w); err != nil {
|
||||||
|
t.Fatalf("failed to write batch to WAL: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add some individual operations too
|
||||||
|
if _, err := w.Append(wal.OpTypePut, []byte("key4"), []byte("value4")); err != nil {
|
||||||
|
t.Fatalf("failed to append to WAL: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sync and close WAL
|
||||||
|
if err := w.Sync(); err != nil {
|
||||||
|
t.Fatalf("failed to sync WAL: %v", err)
|
||||||
|
}
|
||||||
|
if err := w.Close(); err != nil {
|
||||||
|
t.Fatalf("failed to close WAL: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create config for recovery
|
||||||
|
cfg := config.NewDefaultConfig(tmpDir)
|
||||||
|
cfg.WALDir = tmpDir
|
||||||
|
|
||||||
|
// Recover memtables from WAL
|
||||||
|
memTables, maxSeq, err := RecoverFromWAL(cfg, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to recover from WAL: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(memTables) == 0 {
|
||||||
|
t.Fatalf("expected at least one memtable from recovery")
|
||||||
|
}
|
||||||
|
|
||||||
|
// The max sequence number should account for batch operations
|
||||||
|
if maxSeq < 3 { // At least 3 from batch + individual op
|
||||||
|
t.Errorf("expected max sequence number >= 3, got %d", maxSeq)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate content of the recovered memtable
|
||||||
|
mt := memTables[0]
|
||||||
|
|
||||||
|
// Check batch keys were recovered
|
||||||
|
value, found := mt.Get([]byte("batch_key1"))
|
||||||
|
if !found {
|
||||||
|
t.Errorf("batch_key1 not found in recovered memtable")
|
||||||
|
} else if string(value) != "batch_value1" {
|
||||||
|
t.Errorf("expected batch_key1 to have value 'batch_value1', got '%s'", string(value))
|
||||||
|
}
|
||||||
|
|
||||||
|
value, found = mt.Get([]byte("batch_key2"))
|
||||||
|
if !found {
|
||||||
|
t.Errorf("batch_key2 not found in recovered memtable")
|
||||||
|
} else if string(value) != "batch_value2" {
|
||||||
|
t.Errorf("expected batch_key2 to have value 'batch_value2', got '%s'", string(value))
|
||||||
|
}
|
||||||
|
|
||||||
|
// batch_key3 should be marked as deleted
|
||||||
|
value, found = mt.Get([]byte("batch_key3"))
|
||||||
|
if !found {
|
||||||
|
t.Errorf("expected batch_key3 to be found as deleted")
|
||||||
|
}
|
||||||
|
if value != nil {
|
||||||
|
t.Errorf("expected batch_key3 to have nil value (deleted), got %v", value)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check individual operation was recovered
|
||||||
|
value, found = mt.Get([]byte("key4"))
|
||||||
|
if !found {
|
||||||
|
t.Errorf("key4 not found in recovered memtable")
|
||||||
|
} else if string(value) != "value4" {
|
||||||
|
t.Errorf("expected key4 to have value 'value4', got '%s'", string(value))
|
||||||
|
}
|
||||||
|
}
|
308
pkg/memtable/skiplist.go
Normal file
308
pkg/memtable/skiplist.go
Normal file
@ -0,0 +1,308 @@
|
|||||||
|
package memtable
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"math/rand"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
"unsafe"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// MaxHeight is the maximum height of the skip list
|
||||||
|
MaxHeight = 12
|
||||||
|
|
||||||
|
// BranchingFactor determines the probability of increasing the height
|
||||||
|
BranchingFactor = 4
|
||||||
|
|
||||||
|
// DefaultCacheLineSize aligns nodes to cache lines for better performance
|
||||||
|
DefaultCacheLineSize = 64
|
||||||
|
)
|
||||||
|
|
||||||
|
// ValueType represents the type of a key-value entry
|
||||||
|
type ValueType uint8
|
||||||
|
|
||||||
|
const (
|
||||||
|
// TypeValue indicates the entry contains a value
|
||||||
|
TypeValue ValueType = iota + 1
|
||||||
|
|
||||||
|
// TypeDeletion indicates the entry is a tombstone (deletion marker)
|
||||||
|
TypeDeletion
|
||||||
|
)
|
||||||
|
|
||||||
|
// entry represents a key-value pair with additional metadata
|
||||||
|
type entry struct {
|
||||||
|
key []byte
|
||||||
|
value []byte
|
||||||
|
valueType ValueType
|
||||||
|
seqNum uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
// newEntry creates a new entry
|
||||||
|
func newEntry(key, value []byte, valueType ValueType, seqNum uint64) *entry {
|
||||||
|
return &entry{
|
||||||
|
key: key,
|
||||||
|
value: value,
|
||||||
|
valueType: valueType,
|
||||||
|
seqNum: seqNum,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// size returns the approximate size of the entry in memory
|
||||||
|
func (e *entry) size() int {
|
||||||
|
return len(e.key) + len(e.value) + 16 // adding overhead for metadata
|
||||||
|
}
|
||||||
|
|
||||||
|
// compare compares this entry with another key
|
||||||
|
// Returns: negative if e.key < key, 0 if equal, positive if e.key > key
|
||||||
|
func (e *entry) compare(key []byte) int {
|
||||||
|
return bytes.Compare(e.key, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
// compareWithEntry compares this entry with another entry
|
||||||
|
// First by key, then by sequence number (in reverse order to prioritize newer entries)
|
||||||
|
func (e *entry) compareWithEntry(other *entry) int {
|
||||||
|
cmp := bytes.Compare(e.key, other.key)
|
||||||
|
if cmp == 0 {
|
||||||
|
// If keys are equal, compare sequence numbers in reverse order (newer first)
|
||||||
|
if e.seqNum > other.seqNum {
|
||||||
|
return -1
|
||||||
|
} else if e.seqNum < other.seqNum {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return cmp
|
||||||
|
}
|
||||||
|
|
||||||
|
// node represents a node in the skip list
|
||||||
|
type node struct {
|
||||||
|
entry *entry
|
||||||
|
height int32
|
||||||
|
// next contains pointers to the next nodes at each level
|
||||||
|
// This is allocated as a single block for cache efficiency
|
||||||
|
next [MaxHeight]unsafe.Pointer
|
||||||
|
}
|
||||||
|
|
||||||
|
// newNode creates a new node with a random height
|
||||||
|
func newNode(e *entry, height int) *node {
|
||||||
|
return &node{
|
||||||
|
entry: e,
|
||||||
|
height: int32(height),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// getNext returns the next node at the given level
|
||||||
|
func (n *node) getNext(level int) *node {
|
||||||
|
return (*node)(atomic.LoadPointer(&n.next[level]))
|
||||||
|
}
|
||||||
|
|
||||||
|
// setNext sets the next node at the given level
|
||||||
|
func (n *node) setNext(level int, next *node) {
|
||||||
|
atomic.StorePointer(&n.next[level], unsafe.Pointer(next))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SkipList is a concurrent skip list implementation for the MemTable
|
||||||
|
type SkipList struct {
|
||||||
|
head *node
|
||||||
|
maxHeight int32
|
||||||
|
rnd *rand.Rand
|
||||||
|
rndMtx sync.Mutex
|
||||||
|
size int64
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSkipList creates a new skip list
|
||||||
|
func NewSkipList() *SkipList {
|
||||||
|
seed := time.Now().UnixNano()
|
||||||
|
list := &SkipList{
|
||||||
|
head: newNode(nil, MaxHeight),
|
||||||
|
maxHeight: 1,
|
||||||
|
rnd: rand.New(rand.NewSource(seed)),
|
||||||
|
}
|
||||||
|
return list
|
||||||
|
}
|
||||||
|
|
||||||
|
// randomHeight generates a random height for a new node
|
||||||
|
func (s *SkipList) randomHeight() int {
|
||||||
|
s.rndMtx.Lock()
|
||||||
|
defer s.rndMtx.Unlock()
|
||||||
|
|
||||||
|
height := 1
|
||||||
|
for height < MaxHeight && s.rnd.Intn(BranchingFactor) == 0 {
|
||||||
|
height++
|
||||||
|
}
|
||||||
|
return height
|
||||||
|
}
|
||||||
|
|
||||||
|
// getCurrentHeight returns the current maximum height of the skip list
|
||||||
|
func (s *SkipList) getCurrentHeight() int {
|
||||||
|
return int(atomic.LoadInt32(&s.maxHeight))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Insert adds a new entry to the skip list
|
||||||
|
func (s *SkipList) Insert(e *entry) {
|
||||||
|
height := s.randomHeight()
|
||||||
|
prev := [MaxHeight]*node{}
|
||||||
|
node := newNode(e, height)
|
||||||
|
|
||||||
|
// Try to increase the height of the list
|
||||||
|
currHeight := s.getCurrentHeight()
|
||||||
|
if height > currHeight {
|
||||||
|
// Attempt to increase the height
|
||||||
|
if atomic.CompareAndSwapInt32(&s.maxHeight, int32(currHeight), int32(height)) {
|
||||||
|
currHeight = height
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find where to insert at each level
|
||||||
|
current := s.head
|
||||||
|
for level := currHeight - 1; level >= 0; level-- {
|
||||||
|
// Find the insertion point at this level
|
||||||
|
for next := current.getNext(level); next != nil; next = current.getNext(level) {
|
||||||
|
if next.entry.compareWithEntry(e) >= 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
current = next
|
||||||
|
}
|
||||||
|
prev[level] = current
|
||||||
|
}
|
||||||
|
|
||||||
|
// Insert the node at each level
|
||||||
|
for level := 0; level < height; level++ {
|
||||||
|
node.setNext(level, prev[level].getNext(level))
|
||||||
|
prev[level].setNext(level, node)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update approximate size
|
||||||
|
atomic.AddInt64(&s.size, int64(e.size()))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find looks for an entry with the specified key
|
||||||
|
// If multiple entries have the same key, the most recent one is returned
|
||||||
|
func (s *SkipList) Find(key []byte) *entry {
|
||||||
|
var result *entry
|
||||||
|
current := s.head
|
||||||
|
height := s.getCurrentHeight()
|
||||||
|
|
||||||
|
// Start from the highest level for efficient search
|
||||||
|
for level := height - 1; level >= 0; level-- {
|
||||||
|
// Scan forward until we find a key greater than or equal to the target
|
||||||
|
for next := current.getNext(level); next != nil; next = current.getNext(level) {
|
||||||
|
cmp := next.entry.compare(key)
|
||||||
|
if cmp > 0 {
|
||||||
|
// Key at next is greater than target, go down a level
|
||||||
|
break
|
||||||
|
} else if cmp == 0 {
|
||||||
|
// Found a match, check if it's newer than our current result
|
||||||
|
if result == nil || next.entry.seqNum > result.seqNum {
|
||||||
|
result = next.entry
|
||||||
|
}
|
||||||
|
// Continue at this level to see if there are more entries with same key
|
||||||
|
current = next
|
||||||
|
} else {
|
||||||
|
// Key at next is less than target, move forward
|
||||||
|
current = next
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// For level 0, do one more sweep to ensure we get the newest entry
|
||||||
|
current = s.head
|
||||||
|
for next := current.getNext(0); next != nil; next = next.getNext(0) {
|
||||||
|
cmp := next.entry.compare(key)
|
||||||
|
if cmp > 0 {
|
||||||
|
// Past the key
|
||||||
|
break
|
||||||
|
} else if cmp == 0 {
|
||||||
|
// Found a match, update result if it's newer
|
||||||
|
if result == nil || next.entry.seqNum > result.seqNum {
|
||||||
|
result = next.entry
|
||||||
|
}
|
||||||
|
}
|
||||||
|
current = next
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// ApproximateSize returns the approximate size of the skip list in bytes
|
||||||
|
func (s *SkipList) ApproximateSize() int64 {
|
||||||
|
return atomic.LoadInt64(&s.size)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Iterator provides sequential access to the skip list entries
|
||||||
|
type Iterator struct {
|
||||||
|
list *SkipList
|
||||||
|
current *node
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewIterator creates a new Iterator for the skip list
|
||||||
|
func (s *SkipList) NewIterator() *Iterator {
|
||||||
|
return &Iterator{
|
||||||
|
list: s,
|
||||||
|
current: s.head,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Valid returns true if the iterator is positioned at a valid entry
|
||||||
|
func (it *Iterator) Valid() bool {
|
||||||
|
return it.current != nil && it.current != it.list.head
|
||||||
|
}
|
||||||
|
|
||||||
|
// Next advances the iterator to the next entry
|
||||||
|
func (it *Iterator) Next() {
|
||||||
|
if it.current == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
it.current = it.current.getNext(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SeekToFirst positions the iterator at the first entry
|
||||||
|
func (it *Iterator) SeekToFirst() {
|
||||||
|
it.current = it.list.head.getNext(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Seek positions the iterator at the first entry with a key >= target
|
||||||
|
func (it *Iterator) Seek(key []byte) {
|
||||||
|
// Start from head
|
||||||
|
current := it.list.head
|
||||||
|
height := it.list.getCurrentHeight()
|
||||||
|
|
||||||
|
// Search algorithm similar to Find
|
||||||
|
for level := height - 1; level >= 0; level-- {
|
||||||
|
for next := current.getNext(level); next != nil; next = current.getNext(level) {
|
||||||
|
if next.entry.compare(key) >= 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
current = next
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Move to the next node, which should be >= target
|
||||||
|
it.current = current.getNext(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Key returns the key of the current entry
|
||||||
|
func (it *Iterator) Key() []byte {
|
||||||
|
if !it.Valid() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return it.current.entry.key
|
||||||
|
}
|
||||||
|
|
||||||
|
// Value returns the value of the current entry
|
||||||
|
func (it *Iterator) Value() []byte {
|
||||||
|
if !it.Valid() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return it.current.entry.value
|
||||||
|
}
|
||||||
|
|
||||||
|
// Entry returns the current entry
|
||||||
|
func (it *Iterator) Entry() *entry {
|
||||||
|
if !it.Valid() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return it.current.entry
|
||||||
|
}
|
232
pkg/memtable/skiplist_test.go
Normal file
232
pkg/memtable/skiplist_test.go
Normal file
@ -0,0 +1,232 @@
|
|||||||
|
package memtable
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSkipListBasicOperations(t *testing.T) {
|
||||||
|
sl := NewSkipList()
|
||||||
|
|
||||||
|
// Test insertion
|
||||||
|
e1 := newEntry([]byte("key1"), []byte("value1"), TypeValue, 1)
|
||||||
|
e2 := newEntry([]byte("key2"), []byte("value2"), TypeValue, 2)
|
||||||
|
e3 := newEntry([]byte("key3"), []byte("value3"), TypeValue, 3)
|
||||||
|
|
||||||
|
sl.Insert(e1)
|
||||||
|
sl.Insert(e2)
|
||||||
|
sl.Insert(e3)
|
||||||
|
|
||||||
|
// Test lookup
|
||||||
|
found := sl.Find([]byte("key2"))
|
||||||
|
if found == nil {
|
||||||
|
t.Fatalf("expected to find key2, but got nil")
|
||||||
|
}
|
||||||
|
if string(found.value) != "value2" {
|
||||||
|
t.Errorf("expected value to be 'value2', got '%s'", string(found.value))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test lookup of non-existent key
|
||||||
|
notFound := sl.Find([]byte("key4"))
|
||||||
|
if notFound != nil {
|
||||||
|
t.Errorf("expected nil for non-existent key, got %v", notFound)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSkipListSequenceNumbers(t *testing.T) {
|
||||||
|
sl := NewSkipList()
|
||||||
|
|
||||||
|
// Insert same key with different sequence numbers
|
||||||
|
e1 := newEntry([]byte("key"), []byte("value1"), TypeValue, 1)
|
||||||
|
e2 := newEntry([]byte("key"), []byte("value2"), TypeValue, 2)
|
||||||
|
e3 := newEntry([]byte("key"), []byte("value3"), TypeValue, 3)
|
||||||
|
|
||||||
|
// Insert in reverse order to test ordering
|
||||||
|
sl.Insert(e3)
|
||||||
|
sl.Insert(e2)
|
||||||
|
sl.Insert(e1)
|
||||||
|
|
||||||
|
// Find should return the entry with the highest sequence number
|
||||||
|
found := sl.Find([]byte("key"))
|
||||||
|
if found == nil {
|
||||||
|
t.Fatalf("expected to find key, but got nil")
|
||||||
|
}
|
||||||
|
if string(found.value) != "value3" {
|
||||||
|
t.Errorf("expected value to be 'value3' (highest seq num), got '%s'", string(found.value))
|
||||||
|
}
|
||||||
|
if found.seqNum != 3 {
|
||||||
|
t.Errorf("expected sequence number to be 3, got %d", found.seqNum)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSkipListIterator(t *testing.T) {
|
||||||
|
sl := NewSkipList()
|
||||||
|
|
||||||
|
// Insert entries
|
||||||
|
entries := []struct {
|
||||||
|
key string
|
||||||
|
value string
|
||||||
|
seq uint64
|
||||||
|
}{
|
||||||
|
{"apple", "red", 1},
|
||||||
|
{"banana", "yellow", 2},
|
||||||
|
{"cherry", "red", 3},
|
||||||
|
{"date", "brown", 4},
|
||||||
|
{"elderberry", "purple", 5},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, e := range entries {
|
||||||
|
sl.Insert(newEntry([]byte(e.key), []byte(e.value), TypeValue, e.seq))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test iteration
|
||||||
|
it := sl.NewIterator()
|
||||||
|
it.SeekToFirst()
|
||||||
|
|
||||||
|
count := 0
|
||||||
|
for it.Valid() {
|
||||||
|
if count >= len(entries) {
|
||||||
|
t.Fatalf("iterator returned more entries than expected")
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedKey := entries[count].key
|
||||||
|
expectedValue := entries[count].value
|
||||||
|
|
||||||
|
if string(it.Key()) != expectedKey {
|
||||||
|
t.Errorf("at position %d, expected key '%s', got '%s'", count, expectedKey, string(it.Key()))
|
||||||
|
}
|
||||||
|
if string(it.Value()) != expectedValue {
|
||||||
|
t.Errorf("at position %d, expected value '%s', got '%s'", count, expectedValue, string(it.Value()))
|
||||||
|
}
|
||||||
|
|
||||||
|
it.Next()
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
|
||||||
|
if count != len(entries) {
|
||||||
|
t.Errorf("expected to iterate through %d entries, but got %d", len(entries), count)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSkipListSeek(t *testing.T) {
|
||||||
|
sl := NewSkipList()
|
||||||
|
|
||||||
|
// Insert entries
|
||||||
|
entries := []struct {
|
||||||
|
key string
|
||||||
|
value string
|
||||||
|
seq uint64
|
||||||
|
}{
|
||||||
|
{"apple", "red", 1},
|
||||||
|
{"banana", "yellow", 2},
|
||||||
|
{"cherry", "red", 3},
|
||||||
|
{"date", "brown", 4},
|
||||||
|
{"elderberry", "purple", 5},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, e := range entries {
|
||||||
|
sl.Insert(newEntry([]byte(e.key), []byte(e.value), TypeValue, e.seq))
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
seek string
|
||||||
|
expected string
|
||||||
|
valid bool
|
||||||
|
}{
|
||||||
|
// Before first entry
|
||||||
|
{"a", "apple", true},
|
||||||
|
// Exact match
|
||||||
|
{"cherry", "cherry", true},
|
||||||
|
// Between entries
|
||||||
|
{"blueberry", "cherry", true},
|
||||||
|
// After last entry
|
||||||
|
{"zebra", "", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.seek, func(t *testing.T) {
|
||||||
|
it := sl.NewIterator()
|
||||||
|
it.Seek([]byte(tc.seek))
|
||||||
|
|
||||||
|
if it.Valid() != tc.valid {
|
||||||
|
t.Errorf("expected Valid() to be %v, got %v", tc.valid, it.Valid())
|
||||||
|
}
|
||||||
|
|
||||||
|
if tc.valid {
|
||||||
|
if string(it.Key()) != tc.expected {
|
||||||
|
t.Errorf("expected key '%s', got '%s'", tc.expected, string(it.Key()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEntryComparison(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
e1, e2 *entry
|
||||||
|
expected int
|
||||||
|
}{
|
||||||
|
// Different keys
|
||||||
|
{
|
||||||
|
newEntry([]byte("a"), []byte("val"), TypeValue, 1),
|
||||||
|
newEntry([]byte("b"), []byte("val"), TypeValue, 1),
|
||||||
|
-1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
newEntry([]byte("b"), []byte("val"), TypeValue, 1),
|
||||||
|
newEntry([]byte("a"), []byte("val"), TypeValue, 1),
|
||||||
|
1,
|
||||||
|
},
|
||||||
|
// Same key, different sequence numbers (higher seq should be "less")
|
||||||
|
{
|
||||||
|
newEntry([]byte("same"), []byte("val1"), TypeValue, 2),
|
||||||
|
newEntry([]byte("same"), []byte("val2"), TypeValue, 1),
|
||||||
|
-1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
newEntry([]byte("same"), []byte("val1"), TypeValue, 1),
|
||||||
|
newEntry([]byte("same"), []byte("val2"), TypeValue, 2),
|
||||||
|
1,
|
||||||
|
},
|
||||||
|
// Same key, same sequence number
|
||||||
|
{
|
||||||
|
newEntry([]byte("same"), []byte("val"), TypeValue, 1),
|
||||||
|
newEntry([]byte("same"), []byte("val"), TypeValue, 1),
|
||||||
|
0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, tc := range testCases {
|
||||||
|
result := tc.e1.compareWithEntry(tc.e2)
|
||||||
|
expected := tc.expected
|
||||||
|
// We just care about the sign
|
||||||
|
if (result < 0 && expected >= 0) || (result > 0 && expected <= 0) || (result == 0 && expected != 0) {
|
||||||
|
t.Errorf("case %d: expected comparison result %d, got %d", i, expected, result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSkipListApproximateSize(t *testing.T) {
|
||||||
|
sl := NewSkipList()
|
||||||
|
|
||||||
|
// Initial size should be 0
|
||||||
|
if size := sl.ApproximateSize(); size != 0 {
|
||||||
|
t.Errorf("expected initial size to be 0, got %d", size)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add some entries
|
||||||
|
e1 := newEntry([]byte("key1"), []byte("value1"), TypeValue, 1)
|
||||||
|
e2 := newEntry([]byte("key2"), bytes.Repeat([]byte("v"), 100), TypeValue, 2)
|
||||||
|
|
||||||
|
sl.Insert(e1)
|
||||||
|
expectedSize := int64(e1.size())
|
||||||
|
if size := sl.ApproximateSize(); size != expectedSize {
|
||||||
|
t.Errorf("expected size to be %d after first insert, got %d", expectedSize, size)
|
||||||
|
}
|
||||||
|
|
||||||
|
sl.Insert(e2)
|
||||||
|
expectedSize += int64(e2.size())
|
||||||
|
if size := sl.ApproximateSize(); size != expectedSize {
|
||||||
|
t.Errorf("expected size to be %d after second insert, got %d", expectedSize, size)
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user