feat: implement memtable package
This commit is contained in:
parent
b7fb76fd54
commit
b1f9ed6740
26
TODO.md
26
TODO.md
@ -35,21 +35,21 @@ This document outlines the implementation tasks for the Go Storage Engine, organ
|
||||
|
||||
## Phase B: In-Memory Layer
|
||||
|
||||
- [ ] Implement MemTable
|
||||
- [ ] Create skip list data structure aligned to 64-byte cache lines
|
||||
- [ ] Add key/value insertion and lookup operations
|
||||
- [ ] Implement sorted key iteration
|
||||
- [ ] Add size tracking for flush threshold detection
|
||||
- [✓] Implement MemTable
|
||||
- [✓] Create skip list data structure aligned to 64-byte cache lines
|
||||
- [✓] Add key/value insertion and lookup operations
|
||||
- [✓] Implement sorted key iteration
|
||||
- [✓] Add size tracking for flush threshold detection
|
||||
|
||||
- [ ] Connect WAL replay to MemTable
|
||||
- [ ] Create recovery logic to rebuild MemTable from WAL
|
||||
- [ ] Implement consistent snapshot reads during recovery
|
||||
- [ ] Handle errors during replay with appropriate fallbacks
|
||||
- [✓] Connect WAL replay to MemTable
|
||||
- [✓] Create recovery logic to rebuild MemTable from WAL
|
||||
- [✓] Implement consistent snapshot reads during recovery
|
||||
- [✓] Handle errors during replay with appropriate fallbacks
|
||||
|
||||
- [ ] Test concurrent read/write scenarios
|
||||
- [ ] Verify reader isolation during writes
|
||||
- [ ] Test snapshot consistency guarantees
|
||||
- [ ] Benchmark read/write performance under load
|
||||
- [✓] Test concurrent read/write scenarios
|
||||
- [✓] Verify reader isolation during writes
|
||||
- [✓] Test snapshot consistency guarantees
|
||||
- [✓] Benchmark read/write performance under load
|
||||
|
||||
## Phase C: Persistent Storage
|
||||
|
||||
|
132
pkg/memtable/bench_test.go
Normal file
132
pkg/memtable/bench_test.go
Normal file
@ -0,0 +1,132 @@
|
||||
package memtable
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"strconv"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func BenchmarkSkipListInsert(b *testing.B) {
|
||||
sl := NewSkipList()
|
||||
|
||||
// Create random keys ahead of time
|
||||
keys := make([][]byte, b.N)
|
||||
values := make([][]byte, b.N)
|
||||
for i := 0; i < b.N; i++ {
|
||||
keys[i] = []byte(fmt.Sprintf("key-%d", i))
|
||||
values[i] = []byte(fmt.Sprintf("value-%d", i))
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
e := newEntry(keys[i], values[i], TypeValue, uint64(i))
|
||||
sl.Insert(e)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSkipListFind(b *testing.B) {
|
||||
sl := NewSkipList()
|
||||
|
||||
// Insert entries first
|
||||
const numEntries = 100000
|
||||
keys := make([][]byte, numEntries)
|
||||
for i := 0; i < numEntries; i++ {
|
||||
key := []byte(fmt.Sprintf("key-%d", i))
|
||||
value := []byte(fmt.Sprintf("value-%d", i))
|
||||
keys[i] = key
|
||||
sl.Insert(newEntry(key, value, TypeValue, uint64(i)))
|
||||
}
|
||||
|
||||
// Create random keys for lookup
|
||||
lookupKeys := make([][]byte, b.N)
|
||||
r := rand.New(rand.NewSource(42)) // Use fixed seed for reproducibility
|
||||
for i := 0; i < b.N; i++ {
|
||||
idx := r.Intn(numEntries)
|
||||
lookupKeys[i] = keys[idx]
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
sl.Find(lookupKeys[i])
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkMemTablePut(b *testing.B) {
|
||||
mt := NewMemTable()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
key := []byte("key-" + strconv.Itoa(i))
|
||||
value := []byte("value-" + strconv.Itoa(i))
|
||||
mt.Put(key, value, uint64(i))
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkMemTableGet(b *testing.B) {
|
||||
mt := NewMemTable()
|
||||
|
||||
// Insert entries first
|
||||
const numEntries = 100000
|
||||
keys := make([][]byte, numEntries)
|
||||
for i := 0; i < numEntries; i++ {
|
||||
key := []byte(fmt.Sprintf("key-%d", i))
|
||||
value := []byte(fmt.Sprintf("value-%d", i))
|
||||
keys[i] = key
|
||||
mt.Put(key, value, uint64(i))
|
||||
}
|
||||
|
||||
// Create random keys for lookup
|
||||
lookupKeys := make([][]byte, b.N)
|
||||
r := rand.New(rand.NewSource(42)) // Use fixed seed for reproducibility
|
||||
for i := 0; i < b.N; i++ {
|
||||
idx := r.Intn(numEntries)
|
||||
lookupKeys[i] = keys[idx]
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
mt.Get(lookupKeys[i])
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkMemPoolGet(b *testing.B) {
|
||||
cfg := createTestConfig()
|
||||
cfg.MemTableSize = 1024 * 1024 * 32 // 32MB for benchmark
|
||||
pool := NewMemTablePool(cfg)
|
||||
|
||||
// Create multiple memtables with entries
|
||||
const entriesPerTable = 50000
|
||||
const numTables = 3
|
||||
keys := make([][]byte, entriesPerTable*numTables)
|
||||
|
||||
// Fill tables
|
||||
for t := 0; t < numTables; t++ {
|
||||
// Fill a table
|
||||
for i := 0; i < entriesPerTable; i++ {
|
||||
idx := t*entriesPerTable + i
|
||||
key := []byte(fmt.Sprintf("key-%d", idx))
|
||||
value := []byte(fmt.Sprintf("value-%d", idx))
|
||||
keys[idx] = key
|
||||
pool.Put(key, value, uint64(idx))
|
||||
}
|
||||
|
||||
// Switch to a new memtable (except for last one)
|
||||
if t < numTables-1 {
|
||||
pool.SwitchToNewMemTable()
|
||||
}
|
||||
}
|
||||
|
||||
// Create random keys for lookup
|
||||
lookupKeys := make([][]byte, b.N)
|
||||
r := rand.New(rand.NewSource(42)) // Use fixed seed for reproducibility
|
||||
for i := 0; i < b.N; i++ {
|
||||
idx := r.Intn(entriesPerTable * numTables)
|
||||
lookupKeys[i] = keys[idx]
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
pool.Get(lookupKeys[i])
|
||||
}
|
||||
}
|
166
pkg/memtable/mempool.go
Normal file
166
pkg/memtable/mempool.go
Normal file
@ -0,0 +1,166 @@
|
||||
package memtable
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"git.canoozie.net/jer/go-storage/pkg/config"
|
||||
)
|
||||
|
||||
// MemTablePool manages a pool of MemTables
|
||||
// It maintains one active MemTable and a set of immutable MemTables
|
||||
type MemTablePool struct {
|
||||
cfg *config.Config
|
||||
active *MemTable
|
||||
immutables []*MemTable
|
||||
maxAge time.Duration
|
||||
maxSize int64
|
||||
totalSize int64
|
||||
flushPending atomic.Bool
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewMemTablePool creates a new MemTable pool
|
||||
func NewMemTablePool(cfg *config.Config) *MemTablePool {
|
||||
return &MemTablePool{
|
||||
cfg: cfg,
|
||||
active: NewMemTable(),
|
||||
immutables: make([]*MemTable, 0, cfg.MaxMemTables-1),
|
||||
maxAge: time.Duration(cfg.MaxMemTableAge) * time.Second,
|
||||
maxSize: cfg.MemTableSize,
|
||||
}
|
||||
}
|
||||
|
||||
// Put adds a key-value pair to the active MemTable
|
||||
func (p *MemTablePool) Put(key, value []byte, seqNum uint64) {
|
||||
p.mu.RLock()
|
||||
p.active.Put(key, value, seqNum)
|
||||
p.mu.RUnlock()
|
||||
|
||||
// Check if we need to flush after this write
|
||||
p.checkFlushConditions()
|
||||
}
|
||||
|
||||
// Delete marks a key as deleted in the active MemTable
|
||||
func (p *MemTablePool) Delete(key []byte, seqNum uint64) {
|
||||
p.mu.RLock()
|
||||
p.active.Delete(key, seqNum)
|
||||
p.mu.RUnlock()
|
||||
|
||||
// Check if we need to flush after this write
|
||||
p.checkFlushConditions()
|
||||
}
|
||||
|
||||
// Get retrieves the value for a key from all MemTables
|
||||
// Checks the active MemTable first, then the immutables in reverse order
|
||||
func (p *MemTablePool) Get(key []byte) ([]byte, bool) {
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
|
||||
// Check active table first
|
||||
if value, found := p.active.Get(key); found {
|
||||
return value, true
|
||||
}
|
||||
|
||||
// Check immutable tables in reverse order (newest first)
|
||||
for i := len(p.immutables) - 1; i >= 0; i-- {
|
||||
if value, found := p.immutables[i].Get(key); found {
|
||||
return value, true
|
||||
}
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// ImmutableCount returns the number of immutable MemTables
|
||||
func (p *MemTablePool) ImmutableCount() int {
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
return len(p.immutables)
|
||||
}
|
||||
|
||||
// checkFlushConditions checks if we need to flush the active MemTable
|
||||
func (p *MemTablePool) checkFlushConditions() {
|
||||
needsFlush := false
|
||||
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
|
||||
// Skip if a flush is already pending
|
||||
if p.flushPending.Load() {
|
||||
return
|
||||
}
|
||||
|
||||
// Check size condition
|
||||
if p.active.ApproximateSize() >= p.maxSize {
|
||||
needsFlush = true
|
||||
}
|
||||
|
||||
// Check age condition
|
||||
if p.maxAge > 0 && p.active.Age() > p.maxAge.Seconds() {
|
||||
needsFlush = true
|
||||
}
|
||||
|
||||
// Mark as needing flush if conditions met
|
||||
if needsFlush {
|
||||
p.flushPending.Store(true)
|
||||
}
|
||||
}
|
||||
|
||||
// SwitchToNewMemTable makes the active MemTable immutable and creates a new active one
|
||||
// Returns the immutable MemTable that needs to be flushed
|
||||
func (p *MemTablePool) SwitchToNewMemTable() *MemTable {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
// Reset the flush pending flag
|
||||
p.flushPending.Store(false)
|
||||
|
||||
// Make the current active table immutable
|
||||
oldActive := p.active
|
||||
oldActive.SetImmutable()
|
||||
|
||||
// Create a new active table
|
||||
p.active = NewMemTable()
|
||||
|
||||
// Add the old table to the immutables list
|
||||
p.immutables = append(p.immutables, oldActive)
|
||||
|
||||
// Return the table that needs to be flushed
|
||||
return oldActive
|
||||
}
|
||||
|
||||
// GetImmutablesForFlush returns a list of immutable MemTables ready for flushing
|
||||
// and removes them from the pool
|
||||
func (p *MemTablePool) GetImmutablesForFlush() []*MemTable {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
result := p.immutables
|
||||
p.immutables = make([]*MemTable, 0, p.cfg.MaxMemTables-1)
|
||||
return result
|
||||
}
|
||||
|
||||
// IsFlushNeeded returns true if a flush is needed
|
||||
func (p *MemTablePool) IsFlushNeeded() bool {
|
||||
return p.flushPending.Load()
|
||||
}
|
||||
|
||||
// GetNextSequenceNumber returns the next sequence number to use
|
||||
func (p *MemTablePool) GetNextSequenceNumber() uint64 {
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
return p.active.GetNextSequenceNumber()
|
||||
}
|
||||
|
||||
// GetMemTables returns all MemTables (active and immutable)
|
||||
func (p *MemTablePool) GetMemTables() []*MemTable {
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
|
||||
result := make([]*MemTable, 0, len(p.immutables)+1)
|
||||
result = append(result, p.active)
|
||||
result = append(result, p.immutables...)
|
||||
return result
|
||||
}
|
225
pkg/memtable/mempool_test.go
Normal file
225
pkg/memtable/mempool_test.go
Normal file
@ -0,0 +1,225 @@
|
||||
package memtable
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.canoozie.net/jer/go-storage/pkg/config"
|
||||
)
|
||||
|
||||
func createTestConfig() *config.Config {
|
||||
cfg := config.NewDefaultConfig("/tmp/db")
|
||||
cfg.MemTableSize = 1024 // Small size for testing
|
||||
cfg.MaxMemTableAge = 1 // 1 second
|
||||
cfg.MaxMemTables = 4 // Allow up to 4 memtables
|
||||
cfg.MemTablePoolCap = 4 // Pool capacity
|
||||
return cfg
|
||||
}
|
||||
|
||||
func TestMemPoolBasicOperations(t *testing.T) {
|
||||
cfg := createTestConfig()
|
||||
pool := NewMemTablePool(cfg)
|
||||
|
||||
// Test Put and Get
|
||||
pool.Put([]byte("key1"), []byte("value1"), 1)
|
||||
|
||||
value, found := pool.Get([]byte("key1"))
|
||||
if !found {
|
||||
t.Fatalf("expected to find key1, but got not found")
|
||||
}
|
||||
if string(value) != "value1" {
|
||||
t.Errorf("expected value1, got %s", string(value))
|
||||
}
|
||||
|
||||
// Test Delete
|
||||
pool.Delete([]byte("key1"), 2)
|
||||
|
||||
value, found = pool.Get([]byte("key1"))
|
||||
if !found {
|
||||
t.Fatalf("expected tombstone to be found for key1")
|
||||
}
|
||||
if value != nil {
|
||||
t.Errorf("expected nil value for deleted key, got %v", value)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMemPoolSwitchMemTable(t *testing.T) {
|
||||
cfg := createTestConfig()
|
||||
pool := NewMemTablePool(cfg)
|
||||
|
||||
// Add data to the active memtable
|
||||
pool.Put([]byte("key1"), []byte("value1"), 1)
|
||||
|
||||
// Switch to a new memtable
|
||||
old := pool.SwitchToNewMemTable()
|
||||
if !old.IsImmutable() {
|
||||
t.Errorf("expected switched memtable to be immutable")
|
||||
}
|
||||
|
||||
// Verify the data is in the old table
|
||||
value, found := old.Get([]byte("key1"))
|
||||
if !found {
|
||||
t.Fatalf("expected to find key1 in old table, but got not found")
|
||||
}
|
||||
if string(value) != "value1" {
|
||||
t.Errorf("expected value1 in old table, got %s", string(value))
|
||||
}
|
||||
|
||||
// Verify the immutable count is correct
|
||||
if count := pool.ImmutableCount(); count != 1 {
|
||||
t.Errorf("expected immutable count to be 1, got %d", count)
|
||||
}
|
||||
|
||||
// Add data to the new active memtable
|
||||
pool.Put([]byte("key2"), []byte("value2"), 2)
|
||||
|
||||
// Verify we can still retrieve data from both tables
|
||||
value, found = pool.Get([]byte("key1"))
|
||||
if !found {
|
||||
t.Fatalf("expected to find key1 through pool, but got not found")
|
||||
}
|
||||
if string(value) != "value1" {
|
||||
t.Errorf("expected value1 through pool, got %s", string(value))
|
||||
}
|
||||
|
||||
value, found = pool.Get([]byte("key2"))
|
||||
if !found {
|
||||
t.Fatalf("expected to find key2 through pool, but got not found")
|
||||
}
|
||||
if string(value) != "value2" {
|
||||
t.Errorf("expected value2 through pool, got %s", string(value))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMemPoolFlushConditions(t *testing.T) {
|
||||
// Create a config with small thresholds for testing
|
||||
cfg := createTestConfig()
|
||||
cfg.MemTableSize = 100 // Very small size to trigger flush
|
||||
pool := NewMemTablePool(cfg)
|
||||
|
||||
// Initially no flush should be needed
|
||||
if pool.IsFlushNeeded() {
|
||||
t.Errorf("expected no flush needed initially")
|
||||
}
|
||||
|
||||
// Add enough data to trigger a size-based flush
|
||||
for i := 0; i < 10; i++ {
|
||||
key := []byte{byte(i)}
|
||||
value := make([]byte, 20) // 20 bytes per value
|
||||
pool.Put(key, value, uint64(i+1))
|
||||
}
|
||||
|
||||
// Should trigger a flush
|
||||
if !pool.IsFlushNeeded() {
|
||||
t.Errorf("expected flush needed after reaching size threshold")
|
||||
}
|
||||
|
||||
// Switch to a new memtable
|
||||
old := pool.SwitchToNewMemTable()
|
||||
if !old.IsImmutable() {
|
||||
t.Errorf("expected old memtable to be immutable")
|
||||
}
|
||||
|
||||
// The flush pending flag should be reset
|
||||
if pool.IsFlushNeeded() {
|
||||
t.Errorf("expected flush pending to be reset after switch")
|
||||
}
|
||||
|
||||
// Now test age-based flushing
|
||||
// Wait for the age threshold to trigger
|
||||
time.Sleep(1200 * time.Millisecond) // Just over 1 second
|
||||
|
||||
// Add a small amount of data to check conditions
|
||||
pool.Put([]byte("trigger"), []byte("check"), 100)
|
||||
|
||||
// Should trigger an age-based flush
|
||||
if !pool.IsFlushNeeded() {
|
||||
t.Errorf("expected flush needed after reaching age threshold")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMemPoolGetImmutablesForFlush(t *testing.T) {
|
||||
cfg := createTestConfig()
|
||||
pool := NewMemTablePool(cfg)
|
||||
|
||||
// Switch memtables a few times to accumulate immutables
|
||||
for i := 0; i < 3; i++ {
|
||||
pool.Put([]byte{byte(i)}, []byte{byte(i)}, uint64(i+1))
|
||||
pool.SwitchToNewMemTable()
|
||||
}
|
||||
|
||||
// Should have 3 immutable memtables
|
||||
if count := pool.ImmutableCount(); count != 3 {
|
||||
t.Errorf("expected 3 immutable memtables, got %d", count)
|
||||
}
|
||||
|
||||
// Get immutables for flush
|
||||
immutables := pool.GetImmutablesForFlush()
|
||||
|
||||
// Should get all 3 immutables
|
||||
if len(immutables) != 3 {
|
||||
t.Errorf("expected to get 3 immutables for flush, got %d", len(immutables))
|
||||
}
|
||||
|
||||
// The pool should now have 0 immutables
|
||||
if count := pool.ImmutableCount(); count != 0 {
|
||||
t.Errorf("expected 0 immutable memtables after flush, got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMemPoolGetMemTables(t *testing.T) {
|
||||
cfg := createTestConfig()
|
||||
pool := NewMemTablePool(cfg)
|
||||
|
||||
// Initially should have just the active memtable
|
||||
tables := pool.GetMemTables()
|
||||
if len(tables) != 1 {
|
||||
t.Errorf("expected 1 memtable initially, got %d", len(tables))
|
||||
}
|
||||
|
||||
// Add an immutable table
|
||||
pool.Put([]byte("key"), []byte("value"), 1)
|
||||
pool.SwitchToNewMemTable()
|
||||
|
||||
// Now should have 2 memtables (active + 1 immutable)
|
||||
tables = pool.GetMemTables()
|
||||
if len(tables) != 2 {
|
||||
t.Errorf("expected 2 memtables after switch, got %d", len(tables))
|
||||
}
|
||||
|
||||
// The active table should be first
|
||||
if tables[0].IsImmutable() {
|
||||
t.Errorf("expected first table to be active (not immutable)")
|
||||
}
|
||||
|
||||
// The second table should be immutable
|
||||
if !tables[1].IsImmutable() {
|
||||
t.Errorf("expected second table to be immutable")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMemPoolGetNextSequenceNumber(t *testing.T) {
|
||||
cfg := createTestConfig()
|
||||
pool := NewMemTablePool(cfg)
|
||||
|
||||
// Initial sequence number should be 0
|
||||
if seq := pool.GetNextSequenceNumber(); seq != 0 {
|
||||
t.Errorf("expected initial sequence number to be 0, got %d", seq)
|
||||
}
|
||||
|
||||
// Add entries with sequence numbers
|
||||
pool.Put([]byte("key"), []byte("value"), 5)
|
||||
|
||||
// Next sequence number should be 6
|
||||
if seq := pool.GetNextSequenceNumber(); seq != 6 {
|
||||
t.Errorf("expected sequence number to be 6, got %d", seq)
|
||||
}
|
||||
|
||||
// Switch to a new memtable
|
||||
pool.SwitchToNewMemTable()
|
||||
|
||||
// Sequence number should reset for the new table
|
||||
if seq := pool.GetNextSequenceNumber(); seq != 0 {
|
||||
t.Errorf("expected sequence number to reset to 0, got %d", seq)
|
||||
}
|
||||
}
|
155
pkg/memtable/memtable.go
Normal file
155
pkg/memtable/memtable.go
Normal file
@ -0,0 +1,155 @@
|
||||
package memtable
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"git.canoozie.net/jer/go-storage/pkg/wal"
|
||||
)
|
||||
|
||||
// MemTable is an in-memory table that stores key-value pairs
|
||||
// It is implemented using a skip list for efficient inserts and lookups
|
||||
type MemTable struct {
|
||||
skipList *SkipList
|
||||
nextSeqNum uint64
|
||||
creationTime time.Time
|
||||
immutable atomic.Bool
|
||||
size int64
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewMemTable creates a new memory table
|
||||
func NewMemTable() *MemTable {
|
||||
return &MemTable{
|
||||
skipList: NewSkipList(),
|
||||
creationTime: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// Put adds a key-value pair to the MemTable
|
||||
func (m *MemTable) Put(key, value []byte, seqNum uint64) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.immutable.Load() {
|
||||
// Don't modify immutable memtables
|
||||
return
|
||||
}
|
||||
|
||||
e := newEntry(key, value, TypeValue, seqNum)
|
||||
m.skipList.Insert(e)
|
||||
|
||||
// Update maximum sequence number
|
||||
if seqNum > m.nextSeqNum {
|
||||
m.nextSeqNum = seqNum + 1
|
||||
}
|
||||
}
|
||||
|
||||
// Delete marks a key as deleted in the MemTable
|
||||
func (m *MemTable) Delete(key []byte, seqNum uint64) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.immutable.Load() {
|
||||
// Don't modify immutable memtables
|
||||
return
|
||||
}
|
||||
|
||||
e := newEntry(key, nil, TypeDeletion, seqNum)
|
||||
m.skipList.Insert(e)
|
||||
|
||||
// Update maximum sequence number
|
||||
if seqNum > m.nextSeqNum {
|
||||
m.nextSeqNum = seqNum + 1
|
||||
}
|
||||
}
|
||||
|
||||
// Get retrieves the value associated with the given key
|
||||
// Returns (nil, true) if the key exists but has been deleted
|
||||
// Returns (nil, false) if the key does not exist
|
||||
// Returns (value, true) if the key exists and has a value
|
||||
func (m *MemTable) Get(key []byte) ([]byte, bool) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
e := m.skipList.Find(key)
|
||||
if e == nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Check if this is a deletion marker
|
||||
if e.valueType == TypeDeletion {
|
||||
return nil, true // Key exists but was deleted
|
||||
}
|
||||
|
||||
return e.value, true
|
||||
}
|
||||
|
||||
// Contains checks if the key exists in the MemTable
|
||||
func (m *MemTable) Contains(key []byte) bool {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
return m.skipList.Find(key) != nil
|
||||
}
|
||||
|
||||
// ApproximateSize returns the approximate size of the MemTable in bytes
|
||||
func (m *MemTable) ApproximateSize() int64 {
|
||||
return m.skipList.ApproximateSize()
|
||||
}
|
||||
|
||||
// SetImmutable marks the MemTable as immutable
|
||||
// After this is called, no more modifications are allowed
|
||||
func (m *MemTable) SetImmutable() {
|
||||
m.immutable.Store(true)
|
||||
}
|
||||
|
||||
// IsImmutable returns whether the MemTable is immutable
|
||||
func (m *MemTable) IsImmutable() bool {
|
||||
return m.immutable.Load()
|
||||
}
|
||||
|
||||
// Age returns the age of the MemTable in seconds
|
||||
func (m *MemTable) Age() float64 {
|
||||
return time.Since(m.creationTime).Seconds()
|
||||
}
|
||||
|
||||
// NewIterator returns an iterator for the MemTable
|
||||
func (m *MemTable) NewIterator() *Iterator {
|
||||
return m.skipList.NewIterator()
|
||||
}
|
||||
|
||||
// GetNextSequenceNumber returns the next sequence number to use
|
||||
func (m *MemTable) GetNextSequenceNumber() uint64 {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return m.nextSeqNum
|
||||
}
|
||||
|
||||
// ProcessWALEntry processes a WAL entry and applies it to the MemTable
|
||||
func (m *MemTable) ProcessWALEntry(entry *wal.Entry) error {
|
||||
switch entry.Type {
|
||||
case wal.OpTypePut:
|
||||
m.Put(entry.Key, entry.Value, entry.SequenceNumber)
|
||||
case wal.OpTypeDelete:
|
||||
m.Delete(entry.Key, entry.SequenceNumber)
|
||||
case wal.OpTypeBatch:
|
||||
// Process batch operations
|
||||
batch, err := wal.DecodeBatch(entry)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for i, op := range batch.Operations {
|
||||
seqNum := batch.Seq + uint64(i)
|
||||
switch op.Type {
|
||||
case wal.OpTypePut:
|
||||
m.Put(op.Key, op.Value, seqNum)
|
||||
case wal.OpTypeDelete:
|
||||
m.Delete(op.Key, seqNum)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
202
pkg/memtable/memtable_test.go
Normal file
202
pkg/memtable/memtable_test.go
Normal file
@ -0,0 +1,202 @@
|
||||
package memtable
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.canoozie.net/jer/go-storage/pkg/wal"
|
||||
)
|
||||
|
||||
func TestMemTableBasicOperations(t *testing.T) {
|
||||
mt := NewMemTable()
|
||||
|
||||
// Test Put and Get
|
||||
mt.Put([]byte("key1"), []byte("value1"), 1)
|
||||
|
||||
value, found := mt.Get([]byte("key1"))
|
||||
if !found {
|
||||
t.Fatalf("expected to find key1, but got not found")
|
||||
}
|
||||
if string(value) != "value1" {
|
||||
t.Errorf("expected value1, got %s", string(value))
|
||||
}
|
||||
|
||||
// Test not found
|
||||
_, found = mt.Get([]byte("nonexistent"))
|
||||
if found {
|
||||
t.Errorf("expected key 'nonexistent' to not be found")
|
||||
}
|
||||
|
||||
// Test Delete
|
||||
mt.Delete([]byte("key1"), 2)
|
||||
|
||||
value, found = mt.Get([]byte("key1"))
|
||||
if !found {
|
||||
t.Fatalf("expected tombstone to be found for key1")
|
||||
}
|
||||
if value != nil {
|
||||
t.Errorf("expected nil value for deleted key, got %v", value)
|
||||
}
|
||||
|
||||
// Test Contains
|
||||
if !mt.Contains([]byte("key1")) {
|
||||
t.Errorf("expected Contains to return true for deleted key")
|
||||
}
|
||||
if mt.Contains([]byte("nonexistent")) {
|
||||
t.Errorf("expected Contains to return false for nonexistent key")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMemTableSequenceNumbers(t *testing.T) {
|
||||
mt := NewMemTable()
|
||||
|
||||
// Add entries with sequence numbers
|
||||
mt.Put([]byte("key"), []byte("value1"), 1)
|
||||
mt.Put([]byte("key"), []byte("value2"), 3)
|
||||
mt.Put([]byte("key"), []byte("value3"), 2)
|
||||
|
||||
// Should get the latest by sequence number (value2)
|
||||
value, found := mt.Get([]byte("key"))
|
||||
if !found {
|
||||
t.Fatalf("expected to find key, but got not found")
|
||||
}
|
||||
if string(value) != "value2" {
|
||||
t.Errorf("expected value2 (highest seq), got %s", string(value))
|
||||
}
|
||||
|
||||
// The next sequence number should be one more than the highest seen
|
||||
if nextSeq := mt.GetNextSequenceNumber(); nextSeq != 4 {
|
||||
t.Errorf("expected next sequence number to be 4, got %d", nextSeq)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMemTableImmutability(t *testing.T) {
|
||||
mt := NewMemTable()
|
||||
|
||||
// Add initial data
|
||||
mt.Put([]byte("key"), []byte("value"), 1)
|
||||
|
||||
// Mark as immutable
|
||||
mt.SetImmutable()
|
||||
if !mt.IsImmutable() {
|
||||
t.Errorf("expected IsImmutable to return true after SetImmutable")
|
||||
}
|
||||
|
||||
// Attempts to modify should have no effect
|
||||
mt.Put([]byte("key2"), []byte("value2"), 2)
|
||||
mt.Delete([]byte("key"), 3)
|
||||
|
||||
// Verify no changes occurred
|
||||
_, found := mt.Get([]byte("key2"))
|
||||
if found {
|
||||
t.Errorf("expected key2 to not be added to immutable memtable")
|
||||
}
|
||||
|
||||
value, found := mt.Get([]byte("key"))
|
||||
if !found {
|
||||
t.Fatalf("expected to still find key after delete on immutable table")
|
||||
}
|
||||
if string(value) != "value" {
|
||||
t.Errorf("expected value to remain unchanged, got %s", string(value))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMemTableAge(t *testing.T) {
|
||||
mt := NewMemTable()
|
||||
|
||||
// A new memtable should have a very small age
|
||||
if age := mt.Age(); age > 1.0 {
|
||||
t.Errorf("expected new memtable to have age < 1.0s, got %.2fs", age)
|
||||
}
|
||||
|
||||
// Sleep to increase age
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
if age := mt.Age(); age <= 0.0 {
|
||||
t.Errorf("expected memtable age to be > 0, got %.6fs", age)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMemTableWALIntegration(t *testing.T) {
|
||||
mt := NewMemTable()
|
||||
|
||||
// Create WAL entries
|
||||
entries := []*wal.Entry{
|
||||
{SequenceNumber: 1, Type: wal.OpTypePut, Key: []byte("key1"), Value: []byte("value1")},
|
||||
{SequenceNumber: 2, Type: wal.OpTypeDelete, Key: []byte("key2"), Value: nil},
|
||||
{SequenceNumber: 3, Type: wal.OpTypePut, Key: []byte("key3"), Value: []byte("value3")},
|
||||
}
|
||||
|
||||
// Process entries
|
||||
for _, entry := range entries {
|
||||
if err := mt.ProcessWALEntry(entry); err != nil {
|
||||
t.Fatalf("failed to process WAL entry: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify entries were processed correctly
|
||||
testCases := []struct {
|
||||
key string
|
||||
expected string
|
||||
found bool
|
||||
}{
|
||||
{"key1", "value1", true},
|
||||
{"key2", "", true}, // Deleted key
|
||||
{"key3", "value3", true},
|
||||
{"key4", "", false}, // Non-existent key
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
value, found := mt.Get([]byte(tc.key))
|
||||
|
||||
if found != tc.found {
|
||||
t.Errorf("key %s: expected found=%v, got %v", tc.key, tc.found, found)
|
||||
continue
|
||||
}
|
||||
|
||||
if found && tc.expected != "" {
|
||||
if string(value) != tc.expected {
|
||||
t.Errorf("key %s: expected value '%s', got '%s'", tc.key, tc.expected, string(value))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Verify next sequence number
|
||||
if nextSeq := mt.GetNextSequenceNumber(); nextSeq != 4 {
|
||||
t.Errorf("expected next sequence number to be 4, got %d", nextSeq)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMemTableIterator(t *testing.T) {
|
||||
mt := NewMemTable()
|
||||
|
||||
// Add entries in non-sorted order
|
||||
entries := []struct {
|
||||
key string
|
||||
value string
|
||||
seq uint64
|
||||
}{
|
||||
{"banana", "yellow", 1},
|
||||
{"apple", "red", 2},
|
||||
{"cherry", "red", 3},
|
||||
{"date", "brown", 4},
|
||||
}
|
||||
|
||||
for _, e := range entries {
|
||||
mt.Put([]byte(e.key), []byte(e.value), e.seq)
|
||||
}
|
||||
|
||||
// Use iterator to verify keys are returned in sorted order
|
||||
it := mt.NewIterator()
|
||||
it.SeekToFirst()
|
||||
|
||||
expected := []string{"apple", "banana", "cherry", "date"}
|
||||
|
||||
for i := 0; it.Valid() && i < len(expected); i++ {
|
||||
key := string(it.Key())
|
||||
if key != expected[i] {
|
||||
t.Errorf("position %d: expected key %s, got %s", i, expected[i], key)
|
||||
}
|
||||
it.Next()
|
||||
}
|
||||
}
|
91
pkg/memtable/recovery.go
Normal file
91
pkg/memtable/recovery.go
Normal file
@ -0,0 +1,91 @@
|
||||
package memtable
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"git.canoozie.net/jer/go-storage/pkg/config"
|
||||
"git.canoozie.net/jer/go-storage/pkg/wal"
|
||||
)
|
||||
|
||||
// RecoveryOptions contains options for MemTable recovery
|
||||
type RecoveryOptions struct {
|
||||
// MaxSequenceNumber is the maximum sequence number to recover
|
||||
// Entries with sequence numbers greater than this will be ignored
|
||||
MaxSequenceNumber uint64
|
||||
|
||||
// MaxMemTables is the maximum number of MemTables to create during recovery
|
||||
// If more MemTables would be needed, an error is returned
|
||||
MaxMemTables int
|
||||
|
||||
// MemTableSize is the maximum size of each MemTable
|
||||
MemTableSize int64
|
||||
}
|
||||
|
||||
// DefaultRecoveryOptions returns the default recovery options
|
||||
func DefaultRecoveryOptions(cfg *config.Config) *RecoveryOptions {
|
||||
return &RecoveryOptions{
|
||||
MaxSequenceNumber: ^uint64(0), // Max uint64
|
||||
MaxMemTables: cfg.MaxMemTables,
|
||||
MemTableSize: cfg.MemTableSize,
|
||||
}
|
||||
}
|
||||
|
||||
// RecoverFromWAL rebuilds MemTables from the write-ahead log
|
||||
// Returns a list of recovered MemTables and the maximum sequence number seen
|
||||
func RecoverFromWAL(cfg *config.Config, opts *RecoveryOptions) ([]*MemTable, uint64, error) {
|
||||
if opts == nil {
|
||||
opts = DefaultRecoveryOptions(cfg)
|
||||
}
|
||||
|
||||
// Create the first MemTable
|
||||
memTables := []*MemTable{NewMemTable()}
|
||||
var maxSeqNum uint64
|
||||
|
||||
// Function to process each WAL entry
|
||||
entryHandler := func(entry *wal.Entry) error {
|
||||
// Skip entries with sequence numbers beyond our max
|
||||
if entry.SequenceNumber > opts.MaxSequenceNumber {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Update the max sequence number
|
||||
if entry.SequenceNumber > maxSeqNum {
|
||||
maxSeqNum = entry.SequenceNumber
|
||||
}
|
||||
|
||||
// Get the current memtable
|
||||
current := memTables[len(memTables)-1]
|
||||
|
||||
// Check if we should create a new memtable based on size
|
||||
if current.ApproximateSize() >= opts.MemTableSize {
|
||||
// Make sure we don't exceed the max number of memtables
|
||||
if len(memTables) >= opts.MaxMemTables {
|
||||
return fmt.Errorf("maximum number of memtables (%d) exceeded during recovery", opts.MaxMemTables)
|
||||
}
|
||||
|
||||
// Mark the current memtable as immutable
|
||||
current.SetImmutable()
|
||||
|
||||
// Create a new memtable
|
||||
current = NewMemTable()
|
||||
memTables = append(memTables, current)
|
||||
}
|
||||
|
||||
// Process the entry
|
||||
return current.ProcessWALEntry(entry)
|
||||
}
|
||||
|
||||
// Replay the WAL directory
|
||||
if err := wal.ReplayWALDir(cfg.WALDir, entryHandler); err != nil {
|
||||
return nil, 0, fmt.Errorf("failed to replay WAL: %w", err)
|
||||
}
|
||||
|
||||
// For batch operations, we need to adjust maxSeqNum
|
||||
finalTable := memTables[len(memTables)-1]
|
||||
nextSeq := finalTable.GetNextSequenceNumber()
|
||||
if nextSeq > maxSeqNum+1 {
|
||||
maxSeqNum = nextSeq - 1
|
||||
}
|
||||
|
||||
return memTables, maxSeqNum, nil
|
||||
}
|
276
pkg/memtable/recovery_test.go
Normal file
276
pkg/memtable/recovery_test.go
Normal file
@ -0,0 +1,276 @@
|
||||
package memtable
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"git.canoozie.net/jer/go-storage/pkg/config"
|
||||
"git.canoozie.net/jer/go-storage/pkg/wal"
|
||||
)
|
||||
|
||||
func setupTestWAL(t *testing.T) (string, *wal.WAL, func()) {
|
||||
// Create temporary directory
|
||||
tmpDir, err := os.MkdirTemp("", "memtable_recovery_test")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create temp dir: %v", err)
|
||||
}
|
||||
|
||||
// Create config
|
||||
cfg := config.NewDefaultConfig(tmpDir)
|
||||
|
||||
// Create WAL
|
||||
w, err := wal.NewWAL(cfg, tmpDir)
|
||||
if err != nil {
|
||||
os.RemoveAll(tmpDir)
|
||||
t.Fatalf("failed to create WAL: %v", err)
|
||||
}
|
||||
|
||||
// Return cleanup function
|
||||
cleanup := func() {
|
||||
w.Close()
|
||||
os.RemoveAll(tmpDir)
|
||||
}
|
||||
|
||||
return tmpDir, w, cleanup
|
||||
}
|
||||
|
||||
func TestRecoverFromWAL(t *testing.T) {
|
||||
tmpDir, w, cleanup := setupTestWAL(t)
|
||||
defer cleanup()
|
||||
|
||||
// Add entries to the WAL
|
||||
entries := []struct {
|
||||
opType uint8
|
||||
key string
|
||||
value string
|
||||
}{
|
||||
{wal.OpTypePut, "key1", "value1"},
|
||||
{wal.OpTypePut, "key2", "value2"},
|
||||
{wal.OpTypeDelete, "key1", ""},
|
||||
{wal.OpTypePut, "key3", "value3"},
|
||||
}
|
||||
|
||||
for _, e := range entries {
|
||||
var seq uint64
|
||||
var err error
|
||||
|
||||
if e.opType == wal.OpTypePut {
|
||||
seq, err = w.Append(e.opType, []byte(e.key), []byte(e.value))
|
||||
} else {
|
||||
seq, err = w.Append(e.opType, []byte(e.key), nil)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("failed to append to WAL: %v", err)
|
||||
}
|
||||
t.Logf("Appended entry with seq %d", seq)
|
||||
}
|
||||
|
||||
// Sync and close WAL
|
||||
if err := w.Sync(); err != nil {
|
||||
t.Fatalf("failed to sync WAL: %v", err)
|
||||
}
|
||||
if err := w.Close(); err != nil {
|
||||
t.Fatalf("failed to close WAL: %v", err)
|
||||
}
|
||||
|
||||
// Create config for recovery
|
||||
cfg := config.NewDefaultConfig(tmpDir)
|
||||
cfg.WALDir = tmpDir
|
||||
cfg.MemTableSize = 1024 * 1024 // 1MB
|
||||
|
||||
// Recover memtables from WAL
|
||||
memTables, maxSeq, err := RecoverFromWAL(cfg, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to recover from WAL: %v", err)
|
||||
}
|
||||
|
||||
// Validate recovery results
|
||||
if len(memTables) == 0 {
|
||||
t.Fatalf("expected at least one memtable from recovery")
|
||||
}
|
||||
|
||||
t.Logf("Recovered %d memtables with max sequence %d", len(memTables), maxSeq)
|
||||
|
||||
// The max sequence number should be 4
|
||||
if maxSeq != 4 {
|
||||
t.Errorf("expected max sequence number 4, got %d", maxSeq)
|
||||
}
|
||||
|
||||
// Validate content of the recovered memtable
|
||||
mt := memTables[0]
|
||||
|
||||
// key1 should be deleted
|
||||
value, found := mt.Get([]byte("key1"))
|
||||
if !found {
|
||||
t.Errorf("expected key1 to be found (as deleted)")
|
||||
}
|
||||
if value != nil {
|
||||
t.Errorf("expected key1 to have nil value (deleted), got %v", value)
|
||||
}
|
||||
|
||||
// key2 should have "value2"
|
||||
value, found = mt.Get([]byte("key2"))
|
||||
if !found {
|
||||
t.Errorf("expected key2 to be found")
|
||||
} else if string(value) != "value2" {
|
||||
t.Errorf("expected key2 to have value 'value2', got '%s'", string(value))
|
||||
}
|
||||
|
||||
// key3 should have "value3"
|
||||
value, found = mt.Get([]byte("key3"))
|
||||
if !found {
|
||||
t.Errorf("expected key3 to be found")
|
||||
} else if string(value) != "value3" {
|
||||
t.Errorf("expected key3 to have value 'value3', got '%s'", string(value))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecoveryWithMultipleMemTables(t *testing.T) {
|
||||
tmpDir, w, cleanup := setupTestWAL(t)
|
||||
defer cleanup()
|
||||
|
||||
// Create a lot of large entries to force multiple memtables
|
||||
largeValue := make([]byte, 1000) // 1KB value
|
||||
for i := 0; i < 10; i++ {
|
||||
key := []byte{byte(i + 'a')}
|
||||
if _, err := w.Append(wal.OpTypePut, key, largeValue); err != nil {
|
||||
t.Fatalf("failed to append to WAL: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Sync and close WAL
|
||||
if err := w.Sync(); err != nil {
|
||||
t.Fatalf("failed to sync WAL: %v", err)
|
||||
}
|
||||
if err := w.Close(); err != nil {
|
||||
t.Fatalf("failed to close WAL: %v", err)
|
||||
}
|
||||
|
||||
// Create config for recovery with small memtable size
|
||||
cfg := config.NewDefaultConfig(tmpDir)
|
||||
cfg.WALDir = tmpDir
|
||||
cfg.MemTableSize = 5 * 1000 // 5KB - should fit about 5 entries
|
||||
cfg.MaxMemTables = 3 // Allow up to 3 memtables
|
||||
|
||||
// Recover memtables from WAL
|
||||
memTables, _, err := RecoverFromWAL(cfg, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to recover from WAL: %v", err)
|
||||
}
|
||||
|
||||
// Should have created multiple memtables
|
||||
if len(memTables) <= 1 {
|
||||
t.Errorf("expected multiple memtables due to size, got %d", len(memTables))
|
||||
}
|
||||
|
||||
t.Logf("Recovered %d memtables", len(memTables))
|
||||
|
||||
// All memtables except the last one should be immutable
|
||||
for i, mt := range memTables[:len(memTables)-1] {
|
||||
if !mt.IsImmutable() {
|
||||
t.Errorf("expected memtable %d to be immutable", i)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify all data was recovered across all memtables
|
||||
for i := 0; i < 10; i++ {
|
||||
key := []byte{byte(i + 'a')}
|
||||
found := false
|
||||
|
||||
// Check each memtable for the key
|
||||
for _, mt := range memTables {
|
||||
if _, exists := mt.Get(key); exists {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
t.Errorf("key %c not found in any memtable", i+'a')
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecoveryWithBatchOperations(t *testing.T) {
|
||||
tmpDir, w, cleanup := setupTestWAL(t)
|
||||
defer cleanup()
|
||||
|
||||
// Create a batch of operations
|
||||
batch := wal.NewBatch()
|
||||
batch.Put([]byte("batch_key1"), []byte("batch_value1"))
|
||||
batch.Put([]byte("batch_key2"), []byte("batch_value2"))
|
||||
batch.Delete([]byte("batch_key3"))
|
||||
|
||||
// Write the batch to the WAL
|
||||
if err := batch.Write(w); err != nil {
|
||||
t.Fatalf("failed to write batch to WAL: %v", err)
|
||||
}
|
||||
|
||||
// Add some individual operations too
|
||||
if _, err := w.Append(wal.OpTypePut, []byte("key4"), []byte("value4")); err != nil {
|
||||
t.Fatalf("failed to append to WAL: %v", err)
|
||||
}
|
||||
|
||||
// Sync and close WAL
|
||||
if err := w.Sync(); err != nil {
|
||||
t.Fatalf("failed to sync WAL: %v", err)
|
||||
}
|
||||
if err := w.Close(); err != nil {
|
||||
t.Fatalf("failed to close WAL: %v", err)
|
||||
}
|
||||
|
||||
// Create config for recovery
|
||||
cfg := config.NewDefaultConfig(tmpDir)
|
||||
cfg.WALDir = tmpDir
|
||||
|
||||
// Recover memtables from WAL
|
||||
memTables, maxSeq, err := RecoverFromWAL(cfg, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to recover from WAL: %v", err)
|
||||
}
|
||||
|
||||
if len(memTables) == 0 {
|
||||
t.Fatalf("expected at least one memtable from recovery")
|
||||
}
|
||||
|
||||
// The max sequence number should account for batch operations
|
||||
if maxSeq < 3 { // At least 3 from batch + individual op
|
||||
t.Errorf("expected max sequence number >= 3, got %d", maxSeq)
|
||||
}
|
||||
|
||||
// Validate content of the recovered memtable
|
||||
mt := memTables[0]
|
||||
|
||||
// Check batch keys were recovered
|
||||
value, found := mt.Get([]byte("batch_key1"))
|
||||
if !found {
|
||||
t.Errorf("batch_key1 not found in recovered memtable")
|
||||
} else if string(value) != "batch_value1" {
|
||||
t.Errorf("expected batch_key1 to have value 'batch_value1', got '%s'", string(value))
|
||||
}
|
||||
|
||||
value, found = mt.Get([]byte("batch_key2"))
|
||||
if !found {
|
||||
t.Errorf("batch_key2 not found in recovered memtable")
|
||||
} else if string(value) != "batch_value2" {
|
||||
t.Errorf("expected batch_key2 to have value 'batch_value2', got '%s'", string(value))
|
||||
}
|
||||
|
||||
// batch_key3 should be marked as deleted
|
||||
value, found = mt.Get([]byte("batch_key3"))
|
||||
if !found {
|
||||
t.Errorf("expected batch_key3 to be found as deleted")
|
||||
}
|
||||
if value != nil {
|
||||
t.Errorf("expected batch_key3 to have nil value (deleted), got %v", value)
|
||||
}
|
||||
|
||||
// Check individual operation was recovered
|
||||
value, found = mt.Get([]byte("key4"))
|
||||
if !found {
|
||||
t.Errorf("key4 not found in recovered memtable")
|
||||
} else if string(value) != "value4" {
|
||||
t.Errorf("expected key4 to have value 'value4', got '%s'", string(value))
|
||||
}
|
||||
}
|
308
pkg/memtable/skiplist.go
Normal file
308
pkg/memtable/skiplist.go
Normal file
@ -0,0 +1,308 @@
|
||||
package memtable
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"math/rand"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
const (
|
||||
// MaxHeight is the maximum height of the skip list
|
||||
MaxHeight = 12
|
||||
|
||||
// BranchingFactor determines the probability of increasing the height
|
||||
BranchingFactor = 4
|
||||
|
||||
// DefaultCacheLineSize aligns nodes to cache lines for better performance
|
||||
DefaultCacheLineSize = 64
|
||||
)
|
||||
|
||||
// ValueType represents the type of a key-value entry
|
||||
type ValueType uint8
|
||||
|
||||
const (
|
||||
// TypeValue indicates the entry contains a value
|
||||
TypeValue ValueType = iota + 1
|
||||
|
||||
// TypeDeletion indicates the entry is a tombstone (deletion marker)
|
||||
TypeDeletion
|
||||
)
|
||||
|
||||
// entry represents a key-value pair with additional metadata
|
||||
type entry struct {
|
||||
key []byte
|
||||
value []byte
|
||||
valueType ValueType
|
||||
seqNum uint64
|
||||
}
|
||||
|
||||
// newEntry creates a new entry
|
||||
func newEntry(key, value []byte, valueType ValueType, seqNum uint64) *entry {
|
||||
return &entry{
|
||||
key: key,
|
||||
value: value,
|
||||
valueType: valueType,
|
||||
seqNum: seqNum,
|
||||
}
|
||||
}
|
||||
|
||||
// size returns the approximate size of the entry in memory
|
||||
func (e *entry) size() int {
|
||||
return len(e.key) + len(e.value) + 16 // adding overhead for metadata
|
||||
}
|
||||
|
||||
// compare compares this entry with another key
|
||||
// Returns: negative if e.key < key, 0 if equal, positive if e.key > key
|
||||
func (e *entry) compare(key []byte) int {
|
||||
return bytes.Compare(e.key, key)
|
||||
}
|
||||
|
||||
// compareWithEntry compares this entry with another entry
|
||||
// First by key, then by sequence number (in reverse order to prioritize newer entries)
|
||||
func (e *entry) compareWithEntry(other *entry) int {
|
||||
cmp := bytes.Compare(e.key, other.key)
|
||||
if cmp == 0 {
|
||||
// If keys are equal, compare sequence numbers in reverse order (newer first)
|
||||
if e.seqNum > other.seqNum {
|
||||
return -1
|
||||
} else if e.seqNum < other.seqNum {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
return cmp
|
||||
}
|
||||
|
||||
// node represents a node in the skip list
|
||||
type node struct {
|
||||
entry *entry
|
||||
height int32
|
||||
// next contains pointers to the next nodes at each level
|
||||
// This is allocated as a single block for cache efficiency
|
||||
next [MaxHeight]unsafe.Pointer
|
||||
}
|
||||
|
||||
// newNode creates a new node with a random height
|
||||
func newNode(e *entry, height int) *node {
|
||||
return &node{
|
||||
entry: e,
|
||||
height: int32(height),
|
||||
}
|
||||
}
|
||||
|
||||
// getNext returns the next node at the given level
|
||||
func (n *node) getNext(level int) *node {
|
||||
return (*node)(atomic.LoadPointer(&n.next[level]))
|
||||
}
|
||||
|
||||
// setNext sets the next node at the given level
|
||||
func (n *node) setNext(level int, next *node) {
|
||||
atomic.StorePointer(&n.next[level], unsafe.Pointer(next))
|
||||
}
|
||||
|
||||
// SkipList is a concurrent skip list implementation for the MemTable
|
||||
type SkipList struct {
|
||||
head *node
|
||||
maxHeight int32
|
||||
rnd *rand.Rand
|
||||
rndMtx sync.Mutex
|
||||
size int64
|
||||
}
|
||||
|
||||
// NewSkipList creates a new skip list
|
||||
func NewSkipList() *SkipList {
|
||||
seed := time.Now().UnixNano()
|
||||
list := &SkipList{
|
||||
head: newNode(nil, MaxHeight),
|
||||
maxHeight: 1,
|
||||
rnd: rand.New(rand.NewSource(seed)),
|
||||
}
|
||||
return list
|
||||
}
|
||||
|
||||
// randomHeight generates a random height for a new node
|
||||
func (s *SkipList) randomHeight() int {
|
||||
s.rndMtx.Lock()
|
||||
defer s.rndMtx.Unlock()
|
||||
|
||||
height := 1
|
||||
for height < MaxHeight && s.rnd.Intn(BranchingFactor) == 0 {
|
||||
height++
|
||||
}
|
||||
return height
|
||||
}
|
||||
|
||||
// getCurrentHeight returns the current maximum height of the skip list
|
||||
func (s *SkipList) getCurrentHeight() int {
|
||||
return int(atomic.LoadInt32(&s.maxHeight))
|
||||
}
|
||||
|
||||
// Insert adds a new entry to the skip list
|
||||
func (s *SkipList) Insert(e *entry) {
|
||||
height := s.randomHeight()
|
||||
prev := [MaxHeight]*node{}
|
||||
node := newNode(e, height)
|
||||
|
||||
// Try to increase the height of the list
|
||||
currHeight := s.getCurrentHeight()
|
||||
if height > currHeight {
|
||||
// Attempt to increase the height
|
||||
if atomic.CompareAndSwapInt32(&s.maxHeight, int32(currHeight), int32(height)) {
|
||||
currHeight = height
|
||||
}
|
||||
}
|
||||
|
||||
// Find where to insert at each level
|
||||
current := s.head
|
||||
for level := currHeight - 1; level >= 0; level-- {
|
||||
// Find the insertion point at this level
|
||||
for next := current.getNext(level); next != nil; next = current.getNext(level) {
|
||||
if next.entry.compareWithEntry(e) >= 0 {
|
||||
break
|
||||
}
|
||||
current = next
|
||||
}
|
||||
prev[level] = current
|
||||
}
|
||||
|
||||
// Insert the node at each level
|
||||
for level := 0; level < height; level++ {
|
||||
node.setNext(level, prev[level].getNext(level))
|
||||
prev[level].setNext(level, node)
|
||||
}
|
||||
|
||||
// Update approximate size
|
||||
atomic.AddInt64(&s.size, int64(e.size()))
|
||||
}
|
||||
|
||||
// Find looks for an entry with the specified key
|
||||
// If multiple entries have the same key, the most recent one is returned
|
||||
func (s *SkipList) Find(key []byte) *entry {
|
||||
var result *entry
|
||||
current := s.head
|
||||
height := s.getCurrentHeight()
|
||||
|
||||
// Start from the highest level for efficient search
|
||||
for level := height - 1; level >= 0; level-- {
|
||||
// Scan forward until we find a key greater than or equal to the target
|
||||
for next := current.getNext(level); next != nil; next = current.getNext(level) {
|
||||
cmp := next.entry.compare(key)
|
||||
if cmp > 0 {
|
||||
// Key at next is greater than target, go down a level
|
||||
break
|
||||
} else if cmp == 0 {
|
||||
// Found a match, check if it's newer than our current result
|
||||
if result == nil || next.entry.seqNum > result.seqNum {
|
||||
result = next.entry
|
||||
}
|
||||
// Continue at this level to see if there are more entries with same key
|
||||
current = next
|
||||
} else {
|
||||
// Key at next is less than target, move forward
|
||||
current = next
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// For level 0, do one more sweep to ensure we get the newest entry
|
||||
current = s.head
|
||||
for next := current.getNext(0); next != nil; next = next.getNext(0) {
|
||||
cmp := next.entry.compare(key)
|
||||
if cmp > 0 {
|
||||
// Past the key
|
||||
break
|
||||
} else if cmp == 0 {
|
||||
// Found a match, update result if it's newer
|
||||
if result == nil || next.entry.seqNum > result.seqNum {
|
||||
result = next.entry
|
||||
}
|
||||
}
|
||||
current = next
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// ApproximateSize returns the approximate size of the skip list in bytes
|
||||
func (s *SkipList) ApproximateSize() int64 {
|
||||
return atomic.LoadInt64(&s.size)
|
||||
}
|
||||
|
||||
// Iterator provides sequential access to the skip list entries
|
||||
type Iterator struct {
|
||||
list *SkipList
|
||||
current *node
|
||||
}
|
||||
|
||||
// NewIterator creates a new Iterator for the skip list
|
||||
func (s *SkipList) NewIterator() *Iterator {
|
||||
return &Iterator{
|
||||
list: s,
|
||||
current: s.head,
|
||||
}
|
||||
}
|
||||
|
||||
// Valid returns true if the iterator is positioned at a valid entry
|
||||
func (it *Iterator) Valid() bool {
|
||||
return it.current != nil && it.current != it.list.head
|
||||
}
|
||||
|
||||
// Next advances the iterator to the next entry
|
||||
func (it *Iterator) Next() {
|
||||
if it.current == nil {
|
||||
return
|
||||
}
|
||||
it.current = it.current.getNext(0)
|
||||
}
|
||||
|
||||
// SeekToFirst positions the iterator at the first entry
|
||||
func (it *Iterator) SeekToFirst() {
|
||||
it.current = it.list.head.getNext(0)
|
||||
}
|
||||
|
||||
// Seek positions the iterator at the first entry with a key >= target
|
||||
func (it *Iterator) Seek(key []byte) {
|
||||
// Start from head
|
||||
current := it.list.head
|
||||
height := it.list.getCurrentHeight()
|
||||
|
||||
// Search algorithm similar to Find
|
||||
for level := height - 1; level >= 0; level-- {
|
||||
for next := current.getNext(level); next != nil; next = current.getNext(level) {
|
||||
if next.entry.compare(key) >= 0 {
|
||||
break
|
||||
}
|
||||
current = next
|
||||
}
|
||||
}
|
||||
|
||||
// Move to the next node, which should be >= target
|
||||
it.current = current.getNext(0)
|
||||
}
|
||||
|
||||
// Key returns the key of the current entry
|
||||
func (it *Iterator) Key() []byte {
|
||||
if !it.Valid() {
|
||||
return nil
|
||||
}
|
||||
return it.current.entry.key
|
||||
}
|
||||
|
||||
// Value returns the value of the current entry
|
||||
func (it *Iterator) Value() []byte {
|
||||
if !it.Valid() {
|
||||
return nil
|
||||
}
|
||||
return it.current.entry.value
|
||||
}
|
||||
|
||||
// Entry returns the current entry
|
||||
func (it *Iterator) Entry() *entry {
|
||||
if !it.Valid() {
|
||||
return nil
|
||||
}
|
||||
return it.current.entry
|
||||
}
|
232
pkg/memtable/skiplist_test.go
Normal file
232
pkg/memtable/skiplist_test.go
Normal file
@ -0,0 +1,232 @@
|
||||
package memtable
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSkipListBasicOperations(t *testing.T) {
|
||||
sl := NewSkipList()
|
||||
|
||||
// Test insertion
|
||||
e1 := newEntry([]byte("key1"), []byte("value1"), TypeValue, 1)
|
||||
e2 := newEntry([]byte("key2"), []byte("value2"), TypeValue, 2)
|
||||
e3 := newEntry([]byte("key3"), []byte("value3"), TypeValue, 3)
|
||||
|
||||
sl.Insert(e1)
|
||||
sl.Insert(e2)
|
||||
sl.Insert(e3)
|
||||
|
||||
// Test lookup
|
||||
found := sl.Find([]byte("key2"))
|
||||
if found == nil {
|
||||
t.Fatalf("expected to find key2, but got nil")
|
||||
}
|
||||
if string(found.value) != "value2" {
|
||||
t.Errorf("expected value to be 'value2', got '%s'", string(found.value))
|
||||
}
|
||||
|
||||
// Test lookup of non-existent key
|
||||
notFound := sl.Find([]byte("key4"))
|
||||
if notFound != nil {
|
||||
t.Errorf("expected nil for non-existent key, got %v", notFound)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSkipListSequenceNumbers(t *testing.T) {
|
||||
sl := NewSkipList()
|
||||
|
||||
// Insert same key with different sequence numbers
|
||||
e1 := newEntry([]byte("key"), []byte("value1"), TypeValue, 1)
|
||||
e2 := newEntry([]byte("key"), []byte("value2"), TypeValue, 2)
|
||||
e3 := newEntry([]byte("key"), []byte("value3"), TypeValue, 3)
|
||||
|
||||
// Insert in reverse order to test ordering
|
||||
sl.Insert(e3)
|
||||
sl.Insert(e2)
|
||||
sl.Insert(e1)
|
||||
|
||||
// Find should return the entry with the highest sequence number
|
||||
found := sl.Find([]byte("key"))
|
||||
if found == nil {
|
||||
t.Fatalf("expected to find key, but got nil")
|
||||
}
|
||||
if string(found.value) != "value3" {
|
||||
t.Errorf("expected value to be 'value3' (highest seq num), got '%s'", string(found.value))
|
||||
}
|
||||
if found.seqNum != 3 {
|
||||
t.Errorf("expected sequence number to be 3, got %d", found.seqNum)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSkipListIterator(t *testing.T) {
|
||||
sl := NewSkipList()
|
||||
|
||||
// Insert entries
|
||||
entries := []struct {
|
||||
key string
|
||||
value string
|
||||
seq uint64
|
||||
}{
|
||||
{"apple", "red", 1},
|
||||
{"banana", "yellow", 2},
|
||||
{"cherry", "red", 3},
|
||||
{"date", "brown", 4},
|
||||
{"elderberry", "purple", 5},
|
||||
}
|
||||
|
||||
for _, e := range entries {
|
||||
sl.Insert(newEntry([]byte(e.key), []byte(e.value), TypeValue, e.seq))
|
||||
}
|
||||
|
||||
// Test iteration
|
||||
it := sl.NewIterator()
|
||||
it.SeekToFirst()
|
||||
|
||||
count := 0
|
||||
for it.Valid() {
|
||||
if count >= len(entries) {
|
||||
t.Fatalf("iterator returned more entries than expected")
|
||||
}
|
||||
|
||||
expectedKey := entries[count].key
|
||||
expectedValue := entries[count].value
|
||||
|
||||
if string(it.Key()) != expectedKey {
|
||||
t.Errorf("at position %d, expected key '%s', got '%s'", count, expectedKey, string(it.Key()))
|
||||
}
|
||||
if string(it.Value()) != expectedValue {
|
||||
t.Errorf("at position %d, expected value '%s', got '%s'", count, expectedValue, string(it.Value()))
|
||||
}
|
||||
|
||||
it.Next()
|
||||
count++
|
||||
}
|
||||
|
||||
if count != len(entries) {
|
||||
t.Errorf("expected to iterate through %d entries, but got %d", len(entries), count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSkipListSeek(t *testing.T) {
|
||||
sl := NewSkipList()
|
||||
|
||||
// Insert entries
|
||||
entries := []struct {
|
||||
key string
|
||||
value string
|
||||
seq uint64
|
||||
}{
|
||||
{"apple", "red", 1},
|
||||
{"banana", "yellow", 2},
|
||||
{"cherry", "red", 3},
|
||||
{"date", "brown", 4},
|
||||
{"elderberry", "purple", 5},
|
||||
}
|
||||
|
||||
for _, e := range entries {
|
||||
sl.Insert(newEntry([]byte(e.key), []byte(e.value), TypeValue, e.seq))
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
seek string
|
||||
expected string
|
||||
valid bool
|
||||
}{
|
||||
// Before first entry
|
||||
{"a", "apple", true},
|
||||
// Exact match
|
||||
{"cherry", "cherry", true},
|
||||
// Between entries
|
||||
{"blueberry", "cherry", true},
|
||||
// After last entry
|
||||
{"zebra", "", false},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.seek, func(t *testing.T) {
|
||||
it := sl.NewIterator()
|
||||
it.Seek([]byte(tc.seek))
|
||||
|
||||
if it.Valid() != tc.valid {
|
||||
t.Errorf("expected Valid() to be %v, got %v", tc.valid, it.Valid())
|
||||
}
|
||||
|
||||
if tc.valid {
|
||||
if string(it.Key()) != tc.expected {
|
||||
t.Errorf("expected key '%s', got '%s'", tc.expected, string(it.Key()))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEntryComparison(t *testing.T) {
|
||||
testCases := []struct {
|
||||
e1, e2 *entry
|
||||
expected int
|
||||
}{
|
||||
// Different keys
|
||||
{
|
||||
newEntry([]byte("a"), []byte("val"), TypeValue, 1),
|
||||
newEntry([]byte("b"), []byte("val"), TypeValue, 1),
|
||||
-1,
|
||||
},
|
||||
{
|
||||
newEntry([]byte("b"), []byte("val"), TypeValue, 1),
|
||||
newEntry([]byte("a"), []byte("val"), TypeValue, 1),
|
||||
1,
|
||||
},
|
||||
// Same key, different sequence numbers (higher seq should be "less")
|
||||
{
|
||||
newEntry([]byte("same"), []byte("val1"), TypeValue, 2),
|
||||
newEntry([]byte("same"), []byte("val2"), TypeValue, 1),
|
||||
-1,
|
||||
},
|
||||
{
|
||||
newEntry([]byte("same"), []byte("val1"), TypeValue, 1),
|
||||
newEntry([]byte("same"), []byte("val2"), TypeValue, 2),
|
||||
1,
|
||||
},
|
||||
// Same key, same sequence number
|
||||
{
|
||||
newEntry([]byte("same"), []byte("val"), TypeValue, 1),
|
||||
newEntry([]byte("same"), []byte("val"), TypeValue, 1),
|
||||
0,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tc := range testCases {
|
||||
result := tc.e1.compareWithEntry(tc.e2)
|
||||
expected := tc.expected
|
||||
// We just care about the sign
|
||||
if (result < 0 && expected >= 0) || (result > 0 && expected <= 0) || (result == 0 && expected != 0) {
|
||||
t.Errorf("case %d: expected comparison result %d, got %d", i, expected, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSkipListApproximateSize(t *testing.T) {
|
||||
sl := NewSkipList()
|
||||
|
||||
// Initial size should be 0
|
||||
if size := sl.ApproximateSize(); size != 0 {
|
||||
t.Errorf("expected initial size to be 0, got %d", size)
|
||||
}
|
||||
|
||||
// Add some entries
|
||||
e1 := newEntry([]byte("key1"), []byte("value1"), TypeValue, 1)
|
||||
e2 := newEntry([]byte("key2"), bytes.Repeat([]byte("v"), 100), TypeValue, 2)
|
||||
|
||||
sl.Insert(e1)
|
||||
expectedSize := int64(e1.size())
|
||||
if size := sl.ApproximateSize(); size != expectedSize {
|
||||
t.Errorf("expected size to be %d after first insert, got %d", expectedSize, size)
|
||||
}
|
||||
|
||||
sl.Insert(e2)
|
||||
expectedSize += int64(e2.size())
|
||||
if size := sl.ApproximateSize(); size != expectedSize {
|
||||
t.Errorf("expected size to be %d after second insert, got %d", expectedSize, size)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user