feat: implement config and WAL packages

- Implement config package with serializable configuration and manifest
- Implement WAL with durability guarantees and fragmentation support
- Fix mutex deadlock bug in WAL sync operations
- Add comprehensive tests for both packages
This commit is contained in:
Jeremy Tregunna 2025-04-19 14:51:17 -06:00
parent ee23a47a74
commit b03176f136
Signed by: jer
GPG Key ID: 1278B36BA6F5D5E4
10 changed files with 2484 additions and 0 deletions

32
CLAUDE.md Normal file
View File

@ -0,0 +1,32 @@
# CLAUDE.md
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
## Build Commands
- Build: `go build ./...`
- Run tests: `go test ./...`
- Run single test: `go test ./pkg/path/to/package -run TestName`
- Benchmark: `go test ./pkg/path/to/package -bench .`
- Race detector: `go test -race ./...`
## Linting/Formatting
- Format code: `go fmt ./...`
- Static analysis: `go vet ./...`
- Install golangci-lint: `go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest`
- Run linter: `golangci-lint run`
## Code Style Guidelines
- Follow Go standard project layout in pkg/ and internal/ directories
- Use descriptive error types with context wrapping
- Implement single-writer architecture for write paths
- Allow concurrent reads via snapshots
- Use interfaces for component boundaries
- Follow idiomatic Go practices
- Add appropriate validation, especially for checksums
- All exported functions must have documentation comments
- For transaction management, use WAL for durability/atomicity
## Version Control
- Use git for version control
- All commit messages must use semantic commit messages
- All commit messages must not reference code being generated or co-authored by Claude

199
pkg/config/config.go Normal file
View File

@ -0,0 +1,199 @@
package config
import (
"encoding/json"
"errors"
"fmt"
"os"
"path/filepath"
"sync"
)
const (
DefaultManifestFileName = "MANIFEST"
CurrentManifestVersion = 1
)
var (
ErrInvalidConfig = errors.New("invalid configuration")
ErrManifestNotFound = errors.New("manifest not found")
ErrInvalidManifest = errors.New("invalid manifest")
)
type SyncMode int
const (
SyncNone SyncMode = iota
SyncBatch
SyncImmediate
)
type Config struct {
Version int `json:"version"`
// WAL configuration
WALDir string `json:"wal_dir"`
WALSyncMode SyncMode `json:"wal_sync_mode"`
WALSyncBytes int64 `json:"wal_sync_bytes"`
// MemTable configuration
MemTableSize int64 `json:"memtable_size"`
MaxMemTables int `json:"max_memtables"`
MaxMemTableAge int64 `json:"max_memtable_age"`
MemTablePoolCap int `json:"memtable_pool_cap"`
// SSTable configuration
SSTDir string `json:"sst_dir"`
SSTableBlockSize int `json:"sstable_block_size"`
SSTableIndexSize int `json:"sstable_index_size"`
SSTableMaxSize int64 `json:"sstable_max_size"`
SSTableRestartSize int `json:"sstable_restart_size"`
// Compaction configuration
CompactionLevels int `json:"compaction_levels"`
CompactionRatio float64 `json:"compaction_ratio"`
CompactionThreads int `json:"compaction_threads"`
CompactionInterval int64 `json:"compaction_interval"`
mu sync.RWMutex
}
// NewDefaultConfig creates a Config with recommended default values
func NewDefaultConfig(dbPath string) *Config {
walDir := filepath.Join(dbPath, "wal")
sstDir := filepath.Join(dbPath, "sst")
return &Config{
Version: CurrentManifestVersion,
// WAL defaults
WALDir: walDir,
WALSyncMode: SyncBatch,
WALSyncBytes: 1024 * 1024, // 1MB
// MemTable defaults
MemTableSize: 32 * 1024 * 1024, // 32MB
MaxMemTables: 4,
MaxMemTableAge: 600, // 10 minutes
MemTablePoolCap: 4,
// SSTable defaults
SSTDir: sstDir,
SSTableBlockSize: 16 * 1024, // 16KB
SSTableIndexSize: 64 * 1024, // 64KB
SSTableMaxSize: 64 * 1024 * 1024, // 64MB
SSTableRestartSize: 16, // Restart points every 16 keys
// Compaction defaults
CompactionLevels: 7,
CompactionRatio: 10,
CompactionThreads: 2,
CompactionInterval: 30, // 30 seconds
}
}
// Validate checks if the configuration is valid
func (c *Config) Validate() error {
c.mu.RLock()
defer c.mu.RUnlock()
if c.Version <= 0 {
return fmt.Errorf("%w: invalid version %d", ErrInvalidConfig, c.Version)
}
if c.WALDir == "" {
return fmt.Errorf("%w: WAL directory not specified", ErrInvalidConfig)
}
if c.SSTDir == "" {
return fmt.Errorf("%w: SSTable directory not specified", ErrInvalidConfig)
}
if c.MemTableSize <= 0 {
return fmt.Errorf("%w: MemTable size must be positive", ErrInvalidConfig)
}
if c.MaxMemTables <= 0 {
return fmt.Errorf("%w: Max MemTables must be positive", ErrInvalidConfig)
}
if c.SSTableBlockSize <= 0 {
return fmt.Errorf("%w: SSTable block size must be positive", ErrInvalidConfig)
}
if c.SSTableIndexSize <= 0 {
return fmt.Errorf("%w: SSTable index size must be positive", ErrInvalidConfig)
}
if c.CompactionLevels <= 0 {
return fmt.Errorf("%w: Compaction levels must be positive", ErrInvalidConfig)
}
if c.CompactionRatio <= 1.0 {
return fmt.Errorf("%w: Compaction ratio must be greater than 1.0", ErrInvalidConfig)
}
return nil
}
// LoadConfigFromManifest loads just the configuration portion from the manifest file
func LoadConfigFromManifest(dbPath string) (*Config, error) {
manifestPath := filepath.Join(dbPath, DefaultManifestFileName)
data, err := os.ReadFile(manifestPath)
if err != nil {
if os.IsNotExist(err) {
return nil, ErrManifestNotFound
}
return nil, fmt.Errorf("failed to read manifest: %w", err)
}
var cfg Config
if err := json.Unmarshal(data, &cfg); err != nil {
return nil, fmt.Errorf("%w: %v", ErrInvalidManifest, err)
}
if err := cfg.Validate(); err != nil {
return nil, err
}
return &cfg, nil
}
// SaveManifest saves the configuration to the manifest file
func (c *Config) SaveManifest(dbPath string) error {
c.mu.RLock()
defer c.mu.RUnlock()
if err := c.Validate(); err != nil {
return err
}
if err := os.MkdirAll(dbPath, 0755); err != nil {
return fmt.Errorf("failed to create directory: %w", err)
}
manifestPath := filepath.Join(dbPath, DefaultManifestFileName)
tempPath := manifestPath + ".tmp"
data, err := json.MarshalIndent(c, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal config: %w", err)
}
if err := os.WriteFile(tempPath, data, 0644); err != nil {
return fmt.Errorf("failed to write manifest: %w", err)
}
if err := os.Rename(tempPath, manifestPath); err != nil {
return fmt.Errorf("failed to rename manifest: %w", err)
}
return nil
}
// Update applies the given function to modify the configuration
func (c *Config) Update(fn func(*Config)) {
c.mu.Lock()
defer c.mu.Unlock()
fn(c)
}

167
pkg/config/config_test.go Normal file
View File

@ -0,0 +1,167 @@
package config
import (
"os"
"path/filepath"
"testing"
)
func TestNewDefaultConfig(t *testing.T) {
dbPath := "/tmp/testdb"
cfg := NewDefaultConfig(dbPath)
if cfg.Version != CurrentManifestVersion {
t.Errorf("expected version %d, got %d", CurrentManifestVersion, cfg.Version)
}
if cfg.WALDir != filepath.Join(dbPath, "wal") {
t.Errorf("expected WAL dir %s, got %s", filepath.Join(dbPath, "wal"), cfg.WALDir)
}
if cfg.SSTDir != filepath.Join(dbPath, "sst") {
t.Errorf("expected SST dir %s, got %s", filepath.Join(dbPath, "sst"), cfg.SSTDir)
}
// Test default values
if cfg.WALSyncMode != SyncBatch {
t.Errorf("expected WAL sync mode %d, got %d", SyncBatch, cfg.WALSyncMode)
}
if cfg.MemTableSize != 32*1024*1024 {
t.Errorf("expected memtable size %d, got %d", 32*1024*1024, cfg.MemTableSize)
}
}
func TestConfigValidate(t *testing.T) {
cfg := NewDefaultConfig("/tmp/testdb")
// Valid config
if err := cfg.Validate(); err != nil {
t.Errorf("expected valid config, got error: %v", err)
}
// Test invalid configs
testCases := []struct {
name string
mutate func(*Config)
expected string
}{
{
name: "invalid version",
mutate: func(c *Config) {
c.Version = 0
},
expected: "invalid configuration: invalid version 0",
},
{
name: "empty WAL dir",
mutate: func(c *Config) {
c.WALDir = ""
},
expected: "invalid configuration: WAL directory not specified",
},
{
name: "empty SST dir",
mutate: func(c *Config) {
c.SSTDir = ""
},
expected: "invalid configuration: SSTable directory not specified",
},
{
name: "zero memtable size",
mutate: func(c *Config) {
c.MemTableSize = 0
},
expected: "invalid configuration: MemTable size must be positive",
},
{
name: "negative max memtables",
mutate: func(c *Config) {
c.MaxMemTables = -1
},
expected: "invalid configuration: Max MemTables must be positive",
},
{
name: "zero block size",
mutate: func(c *Config) {
c.SSTableBlockSize = 0
},
expected: "invalid configuration: SSTable block size must be positive",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
cfg := NewDefaultConfig("/tmp/testdb")
tc.mutate(cfg)
err := cfg.Validate()
if err == nil {
t.Fatal("expected error, got nil")
}
if err.Error() != tc.expected {
t.Errorf("expected error %q, got %q", tc.expected, err.Error())
}
})
}
}
func TestConfigManifestSaveLoad(t *testing.T) {
// Create a temporary directory for the test
tempDir, err := os.MkdirTemp("", "config_test")
if err != nil {
t.Fatalf("failed to create temp dir: %v", err)
}
defer os.RemoveAll(tempDir)
// Create a config and save it
cfg := NewDefaultConfig(tempDir)
cfg.MemTableSize = 16 * 1024 * 1024 // 16MB
cfg.CompactionThreads = 4
if err := cfg.SaveManifest(tempDir); err != nil {
t.Fatalf("failed to save manifest: %v", err)
}
// Load the config
loadedCfg, err := LoadConfigFromManifest(tempDir)
if err != nil {
t.Fatalf("failed to load manifest: %v", err)
}
// Verify loaded config
if loadedCfg.MemTableSize != cfg.MemTableSize {
t.Errorf("expected memtable size %d, got %d", cfg.MemTableSize, loadedCfg.MemTableSize)
}
if loadedCfg.CompactionThreads != cfg.CompactionThreads {
t.Errorf("expected compaction threads %d, got %d", cfg.CompactionThreads, loadedCfg.CompactionThreads)
}
// Test loading non-existent manifest
nonExistentDir := filepath.Join(tempDir, "nonexistent")
_, err = LoadConfigFromManifest(nonExistentDir)
if err != ErrManifestNotFound {
t.Errorf("expected ErrManifestNotFound, got %v", err)
}
}
func TestConfigUpdate(t *testing.T) {
cfg := NewDefaultConfig("/tmp/testdb")
// Update config
cfg.Update(func(c *Config) {
c.MemTableSize = 64 * 1024 * 1024 // 64MB
c.MaxMemTables = 8
})
// Verify update
if cfg.MemTableSize != 64*1024*1024 {
t.Errorf("expected memtable size %d, got %d", 64*1024*1024, cfg.MemTableSize)
}
if cfg.MaxMemTables != 8 {
t.Errorf("expected max memtables %d, got %d", 8, cfg.MaxMemTables)
}
}

214
pkg/config/manifest.go Normal file
View File

@ -0,0 +1,214 @@
package config
import (
"encoding/json"
"fmt"
"io"
"os"
"path/filepath"
"sync"
"time"
)
type ManifestEntry struct {
Timestamp int64 `json:"timestamp"`
Version int `json:"version"`
Config *Config `json:"config"`
FileSystem map[string]int64 `json:"filesystem,omitempty"` // Map of file paths to sequence numbers
}
type Manifest struct {
DBPath string
Entries []ManifestEntry
Current *ManifestEntry
LastUpdate time.Time
mu sync.RWMutex
}
// NewManifest creates a new manifest for the given database path
func NewManifest(dbPath string, config *Config) (*Manifest, error) {
if config == nil {
config = NewDefaultConfig(dbPath)
}
if err := config.Validate(); err != nil {
return nil, err
}
entry := ManifestEntry{
Timestamp: time.Now().Unix(),
Version: CurrentManifestVersion,
Config: config,
}
m := &Manifest{
DBPath: dbPath,
Entries: []ManifestEntry{entry},
Current: &entry,
LastUpdate: time.Now(),
}
return m, nil
}
// LoadManifest loads an existing manifest from the database directory
func LoadManifest(dbPath string) (*Manifest, error) {
manifestPath := filepath.Join(dbPath, DefaultManifestFileName)
file, err := os.Open(manifestPath)
if err != nil {
if os.IsNotExist(err) {
return nil, ErrManifestNotFound
}
return nil, fmt.Errorf("failed to open manifest: %w", err)
}
defer file.Close()
data, err := io.ReadAll(file)
if err != nil {
return nil, fmt.Errorf("failed to read manifest: %w", err)
}
var entries []ManifestEntry
if err := json.Unmarshal(data, &entries); err != nil {
return nil, fmt.Errorf("%w: %v", ErrInvalidManifest, err)
}
if len(entries) == 0 {
return nil, fmt.Errorf("%w: no entries in manifest", ErrInvalidManifest)
}
current := &entries[len(entries)-1]
if err := current.Config.Validate(); err != nil {
return nil, err
}
m := &Manifest{
DBPath: dbPath,
Entries: entries,
Current: current,
LastUpdate: time.Now(),
}
return m, nil
}
// Save persists the manifest to disk
func (m *Manifest) Save() error {
m.mu.Lock()
defer m.mu.Unlock()
if err := m.Current.Config.Validate(); err != nil {
return err
}
if err := os.MkdirAll(m.DBPath, 0755); err != nil {
return fmt.Errorf("failed to create directory: %w", err)
}
manifestPath := filepath.Join(m.DBPath, DefaultManifestFileName)
tempPath := manifestPath + ".tmp"
data, err := json.MarshalIndent(m.Entries, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal manifest: %w", err)
}
if err := os.WriteFile(tempPath, data, 0644); err != nil {
return fmt.Errorf("failed to write manifest: %w", err)
}
if err := os.Rename(tempPath, manifestPath); err != nil {
return fmt.Errorf("failed to rename manifest: %w", err)
}
m.LastUpdate = time.Now()
return nil
}
// UpdateConfig creates a new configuration entry
func (m *Manifest) UpdateConfig(fn func(*Config)) error {
m.mu.Lock()
defer m.mu.Unlock()
// Create a copy of the current config
currentJSON, err := json.Marshal(m.Current.Config)
if err != nil {
return fmt.Errorf("failed to marshal current config: %w", err)
}
var newConfig Config
if err := json.Unmarshal(currentJSON, &newConfig); err != nil {
return fmt.Errorf("failed to unmarshal config: %w", err)
}
// Apply the update function
fn(&newConfig)
// Validate the new config
if err := newConfig.Validate(); err != nil {
return err
}
// Create a new entry
entry := ManifestEntry{
Timestamp: time.Now().Unix(),
Version: CurrentManifestVersion,
Config: &newConfig,
}
m.Entries = append(m.Entries, entry)
m.Current = &m.Entries[len(m.Entries)-1]
return nil
}
// AddFile registers a file in the manifest
func (m *Manifest) AddFile(path string, seqNum int64) error {
m.mu.Lock()
defer m.mu.Unlock()
if m.Current.FileSystem == nil {
m.Current.FileSystem = make(map[string]int64)
}
m.Current.FileSystem[path] = seqNum
return nil
}
// RemoveFile removes a file from the manifest
func (m *Manifest) RemoveFile(path string) error {
m.mu.Lock()
defer m.mu.Unlock()
if m.Current.FileSystem == nil {
return nil
}
delete(m.Current.FileSystem, path)
return nil
}
// GetConfig returns the current configuration
func (m *Manifest) GetConfig() *Config {
m.mu.RLock()
defer m.mu.RUnlock()
return m.Current.Config
}
// GetFiles returns all files registered in the manifest
func (m *Manifest) GetFiles() map[string]int64 {
m.mu.RLock()
defer m.mu.RUnlock()
if m.Current.FileSystem == nil {
return make(map[string]int64)
}
// Return a copy to prevent concurrent map access
files := make(map[string]int64, len(m.Current.FileSystem))
for k, v := range m.Current.FileSystem {
files[k] = v
}
return files
}

176
pkg/config/manifest_test.go Normal file
View File

@ -0,0 +1,176 @@
package config
import (
"os"
"testing"
)
func TestNewManifest(t *testing.T) {
dbPath := "/tmp/testdb"
cfg := NewDefaultConfig(dbPath)
manifest, err := NewManifest(dbPath, cfg)
if err != nil {
t.Fatalf("failed to create manifest: %v", err)
}
if manifest.DBPath != dbPath {
t.Errorf("expected DBPath %s, got %s", dbPath, manifest.DBPath)
}
if len(manifest.Entries) != 1 {
t.Errorf("expected 1 entry, got %d", len(manifest.Entries))
}
if manifest.Current == nil {
t.Error("current entry is nil")
} else if manifest.Current.Config != cfg {
t.Error("current config does not match the provided config")
}
}
func TestManifestUpdateConfig(t *testing.T) {
dbPath := "/tmp/testdb"
cfg := NewDefaultConfig(dbPath)
manifest, err := NewManifest(dbPath, cfg)
if err != nil {
t.Fatalf("failed to create manifest: %v", err)
}
// Update config
err = manifest.UpdateConfig(func(c *Config) {
c.MemTableSize = 64 * 1024 * 1024 // 64MB
c.MaxMemTables = 8
})
if err != nil {
t.Fatalf("failed to update config: %v", err)
}
// Verify entries count
if len(manifest.Entries) != 2 {
t.Errorf("expected 2 entries, got %d", len(manifest.Entries))
}
// Verify updated config
current := manifest.GetConfig()
if current.MemTableSize != 64*1024*1024 {
t.Errorf("expected memtable size %d, got %d", 64*1024*1024, current.MemTableSize)
}
if current.MaxMemTables != 8 {
t.Errorf("expected max memtables %d, got %d", 8, current.MaxMemTables)
}
}
func TestManifestFileTracking(t *testing.T) {
dbPath := "/tmp/testdb"
cfg := NewDefaultConfig(dbPath)
manifest, err := NewManifest(dbPath, cfg)
if err != nil {
t.Fatalf("failed to create manifest: %v", err)
}
// Add files
err = manifest.AddFile("sst/000001.sst", 1)
if err != nil {
t.Fatalf("failed to add file: %v", err)
}
err = manifest.AddFile("sst/000002.sst", 2)
if err != nil {
t.Fatalf("failed to add file: %v", err)
}
// Verify files
files := manifest.GetFiles()
if len(files) != 2 {
t.Errorf("expected 2 files, got %d", len(files))
}
if files["sst/000001.sst"] != 1 {
t.Errorf("expected sequence number 1, got %d", files["sst/000001.sst"])
}
if files["sst/000002.sst"] != 2 {
t.Errorf("expected sequence number 2, got %d", files["sst/000002.sst"])
}
// Remove file
err = manifest.RemoveFile("sst/000001.sst")
if err != nil {
t.Fatalf("failed to remove file: %v", err)
}
// Verify files after removal
files = manifest.GetFiles()
if len(files) != 1 {
t.Errorf("expected 1 file, got %d", len(files))
}
if _, exists := files["sst/000001.sst"]; exists {
t.Error("file should have been removed")
}
}
func TestManifestSaveLoad(t *testing.T) {
// Create a temporary directory for the test
tempDir, err := os.MkdirTemp("", "manifest_test")
if err != nil {
t.Fatalf("failed to create temp dir: %v", err)
}
defer os.RemoveAll(tempDir)
// Create a manifest
cfg := NewDefaultConfig(tempDir)
manifest, err := NewManifest(tempDir, cfg)
if err != nil {
t.Fatalf("failed to create manifest: %v", err)
}
// Update config
err = manifest.UpdateConfig(func(c *Config) {
c.MemTableSize = 64 * 1024 * 1024 // 64MB
})
if err != nil {
t.Fatalf("failed to update config: %v", err)
}
// Add some files
err = manifest.AddFile("sst/000001.sst", 1)
if err != nil {
t.Fatalf("failed to add file: %v", err)
}
// Save the manifest
if err := manifest.Save(); err != nil {
t.Fatalf("failed to save manifest: %v", err)
}
// Load the manifest
loadedManifest, err := LoadManifest(tempDir)
if err != nil {
t.Fatalf("failed to load manifest: %v", err)
}
// Verify entries count
if len(loadedManifest.Entries) != len(manifest.Entries) {
t.Errorf("expected %d entries, got %d", len(manifest.Entries), len(loadedManifest.Entries))
}
// Verify config
loadedConfig := loadedManifest.GetConfig()
if loadedConfig.MemTableSize != 64*1024*1024 {
t.Errorf("expected memtable size %d, got %d", 64*1024*1024, loadedConfig.MemTableSize)
}
// Verify files
loadedFiles := loadedManifest.GetFiles()
if len(loadedFiles) != 1 {
t.Errorf("expected 1 file, got %d", len(loadedFiles))
}
if loadedFiles["sst/000001.sst"] != 1 {
t.Errorf("expected sequence number 1, got %d", loadedFiles["sst/000001.sst"])
}
}

244
pkg/wal/batch.go Normal file
View File

@ -0,0 +1,244 @@
package wal
import (
"encoding/binary"
"errors"
"fmt"
)
const (
BatchHeaderSize = 12 // count(4) + seq(8)
)
var (
ErrEmptyBatch = errors.New("batch is empty")
ErrBatchTooLarge = errors.New("batch too large")
)
// BatchOperation represents a single operation in a batch
type BatchOperation struct {
Type uint8 // OpTypePut, OpTypeDelete, etc.
Key []byte
Value []byte
}
// Batch represents a collection of operations to be performed atomically
type Batch struct {
Operations []BatchOperation
Seq uint64 // Base sequence number
}
// NewBatch creates a new empty batch
func NewBatch() *Batch {
return &Batch{
Operations: make([]BatchOperation, 0, 16),
}
}
// Put adds a Put operation to the batch
func (b *Batch) Put(key, value []byte) {
b.Operations = append(b.Operations, BatchOperation{
Type: OpTypePut,
Key: key,
Value: value,
})
}
// Delete adds a Delete operation to the batch
func (b *Batch) Delete(key []byte) {
b.Operations = append(b.Operations, BatchOperation{
Type: OpTypeDelete,
Key: key,
})
}
// Count returns the number of operations in the batch
func (b *Batch) Count() int {
return len(b.Operations)
}
// Reset clears all operations from the batch
func (b *Batch) Reset() {
b.Operations = b.Operations[:0]
b.Seq = 0
}
// Size estimates the size of the batch in the WAL
func (b *Batch) Size() int {
size := BatchHeaderSize // count + seq
for _, op := range b.Operations {
// Type(1) + KeyLen(4) + Key
size += 1 + 4 + len(op.Key)
// ValueLen(4) + Value for Put operations
if op.Type != OpTypeDelete {
size += 4 + len(op.Value)
}
}
return size
}
// Write writes the batch to the WAL
func (b *Batch) Write(w *WAL) error {
if len(b.Operations) == 0 {
return ErrEmptyBatch
}
// Estimate batch size
size := b.Size()
if size > MaxRecordSize {
return fmt.Errorf("%w: %d > %d", ErrBatchTooLarge, size, MaxRecordSize)
}
// Serialize batch
data := make([]byte, size)
offset := 0
// Write count
binary.LittleEndian.PutUint32(data[offset:offset+4], uint32(len(b.Operations)))
offset += 4
// Write sequence base (will be set by WAL.AppendBatch)
offset += 8
// Write operations
for _, op := range b.Operations {
// Write type
data[offset] = op.Type
offset++
// Write key length
binary.LittleEndian.PutUint32(data[offset:offset+4], uint32(len(op.Key)))
offset += 4
// Write key
copy(data[offset:], op.Key)
offset += len(op.Key)
// Write value for non-delete operations
if op.Type != OpTypeDelete {
// Write value length
binary.LittleEndian.PutUint32(data[offset:offset+4], uint32(len(op.Value)))
offset += 4
// Write value
copy(data[offset:], op.Value)
offset += len(op.Value)
}
}
// Append to WAL
w.mu.Lock()
defer w.mu.Unlock()
if w.closed {
return ErrWALClosed
}
// Set the sequence number
b.Seq = w.nextSequence
binary.LittleEndian.PutUint64(data[4:12], b.Seq)
// Increment sequence for future operations
w.nextSequence += uint64(len(b.Operations))
// Write as a batch entry
if err := w.writeRecord(uint8(RecordTypeFull), OpTypeBatch, b.Seq, data, nil); err != nil {
return err
}
// Sync if needed
return w.maybeSync()
}
// DecodeBatch decodes a batch entry from a WAL record
func DecodeBatch(entry *Entry) (*Batch, error) {
if entry.Type != OpTypeBatch {
return nil, fmt.Errorf("not a batch entry: type %d", entry.Type)
}
// For batch entries, the batch data is in the Key field, not Value
data := entry.Key
if len(data) < BatchHeaderSize {
return nil, fmt.Errorf("%w: batch header too small", ErrCorruptRecord)
}
// Read count and sequence
count := binary.LittleEndian.Uint32(data[0:4])
seq := binary.LittleEndian.Uint64(data[4:12])
batch := &Batch{
Operations: make([]BatchOperation, 0, count),
Seq: seq,
}
offset := BatchHeaderSize
// Read operations
for i := uint32(0); i < count; i++ {
// Check if we have enough data for type
if offset >= len(data) {
return nil, fmt.Errorf("%w: unexpected end of batch data", ErrCorruptRecord)
}
// Read type
opType := data[offset]
offset++
// Validate operation type
if opType != OpTypePut && opType != OpTypeDelete && opType != OpTypeMerge {
return nil, fmt.Errorf("%w: %d", ErrInvalidOpType, opType)
}
// Check if we have enough data for key length
if offset+4 > len(data) {
return nil, fmt.Errorf("%w: unexpected end of batch data", ErrCorruptRecord)
}
// Read key length
keyLen := binary.LittleEndian.Uint32(data[offset : offset+4])
offset += 4
// Validate key length
if offset+int(keyLen) > len(data) {
return nil, fmt.Errorf("%w: invalid key length %d", ErrCorruptRecord, keyLen)
}
// Read key
key := make([]byte, keyLen)
copy(key, data[offset:offset+int(keyLen)])
offset += int(keyLen)
var value []byte
if opType != OpTypeDelete {
// Check if we have enough data for value length
if offset+4 > len(data) {
return nil, fmt.Errorf("%w: unexpected end of batch data", ErrCorruptRecord)
}
// Read value length
valueLen := binary.LittleEndian.Uint32(data[offset : offset+4])
offset += 4
// Validate value length
if offset+int(valueLen) > len(data) {
return nil, fmt.Errorf("%w: invalid value length %d", ErrCorruptRecord, valueLen)
}
// Read value
value = make([]byte, valueLen)
copy(value, data[offset:offset+int(valueLen)])
offset += int(valueLen)
}
batch.Operations = append(batch.Operations, BatchOperation{
Type: opType,
Key: key,
Value: value,
})
}
return batch, nil
}

187
pkg/wal/batch_test.go Normal file
View File

@ -0,0 +1,187 @@
package wal
import (
"bytes"
"fmt"
"os"
"testing"
)
func TestBatchOperations(t *testing.T) {
batch := NewBatch()
// Test initially empty
if batch.Count() != 0 {
t.Errorf("Expected empty batch, got count %d", batch.Count())
}
// Add operations
batch.Put([]byte("key1"), []byte("value1"))
batch.Put([]byte("key2"), []byte("value2"))
batch.Delete([]byte("key3"))
// Check count
if batch.Count() != 3 {
t.Errorf("Expected batch with 3 operations, got %d", batch.Count())
}
// Check size calculation
expectedSize := BatchHeaderSize // count + seq
expectedSize += 1 + 4 + 4 + len("key1") + len("value1") // type + keylen + vallen + key + value
expectedSize += 1 + 4 + 4 + len("key2") + len("value2") // type + keylen + vallen + key + value
expectedSize += 1 + 4 + len("key3") // type + keylen + key (no value for delete)
if batch.Size() != expectedSize {
t.Errorf("Expected batch size %d, got %d", expectedSize, batch.Size())
}
// Test reset
batch.Reset()
if batch.Count() != 0 {
t.Errorf("Expected empty batch after reset, got count %d", batch.Count())
}
}
func TestBatchEncoding(t *testing.T) {
dir := createTempDir(t)
defer os.RemoveAll(dir)
cfg := createTestConfig()
wal, err := NewWAL(cfg, dir)
if err != nil {
t.Fatalf("Failed to create WAL: %v", err)
}
// Create and write a batch
batch := NewBatch()
batch.Put([]byte("key1"), []byte("value1"))
batch.Put([]byte("key2"), []byte("value2"))
batch.Delete([]byte("key3"))
if err := batch.Write(wal); err != nil {
t.Fatalf("Failed to write batch: %v", err)
}
// Check sequence
if batch.Seq == 0 {
t.Errorf("Batch sequence number not set")
}
// Close WAL
if err := wal.Close(); err != nil {
t.Fatalf("Failed to close WAL: %v", err)
}
// Replay and decode
var decodedBatch *Batch
err = ReplayWALDir(dir, func(entry *Entry) error {
if entry.Type == OpTypeBatch {
var err error
decodedBatch, err = DecodeBatch(entry)
if err != nil {
return err
}
}
return nil
})
if err != nil {
t.Fatalf("Failed to replay WAL: %v", err)
}
if decodedBatch == nil {
t.Fatal("No batch found in replay")
}
// Verify decoded batch
if decodedBatch.Count() != 3 {
t.Errorf("Expected 3 operations, got %d", decodedBatch.Count())
}
if decodedBatch.Seq != batch.Seq {
t.Errorf("Expected sequence %d, got %d", batch.Seq, decodedBatch.Seq)
}
// Verify operations
ops := decodedBatch.Operations
if ops[0].Type != OpTypePut || !bytes.Equal(ops[0].Key, []byte("key1")) || !bytes.Equal(ops[0].Value, []byte("value1")) {
t.Errorf("First operation mismatch")
}
if ops[1].Type != OpTypePut || !bytes.Equal(ops[1].Key, []byte("key2")) || !bytes.Equal(ops[1].Value, []byte("value2")) {
t.Errorf("Second operation mismatch")
}
if ops[2].Type != OpTypeDelete || !bytes.Equal(ops[2].Key, []byte("key3")) {
t.Errorf("Third operation mismatch")
}
}
func TestEmptyBatch(t *testing.T) {
dir := createTempDir(t)
defer os.RemoveAll(dir)
cfg := createTestConfig()
wal, err := NewWAL(cfg, dir)
if err != nil {
t.Fatalf("Failed to create WAL: %v", err)
}
// Create empty batch
batch := NewBatch()
// Try to write empty batch
err = batch.Write(wal)
if err != ErrEmptyBatch {
t.Errorf("Expected ErrEmptyBatch, got: %v", err)
}
// Close WAL
if err := wal.Close(); err != nil {
t.Fatalf("Failed to close WAL: %v", err)
}
}
func TestLargeBatch(t *testing.T) {
dir := createTempDir(t)
defer os.RemoveAll(dir)
cfg := createTestConfig()
wal, err := NewWAL(cfg, dir)
if err != nil {
t.Fatalf("Failed to create WAL: %v", err)
}
// Create a batch that will exceed the maximum record size
batch := NewBatch()
// Add many large key-value pairs
largeValue := make([]byte, 4096) // 4KB
for i := 0; i < 20; i++ {
key := []byte(fmt.Sprintf("key%d", i))
batch.Put(key, largeValue)
}
// Verify the batch is too large
if batch.Size() <= MaxRecordSize {
t.Fatalf("Expected batch size > %d, got %d", MaxRecordSize, batch.Size())
}
// Try to write the large batch
err = batch.Write(wal)
if err == nil {
t.Error("Expected error when writing large batch")
}
// Check that the error is ErrBatchTooLarge
if err != nil && !bytes.Contains([]byte(err.Error()), []byte("batch too large")) {
t.Errorf("Expected ErrBatchTooLarge, got: %v", err)
}
// Close WAL
if err := wal.Close(); err != nil {
t.Fatalf("Failed to close WAL: %v", err)
}
}

283
pkg/wal/reader.go Normal file
View File

@ -0,0 +1,283 @@
package wal
import (
"bufio"
"encoding/binary"
"fmt"
"hash/crc32"
"io"
"os"
"path/filepath"
"sort"
)
// Reader reads entries from WAL files
type Reader struct {
file *os.File
reader *bufio.Reader
buffer []byte
fragments [][]byte
currType uint8
}
// OpenReader creates a new Reader for the given WAL file
func OpenReader(path string) (*Reader, error) {
file, err := os.Open(path)
if err != nil {
return nil, fmt.Errorf("failed to open WAL file: %w", err)
}
return &Reader{
file: file,
reader: bufio.NewReaderSize(file, 64*1024), // 64KB buffer
buffer: make([]byte, MaxRecordSize),
fragments: make([][]byte, 0),
}, nil
}
// ReadEntry reads the next entry from the WAL
func (r *Reader) ReadEntry() (*Entry, error) {
// Loop until we have a complete entry
for {
// Read a record
record, err := r.readRecord()
if err != nil {
if err == io.EOF {
// If we have fragments, this is unexpected EOF
if len(r.fragments) > 0 {
return nil, fmt.Errorf("unexpected EOF with %d fragments", len(r.fragments))
}
return nil, io.EOF
}
return nil, err
}
// Process based on record type
switch record.recordType {
case RecordTypeFull:
// Single record, parse directly
return r.parseEntryData(record.data)
case RecordTypeFirst:
// Start of a fragmented entry
r.fragments = append(r.fragments, record.data)
r.currType = record.data[0] // Save the operation type
case RecordTypeMiddle:
// Middle fragment
if len(r.fragments) == 0 {
return nil, fmt.Errorf("%w: middle fragment without first fragment", ErrCorruptRecord)
}
r.fragments = append(r.fragments, record.data)
case RecordTypeLast:
// Last fragment
if len(r.fragments) == 0 {
return nil, fmt.Errorf("%w: last fragment without previous fragments", ErrCorruptRecord)
}
r.fragments = append(r.fragments, record.data)
// Combine fragments into a single entry
entry, err := r.processFragments()
if err != nil {
return nil, err
}
return entry, nil
default:
return nil, fmt.Errorf("%w: %d", ErrInvalidRecordType, record.recordType)
}
}
}
// Record represents a physical record in the WAL
type record struct {
recordType uint8
data []byte
}
// readRecord reads a single physical record from the WAL
func (r *Reader) readRecord() (*record, error) {
// Read header
header := make([]byte, HeaderSize)
if _, err := io.ReadFull(r.reader, header); err != nil {
return nil, err
}
// Parse header
crc := binary.LittleEndian.Uint32(header[0:4])
length := binary.LittleEndian.Uint16(header[4:6])
recordType := header[6]
// Validate record type
if recordType < RecordTypeFull || recordType > RecordTypeLast {
return nil, fmt.Errorf("%w: %d", ErrInvalidRecordType, recordType)
}
// Read payload
data := make([]byte, length)
if _, err := io.ReadFull(r.reader, data); err != nil {
return nil, err
}
// Verify CRC
computedCRC := crc32.ChecksumIEEE(data)
if computedCRC != crc {
return nil, fmt.Errorf("%w: expected CRC %d, got %d", ErrCorruptRecord, crc, computedCRC)
}
return &record{
recordType: recordType,
data: data,
}, nil
}
// processFragments combines fragments into a single entry
func (r *Reader) processFragments() (*Entry, error) {
// Determine total size
totalSize := 0
for _, frag := range r.fragments {
totalSize += len(frag)
}
// Combine fragments
combined := make([]byte, totalSize)
offset := 0
for _, frag := range r.fragments {
copy(combined[offset:], frag)
offset += len(frag)
}
// Reset fragments
r.fragments = r.fragments[:0]
// Parse the combined data into an entry
return r.parseEntryData(combined)
}
// parseEntryData parses the binary data into an Entry structure
func (r *Reader) parseEntryData(data []byte) (*Entry, error) {
if len(data) < 13 { // Minimum size: type(1) + seq(8) + keylen(4)
return nil, fmt.Errorf("%w: entry too small, %d bytes", ErrCorruptRecord, len(data))
}
offset := 0
// Read entry type
entryType := data[offset]
offset++
// Validate entry type
if entryType != OpTypePut && entryType != OpTypeDelete && entryType != OpTypeMerge && entryType != OpTypeBatch {
return nil, fmt.Errorf("%w: %d", ErrInvalidOpType, entryType)
}
// Read sequence number
seqNum := binary.LittleEndian.Uint64(data[offset : offset+8])
offset += 8
// Read key length
keyLen := binary.LittleEndian.Uint32(data[offset : offset+4])
offset += 4
// Validate key length
if offset+int(keyLen) > len(data) {
return nil, fmt.Errorf("%w: invalid key length %d", ErrCorruptRecord, keyLen)
}
// Read key
key := make([]byte, keyLen)
copy(key, data[offset:offset+int(keyLen)])
offset += int(keyLen)
// Read value if applicable
var value []byte
if entryType != OpTypeDelete {
// Check if there's enough data for value length
if offset+4 > len(data) {
return nil, fmt.Errorf("%w: missing value length", ErrCorruptRecord)
}
// Read value length
valueLen := binary.LittleEndian.Uint32(data[offset : offset+4])
offset += 4
// Validate value length
if offset+int(valueLen) > len(data) {
return nil, fmt.Errorf("%w: invalid value length %d", ErrCorruptRecord, valueLen)
}
// Read value
value = make([]byte, valueLen)
copy(value, data[offset:offset+int(valueLen)])
}
return &Entry{
SequenceNumber: seqNum,
Type: entryType,
Key: key,
Value: value,
}, nil
}
// Close closes the reader
func (r *Reader) Close() error {
return r.file.Close()
}
// EntryHandler is a function that processes WAL entries during replay
type EntryHandler func(*Entry) error
// FindWALFiles returns a list of WAL files in the given directory
func FindWALFiles(dir string) ([]string, error) {
pattern := filepath.Join(dir, "*.wal")
matches, err := filepath.Glob(pattern)
if err != nil {
return nil, fmt.Errorf("failed to glob WAL files: %w", err)
}
// Sort by filename (which should be timestamp-based)
sort.Strings(matches)
return matches, nil
}
// ReplayWALFile replays a single WAL file and calls the handler for each entry
func ReplayWALFile(path string, handler EntryHandler) error {
reader, err := OpenReader(path)
if err != nil {
return err
}
defer reader.Close()
for {
entry, err := reader.ReadEntry()
if err != nil {
if err == io.EOF {
break
}
return fmt.Errorf("error reading entry from %s: %w", path, err)
}
if err := handler(entry); err != nil {
return fmt.Errorf("error handling entry: %w", err)
}
}
return nil
}
// ReplayWALDir replays all WAL files in the given directory in order
func ReplayWALDir(dir string, handler EntryHandler) error {
files, err := FindWALFiles(dir)
if err != nil {
return err
}
for _, file := range files {
if err := ReplayWALFile(file, handler); err != nil {
return err
}
}
return nil
}

392
pkg/wal/wal.go Normal file
View File

@ -0,0 +1,392 @@
package wal
import (
"bufio"
"encoding/binary"
"errors"
"fmt"
"hash/crc32"
"os"
"path/filepath"
"sync"
"time"
"git.canoozie.net/jer/go-storage/pkg/config"
)
const (
// Record types
RecordTypeFull = 1
RecordTypeFirst = 2
RecordTypeMiddle = 3
RecordTypeLast = 4
// Operation types
OpTypePut = 1
OpTypeDelete = 2
OpTypeMerge = 3
OpTypeBatch = 4
// Header layout
// - CRC (4 bytes)
// - Length (2 bytes)
// - Type (1 byte)
HeaderSize = 7
// Maximum size of a record payload
MaxRecordSize = 32 * 1024 // 32KB
// Default WAL file size
DefaultWALFileSize = 64 * 1024 * 1024 // 64MB
)
var (
ErrCorruptRecord = errors.New("corrupt record")
ErrInvalidRecordType = errors.New("invalid record type")
ErrInvalidOpType = errors.New("invalid operation type")
ErrWALClosed = errors.New("WAL is closed")
ErrWALFull = errors.New("WAL file is full")
)
// Entry represents a logical entry in the WAL
type Entry struct {
SequenceNumber uint64
Type uint8 // OpTypePut, OpTypeDelete, etc.
Key []byte
Value []byte
}
// WAL represents a write-ahead log
type WAL struct {
cfg *config.Config
dir string
file *os.File
writer *bufio.Writer
nextSequence uint64
bytesWritten int64
lastSync time.Time
batchByteSize int64
closed bool
mu sync.Mutex
}
// NewWAL creates a new write-ahead log
func NewWAL(cfg *config.Config, dir string) (*WAL, error) {
if cfg == nil {
return nil, errors.New("config cannot be nil")
}
if err := os.MkdirAll(dir, 0755); err != nil {
return nil, fmt.Errorf("failed to create WAL directory: %w", err)
}
// Create a new WAL file
filename := fmt.Sprintf("%020d.wal", time.Now().UnixNano())
path := filepath.Join(dir, filename)
file, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE|os.O_EXCL, 0644)
if err != nil {
return nil, fmt.Errorf("failed to create WAL file: %w", err)
}
wal := &WAL{
cfg: cfg,
dir: dir,
file: file,
writer: bufio.NewWriterSize(file, 64*1024), // 64KB buffer
nextSequence: 1,
lastSync: time.Now(),
}
return wal, nil
}
// Append adds an entry to the WAL
func (w *WAL) Append(entryType uint8, key, value []byte) (uint64, error) {
w.mu.Lock()
defer w.mu.Unlock()
if w.closed {
return 0, ErrWALClosed
}
if entryType != OpTypePut && entryType != OpTypeDelete && entryType != OpTypeMerge {
return 0, ErrInvalidOpType
}
// Sequence number for this entry
seqNum := w.nextSequence
w.nextSequence++
// Encode the entry
// Format: type(1) + seq(8) + keylen(4) + key + vallen(4) + val
entrySize := 1 + 8 + 4 + len(key)
if entryType != OpTypeDelete {
entrySize += 4 + len(value)
}
// Check if we need to split the record
if entrySize <= MaxRecordSize {
// Single record case
recordType := uint8(RecordTypeFull)
if err := w.writeRecord(recordType, entryType, seqNum, key, value); err != nil {
return 0, err
}
} else {
// Split into multiple records
if err := w.writeFragmentedRecord(entryType, seqNum, key, value); err != nil {
return 0, err
}
}
// Sync the file if needed
if err := w.maybeSync(); err != nil {
return 0, err
}
return seqNum, nil
}
// Write a single record
func (w *WAL) writeRecord(recordType uint8, entryType uint8, seqNum uint64, key, value []byte) error {
// Calculate the record size
payloadSize := 1 + 8 + 4 + len(key) // type + seq + keylen + key
if entryType != OpTypeDelete {
payloadSize += 4 + len(value) // vallen + value
}
if payloadSize > MaxRecordSize {
return fmt.Errorf("record too large: %d > %d", payloadSize, MaxRecordSize)
}
// Prepare the header
header := make([]byte, HeaderSize)
binary.LittleEndian.PutUint16(header[4:6], uint16(payloadSize))
header[6] = recordType
// Prepare the payload
payload := make([]byte, payloadSize)
offset := 0
// Write entry type
payload[offset] = entryType
offset++
// Write sequence number
binary.LittleEndian.PutUint64(payload[offset:offset+8], seqNum)
offset += 8
// Write key length and key
binary.LittleEndian.PutUint32(payload[offset:offset+4], uint32(len(key)))
offset += 4
copy(payload[offset:], key)
offset += len(key)
// Write value length and value (if applicable)
if entryType != OpTypeDelete {
binary.LittleEndian.PutUint32(payload[offset:offset+4], uint32(len(value)))
offset += 4
copy(payload[offset:], value)
}
// Calculate CRC
crc := crc32.ChecksumIEEE(payload)
binary.LittleEndian.PutUint32(header[0:4], crc)
// Write the record
if _, err := w.writer.Write(header); err != nil {
return fmt.Errorf("failed to write record header: %w", err)
}
if _, err := w.writer.Write(payload); err != nil {
return fmt.Errorf("failed to write record payload: %w", err)
}
// Update bytes written
w.bytesWritten += int64(HeaderSize + payloadSize)
w.batchByteSize += int64(HeaderSize + payloadSize)
return nil
}
// writeRawRecord writes a raw record with provided data as payload
func (w *WAL) writeRawRecord(recordType uint8, data []byte) error {
if len(data) > MaxRecordSize {
return fmt.Errorf("record too large: %d > %d", len(data), MaxRecordSize)
}
// Prepare the header
header := make([]byte, HeaderSize)
binary.LittleEndian.PutUint16(header[4:6], uint16(len(data)))
header[6] = recordType
// Calculate CRC
crc := crc32.ChecksumIEEE(data)
binary.LittleEndian.PutUint32(header[0:4], crc)
// Write the record
if _, err := w.writer.Write(header); err != nil {
return fmt.Errorf("failed to write record header: %w", err)
}
if _, err := w.writer.Write(data); err != nil {
return fmt.Errorf("failed to write record payload: %w", err)
}
// Update bytes written
w.bytesWritten += int64(HeaderSize + len(data))
w.batchByteSize += int64(HeaderSize + len(data))
return nil
}
// Write a fragmented record
func (w *WAL) writeFragmentedRecord(entryType uint8, seqNum uint64, key, value []byte) error {
// First fragment contains metadata: type, sequence, key length, and as much of the key as fits
headerSize := 1 + 8 + 4 // type + seq + keylen
// Calculate how much of the key can fit in the first fragment
maxKeyInFirst := MaxRecordSize - headerSize
keyInFirst := min(len(key), maxKeyInFirst)
// Create the first fragment
firstFragment := make([]byte, headerSize + keyInFirst)
offset := 0
// Add metadata to first fragment
firstFragment[offset] = entryType
offset++
binary.LittleEndian.PutUint64(firstFragment[offset:offset+8], seqNum)
offset += 8
binary.LittleEndian.PutUint32(firstFragment[offset:offset+4], uint32(len(key)))
offset += 4
// Add as much of the key as fits
copy(firstFragment[offset:], key[:keyInFirst])
// Write the first fragment
if err := w.writeRawRecord(uint8(RecordTypeFirst), firstFragment); err != nil {
return err
}
// Prepare the remaining data
var remaining []byte
// Add any remaining key bytes
if keyInFirst < len(key) {
remaining = append(remaining, key[keyInFirst:]...)
}
// Add value data if this isn't a delete operation
if entryType != OpTypeDelete {
// Add value length
valueLenBuf := make([]byte, 4)
binary.LittleEndian.PutUint32(valueLenBuf, uint32(len(value)))
remaining = append(remaining, valueLenBuf...)
// Add value
remaining = append(remaining, value...)
}
// Write middle fragments (all full-sized except possibly the last)
for len(remaining) > MaxRecordSize {
chunk := remaining[:MaxRecordSize]
remaining = remaining[MaxRecordSize:]
if err := w.writeRawRecord(uint8(RecordTypeMiddle), chunk); err != nil {
return err
}
}
// Write the last fragment if there's any remaining data
if len(remaining) > 0 {
if err := w.writeRawRecord(uint8(RecordTypeLast), remaining); err != nil {
return err
}
}
return nil
}
// maybeSync syncs the WAL file if needed based on configuration
func (w *WAL) maybeSync() error {
needSync := false
switch w.cfg.WALSyncMode {
case config.SyncImmediate:
needSync = true
case config.SyncBatch:
// Sync if we've written enough bytes
if w.batchByteSize >= w.cfg.WALSyncBytes {
needSync = true
}
case config.SyncNone:
// No syncing
}
if needSync {
// Use syncLocked since we're already holding the mutex
if err := w.syncLocked(); err != nil {
return err
}
}
return nil
}
// syncLocked performs the sync operation assuming the mutex is already held
func (w *WAL) syncLocked() error {
if w.closed {
return ErrWALClosed
}
if err := w.writer.Flush(); err != nil {
return fmt.Errorf("failed to flush WAL buffer: %w", err)
}
if err := w.file.Sync(); err != nil {
return fmt.Errorf("failed to sync WAL file: %w", err)
}
w.lastSync = time.Now()
w.batchByteSize = 0
return nil
}
// Sync flushes all buffered data to disk
func (w *WAL) Sync() error {
w.mu.Lock()
defer w.mu.Unlock()
return w.syncLocked()
}
// Close closes the WAL
func (w *WAL) Close() error {
w.mu.Lock()
defer w.mu.Unlock()
if w.closed {
return nil
}
// Use syncLocked to flush and sync
if err := w.syncLocked(); err != nil {
return err
}
if err := w.file.Close(); err != nil {
return fmt.Errorf("failed to close WAL file: %w", err)
}
w.closed = true
return nil
}
func min(a, b int) int {
if a < b {
return a
}
return b
}

590
pkg/wal/wal_test.go Normal file
View File

@ -0,0 +1,590 @@
package wal
import (
"bytes"
"fmt"
"math/rand"
"os"
"path/filepath"
"testing"
"git.canoozie.net/jer/go-storage/pkg/config"
)
func createTestConfig() *config.Config {
return config.NewDefaultConfig("/tmp/gostorage_test")
}
func createTempDir(t *testing.T) string {
dir, err := os.MkdirTemp("", "wal_test")
if err != nil {
t.Fatalf("Failed to create temp directory: %v", err)
}
return dir
}
func TestWALWrite(t *testing.T) {
dir := createTempDir(t)
defer os.RemoveAll(dir)
cfg := createTestConfig()
wal, err := NewWAL(cfg, dir)
if err != nil {
t.Fatalf("Failed to create WAL: %v", err)
}
// Write some entries
keys := []string{"key1", "key2", "key3"}
values := []string{"value1", "value2", "value3"}
for i, key := range keys {
seq, err := wal.Append(OpTypePut, []byte(key), []byte(values[i]))
if err != nil {
t.Fatalf("Failed to append entry: %v", err)
}
if seq != uint64(i+1) {
t.Errorf("Expected sequence %d, got %d", i+1, seq)
}
}
// Close the WAL
if err := wal.Close(); err != nil {
t.Fatalf("Failed to close WAL: %v", err)
}
// Verify entries by replaying
entries := make(map[string]string)
err = ReplayWALDir(dir, func(entry *Entry) error {
if entry.Type == OpTypePut {
entries[string(entry.Key)] = string(entry.Value)
} else if entry.Type == OpTypeDelete {
delete(entries, string(entry.Key))
}
return nil
})
if err != nil {
t.Fatalf("Failed to replay WAL: %v", err)
}
// Verify all entries are present
for i, key := range keys {
value, ok := entries[key]
if !ok {
t.Errorf("Entry for key %q not found", key)
continue
}
if value != values[i] {
t.Errorf("Expected value %q for key %q, got %q", values[i], key, value)
}
}
}
func TestWALDelete(t *testing.T) {
dir := createTempDir(t)
defer os.RemoveAll(dir)
cfg := createTestConfig()
wal, err := NewWAL(cfg, dir)
if err != nil {
t.Fatalf("Failed to create WAL: %v", err)
}
// Write and delete
key := []byte("key1")
value := []byte("value1")
_, err = wal.Append(OpTypePut, key, value)
if err != nil {
t.Fatalf("Failed to append put entry: %v", err)
}
_, err = wal.Append(OpTypeDelete, key, nil)
if err != nil {
t.Fatalf("Failed to append delete entry: %v", err)
}
// Close the WAL
if err := wal.Close(); err != nil {
t.Fatalf("Failed to close WAL: %v", err)
}
// Verify entries by replaying
var deleted bool
err = ReplayWALDir(dir, func(entry *Entry) error {
if entry.Type == OpTypePut && bytes.Equal(entry.Key, key) {
if deleted {
deleted = false // Key was re-added
}
} else if entry.Type == OpTypeDelete && bytes.Equal(entry.Key, key) {
deleted = true
}
return nil
})
if err != nil {
t.Fatalf("Failed to replay WAL: %v", err)
}
if !deleted {
t.Errorf("Expected key to be deleted")
}
}
func TestWALLargeEntry(t *testing.T) {
dir := createTempDir(t)
defer os.RemoveAll(dir)
cfg := createTestConfig()
wal, err := NewWAL(cfg, dir)
if err != nil {
t.Fatalf("Failed to create WAL: %v", err)
}
// Create a large key and value (but not too large for a single record)
key := make([]byte, 8*1024) // 8KB
value := make([]byte, 16*1024) // 16KB
for i := range key {
key[i] = byte(i % 256)
}
for i := range value {
value[i] = byte((i * 2) % 256)
}
// Append the large entry
_, err = wal.Append(OpTypePut, key, value)
if err != nil {
t.Fatalf("Failed to append large entry: %v", err)
}
// Close the WAL
if err := wal.Close(); err != nil {
t.Fatalf("Failed to close WAL: %v", err)
}
// Verify by replaying
var foundLargeEntry bool
err = ReplayWALDir(dir, func(entry *Entry) error {
if entry.Type == OpTypePut && len(entry.Key) == len(key) && len(entry.Value) == len(value) {
// Verify key
for i := range key {
if key[i] != entry.Key[i] {
t.Errorf("Key mismatch at position %d: expected %d, got %d", i, key[i], entry.Key[i])
return nil
}
}
// Verify value
for i := range value {
if value[i] != entry.Value[i] {
t.Errorf("Value mismatch at position %d: expected %d, got %d", i, value[i], entry.Value[i])
return nil
}
}
foundLargeEntry = true
}
return nil
})
if err != nil {
t.Fatalf("Failed to replay WAL: %v", err)
}
if !foundLargeEntry {
t.Error("Large entry not found in replay")
}
}
func TestWALBatch(t *testing.T) {
dir := createTempDir(t)
defer os.RemoveAll(dir)
cfg := createTestConfig()
wal, err := NewWAL(cfg, dir)
if err != nil {
t.Fatalf("Failed to create WAL: %v", err)
}
// Create a batch
batch := NewBatch()
keys := []string{"batch1", "batch2", "batch3"}
values := []string{"value1", "value2", "value3"}
for i, key := range keys {
batch.Put([]byte(key), []byte(values[i]))
}
// Add a delete operation
batch.Delete([]byte("batch2"))
// Write the batch
if err := batch.Write(wal); err != nil {
t.Fatalf("Failed to write batch: %v", err)
}
// Close the WAL
if err := wal.Close(); err != nil {
t.Fatalf("Failed to close WAL: %v", err)
}
// Verify by replaying
entries := make(map[string]string)
batchCount := 0
err = ReplayWALDir(dir, func(entry *Entry) error {
if entry.Type == OpTypeBatch {
batchCount++
// Decode batch
batch, err := DecodeBatch(entry)
if err != nil {
t.Errorf("Failed to decode batch: %v", err)
return nil
}
// Apply batch operations
for _, op := range batch.Operations {
if op.Type == OpTypePut {
entries[string(op.Key)] = string(op.Value)
} else if op.Type == OpTypeDelete {
delete(entries, string(op.Key))
}
}
}
return nil
})
if err != nil {
t.Fatalf("Failed to replay WAL: %v", err)
}
// Verify batch was replayed
if batchCount != 1 {
t.Errorf("Expected 1 batch, got %d", batchCount)
}
// Verify entries
expectedEntries := map[string]string{
"batch1": "value1",
"batch3": "value3",
// batch2 should be deleted
}
for key, expectedValue := range expectedEntries {
value, ok := entries[key]
if !ok {
t.Errorf("Entry for key %q not found", key)
continue
}
if value != expectedValue {
t.Errorf("Expected value %q for key %q, got %q", expectedValue, key, value)
}
}
// Verify batch2 is deleted
if _, ok := entries["batch2"]; ok {
t.Errorf("Key batch2 should be deleted")
}
}
func TestWALRecovery(t *testing.T) {
dir := createTempDir(t)
defer os.RemoveAll(dir)
cfg := createTestConfig()
// Write some entries in the first WAL
wal1, err := NewWAL(cfg, dir)
if err != nil {
t.Fatalf("Failed to create WAL: %v", err)
}
_, err = wal1.Append(OpTypePut, []byte("key1"), []byte("value1"))
if err != nil {
t.Fatalf("Failed to append entry: %v", err)
}
if err := wal1.Close(); err != nil {
t.Fatalf("Failed to close WAL: %v", err)
}
// Create a second WAL file
wal2, err := NewWAL(cfg, dir)
if err != nil {
t.Fatalf("Failed to create WAL: %v", err)
}
_, err = wal2.Append(OpTypePut, []byte("key2"), []byte("value2"))
if err != nil {
t.Fatalf("Failed to append entry: %v", err)
}
if err := wal2.Close(); err != nil {
t.Fatalf("Failed to close WAL: %v", err)
}
// Verify entries by replaying all WAL files in order
entries := make(map[string]string)
err = ReplayWALDir(dir, func(entry *Entry) error {
if entry.Type == OpTypePut {
entries[string(entry.Key)] = string(entry.Value)
} else if entry.Type == OpTypeDelete {
delete(entries, string(entry.Key))
}
return nil
})
if err != nil {
t.Fatalf("Failed to replay WAL: %v", err)
}
// Verify all entries are present
expected := map[string]string{
"key1": "value1",
"key2": "value2",
}
for key, expectedValue := range expected {
value, ok := entries[key]
if !ok {
t.Errorf("Entry for key %q not found", key)
continue
}
if value != expectedValue {
t.Errorf("Expected value %q for key %q, got %q", expectedValue, key, value)
}
}
}
func TestWALSyncModes(t *testing.T) {
testCases := []struct {
name string
syncMode config.SyncMode
}{
{"SyncNone", config.SyncNone},
{"SyncBatch", config.SyncBatch},
{"SyncImmediate", config.SyncImmediate},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
dir := createTempDir(t)
defer os.RemoveAll(dir)
// Create config with specific sync mode
cfg := createTestConfig()
cfg.WALSyncMode = tc.syncMode
wal, err := NewWAL(cfg, dir)
if err != nil {
t.Fatalf("Failed to create WAL: %v", err)
}
// Write some entries
for i := 0; i < 10; i++ {
key := []byte(fmt.Sprintf("key%d", i))
value := []byte(fmt.Sprintf("value%d", i))
_, err := wal.Append(OpTypePut, key, value)
if err != nil {
t.Fatalf("Failed to append entry: %v", err)
}
}
// Close the WAL
if err := wal.Close(); err != nil {
t.Fatalf("Failed to close WAL: %v", err)
}
// Verify entries by replaying
count := 0
err = ReplayWALDir(dir, func(entry *Entry) error {
if entry.Type == OpTypePut {
count++
}
return nil
})
if err != nil {
t.Fatalf("Failed to replay WAL: %v", err)
}
if count != 10 {
t.Errorf("Expected 10 entries, got %d", count)
}
})
}
}
func TestWALFragmentation(t *testing.T) {
dir := createTempDir(t)
defer os.RemoveAll(dir)
cfg := createTestConfig()
wal, err := NewWAL(cfg, dir)
if err != nil {
t.Fatalf("Failed to create WAL: %v", err)
}
// Create an entry that's guaranteed to be fragmented
// Header size is 1 + 8 + 4 = 13 bytes, so allocate more than MaxRecordSize - 13 for the key
keySize := MaxRecordSize - 10
valueSize := MaxRecordSize * 2
key := make([]byte, keySize) // Just under MaxRecordSize to ensure key fragmentation
value := make([]byte, valueSize) // Large value to ensure value fragmentation
// Fill with recognizable patterns
for i := range key {
key[i] = byte(i % 256)
}
for i := range value {
value[i] = byte((i * 3) % 256)
}
// Append the large entry - this should trigger fragmentation
_, err = wal.Append(OpTypePut, key, value)
if err != nil {
t.Fatalf("Failed to append fragmented entry: %v", err)
}
// Close the WAL
if err := wal.Close(); err != nil {
t.Fatalf("Failed to close WAL: %v", err)
}
// Verify by replaying
var reconstructedKey []byte
var reconstructedValue []byte
var foundPut bool
err = ReplayWALDir(dir, func(entry *Entry) error {
if entry.Type == OpTypePut {
foundPut = true
reconstructedKey = entry.Key
reconstructedValue = entry.Value
}
return nil
})
if err != nil {
t.Fatalf("Failed to replay WAL: %v", err)
}
// Check that we found the entry
if !foundPut {
t.Fatal("Did not find PUT entry in replay")
}
// Verify key length matches
if len(reconstructedKey) != keySize {
t.Errorf("Key length mismatch: expected %d, got %d", keySize, len(reconstructedKey))
}
// Verify value length matches
if len(reconstructedValue) != valueSize {
t.Errorf("Value length mismatch: expected %d, got %d", valueSize, len(reconstructedValue))
}
// Check key content (first 10 bytes)
for i := 0; i < 10 && i < len(key); i++ {
if key[i] != reconstructedKey[i] {
t.Errorf("Key mismatch at position %d: expected %d, got %d", i, key[i], reconstructedKey[i])
}
}
// Check key content (last 10 bytes)
for i := 0; i < 10 && i < len(key); i++ {
idx := len(key) - 1 - i
if key[idx] != reconstructedKey[idx] {
t.Errorf("Key mismatch at position %d: expected %d, got %d", idx, key[idx], reconstructedKey[idx])
}
}
// Check value content (first 10 bytes)
for i := 0; i < 10 && i < len(value); i++ {
if value[i] != reconstructedValue[i] {
t.Errorf("Value mismatch at position %d: expected %d, got %d", i, value[i], reconstructedValue[i])
}
}
// Check value content (last 10 bytes)
for i := 0; i < 10 && i < len(value); i++ {
idx := len(value) - 1 - i
if value[idx] != reconstructedValue[idx] {
t.Errorf("Value mismatch at position %d: expected %d, got %d", idx, value[idx], reconstructedValue[idx])
}
}
// Verify random samples from the key and value
for i := 0; i < 10; i++ {
// Check random positions in the key
keyPos := rand.Intn(keySize)
if key[keyPos] != reconstructedKey[keyPos] {
t.Errorf("Key mismatch at random position %d: expected %d, got %d", keyPos, key[keyPos], reconstructedKey[keyPos])
}
// Check random positions in the value
valuePos := rand.Intn(valueSize)
if value[valuePos] != reconstructedValue[valuePos] {
t.Errorf("Value mismatch at random position %d: expected %d, got %d", valuePos, value[valuePos], reconstructedValue[valuePos])
}
}
}
func TestWALErrorHandling(t *testing.T) {
dir := createTempDir(t)
defer os.RemoveAll(dir)
cfg := createTestConfig()
wal, err := NewWAL(cfg, dir)
if err != nil {
t.Fatalf("Failed to create WAL: %v", err)
}
// Write some entries
_, err = wal.Append(OpTypePut, []byte("key1"), []byte("value1"))
if err != nil {
t.Fatalf("Failed to append entry: %v", err)
}
// Close the WAL
if err := wal.Close(); err != nil {
t.Fatalf("Failed to close WAL: %v", err)
}
// Try to write after close
_, err = wal.Append(OpTypePut, []byte("key2"), []byte("value2"))
if err != ErrWALClosed {
t.Errorf("Expected ErrWALClosed, got: %v", err)
}
// Try to sync after close
err = wal.Sync()
if err != ErrWALClosed {
t.Errorf("Expected ErrWALClosed, got: %v", err)
}
// Try to replay a non-existent file
nonExistentPath := filepath.Join(dir, "nonexistent.wal")
err = ReplayWALFile(nonExistentPath, func(entry *Entry) error {
return nil
})
if err == nil {
t.Error("Expected error when replaying non-existent file")
}
}