feat: implement config and WAL packages
- Implement config package with serializable configuration and manifest - Implement WAL with durability guarantees and fragmentation support - Fix mutex deadlock bug in WAL sync operations - Add comprehensive tests for both packages
This commit is contained in:
parent
ee23a47a74
commit
b03176f136
32
CLAUDE.md
Normal file
32
CLAUDE.md
Normal file
@ -0,0 +1,32 @@
|
||||
# CLAUDE.md
|
||||
|
||||
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||
|
||||
## Build Commands
|
||||
- Build: `go build ./...`
|
||||
- Run tests: `go test ./...`
|
||||
- Run single test: `go test ./pkg/path/to/package -run TestName`
|
||||
- Benchmark: `go test ./pkg/path/to/package -bench .`
|
||||
- Race detector: `go test -race ./...`
|
||||
|
||||
## Linting/Formatting
|
||||
- Format code: `go fmt ./...`
|
||||
- Static analysis: `go vet ./...`
|
||||
- Install golangci-lint: `go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest`
|
||||
- Run linter: `golangci-lint run`
|
||||
|
||||
## Code Style Guidelines
|
||||
- Follow Go standard project layout in pkg/ and internal/ directories
|
||||
- Use descriptive error types with context wrapping
|
||||
- Implement single-writer architecture for write paths
|
||||
- Allow concurrent reads via snapshots
|
||||
- Use interfaces for component boundaries
|
||||
- Follow idiomatic Go practices
|
||||
- Add appropriate validation, especially for checksums
|
||||
- All exported functions must have documentation comments
|
||||
- For transaction management, use WAL for durability/atomicity
|
||||
|
||||
## Version Control
|
||||
- Use git for version control
|
||||
- All commit messages must use semantic commit messages
|
||||
- All commit messages must not reference code being generated or co-authored by Claude
|
199
pkg/config/config.go
Normal file
199
pkg/config/config.go
Normal file
@ -0,0 +1,199 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultManifestFileName = "MANIFEST"
|
||||
CurrentManifestVersion = 1
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidConfig = errors.New("invalid configuration")
|
||||
ErrManifestNotFound = errors.New("manifest not found")
|
||||
ErrInvalidManifest = errors.New("invalid manifest")
|
||||
)
|
||||
|
||||
type SyncMode int
|
||||
|
||||
const (
|
||||
SyncNone SyncMode = iota
|
||||
SyncBatch
|
||||
SyncImmediate
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Version int `json:"version"`
|
||||
|
||||
// WAL configuration
|
||||
WALDir string `json:"wal_dir"`
|
||||
WALSyncMode SyncMode `json:"wal_sync_mode"`
|
||||
WALSyncBytes int64 `json:"wal_sync_bytes"`
|
||||
|
||||
// MemTable configuration
|
||||
MemTableSize int64 `json:"memtable_size"`
|
||||
MaxMemTables int `json:"max_memtables"`
|
||||
MaxMemTableAge int64 `json:"max_memtable_age"`
|
||||
MemTablePoolCap int `json:"memtable_pool_cap"`
|
||||
|
||||
// SSTable configuration
|
||||
SSTDir string `json:"sst_dir"`
|
||||
SSTableBlockSize int `json:"sstable_block_size"`
|
||||
SSTableIndexSize int `json:"sstable_index_size"`
|
||||
SSTableMaxSize int64 `json:"sstable_max_size"`
|
||||
SSTableRestartSize int `json:"sstable_restart_size"`
|
||||
|
||||
// Compaction configuration
|
||||
CompactionLevels int `json:"compaction_levels"`
|
||||
CompactionRatio float64 `json:"compaction_ratio"`
|
||||
CompactionThreads int `json:"compaction_threads"`
|
||||
CompactionInterval int64 `json:"compaction_interval"`
|
||||
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewDefaultConfig creates a Config with recommended default values
|
||||
func NewDefaultConfig(dbPath string) *Config {
|
||||
walDir := filepath.Join(dbPath, "wal")
|
||||
sstDir := filepath.Join(dbPath, "sst")
|
||||
|
||||
return &Config{
|
||||
Version: CurrentManifestVersion,
|
||||
|
||||
// WAL defaults
|
||||
WALDir: walDir,
|
||||
WALSyncMode: SyncBatch,
|
||||
WALSyncBytes: 1024 * 1024, // 1MB
|
||||
|
||||
// MemTable defaults
|
||||
MemTableSize: 32 * 1024 * 1024, // 32MB
|
||||
MaxMemTables: 4,
|
||||
MaxMemTableAge: 600, // 10 minutes
|
||||
MemTablePoolCap: 4,
|
||||
|
||||
// SSTable defaults
|
||||
SSTDir: sstDir,
|
||||
SSTableBlockSize: 16 * 1024, // 16KB
|
||||
SSTableIndexSize: 64 * 1024, // 64KB
|
||||
SSTableMaxSize: 64 * 1024 * 1024, // 64MB
|
||||
SSTableRestartSize: 16, // Restart points every 16 keys
|
||||
|
||||
// Compaction defaults
|
||||
CompactionLevels: 7,
|
||||
CompactionRatio: 10,
|
||||
CompactionThreads: 2,
|
||||
CompactionInterval: 30, // 30 seconds
|
||||
}
|
||||
}
|
||||
|
||||
// Validate checks if the configuration is valid
|
||||
func (c *Config) Validate() error {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
if c.Version <= 0 {
|
||||
return fmt.Errorf("%w: invalid version %d", ErrInvalidConfig, c.Version)
|
||||
}
|
||||
|
||||
if c.WALDir == "" {
|
||||
return fmt.Errorf("%w: WAL directory not specified", ErrInvalidConfig)
|
||||
}
|
||||
|
||||
if c.SSTDir == "" {
|
||||
return fmt.Errorf("%w: SSTable directory not specified", ErrInvalidConfig)
|
||||
}
|
||||
|
||||
if c.MemTableSize <= 0 {
|
||||
return fmt.Errorf("%w: MemTable size must be positive", ErrInvalidConfig)
|
||||
}
|
||||
|
||||
if c.MaxMemTables <= 0 {
|
||||
return fmt.Errorf("%w: Max MemTables must be positive", ErrInvalidConfig)
|
||||
}
|
||||
|
||||
if c.SSTableBlockSize <= 0 {
|
||||
return fmt.Errorf("%w: SSTable block size must be positive", ErrInvalidConfig)
|
||||
}
|
||||
|
||||
if c.SSTableIndexSize <= 0 {
|
||||
return fmt.Errorf("%w: SSTable index size must be positive", ErrInvalidConfig)
|
||||
}
|
||||
|
||||
if c.CompactionLevels <= 0 {
|
||||
return fmt.Errorf("%w: Compaction levels must be positive", ErrInvalidConfig)
|
||||
}
|
||||
|
||||
if c.CompactionRatio <= 1.0 {
|
||||
return fmt.Errorf("%w: Compaction ratio must be greater than 1.0", ErrInvalidConfig)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadConfigFromManifest loads just the configuration portion from the manifest file
|
||||
func LoadConfigFromManifest(dbPath string) (*Config, error) {
|
||||
manifestPath := filepath.Join(dbPath, DefaultManifestFileName)
|
||||
data, err := os.ReadFile(manifestPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil, ErrManifestNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("failed to read manifest: %w", err)
|
||||
}
|
||||
|
||||
var cfg Config
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return nil, fmt.Errorf("%w: %v", ErrInvalidManifest, err)
|
||||
}
|
||||
|
||||
if err := cfg.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
// SaveManifest saves the configuration to the manifest file
|
||||
func (c *Config) SaveManifest(dbPath string) error {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
if err := c.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(dbPath, 0755); err != nil {
|
||||
return fmt.Errorf("failed to create directory: %w", err)
|
||||
}
|
||||
|
||||
manifestPath := filepath.Join(dbPath, DefaultManifestFileName)
|
||||
tempPath := manifestPath + ".tmp"
|
||||
|
||||
data, err := json.MarshalIndent(c, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal config: %w", err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(tempPath, data, 0644); err != nil {
|
||||
return fmt.Errorf("failed to write manifest: %w", err)
|
||||
}
|
||||
|
||||
if err := os.Rename(tempPath, manifestPath); err != nil {
|
||||
return fmt.Errorf("failed to rename manifest: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Update applies the given function to modify the configuration
|
||||
func (c *Config) Update(fn func(*Config)) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
fn(c)
|
||||
}
|
167
pkg/config/config_test.go
Normal file
167
pkg/config/config_test.go
Normal file
@ -0,0 +1,167 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewDefaultConfig(t *testing.T) {
|
||||
dbPath := "/tmp/testdb"
|
||||
cfg := NewDefaultConfig(dbPath)
|
||||
|
||||
if cfg.Version != CurrentManifestVersion {
|
||||
t.Errorf("expected version %d, got %d", CurrentManifestVersion, cfg.Version)
|
||||
}
|
||||
|
||||
if cfg.WALDir != filepath.Join(dbPath, "wal") {
|
||||
t.Errorf("expected WAL dir %s, got %s", filepath.Join(dbPath, "wal"), cfg.WALDir)
|
||||
}
|
||||
|
||||
if cfg.SSTDir != filepath.Join(dbPath, "sst") {
|
||||
t.Errorf("expected SST dir %s, got %s", filepath.Join(dbPath, "sst"), cfg.SSTDir)
|
||||
}
|
||||
|
||||
// Test default values
|
||||
if cfg.WALSyncMode != SyncBatch {
|
||||
t.Errorf("expected WAL sync mode %d, got %d", SyncBatch, cfg.WALSyncMode)
|
||||
}
|
||||
|
||||
if cfg.MemTableSize != 32*1024*1024 {
|
||||
t.Errorf("expected memtable size %d, got %d", 32*1024*1024, cfg.MemTableSize)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigValidate(t *testing.T) {
|
||||
cfg := NewDefaultConfig("/tmp/testdb")
|
||||
|
||||
// Valid config
|
||||
if err := cfg.Validate(); err != nil {
|
||||
t.Errorf("expected valid config, got error: %v", err)
|
||||
}
|
||||
|
||||
// Test invalid configs
|
||||
testCases := []struct {
|
||||
name string
|
||||
mutate func(*Config)
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "invalid version",
|
||||
mutate: func(c *Config) {
|
||||
c.Version = 0
|
||||
},
|
||||
expected: "invalid configuration: invalid version 0",
|
||||
},
|
||||
{
|
||||
name: "empty WAL dir",
|
||||
mutate: func(c *Config) {
|
||||
c.WALDir = ""
|
||||
},
|
||||
expected: "invalid configuration: WAL directory not specified",
|
||||
},
|
||||
{
|
||||
name: "empty SST dir",
|
||||
mutate: func(c *Config) {
|
||||
c.SSTDir = ""
|
||||
},
|
||||
expected: "invalid configuration: SSTable directory not specified",
|
||||
},
|
||||
{
|
||||
name: "zero memtable size",
|
||||
mutate: func(c *Config) {
|
||||
c.MemTableSize = 0
|
||||
},
|
||||
expected: "invalid configuration: MemTable size must be positive",
|
||||
},
|
||||
{
|
||||
name: "negative max memtables",
|
||||
mutate: func(c *Config) {
|
||||
c.MaxMemTables = -1
|
||||
},
|
||||
expected: "invalid configuration: Max MemTables must be positive",
|
||||
},
|
||||
{
|
||||
name: "zero block size",
|
||||
mutate: func(c *Config) {
|
||||
c.SSTableBlockSize = 0
|
||||
},
|
||||
expected: "invalid configuration: SSTable block size must be positive",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
cfg := NewDefaultConfig("/tmp/testdb")
|
||||
tc.mutate(cfg)
|
||||
|
||||
err := cfg.Validate()
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
|
||||
if err.Error() != tc.expected {
|
||||
t.Errorf("expected error %q, got %q", tc.expected, err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigManifestSaveLoad(t *testing.T) {
|
||||
// Create a temporary directory for the test
|
||||
tempDir, err := os.MkdirTemp("", "config_test")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
// Create a config and save it
|
||||
cfg := NewDefaultConfig(tempDir)
|
||||
cfg.MemTableSize = 16 * 1024 * 1024 // 16MB
|
||||
cfg.CompactionThreads = 4
|
||||
|
||||
if err := cfg.SaveManifest(tempDir); err != nil {
|
||||
t.Fatalf("failed to save manifest: %v", err)
|
||||
}
|
||||
|
||||
// Load the config
|
||||
loadedCfg, err := LoadConfigFromManifest(tempDir)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load manifest: %v", err)
|
||||
}
|
||||
|
||||
// Verify loaded config
|
||||
if loadedCfg.MemTableSize != cfg.MemTableSize {
|
||||
t.Errorf("expected memtable size %d, got %d", cfg.MemTableSize, loadedCfg.MemTableSize)
|
||||
}
|
||||
|
||||
if loadedCfg.CompactionThreads != cfg.CompactionThreads {
|
||||
t.Errorf("expected compaction threads %d, got %d", cfg.CompactionThreads, loadedCfg.CompactionThreads)
|
||||
}
|
||||
|
||||
// Test loading non-existent manifest
|
||||
nonExistentDir := filepath.Join(tempDir, "nonexistent")
|
||||
_, err = LoadConfigFromManifest(nonExistentDir)
|
||||
if err != ErrManifestNotFound {
|
||||
t.Errorf("expected ErrManifestNotFound, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigUpdate(t *testing.T) {
|
||||
cfg := NewDefaultConfig("/tmp/testdb")
|
||||
|
||||
// Update config
|
||||
cfg.Update(func(c *Config) {
|
||||
c.MemTableSize = 64 * 1024 * 1024 // 64MB
|
||||
c.MaxMemTables = 8
|
||||
})
|
||||
|
||||
// Verify update
|
||||
if cfg.MemTableSize != 64*1024*1024 {
|
||||
t.Errorf("expected memtable size %d, got %d", 64*1024*1024, cfg.MemTableSize)
|
||||
}
|
||||
|
||||
if cfg.MaxMemTables != 8 {
|
||||
t.Errorf("expected max memtables %d, got %d", 8, cfg.MaxMemTables)
|
||||
}
|
||||
}
|
214
pkg/config/manifest.go
Normal file
214
pkg/config/manifest.go
Normal file
@ -0,0 +1,214 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type ManifestEntry struct {
|
||||
Timestamp int64 `json:"timestamp"`
|
||||
Version int `json:"version"`
|
||||
Config *Config `json:"config"`
|
||||
FileSystem map[string]int64 `json:"filesystem,omitempty"` // Map of file paths to sequence numbers
|
||||
}
|
||||
|
||||
type Manifest struct {
|
||||
DBPath string
|
||||
Entries []ManifestEntry
|
||||
Current *ManifestEntry
|
||||
LastUpdate time.Time
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewManifest creates a new manifest for the given database path
|
||||
func NewManifest(dbPath string, config *Config) (*Manifest, error) {
|
||||
if config == nil {
|
||||
config = NewDefaultConfig(dbPath)
|
||||
}
|
||||
|
||||
if err := config.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
entry := ManifestEntry{
|
||||
Timestamp: time.Now().Unix(),
|
||||
Version: CurrentManifestVersion,
|
||||
Config: config,
|
||||
}
|
||||
|
||||
m := &Manifest{
|
||||
DBPath: dbPath,
|
||||
Entries: []ManifestEntry{entry},
|
||||
Current: &entry,
|
||||
LastUpdate: time.Now(),
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// LoadManifest loads an existing manifest from the database directory
|
||||
func LoadManifest(dbPath string) (*Manifest, error) {
|
||||
manifestPath := filepath.Join(dbPath, DefaultManifestFileName)
|
||||
file, err := os.Open(manifestPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil, ErrManifestNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("failed to open manifest: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
data, err := io.ReadAll(file)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read manifest: %w", err)
|
||||
}
|
||||
|
||||
var entries []ManifestEntry
|
||||
if err := json.Unmarshal(data, &entries); err != nil {
|
||||
return nil, fmt.Errorf("%w: %v", ErrInvalidManifest, err)
|
||||
}
|
||||
|
||||
if len(entries) == 0 {
|
||||
return nil, fmt.Errorf("%w: no entries in manifest", ErrInvalidManifest)
|
||||
}
|
||||
|
||||
current := &entries[len(entries)-1]
|
||||
if err := current.Config.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m := &Manifest{
|
||||
DBPath: dbPath,
|
||||
Entries: entries,
|
||||
Current: current,
|
||||
LastUpdate: time.Now(),
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Save persists the manifest to disk
|
||||
func (m *Manifest) Save() error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if err := m.Current.Config.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(m.DBPath, 0755); err != nil {
|
||||
return fmt.Errorf("failed to create directory: %w", err)
|
||||
}
|
||||
|
||||
manifestPath := filepath.Join(m.DBPath, DefaultManifestFileName)
|
||||
tempPath := manifestPath + ".tmp"
|
||||
|
||||
data, err := json.MarshalIndent(m.Entries, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal manifest: %w", err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(tempPath, data, 0644); err != nil {
|
||||
return fmt.Errorf("failed to write manifest: %w", err)
|
||||
}
|
||||
|
||||
if err := os.Rename(tempPath, manifestPath); err != nil {
|
||||
return fmt.Errorf("failed to rename manifest: %w", err)
|
||||
}
|
||||
|
||||
m.LastUpdate = time.Now()
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateConfig creates a new configuration entry
|
||||
func (m *Manifest) UpdateConfig(fn func(*Config)) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
// Create a copy of the current config
|
||||
currentJSON, err := json.Marshal(m.Current.Config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal current config: %w", err)
|
||||
}
|
||||
|
||||
var newConfig Config
|
||||
if err := json.Unmarshal(currentJSON, &newConfig); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal config: %w", err)
|
||||
}
|
||||
|
||||
// Apply the update function
|
||||
fn(&newConfig)
|
||||
|
||||
// Validate the new config
|
||||
if err := newConfig.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Create a new entry
|
||||
entry := ManifestEntry{
|
||||
Timestamp: time.Now().Unix(),
|
||||
Version: CurrentManifestVersion,
|
||||
Config: &newConfig,
|
||||
}
|
||||
|
||||
m.Entries = append(m.Entries, entry)
|
||||
m.Current = &m.Entries[len(m.Entries)-1]
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddFile registers a file in the manifest
|
||||
func (m *Manifest) AddFile(path string, seqNum int64) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.Current.FileSystem == nil {
|
||||
m.Current.FileSystem = make(map[string]int64)
|
||||
}
|
||||
|
||||
m.Current.FileSystem[path] = seqNum
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveFile removes a file from the manifest
|
||||
func (m *Manifest) RemoveFile(path string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.Current.FileSystem == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
delete(m.Current.FileSystem, path)
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetConfig returns the current configuration
|
||||
func (m *Manifest) GetConfig() *Config {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return m.Current.Config
|
||||
}
|
||||
|
||||
// GetFiles returns all files registered in the manifest
|
||||
func (m *Manifest) GetFiles() map[string]int64 {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
if m.Current.FileSystem == nil {
|
||||
return make(map[string]int64)
|
||||
}
|
||||
|
||||
// Return a copy to prevent concurrent map access
|
||||
files := make(map[string]int64, len(m.Current.FileSystem))
|
||||
for k, v := range m.Current.FileSystem {
|
||||
files[k] = v
|
||||
}
|
||||
|
||||
return files
|
||||
}
|
176
pkg/config/manifest_test.go
Normal file
176
pkg/config/manifest_test.go
Normal file
@ -0,0 +1,176 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewManifest(t *testing.T) {
|
||||
dbPath := "/tmp/testdb"
|
||||
cfg := NewDefaultConfig(dbPath)
|
||||
|
||||
manifest, err := NewManifest(dbPath, cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create manifest: %v", err)
|
||||
}
|
||||
|
||||
if manifest.DBPath != dbPath {
|
||||
t.Errorf("expected DBPath %s, got %s", dbPath, manifest.DBPath)
|
||||
}
|
||||
|
||||
if len(manifest.Entries) != 1 {
|
||||
t.Errorf("expected 1 entry, got %d", len(manifest.Entries))
|
||||
}
|
||||
|
||||
if manifest.Current == nil {
|
||||
t.Error("current entry is nil")
|
||||
} else if manifest.Current.Config != cfg {
|
||||
t.Error("current config does not match the provided config")
|
||||
}
|
||||
}
|
||||
|
||||
func TestManifestUpdateConfig(t *testing.T) {
|
||||
dbPath := "/tmp/testdb"
|
||||
cfg := NewDefaultConfig(dbPath)
|
||||
|
||||
manifest, err := NewManifest(dbPath, cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create manifest: %v", err)
|
||||
}
|
||||
|
||||
// Update config
|
||||
err = manifest.UpdateConfig(func(c *Config) {
|
||||
c.MemTableSize = 64 * 1024 * 1024 // 64MB
|
||||
c.MaxMemTables = 8
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to update config: %v", err)
|
||||
}
|
||||
|
||||
// Verify entries count
|
||||
if len(manifest.Entries) != 2 {
|
||||
t.Errorf("expected 2 entries, got %d", len(manifest.Entries))
|
||||
}
|
||||
|
||||
// Verify updated config
|
||||
current := manifest.GetConfig()
|
||||
if current.MemTableSize != 64*1024*1024 {
|
||||
t.Errorf("expected memtable size %d, got %d", 64*1024*1024, current.MemTableSize)
|
||||
}
|
||||
if current.MaxMemTables != 8 {
|
||||
t.Errorf("expected max memtables %d, got %d", 8, current.MaxMemTables)
|
||||
}
|
||||
}
|
||||
|
||||
func TestManifestFileTracking(t *testing.T) {
|
||||
dbPath := "/tmp/testdb"
|
||||
cfg := NewDefaultConfig(dbPath)
|
||||
|
||||
manifest, err := NewManifest(dbPath, cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create manifest: %v", err)
|
||||
}
|
||||
|
||||
// Add files
|
||||
err = manifest.AddFile("sst/000001.sst", 1)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to add file: %v", err)
|
||||
}
|
||||
|
||||
err = manifest.AddFile("sst/000002.sst", 2)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to add file: %v", err)
|
||||
}
|
||||
|
||||
// Verify files
|
||||
files := manifest.GetFiles()
|
||||
if len(files) != 2 {
|
||||
t.Errorf("expected 2 files, got %d", len(files))
|
||||
}
|
||||
|
||||
if files["sst/000001.sst"] != 1 {
|
||||
t.Errorf("expected sequence number 1, got %d", files["sst/000001.sst"])
|
||||
}
|
||||
|
||||
if files["sst/000002.sst"] != 2 {
|
||||
t.Errorf("expected sequence number 2, got %d", files["sst/000002.sst"])
|
||||
}
|
||||
|
||||
// Remove file
|
||||
err = manifest.RemoveFile("sst/000001.sst")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to remove file: %v", err)
|
||||
}
|
||||
|
||||
// Verify files after removal
|
||||
files = manifest.GetFiles()
|
||||
if len(files) != 1 {
|
||||
t.Errorf("expected 1 file, got %d", len(files))
|
||||
}
|
||||
|
||||
if _, exists := files["sst/000001.sst"]; exists {
|
||||
t.Error("file should have been removed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestManifestSaveLoad(t *testing.T) {
|
||||
// Create a temporary directory for the test
|
||||
tempDir, err := os.MkdirTemp("", "manifest_test")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
// Create a manifest
|
||||
cfg := NewDefaultConfig(tempDir)
|
||||
manifest, err := NewManifest(tempDir, cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create manifest: %v", err)
|
||||
}
|
||||
|
||||
// Update config
|
||||
err = manifest.UpdateConfig(func(c *Config) {
|
||||
c.MemTableSize = 64 * 1024 * 1024 // 64MB
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to update config: %v", err)
|
||||
}
|
||||
|
||||
// Add some files
|
||||
err = manifest.AddFile("sst/000001.sst", 1)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to add file: %v", err)
|
||||
}
|
||||
|
||||
// Save the manifest
|
||||
if err := manifest.Save(); err != nil {
|
||||
t.Fatalf("failed to save manifest: %v", err)
|
||||
}
|
||||
|
||||
// Load the manifest
|
||||
loadedManifest, err := LoadManifest(tempDir)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load manifest: %v", err)
|
||||
}
|
||||
|
||||
// Verify entries count
|
||||
if len(loadedManifest.Entries) != len(manifest.Entries) {
|
||||
t.Errorf("expected %d entries, got %d", len(manifest.Entries), len(loadedManifest.Entries))
|
||||
}
|
||||
|
||||
// Verify config
|
||||
loadedConfig := loadedManifest.GetConfig()
|
||||
if loadedConfig.MemTableSize != 64*1024*1024 {
|
||||
t.Errorf("expected memtable size %d, got %d", 64*1024*1024, loadedConfig.MemTableSize)
|
||||
}
|
||||
|
||||
// Verify files
|
||||
loadedFiles := loadedManifest.GetFiles()
|
||||
if len(loadedFiles) != 1 {
|
||||
t.Errorf("expected 1 file, got %d", len(loadedFiles))
|
||||
}
|
||||
|
||||
if loadedFiles["sst/000001.sst"] != 1 {
|
||||
t.Errorf("expected sequence number 1, got %d", loadedFiles["sst/000001.sst"])
|
||||
}
|
||||
}
|
244
pkg/wal/batch.go
Normal file
244
pkg/wal/batch.go
Normal file
@ -0,0 +1,244 @@
|
||||
package wal
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
const (
|
||||
BatchHeaderSize = 12 // count(4) + seq(8)
|
||||
)
|
||||
|
||||
var (
|
||||
ErrEmptyBatch = errors.New("batch is empty")
|
||||
ErrBatchTooLarge = errors.New("batch too large")
|
||||
)
|
||||
|
||||
// BatchOperation represents a single operation in a batch
|
||||
type BatchOperation struct {
|
||||
Type uint8 // OpTypePut, OpTypeDelete, etc.
|
||||
Key []byte
|
||||
Value []byte
|
||||
}
|
||||
|
||||
// Batch represents a collection of operations to be performed atomically
|
||||
type Batch struct {
|
||||
Operations []BatchOperation
|
||||
Seq uint64 // Base sequence number
|
||||
}
|
||||
|
||||
// NewBatch creates a new empty batch
|
||||
func NewBatch() *Batch {
|
||||
return &Batch{
|
||||
Operations: make([]BatchOperation, 0, 16),
|
||||
}
|
||||
}
|
||||
|
||||
// Put adds a Put operation to the batch
|
||||
func (b *Batch) Put(key, value []byte) {
|
||||
b.Operations = append(b.Operations, BatchOperation{
|
||||
Type: OpTypePut,
|
||||
Key: key,
|
||||
Value: value,
|
||||
})
|
||||
}
|
||||
|
||||
// Delete adds a Delete operation to the batch
|
||||
func (b *Batch) Delete(key []byte) {
|
||||
b.Operations = append(b.Operations, BatchOperation{
|
||||
Type: OpTypeDelete,
|
||||
Key: key,
|
||||
})
|
||||
}
|
||||
|
||||
// Count returns the number of operations in the batch
|
||||
func (b *Batch) Count() int {
|
||||
return len(b.Operations)
|
||||
}
|
||||
|
||||
// Reset clears all operations from the batch
|
||||
func (b *Batch) Reset() {
|
||||
b.Operations = b.Operations[:0]
|
||||
b.Seq = 0
|
||||
}
|
||||
|
||||
// Size estimates the size of the batch in the WAL
|
||||
func (b *Batch) Size() int {
|
||||
size := BatchHeaderSize // count + seq
|
||||
|
||||
for _, op := range b.Operations {
|
||||
// Type(1) + KeyLen(4) + Key
|
||||
size += 1 + 4 + len(op.Key)
|
||||
|
||||
// ValueLen(4) + Value for Put operations
|
||||
if op.Type != OpTypeDelete {
|
||||
size += 4 + len(op.Value)
|
||||
}
|
||||
}
|
||||
|
||||
return size
|
||||
}
|
||||
|
||||
// Write writes the batch to the WAL
|
||||
func (b *Batch) Write(w *WAL) error {
|
||||
if len(b.Operations) == 0 {
|
||||
return ErrEmptyBatch
|
||||
}
|
||||
|
||||
// Estimate batch size
|
||||
size := b.Size()
|
||||
if size > MaxRecordSize {
|
||||
return fmt.Errorf("%w: %d > %d", ErrBatchTooLarge, size, MaxRecordSize)
|
||||
}
|
||||
|
||||
// Serialize batch
|
||||
data := make([]byte, size)
|
||||
offset := 0
|
||||
|
||||
// Write count
|
||||
binary.LittleEndian.PutUint32(data[offset:offset+4], uint32(len(b.Operations)))
|
||||
offset += 4
|
||||
|
||||
// Write sequence base (will be set by WAL.AppendBatch)
|
||||
offset += 8
|
||||
|
||||
// Write operations
|
||||
for _, op := range b.Operations {
|
||||
// Write type
|
||||
data[offset] = op.Type
|
||||
offset++
|
||||
|
||||
// Write key length
|
||||
binary.LittleEndian.PutUint32(data[offset:offset+4], uint32(len(op.Key)))
|
||||
offset += 4
|
||||
|
||||
// Write key
|
||||
copy(data[offset:], op.Key)
|
||||
offset += len(op.Key)
|
||||
|
||||
// Write value for non-delete operations
|
||||
if op.Type != OpTypeDelete {
|
||||
// Write value length
|
||||
binary.LittleEndian.PutUint32(data[offset:offset+4], uint32(len(op.Value)))
|
||||
offset += 4
|
||||
|
||||
// Write value
|
||||
copy(data[offset:], op.Value)
|
||||
offset += len(op.Value)
|
||||
}
|
||||
}
|
||||
|
||||
// Append to WAL
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
if w.closed {
|
||||
return ErrWALClosed
|
||||
}
|
||||
|
||||
// Set the sequence number
|
||||
b.Seq = w.nextSequence
|
||||
binary.LittleEndian.PutUint64(data[4:12], b.Seq)
|
||||
|
||||
// Increment sequence for future operations
|
||||
w.nextSequence += uint64(len(b.Operations))
|
||||
|
||||
// Write as a batch entry
|
||||
if err := w.writeRecord(uint8(RecordTypeFull), OpTypeBatch, b.Seq, data, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Sync if needed
|
||||
return w.maybeSync()
|
||||
}
|
||||
|
||||
// DecodeBatch decodes a batch entry from a WAL record
|
||||
func DecodeBatch(entry *Entry) (*Batch, error) {
|
||||
if entry.Type != OpTypeBatch {
|
||||
return nil, fmt.Errorf("not a batch entry: type %d", entry.Type)
|
||||
}
|
||||
|
||||
// For batch entries, the batch data is in the Key field, not Value
|
||||
data := entry.Key
|
||||
if len(data) < BatchHeaderSize {
|
||||
return nil, fmt.Errorf("%w: batch header too small", ErrCorruptRecord)
|
||||
}
|
||||
|
||||
// Read count and sequence
|
||||
count := binary.LittleEndian.Uint32(data[0:4])
|
||||
seq := binary.LittleEndian.Uint64(data[4:12])
|
||||
|
||||
batch := &Batch{
|
||||
Operations: make([]BatchOperation, 0, count),
|
||||
Seq: seq,
|
||||
}
|
||||
|
||||
offset := BatchHeaderSize
|
||||
|
||||
// Read operations
|
||||
for i := uint32(0); i < count; i++ {
|
||||
// Check if we have enough data for type
|
||||
if offset >= len(data) {
|
||||
return nil, fmt.Errorf("%w: unexpected end of batch data", ErrCorruptRecord)
|
||||
}
|
||||
|
||||
// Read type
|
||||
opType := data[offset]
|
||||
offset++
|
||||
|
||||
// Validate operation type
|
||||
if opType != OpTypePut && opType != OpTypeDelete && opType != OpTypeMerge {
|
||||
return nil, fmt.Errorf("%w: %d", ErrInvalidOpType, opType)
|
||||
}
|
||||
|
||||
// Check if we have enough data for key length
|
||||
if offset+4 > len(data) {
|
||||
return nil, fmt.Errorf("%w: unexpected end of batch data", ErrCorruptRecord)
|
||||
}
|
||||
|
||||
// Read key length
|
||||
keyLen := binary.LittleEndian.Uint32(data[offset : offset+4])
|
||||
offset += 4
|
||||
|
||||
// Validate key length
|
||||
if offset+int(keyLen) > len(data) {
|
||||
return nil, fmt.Errorf("%w: invalid key length %d", ErrCorruptRecord, keyLen)
|
||||
}
|
||||
|
||||
// Read key
|
||||
key := make([]byte, keyLen)
|
||||
copy(key, data[offset:offset+int(keyLen)])
|
||||
offset += int(keyLen)
|
||||
|
||||
var value []byte
|
||||
if opType != OpTypeDelete {
|
||||
// Check if we have enough data for value length
|
||||
if offset+4 > len(data) {
|
||||
return nil, fmt.Errorf("%w: unexpected end of batch data", ErrCorruptRecord)
|
||||
}
|
||||
|
||||
// Read value length
|
||||
valueLen := binary.LittleEndian.Uint32(data[offset : offset+4])
|
||||
offset += 4
|
||||
|
||||
// Validate value length
|
||||
if offset+int(valueLen) > len(data) {
|
||||
return nil, fmt.Errorf("%w: invalid value length %d", ErrCorruptRecord, valueLen)
|
||||
}
|
||||
|
||||
// Read value
|
||||
value = make([]byte, valueLen)
|
||||
copy(value, data[offset:offset+int(valueLen)])
|
||||
offset += int(valueLen)
|
||||
}
|
||||
|
||||
batch.Operations = append(batch.Operations, BatchOperation{
|
||||
Type: opType,
|
||||
Key: key,
|
||||
Value: value,
|
||||
})
|
||||
}
|
||||
|
||||
return batch, nil
|
||||
}
|
187
pkg/wal/batch_test.go
Normal file
187
pkg/wal/batch_test.go
Normal file
@ -0,0 +1,187 @@
|
||||
package wal
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestBatchOperations(t *testing.T) {
|
||||
batch := NewBatch()
|
||||
|
||||
// Test initially empty
|
||||
if batch.Count() != 0 {
|
||||
t.Errorf("Expected empty batch, got count %d", batch.Count())
|
||||
}
|
||||
|
||||
// Add operations
|
||||
batch.Put([]byte("key1"), []byte("value1"))
|
||||
batch.Put([]byte("key2"), []byte("value2"))
|
||||
batch.Delete([]byte("key3"))
|
||||
|
||||
// Check count
|
||||
if batch.Count() != 3 {
|
||||
t.Errorf("Expected batch with 3 operations, got %d", batch.Count())
|
||||
}
|
||||
|
||||
// Check size calculation
|
||||
expectedSize := BatchHeaderSize // count + seq
|
||||
expectedSize += 1 + 4 + 4 + len("key1") + len("value1") // type + keylen + vallen + key + value
|
||||
expectedSize += 1 + 4 + 4 + len("key2") + len("value2") // type + keylen + vallen + key + value
|
||||
expectedSize += 1 + 4 + len("key3") // type + keylen + key (no value for delete)
|
||||
|
||||
if batch.Size() != expectedSize {
|
||||
t.Errorf("Expected batch size %d, got %d", expectedSize, batch.Size())
|
||||
}
|
||||
|
||||
// Test reset
|
||||
batch.Reset()
|
||||
if batch.Count() != 0 {
|
||||
t.Errorf("Expected empty batch after reset, got count %d", batch.Count())
|
||||
}
|
||||
}
|
||||
|
||||
func TestBatchEncoding(t *testing.T) {
|
||||
dir := createTempDir(t)
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
cfg := createTestConfig()
|
||||
wal, err := NewWAL(cfg, dir)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create WAL: %v", err)
|
||||
}
|
||||
|
||||
// Create and write a batch
|
||||
batch := NewBatch()
|
||||
batch.Put([]byte("key1"), []byte("value1"))
|
||||
batch.Put([]byte("key2"), []byte("value2"))
|
||||
batch.Delete([]byte("key3"))
|
||||
|
||||
if err := batch.Write(wal); err != nil {
|
||||
t.Fatalf("Failed to write batch: %v", err)
|
||||
}
|
||||
|
||||
// Check sequence
|
||||
if batch.Seq == 0 {
|
||||
t.Errorf("Batch sequence number not set")
|
||||
}
|
||||
|
||||
// Close WAL
|
||||
if err := wal.Close(); err != nil {
|
||||
t.Fatalf("Failed to close WAL: %v", err)
|
||||
}
|
||||
|
||||
// Replay and decode
|
||||
var decodedBatch *Batch
|
||||
|
||||
err = ReplayWALDir(dir, func(entry *Entry) error {
|
||||
if entry.Type == OpTypeBatch {
|
||||
var err error
|
||||
decodedBatch, err = DecodeBatch(entry)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to replay WAL: %v", err)
|
||||
}
|
||||
|
||||
if decodedBatch == nil {
|
||||
t.Fatal("No batch found in replay")
|
||||
}
|
||||
|
||||
// Verify decoded batch
|
||||
if decodedBatch.Count() != 3 {
|
||||
t.Errorf("Expected 3 operations, got %d", decodedBatch.Count())
|
||||
}
|
||||
|
||||
if decodedBatch.Seq != batch.Seq {
|
||||
t.Errorf("Expected sequence %d, got %d", batch.Seq, decodedBatch.Seq)
|
||||
}
|
||||
|
||||
// Verify operations
|
||||
ops := decodedBatch.Operations
|
||||
|
||||
if ops[0].Type != OpTypePut || !bytes.Equal(ops[0].Key, []byte("key1")) || !bytes.Equal(ops[0].Value, []byte("value1")) {
|
||||
t.Errorf("First operation mismatch")
|
||||
}
|
||||
|
||||
if ops[1].Type != OpTypePut || !bytes.Equal(ops[1].Key, []byte("key2")) || !bytes.Equal(ops[1].Value, []byte("value2")) {
|
||||
t.Errorf("Second operation mismatch")
|
||||
}
|
||||
|
||||
if ops[2].Type != OpTypeDelete || !bytes.Equal(ops[2].Key, []byte("key3")) {
|
||||
t.Errorf("Third operation mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEmptyBatch(t *testing.T) {
|
||||
dir := createTempDir(t)
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
cfg := createTestConfig()
|
||||
wal, err := NewWAL(cfg, dir)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create WAL: %v", err)
|
||||
}
|
||||
|
||||
// Create empty batch
|
||||
batch := NewBatch()
|
||||
|
||||
// Try to write empty batch
|
||||
err = batch.Write(wal)
|
||||
if err != ErrEmptyBatch {
|
||||
t.Errorf("Expected ErrEmptyBatch, got: %v", err)
|
||||
}
|
||||
|
||||
// Close WAL
|
||||
if err := wal.Close(); err != nil {
|
||||
t.Fatalf("Failed to close WAL: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLargeBatch(t *testing.T) {
|
||||
dir := createTempDir(t)
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
cfg := createTestConfig()
|
||||
wal, err := NewWAL(cfg, dir)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create WAL: %v", err)
|
||||
}
|
||||
|
||||
// Create a batch that will exceed the maximum record size
|
||||
batch := NewBatch()
|
||||
|
||||
// Add many large key-value pairs
|
||||
largeValue := make([]byte, 4096) // 4KB
|
||||
for i := 0; i < 20; i++ {
|
||||
key := []byte(fmt.Sprintf("key%d", i))
|
||||
batch.Put(key, largeValue)
|
||||
}
|
||||
|
||||
// Verify the batch is too large
|
||||
if batch.Size() <= MaxRecordSize {
|
||||
t.Fatalf("Expected batch size > %d, got %d", MaxRecordSize, batch.Size())
|
||||
}
|
||||
|
||||
// Try to write the large batch
|
||||
err = batch.Write(wal)
|
||||
if err == nil {
|
||||
t.Error("Expected error when writing large batch")
|
||||
}
|
||||
|
||||
// Check that the error is ErrBatchTooLarge
|
||||
if err != nil && !bytes.Contains([]byte(err.Error()), []byte("batch too large")) {
|
||||
t.Errorf("Expected ErrBatchTooLarge, got: %v", err)
|
||||
}
|
||||
|
||||
// Close WAL
|
||||
if err := wal.Close(); err != nil {
|
||||
t.Fatalf("Failed to close WAL: %v", err)
|
||||
}
|
||||
}
|
283
pkg/wal/reader.go
Normal file
283
pkg/wal/reader.go
Normal file
@ -0,0 +1,283 @@
|
||||
package wal
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"hash/crc32"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
)
|
||||
|
||||
// Reader reads entries from WAL files
|
||||
type Reader struct {
|
||||
file *os.File
|
||||
reader *bufio.Reader
|
||||
buffer []byte
|
||||
fragments [][]byte
|
||||
currType uint8
|
||||
}
|
||||
|
||||
// OpenReader creates a new Reader for the given WAL file
|
||||
func OpenReader(path string) (*Reader, error) {
|
||||
file, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open WAL file: %w", err)
|
||||
}
|
||||
|
||||
return &Reader{
|
||||
file: file,
|
||||
reader: bufio.NewReaderSize(file, 64*1024), // 64KB buffer
|
||||
buffer: make([]byte, MaxRecordSize),
|
||||
fragments: make([][]byte, 0),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ReadEntry reads the next entry from the WAL
|
||||
func (r *Reader) ReadEntry() (*Entry, error) {
|
||||
// Loop until we have a complete entry
|
||||
for {
|
||||
// Read a record
|
||||
record, err := r.readRecord()
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
// If we have fragments, this is unexpected EOF
|
||||
if len(r.fragments) > 0 {
|
||||
return nil, fmt.Errorf("unexpected EOF with %d fragments", len(r.fragments))
|
||||
}
|
||||
return nil, io.EOF
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Process based on record type
|
||||
switch record.recordType {
|
||||
case RecordTypeFull:
|
||||
// Single record, parse directly
|
||||
return r.parseEntryData(record.data)
|
||||
|
||||
case RecordTypeFirst:
|
||||
// Start of a fragmented entry
|
||||
r.fragments = append(r.fragments, record.data)
|
||||
r.currType = record.data[0] // Save the operation type
|
||||
|
||||
case RecordTypeMiddle:
|
||||
// Middle fragment
|
||||
if len(r.fragments) == 0 {
|
||||
return nil, fmt.Errorf("%w: middle fragment without first fragment", ErrCorruptRecord)
|
||||
}
|
||||
r.fragments = append(r.fragments, record.data)
|
||||
|
||||
case RecordTypeLast:
|
||||
// Last fragment
|
||||
if len(r.fragments) == 0 {
|
||||
return nil, fmt.Errorf("%w: last fragment without previous fragments", ErrCorruptRecord)
|
||||
}
|
||||
r.fragments = append(r.fragments, record.data)
|
||||
|
||||
// Combine fragments into a single entry
|
||||
entry, err := r.processFragments()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return entry, nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("%w: %d", ErrInvalidRecordType, record.recordType)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Record represents a physical record in the WAL
|
||||
type record struct {
|
||||
recordType uint8
|
||||
data []byte
|
||||
}
|
||||
|
||||
// readRecord reads a single physical record from the WAL
|
||||
func (r *Reader) readRecord() (*record, error) {
|
||||
// Read header
|
||||
header := make([]byte, HeaderSize)
|
||||
if _, err := io.ReadFull(r.reader, header); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse header
|
||||
crc := binary.LittleEndian.Uint32(header[0:4])
|
||||
length := binary.LittleEndian.Uint16(header[4:6])
|
||||
recordType := header[6]
|
||||
|
||||
// Validate record type
|
||||
if recordType < RecordTypeFull || recordType > RecordTypeLast {
|
||||
return nil, fmt.Errorf("%w: %d", ErrInvalidRecordType, recordType)
|
||||
}
|
||||
|
||||
// Read payload
|
||||
data := make([]byte, length)
|
||||
if _, err := io.ReadFull(r.reader, data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Verify CRC
|
||||
computedCRC := crc32.ChecksumIEEE(data)
|
||||
if computedCRC != crc {
|
||||
return nil, fmt.Errorf("%w: expected CRC %d, got %d", ErrCorruptRecord, crc, computedCRC)
|
||||
}
|
||||
|
||||
return &record{
|
||||
recordType: recordType,
|
||||
data: data,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// processFragments combines fragments into a single entry
|
||||
func (r *Reader) processFragments() (*Entry, error) {
|
||||
// Determine total size
|
||||
totalSize := 0
|
||||
for _, frag := range r.fragments {
|
||||
totalSize += len(frag)
|
||||
}
|
||||
|
||||
// Combine fragments
|
||||
combined := make([]byte, totalSize)
|
||||
offset := 0
|
||||
for _, frag := range r.fragments {
|
||||
copy(combined[offset:], frag)
|
||||
offset += len(frag)
|
||||
}
|
||||
|
||||
// Reset fragments
|
||||
r.fragments = r.fragments[:0]
|
||||
|
||||
// Parse the combined data into an entry
|
||||
return r.parseEntryData(combined)
|
||||
}
|
||||
|
||||
// parseEntryData parses the binary data into an Entry structure
|
||||
func (r *Reader) parseEntryData(data []byte) (*Entry, error) {
|
||||
if len(data) < 13 { // Minimum size: type(1) + seq(8) + keylen(4)
|
||||
return nil, fmt.Errorf("%w: entry too small, %d bytes", ErrCorruptRecord, len(data))
|
||||
}
|
||||
|
||||
offset := 0
|
||||
|
||||
// Read entry type
|
||||
entryType := data[offset]
|
||||
offset++
|
||||
|
||||
// Validate entry type
|
||||
if entryType != OpTypePut && entryType != OpTypeDelete && entryType != OpTypeMerge && entryType != OpTypeBatch {
|
||||
return nil, fmt.Errorf("%w: %d", ErrInvalidOpType, entryType)
|
||||
}
|
||||
|
||||
// Read sequence number
|
||||
seqNum := binary.LittleEndian.Uint64(data[offset : offset+8])
|
||||
offset += 8
|
||||
|
||||
// Read key length
|
||||
keyLen := binary.LittleEndian.Uint32(data[offset : offset+4])
|
||||
offset += 4
|
||||
|
||||
// Validate key length
|
||||
if offset+int(keyLen) > len(data) {
|
||||
return nil, fmt.Errorf("%w: invalid key length %d", ErrCorruptRecord, keyLen)
|
||||
}
|
||||
|
||||
// Read key
|
||||
key := make([]byte, keyLen)
|
||||
copy(key, data[offset:offset+int(keyLen)])
|
||||
offset += int(keyLen)
|
||||
|
||||
// Read value if applicable
|
||||
var value []byte
|
||||
if entryType != OpTypeDelete {
|
||||
// Check if there's enough data for value length
|
||||
if offset+4 > len(data) {
|
||||
return nil, fmt.Errorf("%w: missing value length", ErrCorruptRecord)
|
||||
}
|
||||
|
||||
// Read value length
|
||||
valueLen := binary.LittleEndian.Uint32(data[offset : offset+4])
|
||||
offset += 4
|
||||
|
||||
// Validate value length
|
||||
if offset+int(valueLen) > len(data) {
|
||||
return nil, fmt.Errorf("%w: invalid value length %d", ErrCorruptRecord, valueLen)
|
||||
}
|
||||
|
||||
// Read value
|
||||
value = make([]byte, valueLen)
|
||||
copy(value, data[offset:offset+int(valueLen)])
|
||||
}
|
||||
|
||||
return &Entry{
|
||||
SequenceNumber: seqNum,
|
||||
Type: entryType,
|
||||
Key: key,
|
||||
Value: value,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Close closes the reader
|
||||
func (r *Reader) Close() error {
|
||||
return r.file.Close()
|
||||
}
|
||||
|
||||
// EntryHandler is a function that processes WAL entries during replay
|
||||
type EntryHandler func(*Entry) error
|
||||
|
||||
// FindWALFiles returns a list of WAL files in the given directory
|
||||
func FindWALFiles(dir string) ([]string, error) {
|
||||
pattern := filepath.Join(dir, "*.wal")
|
||||
matches, err := filepath.Glob(pattern)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to glob WAL files: %w", err)
|
||||
}
|
||||
|
||||
// Sort by filename (which should be timestamp-based)
|
||||
sort.Strings(matches)
|
||||
return matches, nil
|
||||
}
|
||||
|
||||
// ReplayWALFile replays a single WAL file and calls the handler for each entry
|
||||
func ReplayWALFile(path string, handler EntryHandler) error {
|
||||
reader, err := OpenReader(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
for {
|
||||
entry, err := reader.ReadEntry()
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
return fmt.Errorf("error reading entry from %s: %w", path, err)
|
||||
}
|
||||
|
||||
if err := handler(entry); err != nil {
|
||||
return fmt.Errorf("error handling entry: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReplayWALDir replays all WAL files in the given directory in order
|
||||
func ReplayWALDir(dir string, handler EntryHandler) error {
|
||||
files, err := FindWALFiles(dir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, file := range files {
|
||||
if err := ReplayWALFile(file, handler); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
392
pkg/wal/wal.go
Normal file
392
pkg/wal/wal.go
Normal file
@ -0,0 +1,392 @@
|
||||
package wal
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash/crc32"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"git.canoozie.net/jer/go-storage/pkg/config"
|
||||
)
|
||||
|
||||
const (
|
||||
// Record types
|
||||
RecordTypeFull = 1
|
||||
RecordTypeFirst = 2
|
||||
RecordTypeMiddle = 3
|
||||
RecordTypeLast = 4
|
||||
|
||||
// Operation types
|
||||
OpTypePut = 1
|
||||
OpTypeDelete = 2
|
||||
OpTypeMerge = 3
|
||||
OpTypeBatch = 4
|
||||
|
||||
// Header layout
|
||||
// - CRC (4 bytes)
|
||||
// - Length (2 bytes)
|
||||
// - Type (1 byte)
|
||||
HeaderSize = 7
|
||||
|
||||
// Maximum size of a record payload
|
||||
MaxRecordSize = 32 * 1024 // 32KB
|
||||
|
||||
// Default WAL file size
|
||||
DefaultWALFileSize = 64 * 1024 * 1024 // 64MB
|
||||
)
|
||||
|
||||
var (
|
||||
ErrCorruptRecord = errors.New("corrupt record")
|
||||
ErrInvalidRecordType = errors.New("invalid record type")
|
||||
ErrInvalidOpType = errors.New("invalid operation type")
|
||||
ErrWALClosed = errors.New("WAL is closed")
|
||||
ErrWALFull = errors.New("WAL file is full")
|
||||
)
|
||||
|
||||
// Entry represents a logical entry in the WAL
|
||||
type Entry struct {
|
||||
SequenceNumber uint64
|
||||
Type uint8 // OpTypePut, OpTypeDelete, etc.
|
||||
Key []byte
|
||||
Value []byte
|
||||
}
|
||||
|
||||
// WAL represents a write-ahead log
|
||||
type WAL struct {
|
||||
cfg *config.Config
|
||||
dir string
|
||||
file *os.File
|
||||
writer *bufio.Writer
|
||||
nextSequence uint64
|
||||
bytesWritten int64
|
||||
lastSync time.Time
|
||||
batchByteSize int64
|
||||
closed bool
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// NewWAL creates a new write-ahead log
|
||||
func NewWAL(cfg *config.Config, dir string) (*WAL, error) {
|
||||
if cfg == nil {
|
||||
return nil, errors.New("config cannot be nil")
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
return nil, fmt.Errorf("failed to create WAL directory: %w", err)
|
||||
}
|
||||
|
||||
// Create a new WAL file
|
||||
filename := fmt.Sprintf("%020d.wal", time.Now().UnixNano())
|
||||
path := filepath.Join(dir, filename)
|
||||
|
||||
file, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE|os.O_EXCL, 0644)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create WAL file: %w", err)
|
||||
}
|
||||
|
||||
wal := &WAL{
|
||||
cfg: cfg,
|
||||
dir: dir,
|
||||
file: file,
|
||||
writer: bufio.NewWriterSize(file, 64*1024), // 64KB buffer
|
||||
nextSequence: 1,
|
||||
lastSync: time.Now(),
|
||||
}
|
||||
|
||||
return wal, nil
|
||||
}
|
||||
|
||||
// Append adds an entry to the WAL
|
||||
func (w *WAL) Append(entryType uint8, key, value []byte) (uint64, error) {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
if w.closed {
|
||||
return 0, ErrWALClosed
|
||||
}
|
||||
|
||||
if entryType != OpTypePut && entryType != OpTypeDelete && entryType != OpTypeMerge {
|
||||
return 0, ErrInvalidOpType
|
||||
}
|
||||
|
||||
// Sequence number for this entry
|
||||
seqNum := w.nextSequence
|
||||
w.nextSequence++
|
||||
|
||||
// Encode the entry
|
||||
// Format: type(1) + seq(8) + keylen(4) + key + vallen(4) + val
|
||||
entrySize := 1 + 8 + 4 + len(key)
|
||||
if entryType != OpTypeDelete {
|
||||
entrySize += 4 + len(value)
|
||||
}
|
||||
|
||||
// Check if we need to split the record
|
||||
if entrySize <= MaxRecordSize {
|
||||
// Single record case
|
||||
recordType := uint8(RecordTypeFull)
|
||||
if err := w.writeRecord(recordType, entryType, seqNum, key, value); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
} else {
|
||||
// Split into multiple records
|
||||
if err := w.writeFragmentedRecord(entryType, seqNum, key, value); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
// Sync the file if needed
|
||||
if err := w.maybeSync(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return seqNum, nil
|
||||
}
|
||||
|
||||
// Write a single record
|
||||
func (w *WAL) writeRecord(recordType uint8, entryType uint8, seqNum uint64, key, value []byte) error {
|
||||
// Calculate the record size
|
||||
payloadSize := 1 + 8 + 4 + len(key) // type + seq + keylen + key
|
||||
if entryType != OpTypeDelete {
|
||||
payloadSize += 4 + len(value) // vallen + value
|
||||
}
|
||||
|
||||
if payloadSize > MaxRecordSize {
|
||||
return fmt.Errorf("record too large: %d > %d", payloadSize, MaxRecordSize)
|
||||
}
|
||||
|
||||
// Prepare the header
|
||||
header := make([]byte, HeaderSize)
|
||||
binary.LittleEndian.PutUint16(header[4:6], uint16(payloadSize))
|
||||
header[6] = recordType
|
||||
|
||||
// Prepare the payload
|
||||
payload := make([]byte, payloadSize)
|
||||
offset := 0
|
||||
|
||||
// Write entry type
|
||||
payload[offset] = entryType
|
||||
offset++
|
||||
|
||||
// Write sequence number
|
||||
binary.LittleEndian.PutUint64(payload[offset:offset+8], seqNum)
|
||||
offset += 8
|
||||
|
||||
// Write key length and key
|
||||
binary.LittleEndian.PutUint32(payload[offset:offset+4], uint32(len(key)))
|
||||
offset += 4
|
||||
copy(payload[offset:], key)
|
||||
offset += len(key)
|
||||
|
||||
// Write value length and value (if applicable)
|
||||
if entryType != OpTypeDelete {
|
||||
binary.LittleEndian.PutUint32(payload[offset:offset+4], uint32(len(value)))
|
||||
offset += 4
|
||||
copy(payload[offset:], value)
|
||||
}
|
||||
|
||||
// Calculate CRC
|
||||
crc := crc32.ChecksumIEEE(payload)
|
||||
binary.LittleEndian.PutUint32(header[0:4], crc)
|
||||
|
||||
// Write the record
|
||||
if _, err := w.writer.Write(header); err != nil {
|
||||
return fmt.Errorf("failed to write record header: %w", err)
|
||||
}
|
||||
if _, err := w.writer.Write(payload); err != nil {
|
||||
return fmt.Errorf("failed to write record payload: %w", err)
|
||||
}
|
||||
|
||||
// Update bytes written
|
||||
w.bytesWritten += int64(HeaderSize + payloadSize)
|
||||
w.batchByteSize += int64(HeaderSize + payloadSize)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// writeRawRecord writes a raw record with provided data as payload
|
||||
func (w *WAL) writeRawRecord(recordType uint8, data []byte) error {
|
||||
if len(data) > MaxRecordSize {
|
||||
return fmt.Errorf("record too large: %d > %d", len(data), MaxRecordSize)
|
||||
}
|
||||
|
||||
// Prepare the header
|
||||
header := make([]byte, HeaderSize)
|
||||
binary.LittleEndian.PutUint16(header[4:6], uint16(len(data)))
|
||||
header[6] = recordType
|
||||
|
||||
// Calculate CRC
|
||||
crc := crc32.ChecksumIEEE(data)
|
||||
binary.LittleEndian.PutUint32(header[0:4], crc)
|
||||
|
||||
// Write the record
|
||||
if _, err := w.writer.Write(header); err != nil {
|
||||
return fmt.Errorf("failed to write record header: %w", err)
|
||||
}
|
||||
if _, err := w.writer.Write(data); err != nil {
|
||||
return fmt.Errorf("failed to write record payload: %w", err)
|
||||
}
|
||||
|
||||
// Update bytes written
|
||||
w.bytesWritten += int64(HeaderSize + len(data))
|
||||
w.batchByteSize += int64(HeaderSize + len(data))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Write a fragmented record
|
||||
func (w *WAL) writeFragmentedRecord(entryType uint8, seqNum uint64, key, value []byte) error {
|
||||
// First fragment contains metadata: type, sequence, key length, and as much of the key as fits
|
||||
headerSize := 1 + 8 + 4 // type + seq + keylen
|
||||
|
||||
// Calculate how much of the key can fit in the first fragment
|
||||
maxKeyInFirst := MaxRecordSize - headerSize
|
||||
keyInFirst := min(len(key), maxKeyInFirst)
|
||||
|
||||
// Create the first fragment
|
||||
firstFragment := make([]byte, headerSize + keyInFirst)
|
||||
offset := 0
|
||||
|
||||
// Add metadata to first fragment
|
||||
firstFragment[offset] = entryType
|
||||
offset++
|
||||
|
||||
binary.LittleEndian.PutUint64(firstFragment[offset:offset+8], seqNum)
|
||||
offset += 8
|
||||
|
||||
binary.LittleEndian.PutUint32(firstFragment[offset:offset+4], uint32(len(key)))
|
||||
offset += 4
|
||||
|
||||
// Add as much of the key as fits
|
||||
copy(firstFragment[offset:], key[:keyInFirst])
|
||||
|
||||
// Write the first fragment
|
||||
if err := w.writeRawRecord(uint8(RecordTypeFirst), firstFragment); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Prepare the remaining data
|
||||
var remaining []byte
|
||||
|
||||
// Add any remaining key bytes
|
||||
if keyInFirst < len(key) {
|
||||
remaining = append(remaining, key[keyInFirst:]...)
|
||||
}
|
||||
|
||||
// Add value data if this isn't a delete operation
|
||||
if entryType != OpTypeDelete {
|
||||
// Add value length
|
||||
valueLenBuf := make([]byte, 4)
|
||||
binary.LittleEndian.PutUint32(valueLenBuf, uint32(len(value)))
|
||||
remaining = append(remaining, valueLenBuf...)
|
||||
|
||||
// Add value
|
||||
remaining = append(remaining, value...)
|
||||
}
|
||||
|
||||
// Write middle fragments (all full-sized except possibly the last)
|
||||
for len(remaining) > MaxRecordSize {
|
||||
chunk := remaining[:MaxRecordSize]
|
||||
remaining = remaining[MaxRecordSize:]
|
||||
|
||||
if err := w.writeRawRecord(uint8(RecordTypeMiddle), chunk); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Write the last fragment if there's any remaining data
|
||||
if len(remaining) > 0 {
|
||||
if err := w.writeRawRecord(uint8(RecordTypeLast), remaining); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// maybeSync syncs the WAL file if needed based on configuration
|
||||
func (w *WAL) maybeSync() error {
|
||||
needSync := false
|
||||
|
||||
switch w.cfg.WALSyncMode {
|
||||
case config.SyncImmediate:
|
||||
needSync = true
|
||||
case config.SyncBatch:
|
||||
// Sync if we've written enough bytes
|
||||
if w.batchByteSize >= w.cfg.WALSyncBytes {
|
||||
needSync = true
|
||||
}
|
||||
case config.SyncNone:
|
||||
// No syncing
|
||||
}
|
||||
|
||||
if needSync {
|
||||
// Use syncLocked since we're already holding the mutex
|
||||
if err := w.syncLocked(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// syncLocked performs the sync operation assuming the mutex is already held
|
||||
func (w *WAL) syncLocked() error {
|
||||
if w.closed {
|
||||
return ErrWALClosed
|
||||
}
|
||||
|
||||
if err := w.writer.Flush(); err != nil {
|
||||
return fmt.Errorf("failed to flush WAL buffer: %w", err)
|
||||
}
|
||||
|
||||
if err := w.file.Sync(); err != nil {
|
||||
return fmt.Errorf("failed to sync WAL file: %w", err)
|
||||
}
|
||||
|
||||
w.lastSync = time.Now()
|
||||
w.batchByteSize = 0
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Sync flushes all buffered data to disk
|
||||
func (w *WAL) Sync() error {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
return w.syncLocked()
|
||||
}
|
||||
|
||||
// Close closes the WAL
|
||||
func (w *WAL) Close() error {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
if w.closed {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Use syncLocked to flush and sync
|
||||
if err := w.syncLocked(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := w.file.Close(); err != nil {
|
||||
return fmt.Errorf("failed to close WAL file: %w", err)
|
||||
}
|
||||
|
||||
w.closed = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func min(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
590
pkg/wal/wal_test.go
Normal file
590
pkg/wal/wal_test.go
Normal file
@ -0,0 +1,590 @@
|
||||
package wal
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"git.canoozie.net/jer/go-storage/pkg/config"
|
||||
)
|
||||
|
||||
func createTestConfig() *config.Config {
|
||||
return config.NewDefaultConfig("/tmp/gostorage_test")
|
||||
}
|
||||
|
||||
func createTempDir(t *testing.T) string {
|
||||
dir, err := os.MkdirTemp("", "wal_test")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp directory: %v", err)
|
||||
}
|
||||
return dir
|
||||
}
|
||||
|
||||
func TestWALWrite(t *testing.T) {
|
||||
dir := createTempDir(t)
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
cfg := createTestConfig()
|
||||
wal, err := NewWAL(cfg, dir)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create WAL: %v", err)
|
||||
}
|
||||
|
||||
// Write some entries
|
||||
keys := []string{"key1", "key2", "key3"}
|
||||
values := []string{"value1", "value2", "value3"}
|
||||
|
||||
for i, key := range keys {
|
||||
seq, err := wal.Append(OpTypePut, []byte(key), []byte(values[i]))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to append entry: %v", err)
|
||||
}
|
||||
|
||||
if seq != uint64(i+1) {
|
||||
t.Errorf("Expected sequence %d, got %d", i+1, seq)
|
||||
}
|
||||
}
|
||||
|
||||
// Close the WAL
|
||||
if err := wal.Close(); err != nil {
|
||||
t.Fatalf("Failed to close WAL: %v", err)
|
||||
}
|
||||
|
||||
// Verify entries by replaying
|
||||
entries := make(map[string]string)
|
||||
|
||||
err = ReplayWALDir(dir, func(entry *Entry) error {
|
||||
if entry.Type == OpTypePut {
|
||||
entries[string(entry.Key)] = string(entry.Value)
|
||||
} else if entry.Type == OpTypeDelete {
|
||||
delete(entries, string(entry.Key))
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to replay WAL: %v", err)
|
||||
}
|
||||
|
||||
// Verify all entries are present
|
||||
for i, key := range keys {
|
||||
value, ok := entries[key]
|
||||
if !ok {
|
||||
t.Errorf("Entry for key %q not found", key)
|
||||
continue
|
||||
}
|
||||
|
||||
if value != values[i] {
|
||||
t.Errorf("Expected value %q for key %q, got %q", values[i], key, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestWALDelete(t *testing.T) {
|
||||
dir := createTempDir(t)
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
cfg := createTestConfig()
|
||||
wal, err := NewWAL(cfg, dir)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create WAL: %v", err)
|
||||
}
|
||||
|
||||
// Write and delete
|
||||
key := []byte("key1")
|
||||
value := []byte("value1")
|
||||
|
||||
_, err = wal.Append(OpTypePut, key, value)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to append put entry: %v", err)
|
||||
}
|
||||
|
||||
_, err = wal.Append(OpTypeDelete, key, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to append delete entry: %v", err)
|
||||
}
|
||||
|
||||
// Close the WAL
|
||||
if err := wal.Close(); err != nil {
|
||||
t.Fatalf("Failed to close WAL: %v", err)
|
||||
}
|
||||
|
||||
// Verify entries by replaying
|
||||
var deleted bool
|
||||
|
||||
err = ReplayWALDir(dir, func(entry *Entry) error {
|
||||
if entry.Type == OpTypePut && bytes.Equal(entry.Key, key) {
|
||||
if deleted {
|
||||
deleted = false // Key was re-added
|
||||
}
|
||||
} else if entry.Type == OpTypeDelete && bytes.Equal(entry.Key, key) {
|
||||
deleted = true
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to replay WAL: %v", err)
|
||||
}
|
||||
|
||||
if !deleted {
|
||||
t.Errorf("Expected key to be deleted")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWALLargeEntry(t *testing.T) {
|
||||
dir := createTempDir(t)
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
cfg := createTestConfig()
|
||||
wal, err := NewWAL(cfg, dir)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create WAL: %v", err)
|
||||
}
|
||||
|
||||
// Create a large key and value (but not too large for a single record)
|
||||
key := make([]byte, 8*1024) // 8KB
|
||||
value := make([]byte, 16*1024) // 16KB
|
||||
|
||||
for i := range key {
|
||||
key[i] = byte(i % 256)
|
||||
}
|
||||
|
||||
for i := range value {
|
||||
value[i] = byte((i * 2) % 256)
|
||||
}
|
||||
|
||||
// Append the large entry
|
||||
_, err = wal.Append(OpTypePut, key, value)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to append large entry: %v", err)
|
||||
}
|
||||
|
||||
// Close the WAL
|
||||
if err := wal.Close(); err != nil {
|
||||
t.Fatalf("Failed to close WAL: %v", err)
|
||||
}
|
||||
|
||||
// Verify by replaying
|
||||
var foundLargeEntry bool
|
||||
|
||||
err = ReplayWALDir(dir, func(entry *Entry) error {
|
||||
if entry.Type == OpTypePut && len(entry.Key) == len(key) && len(entry.Value) == len(value) {
|
||||
// Verify key
|
||||
for i := range key {
|
||||
if key[i] != entry.Key[i] {
|
||||
t.Errorf("Key mismatch at position %d: expected %d, got %d", i, key[i], entry.Key[i])
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Verify value
|
||||
for i := range value {
|
||||
if value[i] != entry.Value[i] {
|
||||
t.Errorf("Value mismatch at position %d: expected %d, got %d", i, value[i], entry.Value[i])
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
foundLargeEntry = true
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to replay WAL: %v", err)
|
||||
}
|
||||
|
||||
if !foundLargeEntry {
|
||||
t.Error("Large entry not found in replay")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWALBatch(t *testing.T) {
|
||||
dir := createTempDir(t)
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
cfg := createTestConfig()
|
||||
wal, err := NewWAL(cfg, dir)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create WAL: %v", err)
|
||||
}
|
||||
|
||||
// Create a batch
|
||||
batch := NewBatch()
|
||||
|
||||
keys := []string{"batch1", "batch2", "batch3"}
|
||||
values := []string{"value1", "value2", "value3"}
|
||||
|
||||
for i, key := range keys {
|
||||
batch.Put([]byte(key), []byte(values[i]))
|
||||
}
|
||||
|
||||
// Add a delete operation
|
||||
batch.Delete([]byte("batch2"))
|
||||
|
||||
// Write the batch
|
||||
if err := batch.Write(wal); err != nil {
|
||||
t.Fatalf("Failed to write batch: %v", err)
|
||||
}
|
||||
|
||||
// Close the WAL
|
||||
if err := wal.Close(); err != nil {
|
||||
t.Fatalf("Failed to close WAL: %v", err)
|
||||
}
|
||||
|
||||
// Verify by replaying
|
||||
entries := make(map[string]string)
|
||||
batchCount := 0
|
||||
|
||||
err = ReplayWALDir(dir, func(entry *Entry) error {
|
||||
if entry.Type == OpTypeBatch {
|
||||
batchCount++
|
||||
|
||||
// Decode batch
|
||||
batch, err := DecodeBatch(entry)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to decode batch: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Apply batch operations
|
||||
for _, op := range batch.Operations {
|
||||
if op.Type == OpTypePut {
|
||||
entries[string(op.Key)] = string(op.Value)
|
||||
} else if op.Type == OpTypeDelete {
|
||||
delete(entries, string(op.Key))
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to replay WAL: %v", err)
|
||||
}
|
||||
|
||||
// Verify batch was replayed
|
||||
if batchCount != 1 {
|
||||
t.Errorf("Expected 1 batch, got %d", batchCount)
|
||||
}
|
||||
|
||||
// Verify entries
|
||||
expectedEntries := map[string]string{
|
||||
"batch1": "value1",
|
||||
"batch3": "value3",
|
||||
// batch2 should be deleted
|
||||
}
|
||||
|
||||
for key, expectedValue := range expectedEntries {
|
||||
value, ok := entries[key]
|
||||
if !ok {
|
||||
t.Errorf("Entry for key %q not found", key)
|
||||
continue
|
||||
}
|
||||
|
||||
if value != expectedValue {
|
||||
t.Errorf("Expected value %q for key %q, got %q", expectedValue, key, value)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify batch2 is deleted
|
||||
if _, ok := entries["batch2"]; ok {
|
||||
t.Errorf("Key batch2 should be deleted")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWALRecovery(t *testing.T) {
|
||||
dir := createTempDir(t)
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
cfg := createTestConfig()
|
||||
|
||||
// Write some entries in the first WAL
|
||||
wal1, err := NewWAL(cfg, dir)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create WAL: %v", err)
|
||||
}
|
||||
|
||||
_, err = wal1.Append(OpTypePut, []byte("key1"), []byte("value1"))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to append entry: %v", err)
|
||||
}
|
||||
|
||||
if err := wal1.Close(); err != nil {
|
||||
t.Fatalf("Failed to close WAL: %v", err)
|
||||
}
|
||||
|
||||
// Create a second WAL file
|
||||
wal2, err := NewWAL(cfg, dir)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create WAL: %v", err)
|
||||
}
|
||||
|
||||
_, err = wal2.Append(OpTypePut, []byte("key2"), []byte("value2"))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to append entry: %v", err)
|
||||
}
|
||||
|
||||
if err := wal2.Close(); err != nil {
|
||||
t.Fatalf("Failed to close WAL: %v", err)
|
||||
}
|
||||
|
||||
// Verify entries by replaying all WAL files in order
|
||||
entries := make(map[string]string)
|
||||
|
||||
err = ReplayWALDir(dir, func(entry *Entry) error {
|
||||
if entry.Type == OpTypePut {
|
||||
entries[string(entry.Key)] = string(entry.Value)
|
||||
} else if entry.Type == OpTypeDelete {
|
||||
delete(entries, string(entry.Key))
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to replay WAL: %v", err)
|
||||
}
|
||||
|
||||
// Verify all entries are present
|
||||
expected := map[string]string{
|
||||
"key1": "value1",
|
||||
"key2": "value2",
|
||||
}
|
||||
|
||||
for key, expectedValue := range expected {
|
||||
value, ok := entries[key]
|
||||
if !ok {
|
||||
t.Errorf("Entry for key %q not found", key)
|
||||
continue
|
||||
}
|
||||
|
||||
if value != expectedValue {
|
||||
t.Errorf("Expected value %q for key %q, got %q", expectedValue, key, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestWALSyncModes(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
syncMode config.SyncMode
|
||||
}{
|
||||
{"SyncNone", config.SyncNone},
|
||||
{"SyncBatch", config.SyncBatch},
|
||||
{"SyncImmediate", config.SyncImmediate},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
dir := createTempDir(t)
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
// Create config with specific sync mode
|
||||
cfg := createTestConfig()
|
||||
cfg.WALSyncMode = tc.syncMode
|
||||
|
||||
wal, err := NewWAL(cfg, dir)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create WAL: %v", err)
|
||||
}
|
||||
|
||||
// Write some entries
|
||||
for i := 0; i < 10; i++ {
|
||||
key := []byte(fmt.Sprintf("key%d", i))
|
||||
value := []byte(fmt.Sprintf("value%d", i))
|
||||
|
||||
_, err := wal.Append(OpTypePut, key, value)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to append entry: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Close the WAL
|
||||
if err := wal.Close(); err != nil {
|
||||
t.Fatalf("Failed to close WAL: %v", err)
|
||||
}
|
||||
|
||||
// Verify entries by replaying
|
||||
count := 0
|
||||
err = ReplayWALDir(dir, func(entry *Entry) error {
|
||||
if entry.Type == OpTypePut {
|
||||
count++
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to replay WAL: %v", err)
|
||||
}
|
||||
|
||||
if count != 10 {
|
||||
t.Errorf("Expected 10 entries, got %d", count)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWALFragmentation(t *testing.T) {
|
||||
dir := createTempDir(t)
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
cfg := createTestConfig()
|
||||
wal, err := NewWAL(cfg, dir)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create WAL: %v", err)
|
||||
}
|
||||
|
||||
// Create an entry that's guaranteed to be fragmented
|
||||
// Header size is 1 + 8 + 4 = 13 bytes, so allocate more than MaxRecordSize - 13 for the key
|
||||
keySize := MaxRecordSize - 10
|
||||
valueSize := MaxRecordSize * 2
|
||||
|
||||
key := make([]byte, keySize) // Just under MaxRecordSize to ensure key fragmentation
|
||||
value := make([]byte, valueSize) // Large value to ensure value fragmentation
|
||||
|
||||
// Fill with recognizable patterns
|
||||
for i := range key {
|
||||
key[i] = byte(i % 256)
|
||||
}
|
||||
|
||||
for i := range value {
|
||||
value[i] = byte((i * 3) % 256)
|
||||
}
|
||||
|
||||
// Append the large entry - this should trigger fragmentation
|
||||
_, err = wal.Append(OpTypePut, key, value)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to append fragmented entry: %v", err)
|
||||
}
|
||||
|
||||
// Close the WAL
|
||||
if err := wal.Close(); err != nil {
|
||||
t.Fatalf("Failed to close WAL: %v", err)
|
||||
}
|
||||
|
||||
// Verify by replaying
|
||||
var reconstructedKey []byte
|
||||
var reconstructedValue []byte
|
||||
var foundPut bool
|
||||
|
||||
err = ReplayWALDir(dir, func(entry *Entry) error {
|
||||
if entry.Type == OpTypePut {
|
||||
foundPut = true
|
||||
reconstructedKey = entry.Key
|
||||
reconstructedValue = entry.Value
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to replay WAL: %v", err)
|
||||
}
|
||||
|
||||
// Check that we found the entry
|
||||
if !foundPut {
|
||||
t.Fatal("Did not find PUT entry in replay")
|
||||
}
|
||||
|
||||
// Verify key length matches
|
||||
if len(reconstructedKey) != keySize {
|
||||
t.Errorf("Key length mismatch: expected %d, got %d", keySize, len(reconstructedKey))
|
||||
}
|
||||
|
||||
// Verify value length matches
|
||||
if len(reconstructedValue) != valueSize {
|
||||
t.Errorf("Value length mismatch: expected %d, got %d", valueSize, len(reconstructedValue))
|
||||
}
|
||||
|
||||
// Check key content (first 10 bytes)
|
||||
for i := 0; i < 10 && i < len(key); i++ {
|
||||
if key[i] != reconstructedKey[i] {
|
||||
t.Errorf("Key mismatch at position %d: expected %d, got %d", i, key[i], reconstructedKey[i])
|
||||
}
|
||||
}
|
||||
|
||||
// Check key content (last 10 bytes)
|
||||
for i := 0; i < 10 && i < len(key); i++ {
|
||||
idx := len(key) - 1 - i
|
||||
if key[idx] != reconstructedKey[idx] {
|
||||
t.Errorf("Key mismatch at position %d: expected %d, got %d", idx, key[idx], reconstructedKey[idx])
|
||||
}
|
||||
}
|
||||
|
||||
// Check value content (first 10 bytes)
|
||||
for i := 0; i < 10 && i < len(value); i++ {
|
||||
if value[i] != reconstructedValue[i] {
|
||||
t.Errorf("Value mismatch at position %d: expected %d, got %d", i, value[i], reconstructedValue[i])
|
||||
}
|
||||
}
|
||||
|
||||
// Check value content (last 10 bytes)
|
||||
for i := 0; i < 10 && i < len(value); i++ {
|
||||
idx := len(value) - 1 - i
|
||||
if value[idx] != reconstructedValue[idx] {
|
||||
t.Errorf("Value mismatch at position %d: expected %d, got %d", idx, value[idx], reconstructedValue[idx])
|
||||
}
|
||||
}
|
||||
|
||||
// Verify random samples from the key and value
|
||||
for i := 0; i < 10; i++ {
|
||||
// Check random positions in the key
|
||||
keyPos := rand.Intn(keySize)
|
||||
if key[keyPos] != reconstructedKey[keyPos] {
|
||||
t.Errorf("Key mismatch at random position %d: expected %d, got %d", keyPos, key[keyPos], reconstructedKey[keyPos])
|
||||
}
|
||||
|
||||
// Check random positions in the value
|
||||
valuePos := rand.Intn(valueSize)
|
||||
if value[valuePos] != reconstructedValue[valuePos] {
|
||||
t.Errorf("Value mismatch at random position %d: expected %d, got %d", valuePos, value[valuePos], reconstructedValue[valuePos])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestWALErrorHandling(t *testing.T) {
|
||||
dir := createTempDir(t)
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
cfg := createTestConfig()
|
||||
wal, err := NewWAL(cfg, dir)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create WAL: %v", err)
|
||||
}
|
||||
|
||||
// Write some entries
|
||||
_, err = wal.Append(OpTypePut, []byte("key1"), []byte("value1"))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to append entry: %v", err)
|
||||
}
|
||||
|
||||
// Close the WAL
|
||||
if err := wal.Close(); err != nil {
|
||||
t.Fatalf("Failed to close WAL: %v", err)
|
||||
}
|
||||
|
||||
// Try to write after close
|
||||
_, err = wal.Append(OpTypePut, []byte("key2"), []byte("value2"))
|
||||
if err != ErrWALClosed {
|
||||
t.Errorf("Expected ErrWALClosed, got: %v", err)
|
||||
}
|
||||
|
||||
// Try to sync after close
|
||||
err = wal.Sync()
|
||||
if err != ErrWALClosed {
|
||||
t.Errorf("Expected ErrWALClosed, got: %v", err)
|
||||
}
|
||||
|
||||
// Try to replay a non-existent file
|
||||
nonExistentPath := filepath.Join(dir, "nonexistent.wal")
|
||||
err = ReplayWALFile(nonExistentPath, func(entry *Entry) error {
|
||||
return nil
|
||||
})
|
||||
|
||||
if err == nil {
|
||||
t.Error("Expected error when replaying non-existent file")
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user