diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..f43f217 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,32 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Build Commands +- Build: `go build ./...` +- Run tests: `go test ./...` +- Run single test: `go test ./pkg/path/to/package -run TestName` +- Benchmark: `go test ./pkg/path/to/package -bench .` +- Race detector: `go test -race ./...` + +## Linting/Formatting +- Format code: `go fmt ./...` +- Static analysis: `go vet ./...` +- Install golangci-lint: `go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest` +- Run linter: `golangci-lint run` + +## Code Style Guidelines +- Follow Go standard project layout in pkg/ and internal/ directories +- Use descriptive error types with context wrapping +- Implement single-writer architecture for write paths +- Allow concurrent reads via snapshots +- Use interfaces for component boundaries +- Follow idiomatic Go practices +- Add appropriate validation, especially for checksums +- All exported functions must have documentation comments +- For transaction management, use WAL for durability/atomicity + +## Version Control +- Use git for version control +- All commit messages must use semantic commit messages +- All commit messages must not reference code being generated or co-authored by Claude diff --git a/pkg/config/config.go b/pkg/config/config.go new file mode 100644 index 0000000..57cbbb8 --- /dev/null +++ b/pkg/config/config.go @@ -0,0 +1,199 @@ +package config + +import ( + "encoding/json" + "errors" + "fmt" + "os" + "path/filepath" + "sync" +) + +const ( + DefaultManifestFileName = "MANIFEST" + CurrentManifestVersion = 1 +) + +var ( + ErrInvalidConfig = errors.New("invalid configuration") + ErrManifestNotFound = errors.New("manifest not found") + ErrInvalidManifest = errors.New("invalid manifest") +) + +type SyncMode int + +const ( + SyncNone SyncMode = iota + SyncBatch + SyncImmediate +) + +type Config struct { + Version int `json:"version"` + + // WAL configuration + WALDir string `json:"wal_dir"` + WALSyncMode SyncMode `json:"wal_sync_mode"` + WALSyncBytes int64 `json:"wal_sync_bytes"` + + // MemTable configuration + MemTableSize int64 `json:"memtable_size"` + MaxMemTables int `json:"max_memtables"` + MaxMemTableAge int64 `json:"max_memtable_age"` + MemTablePoolCap int `json:"memtable_pool_cap"` + + // SSTable configuration + SSTDir string `json:"sst_dir"` + SSTableBlockSize int `json:"sstable_block_size"` + SSTableIndexSize int `json:"sstable_index_size"` + SSTableMaxSize int64 `json:"sstable_max_size"` + SSTableRestartSize int `json:"sstable_restart_size"` + + // Compaction configuration + CompactionLevels int `json:"compaction_levels"` + CompactionRatio float64 `json:"compaction_ratio"` + CompactionThreads int `json:"compaction_threads"` + CompactionInterval int64 `json:"compaction_interval"` + + mu sync.RWMutex +} + +// NewDefaultConfig creates a Config with recommended default values +func NewDefaultConfig(dbPath string) *Config { + walDir := filepath.Join(dbPath, "wal") + sstDir := filepath.Join(dbPath, "sst") + + return &Config{ + Version: CurrentManifestVersion, + + // WAL defaults + WALDir: walDir, + WALSyncMode: SyncBatch, + WALSyncBytes: 1024 * 1024, // 1MB + + // MemTable defaults + MemTableSize: 32 * 1024 * 1024, // 32MB + MaxMemTables: 4, + MaxMemTableAge: 600, // 10 minutes + MemTablePoolCap: 4, + + // SSTable defaults + SSTDir: sstDir, + SSTableBlockSize: 16 * 1024, // 16KB + SSTableIndexSize: 64 * 1024, // 64KB + SSTableMaxSize: 64 * 1024 * 1024, // 64MB + SSTableRestartSize: 16, // Restart points every 16 keys + + // Compaction defaults + CompactionLevels: 7, + CompactionRatio: 10, + CompactionThreads: 2, + CompactionInterval: 30, // 30 seconds + } +} + +// Validate checks if the configuration is valid +func (c *Config) Validate() error { + c.mu.RLock() + defer c.mu.RUnlock() + + if c.Version <= 0 { + return fmt.Errorf("%w: invalid version %d", ErrInvalidConfig, c.Version) + } + + if c.WALDir == "" { + return fmt.Errorf("%w: WAL directory not specified", ErrInvalidConfig) + } + + if c.SSTDir == "" { + return fmt.Errorf("%w: SSTable directory not specified", ErrInvalidConfig) + } + + if c.MemTableSize <= 0 { + return fmt.Errorf("%w: MemTable size must be positive", ErrInvalidConfig) + } + + if c.MaxMemTables <= 0 { + return fmt.Errorf("%w: Max MemTables must be positive", ErrInvalidConfig) + } + + if c.SSTableBlockSize <= 0 { + return fmt.Errorf("%w: SSTable block size must be positive", ErrInvalidConfig) + } + + if c.SSTableIndexSize <= 0 { + return fmt.Errorf("%w: SSTable index size must be positive", ErrInvalidConfig) + } + + if c.CompactionLevels <= 0 { + return fmt.Errorf("%w: Compaction levels must be positive", ErrInvalidConfig) + } + + if c.CompactionRatio <= 1.0 { + return fmt.Errorf("%w: Compaction ratio must be greater than 1.0", ErrInvalidConfig) + } + + return nil +} + +// LoadConfigFromManifest loads just the configuration portion from the manifest file +func LoadConfigFromManifest(dbPath string) (*Config, error) { + manifestPath := filepath.Join(dbPath, DefaultManifestFileName) + data, err := os.ReadFile(manifestPath) + if err != nil { + if os.IsNotExist(err) { + return nil, ErrManifestNotFound + } + return nil, fmt.Errorf("failed to read manifest: %w", err) + } + + var cfg Config + if err := json.Unmarshal(data, &cfg); err != nil { + return nil, fmt.Errorf("%w: %v", ErrInvalidManifest, err) + } + + if err := cfg.Validate(); err != nil { + return nil, err + } + + return &cfg, nil +} + +// SaveManifest saves the configuration to the manifest file +func (c *Config) SaveManifest(dbPath string) error { + c.mu.RLock() + defer c.mu.RUnlock() + + if err := c.Validate(); err != nil { + return err + } + + if err := os.MkdirAll(dbPath, 0755); err != nil { + return fmt.Errorf("failed to create directory: %w", err) + } + + manifestPath := filepath.Join(dbPath, DefaultManifestFileName) + tempPath := manifestPath + ".tmp" + + data, err := json.MarshalIndent(c, "", " ") + if err != nil { + return fmt.Errorf("failed to marshal config: %w", err) + } + + if err := os.WriteFile(tempPath, data, 0644); err != nil { + return fmt.Errorf("failed to write manifest: %w", err) + } + + if err := os.Rename(tempPath, manifestPath); err != nil { + return fmt.Errorf("failed to rename manifest: %w", err) + } + + return nil +} + +// Update applies the given function to modify the configuration +func (c *Config) Update(fn func(*Config)) { + c.mu.Lock() + defer c.mu.Unlock() + fn(c) +} \ No newline at end of file diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go new file mode 100644 index 0000000..ec34db8 --- /dev/null +++ b/pkg/config/config_test.go @@ -0,0 +1,167 @@ +package config + +import ( + "os" + "path/filepath" + "testing" +) + +func TestNewDefaultConfig(t *testing.T) { + dbPath := "/tmp/testdb" + cfg := NewDefaultConfig(dbPath) + + if cfg.Version != CurrentManifestVersion { + t.Errorf("expected version %d, got %d", CurrentManifestVersion, cfg.Version) + } + + if cfg.WALDir != filepath.Join(dbPath, "wal") { + t.Errorf("expected WAL dir %s, got %s", filepath.Join(dbPath, "wal"), cfg.WALDir) + } + + if cfg.SSTDir != filepath.Join(dbPath, "sst") { + t.Errorf("expected SST dir %s, got %s", filepath.Join(dbPath, "sst"), cfg.SSTDir) + } + + // Test default values + if cfg.WALSyncMode != SyncBatch { + t.Errorf("expected WAL sync mode %d, got %d", SyncBatch, cfg.WALSyncMode) + } + + if cfg.MemTableSize != 32*1024*1024 { + t.Errorf("expected memtable size %d, got %d", 32*1024*1024, cfg.MemTableSize) + } +} + +func TestConfigValidate(t *testing.T) { + cfg := NewDefaultConfig("/tmp/testdb") + + // Valid config + if err := cfg.Validate(); err != nil { + t.Errorf("expected valid config, got error: %v", err) + } + + // Test invalid configs + testCases := []struct { + name string + mutate func(*Config) + expected string + }{ + { + name: "invalid version", + mutate: func(c *Config) { + c.Version = 0 + }, + expected: "invalid configuration: invalid version 0", + }, + { + name: "empty WAL dir", + mutate: func(c *Config) { + c.WALDir = "" + }, + expected: "invalid configuration: WAL directory not specified", + }, + { + name: "empty SST dir", + mutate: func(c *Config) { + c.SSTDir = "" + }, + expected: "invalid configuration: SSTable directory not specified", + }, + { + name: "zero memtable size", + mutate: func(c *Config) { + c.MemTableSize = 0 + }, + expected: "invalid configuration: MemTable size must be positive", + }, + { + name: "negative max memtables", + mutate: func(c *Config) { + c.MaxMemTables = -1 + }, + expected: "invalid configuration: Max MemTables must be positive", + }, + { + name: "zero block size", + mutate: func(c *Config) { + c.SSTableBlockSize = 0 + }, + expected: "invalid configuration: SSTable block size must be positive", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + cfg := NewDefaultConfig("/tmp/testdb") + tc.mutate(cfg) + + err := cfg.Validate() + if err == nil { + t.Fatal("expected error, got nil") + } + + if err.Error() != tc.expected { + t.Errorf("expected error %q, got %q", tc.expected, err.Error()) + } + }) + } +} + +func TestConfigManifestSaveLoad(t *testing.T) { + // Create a temporary directory for the test + tempDir, err := os.MkdirTemp("", "config_test") + if err != nil { + t.Fatalf("failed to create temp dir: %v", err) + } + defer os.RemoveAll(tempDir) + + // Create a config and save it + cfg := NewDefaultConfig(tempDir) + cfg.MemTableSize = 16 * 1024 * 1024 // 16MB + cfg.CompactionThreads = 4 + + if err := cfg.SaveManifest(tempDir); err != nil { + t.Fatalf("failed to save manifest: %v", err) + } + + // Load the config + loadedCfg, err := LoadConfigFromManifest(tempDir) + if err != nil { + t.Fatalf("failed to load manifest: %v", err) + } + + // Verify loaded config + if loadedCfg.MemTableSize != cfg.MemTableSize { + t.Errorf("expected memtable size %d, got %d", cfg.MemTableSize, loadedCfg.MemTableSize) + } + + if loadedCfg.CompactionThreads != cfg.CompactionThreads { + t.Errorf("expected compaction threads %d, got %d", cfg.CompactionThreads, loadedCfg.CompactionThreads) + } + + // Test loading non-existent manifest + nonExistentDir := filepath.Join(tempDir, "nonexistent") + _, err = LoadConfigFromManifest(nonExistentDir) + if err != ErrManifestNotFound { + t.Errorf("expected ErrManifestNotFound, got %v", err) + } +} + +func TestConfigUpdate(t *testing.T) { + cfg := NewDefaultConfig("/tmp/testdb") + + // Update config + cfg.Update(func(c *Config) { + c.MemTableSize = 64 * 1024 * 1024 // 64MB + c.MaxMemTables = 8 + }) + + // Verify update + if cfg.MemTableSize != 64*1024*1024 { + t.Errorf("expected memtable size %d, got %d", 64*1024*1024, cfg.MemTableSize) + } + + if cfg.MaxMemTables != 8 { + t.Errorf("expected max memtables %d, got %d", 8, cfg.MaxMemTables) + } +} \ No newline at end of file diff --git a/pkg/config/manifest.go b/pkg/config/manifest.go new file mode 100644 index 0000000..85c1813 --- /dev/null +++ b/pkg/config/manifest.go @@ -0,0 +1,214 @@ +package config + +import ( + "encoding/json" + "fmt" + "io" + "os" + "path/filepath" + "sync" + "time" +) + +type ManifestEntry struct { + Timestamp int64 `json:"timestamp"` + Version int `json:"version"` + Config *Config `json:"config"` + FileSystem map[string]int64 `json:"filesystem,omitempty"` // Map of file paths to sequence numbers +} + +type Manifest struct { + DBPath string + Entries []ManifestEntry + Current *ManifestEntry + LastUpdate time.Time + mu sync.RWMutex +} + +// NewManifest creates a new manifest for the given database path +func NewManifest(dbPath string, config *Config) (*Manifest, error) { + if config == nil { + config = NewDefaultConfig(dbPath) + } + + if err := config.Validate(); err != nil { + return nil, err + } + + entry := ManifestEntry{ + Timestamp: time.Now().Unix(), + Version: CurrentManifestVersion, + Config: config, + } + + m := &Manifest{ + DBPath: dbPath, + Entries: []ManifestEntry{entry}, + Current: &entry, + LastUpdate: time.Now(), + } + + return m, nil +} + +// LoadManifest loads an existing manifest from the database directory +func LoadManifest(dbPath string) (*Manifest, error) { + manifestPath := filepath.Join(dbPath, DefaultManifestFileName) + file, err := os.Open(manifestPath) + if err != nil { + if os.IsNotExist(err) { + return nil, ErrManifestNotFound + } + return nil, fmt.Errorf("failed to open manifest: %w", err) + } + defer file.Close() + + data, err := io.ReadAll(file) + if err != nil { + return nil, fmt.Errorf("failed to read manifest: %w", err) + } + + var entries []ManifestEntry + if err := json.Unmarshal(data, &entries); err != nil { + return nil, fmt.Errorf("%w: %v", ErrInvalidManifest, err) + } + + if len(entries) == 0 { + return nil, fmt.Errorf("%w: no entries in manifest", ErrInvalidManifest) + } + + current := &entries[len(entries)-1] + if err := current.Config.Validate(); err != nil { + return nil, err + } + + m := &Manifest{ + DBPath: dbPath, + Entries: entries, + Current: current, + LastUpdate: time.Now(), + } + + return m, nil +} + +// Save persists the manifest to disk +func (m *Manifest) Save() error { + m.mu.Lock() + defer m.mu.Unlock() + + if err := m.Current.Config.Validate(); err != nil { + return err + } + + if err := os.MkdirAll(m.DBPath, 0755); err != nil { + return fmt.Errorf("failed to create directory: %w", err) + } + + manifestPath := filepath.Join(m.DBPath, DefaultManifestFileName) + tempPath := manifestPath + ".tmp" + + data, err := json.MarshalIndent(m.Entries, "", " ") + if err != nil { + return fmt.Errorf("failed to marshal manifest: %w", err) + } + + if err := os.WriteFile(tempPath, data, 0644); err != nil { + return fmt.Errorf("failed to write manifest: %w", err) + } + + if err := os.Rename(tempPath, manifestPath); err != nil { + return fmt.Errorf("failed to rename manifest: %w", err) + } + + m.LastUpdate = time.Now() + return nil +} + +// UpdateConfig creates a new configuration entry +func (m *Manifest) UpdateConfig(fn func(*Config)) error { + m.mu.Lock() + defer m.mu.Unlock() + + // Create a copy of the current config + currentJSON, err := json.Marshal(m.Current.Config) + if err != nil { + return fmt.Errorf("failed to marshal current config: %w", err) + } + + var newConfig Config + if err := json.Unmarshal(currentJSON, &newConfig); err != nil { + return fmt.Errorf("failed to unmarshal config: %w", err) + } + + // Apply the update function + fn(&newConfig) + + // Validate the new config + if err := newConfig.Validate(); err != nil { + return err + } + + // Create a new entry + entry := ManifestEntry{ + Timestamp: time.Now().Unix(), + Version: CurrentManifestVersion, + Config: &newConfig, + } + + m.Entries = append(m.Entries, entry) + m.Current = &m.Entries[len(m.Entries)-1] + + return nil +} + +// AddFile registers a file in the manifest +func (m *Manifest) AddFile(path string, seqNum int64) error { + m.mu.Lock() + defer m.mu.Unlock() + + if m.Current.FileSystem == nil { + m.Current.FileSystem = make(map[string]int64) + } + + m.Current.FileSystem[path] = seqNum + return nil +} + +// RemoveFile removes a file from the manifest +func (m *Manifest) RemoveFile(path string) error { + m.mu.Lock() + defer m.mu.Unlock() + + if m.Current.FileSystem == nil { + return nil + } + + delete(m.Current.FileSystem, path) + return nil +} + +// GetConfig returns the current configuration +func (m *Manifest) GetConfig() *Config { + m.mu.RLock() + defer m.mu.RUnlock() + return m.Current.Config +} + +// GetFiles returns all files registered in the manifest +func (m *Manifest) GetFiles() map[string]int64 { + m.mu.RLock() + defer m.mu.RUnlock() + + if m.Current.FileSystem == nil { + return make(map[string]int64) + } + + // Return a copy to prevent concurrent map access + files := make(map[string]int64, len(m.Current.FileSystem)) + for k, v := range m.Current.FileSystem { + files[k] = v + } + + return files +} \ No newline at end of file diff --git a/pkg/config/manifest_test.go b/pkg/config/manifest_test.go new file mode 100644 index 0000000..c06f850 --- /dev/null +++ b/pkg/config/manifest_test.go @@ -0,0 +1,176 @@ +package config + +import ( + "os" + "testing" +) + +func TestNewManifest(t *testing.T) { + dbPath := "/tmp/testdb" + cfg := NewDefaultConfig(dbPath) + + manifest, err := NewManifest(dbPath, cfg) + if err != nil { + t.Fatalf("failed to create manifest: %v", err) + } + + if manifest.DBPath != dbPath { + t.Errorf("expected DBPath %s, got %s", dbPath, manifest.DBPath) + } + + if len(manifest.Entries) != 1 { + t.Errorf("expected 1 entry, got %d", len(manifest.Entries)) + } + + if manifest.Current == nil { + t.Error("current entry is nil") + } else if manifest.Current.Config != cfg { + t.Error("current config does not match the provided config") + } +} + +func TestManifestUpdateConfig(t *testing.T) { + dbPath := "/tmp/testdb" + cfg := NewDefaultConfig(dbPath) + + manifest, err := NewManifest(dbPath, cfg) + if err != nil { + t.Fatalf("failed to create manifest: %v", err) + } + + // Update config + err = manifest.UpdateConfig(func(c *Config) { + c.MemTableSize = 64 * 1024 * 1024 // 64MB + c.MaxMemTables = 8 + }) + if err != nil { + t.Fatalf("failed to update config: %v", err) + } + + // Verify entries count + if len(manifest.Entries) != 2 { + t.Errorf("expected 2 entries, got %d", len(manifest.Entries)) + } + + // Verify updated config + current := manifest.GetConfig() + if current.MemTableSize != 64*1024*1024 { + t.Errorf("expected memtable size %d, got %d", 64*1024*1024, current.MemTableSize) + } + if current.MaxMemTables != 8 { + t.Errorf("expected max memtables %d, got %d", 8, current.MaxMemTables) + } +} + +func TestManifestFileTracking(t *testing.T) { + dbPath := "/tmp/testdb" + cfg := NewDefaultConfig(dbPath) + + manifest, err := NewManifest(dbPath, cfg) + if err != nil { + t.Fatalf("failed to create manifest: %v", err) + } + + // Add files + err = manifest.AddFile("sst/000001.sst", 1) + if err != nil { + t.Fatalf("failed to add file: %v", err) + } + + err = manifest.AddFile("sst/000002.sst", 2) + if err != nil { + t.Fatalf("failed to add file: %v", err) + } + + // Verify files + files := manifest.GetFiles() + if len(files) != 2 { + t.Errorf("expected 2 files, got %d", len(files)) + } + + if files["sst/000001.sst"] != 1 { + t.Errorf("expected sequence number 1, got %d", files["sst/000001.sst"]) + } + + if files["sst/000002.sst"] != 2 { + t.Errorf("expected sequence number 2, got %d", files["sst/000002.sst"]) + } + + // Remove file + err = manifest.RemoveFile("sst/000001.sst") + if err != nil { + t.Fatalf("failed to remove file: %v", err) + } + + // Verify files after removal + files = manifest.GetFiles() + if len(files) != 1 { + t.Errorf("expected 1 file, got %d", len(files)) + } + + if _, exists := files["sst/000001.sst"]; exists { + t.Error("file should have been removed") + } +} + +func TestManifestSaveLoad(t *testing.T) { + // Create a temporary directory for the test + tempDir, err := os.MkdirTemp("", "manifest_test") + if err != nil { + t.Fatalf("failed to create temp dir: %v", err) + } + defer os.RemoveAll(tempDir) + + // Create a manifest + cfg := NewDefaultConfig(tempDir) + manifest, err := NewManifest(tempDir, cfg) + if err != nil { + t.Fatalf("failed to create manifest: %v", err) + } + + // Update config + err = manifest.UpdateConfig(func(c *Config) { + c.MemTableSize = 64 * 1024 * 1024 // 64MB + }) + if err != nil { + t.Fatalf("failed to update config: %v", err) + } + + // Add some files + err = manifest.AddFile("sst/000001.sst", 1) + if err != nil { + t.Fatalf("failed to add file: %v", err) + } + + // Save the manifest + if err := manifest.Save(); err != nil { + t.Fatalf("failed to save manifest: %v", err) + } + + // Load the manifest + loadedManifest, err := LoadManifest(tempDir) + if err != nil { + t.Fatalf("failed to load manifest: %v", err) + } + + // Verify entries count + if len(loadedManifest.Entries) != len(manifest.Entries) { + t.Errorf("expected %d entries, got %d", len(manifest.Entries), len(loadedManifest.Entries)) + } + + // Verify config + loadedConfig := loadedManifest.GetConfig() + if loadedConfig.MemTableSize != 64*1024*1024 { + t.Errorf("expected memtable size %d, got %d", 64*1024*1024, loadedConfig.MemTableSize) + } + + // Verify files + loadedFiles := loadedManifest.GetFiles() + if len(loadedFiles) != 1 { + t.Errorf("expected 1 file, got %d", len(loadedFiles)) + } + + if loadedFiles["sst/000001.sst"] != 1 { + t.Errorf("expected sequence number 1, got %d", loadedFiles["sst/000001.sst"]) + } +} \ No newline at end of file diff --git a/pkg/wal/batch.go b/pkg/wal/batch.go new file mode 100644 index 0000000..dcd969b --- /dev/null +++ b/pkg/wal/batch.go @@ -0,0 +1,244 @@ +package wal + +import ( + "encoding/binary" + "errors" + "fmt" +) + +const ( + BatchHeaderSize = 12 // count(4) + seq(8) +) + +var ( + ErrEmptyBatch = errors.New("batch is empty") + ErrBatchTooLarge = errors.New("batch too large") +) + +// BatchOperation represents a single operation in a batch +type BatchOperation struct { + Type uint8 // OpTypePut, OpTypeDelete, etc. + Key []byte + Value []byte +} + +// Batch represents a collection of operations to be performed atomically +type Batch struct { + Operations []BatchOperation + Seq uint64 // Base sequence number +} + +// NewBatch creates a new empty batch +func NewBatch() *Batch { + return &Batch{ + Operations: make([]BatchOperation, 0, 16), + } +} + +// Put adds a Put operation to the batch +func (b *Batch) Put(key, value []byte) { + b.Operations = append(b.Operations, BatchOperation{ + Type: OpTypePut, + Key: key, + Value: value, + }) +} + +// Delete adds a Delete operation to the batch +func (b *Batch) Delete(key []byte) { + b.Operations = append(b.Operations, BatchOperation{ + Type: OpTypeDelete, + Key: key, + }) +} + +// Count returns the number of operations in the batch +func (b *Batch) Count() int { + return len(b.Operations) +} + +// Reset clears all operations from the batch +func (b *Batch) Reset() { + b.Operations = b.Operations[:0] + b.Seq = 0 +} + +// Size estimates the size of the batch in the WAL +func (b *Batch) Size() int { + size := BatchHeaderSize // count + seq + + for _, op := range b.Operations { + // Type(1) + KeyLen(4) + Key + size += 1 + 4 + len(op.Key) + + // ValueLen(4) + Value for Put operations + if op.Type != OpTypeDelete { + size += 4 + len(op.Value) + } + } + + return size +} + +// Write writes the batch to the WAL +func (b *Batch) Write(w *WAL) error { + if len(b.Operations) == 0 { + return ErrEmptyBatch + } + + // Estimate batch size + size := b.Size() + if size > MaxRecordSize { + return fmt.Errorf("%w: %d > %d", ErrBatchTooLarge, size, MaxRecordSize) + } + + // Serialize batch + data := make([]byte, size) + offset := 0 + + // Write count + binary.LittleEndian.PutUint32(data[offset:offset+4], uint32(len(b.Operations))) + offset += 4 + + // Write sequence base (will be set by WAL.AppendBatch) + offset += 8 + + // Write operations + for _, op := range b.Operations { + // Write type + data[offset] = op.Type + offset++ + + // Write key length + binary.LittleEndian.PutUint32(data[offset:offset+4], uint32(len(op.Key))) + offset += 4 + + // Write key + copy(data[offset:], op.Key) + offset += len(op.Key) + + // Write value for non-delete operations + if op.Type != OpTypeDelete { + // Write value length + binary.LittleEndian.PutUint32(data[offset:offset+4], uint32(len(op.Value))) + offset += 4 + + // Write value + copy(data[offset:], op.Value) + offset += len(op.Value) + } + } + + // Append to WAL + w.mu.Lock() + defer w.mu.Unlock() + + if w.closed { + return ErrWALClosed + } + + // Set the sequence number + b.Seq = w.nextSequence + binary.LittleEndian.PutUint64(data[4:12], b.Seq) + + // Increment sequence for future operations + w.nextSequence += uint64(len(b.Operations)) + + // Write as a batch entry + if err := w.writeRecord(uint8(RecordTypeFull), OpTypeBatch, b.Seq, data, nil); err != nil { + return err + } + + // Sync if needed + return w.maybeSync() +} + +// DecodeBatch decodes a batch entry from a WAL record +func DecodeBatch(entry *Entry) (*Batch, error) { + if entry.Type != OpTypeBatch { + return nil, fmt.Errorf("not a batch entry: type %d", entry.Type) + } + + // For batch entries, the batch data is in the Key field, not Value + data := entry.Key + if len(data) < BatchHeaderSize { + return nil, fmt.Errorf("%w: batch header too small", ErrCorruptRecord) + } + + // Read count and sequence + count := binary.LittleEndian.Uint32(data[0:4]) + seq := binary.LittleEndian.Uint64(data[4:12]) + + batch := &Batch{ + Operations: make([]BatchOperation, 0, count), + Seq: seq, + } + + offset := BatchHeaderSize + + // Read operations + for i := uint32(0); i < count; i++ { + // Check if we have enough data for type + if offset >= len(data) { + return nil, fmt.Errorf("%w: unexpected end of batch data", ErrCorruptRecord) + } + + // Read type + opType := data[offset] + offset++ + + // Validate operation type + if opType != OpTypePut && opType != OpTypeDelete && opType != OpTypeMerge { + return nil, fmt.Errorf("%w: %d", ErrInvalidOpType, opType) + } + + // Check if we have enough data for key length + if offset+4 > len(data) { + return nil, fmt.Errorf("%w: unexpected end of batch data", ErrCorruptRecord) + } + + // Read key length + keyLen := binary.LittleEndian.Uint32(data[offset : offset+4]) + offset += 4 + + // Validate key length + if offset+int(keyLen) > len(data) { + return nil, fmt.Errorf("%w: invalid key length %d", ErrCorruptRecord, keyLen) + } + + // Read key + key := make([]byte, keyLen) + copy(key, data[offset:offset+int(keyLen)]) + offset += int(keyLen) + + var value []byte + if opType != OpTypeDelete { + // Check if we have enough data for value length + if offset+4 > len(data) { + return nil, fmt.Errorf("%w: unexpected end of batch data", ErrCorruptRecord) + } + + // Read value length + valueLen := binary.LittleEndian.Uint32(data[offset : offset+4]) + offset += 4 + + // Validate value length + if offset+int(valueLen) > len(data) { + return nil, fmt.Errorf("%w: invalid value length %d", ErrCorruptRecord, valueLen) + } + + // Read value + value = make([]byte, valueLen) + copy(value, data[offset:offset+int(valueLen)]) + offset += int(valueLen) + } + + batch.Operations = append(batch.Operations, BatchOperation{ + Type: opType, + Key: key, + Value: value, + }) + } + + return batch, nil +} \ No newline at end of file diff --git a/pkg/wal/batch_test.go b/pkg/wal/batch_test.go new file mode 100644 index 0000000..5e9c440 --- /dev/null +++ b/pkg/wal/batch_test.go @@ -0,0 +1,187 @@ +package wal + +import ( + "bytes" + "fmt" + "os" + "testing" +) + +func TestBatchOperations(t *testing.T) { + batch := NewBatch() + + // Test initially empty + if batch.Count() != 0 { + t.Errorf("Expected empty batch, got count %d", batch.Count()) + } + + // Add operations + batch.Put([]byte("key1"), []byte("value1")) + batch.Put([]byte("key2"), []byte("value2")) + batch.Delete([]byte("key3")) + + // Check count + if batch.Count() != 3 { + t.Errorf("Expected batch with 3 operations, got %d", batch.Count()) + } + + // Check size calculation + expectedSize := BatchHeaderSize // count + seq + expectedSize += 1 + 4 + 4 + len("key1") + len("value1") // type + keylen + vallen + key + value + expectedSize += 1 + 4 + 4 + len("key2") + len("value2") // type + keylen + vallen + key + value + expectedSize += 1 + 4 + len("key3") // type + keylen + key (no value for delete) + + if batch.Size() != expectedSize { + t.Errorf("Expected batch size %d, got %d", expectedSize, batch.Size()) + } + + // Test reset + batch.Reset() + if batch.Count() != 0 { + t.Errorf("Expected empty batch after reset, got count %d", batch.Count()) + } +} + +func TestBatchEncoding(t *testing.T) { + dir := createTempDir(t) + defer os.RemoveAll(dir) + + cfg := createTestConfig() + wal, err := NewWAL(cfg, dir) + if err != nil { + t.Fatalf("Failed to create WAL: %v", err) + } + + // Create and write a batch + batch := NewBatch() + batch.Put([]byte("key1"), []byte("value1")) + batch.Put([]byte("key2"), []byte("value2")) + batch.Delete([]byte("key3")) + + if err := batch.Write(wal); err != nil { + t.Fatalf("Failed to write batch: %v", err) + } + + // Check sequence + if batch.Seq == 0 { + t.Errorf("Batch sequence number not set") + } + + // Close WAL + if err := wal.Close(); err != nil { + t.Fatalf("Failed to close WAL: %v", err) + } + + // Replay and decode + var decodedBatch *Batch + + err = ReplayWALDir(dir, func(entry *Entry) error { + if entry.Type == OpTypeBatch { + var err error + decodedBatch, err = DecodeBatch(entry) + if err != nil { + return err + } + } + return nil + }) + + if err != nil { + t.Fatalf("Failed to replay WAL: %v", err) + } + + if decodedBatch == nil { + t.Fatal("No batch found in replay") + } + + // Verify decoded batch + if decodedBatch.Count() != 3 { + t.Errorf("Expected 3 operations, got %d", decodedBatch.Count()) + } + + if decodedBatch.Seq != batch.Seq { + t.Errorf("Expected sequence %d, got %d", batch.Seq, decodedBatch.Seq) + } + + // Verify operations + ops := decodedBatch.Operations + + if ops[0].Type != OpTypePut || !bytes.Equal(ops[0].Key, []byte("key1")) || !bytes.Equal(ops[0].Value, []byte("value1")) { + t.Errorf("First operation mismatch") + } + + if ops[1].Type != OpTypePut || !bytes.Equal(ops[1].Key, []byte("key2")) || !bytes.Equal(ops[1].Value, []byte("value2")) { + t.Errorf("Second operation mismatch") + } + + if ops[2].Type != OpTypeDelete || !bytes.Equal(ops[2].Key, []byte("key3")) { + t.Errorf("Third operation mismatch") + } +} + +func TestEmptyBatch(t *testing.T) { + dir := createTempDir(t) + defer os.RemoveAll(dir) + + cfg := createTestConfig() + wal, err := NewWAL(cfg, dir) + if err != nil { + t.Fatalf("Failed to create WAL: %v", err) + } + + // Create empty batch + batch := NewBatch() + + // Try to write empty batch + err = batch.Write(wal) + if err != ErrEmptyBatch { + t.Errorf("Expected ErrEmptyBatch, got: %v", err) + } + + // Close WAL + if err := wal.Close(); err != nil { + t.Fatalf("Failed to close WAL: %v", err) + } +} + +func TestLargeBatch(t *testing.T) { + dir := createTempDir(t) + defer os.RemoveAll(dir) + + cfg := createTestConfig() + wal, err := NewWAL(cfg, dir) + if err != nil { + t.Fatalf("Failed to create WAL: %v", err) + } + + // Create a batch that will exceed the maximum record size + batch := NewBatch() + + // Add many large key-value pairs + largeValue := make([]byte, 4096) // 4KB + for i := 0; i < 20; i++ { + key := []byte(fmt.Sprintf("key%d", i)) + batch.Put(key, largeValue) + } + + // Verify the batch is too large + if batch.Size() <= MaxRecordSize { + t.Fatalf("Expected batch size > %d, got %d", MaxRecordSize, batch.Size()) + } + + // Try to write the large batch + err = batch.Write(wal) + if err == nil { + t.Error("Expected error when writing large batch") + } + + // Check that the error is ErrBatchTooLarge + if err != nil && !bytes.Contains([]byte(err.Error()), []byte("batch too large")) { + t.Errorf("Expected ErrBatchTooLarge, got: %v", err) + } + + // Close WAL + if err := wal.Close(); err != nil { + t.Fatalf("Failed to close WAL: %v", err) + } +} \ No newline at end of file diff --git a/pkg/wal/reader.go b/pkg/wal/reader.go new file mode 100644 index 0000000..7f9eb84 --- /dev/null +++ b/pkg/wal/reader.go @@ -0,0 +1,283 @@ +package wal + +import ( + "bufio" + "encoding/binary" + "fmt" + "hash/crc32" + "io" + "os" + "path/filepath" + "sort" +) + +// Reader reads entries from WAL files +type Reader struct { + file *os.File + reader *bufio.Reader + buffer []byte + fragments [][]byte + currType uint8 +} + +// OpenReader creates a new Reader for the given WAL file +func OpenReader(path string) (*Reader, error) { + file, err := os.Open(path) + if err != nil { + return nil, fmt.Errorf("failed to open WAL file: %w", err) + } + + return &Reader{ + file: file, + reader: bufio.NewReaderSize(file, 64*1024), // 64KB buffer + buffer: make([]byte, MaxRecordSize), + fragments: make([][]byte, 0), + }, nil +} + +// ReadEntry reads the next entry from the WAL +func (r *Reader) ReadEntry() (*Entry, error) { + // Loop until we have a complete entry + for { + // Read a record + record, err := r.readRecord() + if err != nil { + if err == io.EOF { + // If we have fragments, this is unexpected EOF + if len(r.fragments) > 0 { + return nil, fmt.Errorf("unexpected EOF with %d fragments", len(r.fragments)) + } + return nil, io.EOF + } + return nil, err + } + + // Process based on record type + switch record.recordType { + case RecordTypeFull: + // Single record, parse directly + return r.parseEntryData(record.data) + + case RecordTypeFirst: + // Start of a fragmented entry + r.fragments = append(r.fragments, record.data) + r.currType = record.data[0] // Save the operation type + + case RecordTypeMiddle: + // Middle fragment + if len(r.fragments) == 0 { + return nil, fmt.Errorf("%w: middle fragment without first fragment", ErrCorruptRecord) + } + r.fragments = append(r.fragments, record.data) + + case RecordTypeLast: + // Last fragment + if len(r.fragments) == 0 { + return nil, fmt.Errorf("%w: last fragment without previous fragments", ErrCorruptRecord) + } + r.fragments = append(r.fragments, record.data) + + // Combine fragments into a single entry + entry, err := r.processFragments() + if err != nil { + return nil, err + } + return entry, nil + + default: + return nil, fmt.Errorf("%w: %d", ErrInvalidRecordType, record.recordType) + } + } +} + +// Record represents a physical record in the WAL +type record struct { + recordType uint8 + data []byte +} + +// readRecord reads a single physical record from the WAL +func (r *Reader) readRecord() (*record, error) { + // Read header + header := make([]byte, HeaderSize) + if _, err := io.ReadFull(r.reader, header); err != nil { + return nil, err + } + + // Parse header + crc := binary.LittleEndian.Uint32(header[0:4]) + length := binary.LittleEndian.Uint16(header[4:6]) + recordType := header[6] + + // Validate record type + if recordType < RecordTypeFull || recordType > RecordTypeLast { + return nil, fmt.Errorf("%w: %d", ErrInvalidRecordType, recordType) + } + + // Read payload + data := make([]byte, length) + if _, err := io.ReadFull(r.reader, data); err != nil { + return nil, err + } + + // Verify CRC + computedCRC := crc32.ChecksumIEEE(data) + if computedCRC != crc { + return nil, fmt.Errorf("%w: expected CRC %d, got %d", ErrCorruptRecord, crc, computedCRC) + } + + return &record{ + recordType: recordType, + data: data, + }, nil +} + +// processFragments combines fragments into a single entry +func (r *Reader) processFragments() (*Entry, error) { + // Determine total size + totalSize := 0 + for _, frag := range r.fragments { + totalSize += len(frag) + } + + // Combine fragments + combined := make([]byte, totalSize) + offset := 0 + for _, frag := range r.fragments { + copy(combined[offset:], frag) + offset += len(frag) + } + + // Reset fragments + r.fragments = r.fragments[:0] + + // Parse the combined data into an entry + return r.parseEntryData(combined) +} + +// parseEntryData parses the binary data into an Entry structure +func (r *Reader) parseEntryData(data []byte) (*Entry, error) { + if len(data) < 13 { // Minimum size: type(1) + seq(8) + keylen(4) + return nil, fmt.Errorf("%w: entry too small, %d bytes", ErrCorruptRecord, len(data)) + } + + offset := 0 + + // Read entry type + entryType := data[offset] + offset++ + + // Validate entry type + if entryType != OpTypePut && entryType != OpTypeDelete && entryType != OpTypeMerge && entryType != OpTypeBatch { + return nil, fmt.Errorf("%w: %d", ErrInvalidOpType, entryType) + } + + // Read sequence number + seqNum := binary.LittleEndian.Uint64(data[offset : offset+8]) + offset += 8 + + // Read key length + keyLen := binary.LittleEndian.Uint32(data[offset : offset+4]) + offset += 4 + + // Validate key length + if offset+int(keyLen) > len(data) { + return nil, fmt.Errorf("%w: invalid key length %d", ErrCorruptRecord, keyLen) + } + + // Read key + key := make([]byte, keyLen) + copy(key, data[offset:offset+int(keyLen)]) + offset += int(keyLen) + + // Read value if applicable + var value []byte + if entryType != OpTypeDelete { + // Check if there's enough data for value length + if offset+4 > len(data) { + return nil, fmt.Errorf("%w: missing value length", ErrCorruptRecord) + } + + // Read value length + valueLen := binary.LittleEndian.Uint32(data[offset : offset+4]) + offset += 4 + + // Validate value length + if offset+int(valueLen) > len(data) { + return nil, fmt.Errorf("%w: invalid value length %d", ErrCorruptRecord, valueLen) + } + + // Read value + value = make([]byte, valueLen) + copy(value, data[offset:offset+int(valueLen)]) + } + + return &Entry{ + SequenceNumber: seqNum, + Type: entryType, + Key: key, + Value: value, + }, nil +} + +// Close closes the reader +func (r *Reader) Close() error { + return r.file.Close() +} + +// EntryHandler is a function that processes WAL entries during replay +type EntryHandler func(*Entry) error + +// FindWALFiles returns a list of WAL files in the given directory +func FindWALFiles(dir string) ([]string, error) { + pattern := filepath.Join(dir, "*.wal") + matches, err := filepath.Glob(pattern) + if err != nil { + return nil, fmt.Errorf("failed to glob WAL files: %w", err) + } + + // Sort by filename (which should be timestamp-based) + sort.Strings(matches) + return matches, nil +} + +// ReplayWALFile replays a single WAL file and calls the handler for each entry +func ReplayWALFile(path string, handler EntryHandler) error { + reader, err := OpenReader(path) + if err != nil { + return err + } + defer reader.Close() + + for { + entry, err := reader.ReadEntry() + if err != nil { + if err == io.EOF { + break + } + return fmt.Errorf("error reading entry from %s: %w", path, err) + } + + if err := handler(entry); err != nil { + return fmt.Errorf("error handling entry: %w", err) + } + } + + return nil +} + +// ReplayWALDir replays all WAL files in the given directory in order +func ReplayWALDir(dir string, handler EntryHandler) error { + files, err := FindWALFiles(dir) + if err != nil { + return err + } + + for _, file := range files { + if err := ReplayWALFile(file, handler); err != nil { + return err + } + } + + return nil +} \ No newline at end of file diff --git a/pkg/wal/wal.go b/pkg/wal/wal.go new file mode 100644 index 0000000..e5189b3 --- /dev/null +++ b/pkg/wal/wal.go @@ -0,0 +1,392 @@ +package wal + +import ( + "bufio" + "encoding/binary" + "errors" + "fmt" + "hash/crc32" + "os" + "path/filepath" + "sync" + "time" + + "git.canoozie.net/jer/go-storage/pkg/config" +) + +const ( + // Record types + RecordTypeFull = 1 + RecordTypeFirst = 2 + RecordTypeMiddle = 3 + RecordTypeLast = 4 + + // Operation types + OpTypePut = 1 + OpTypeDelete = 2 + OpTypeMerge = 3 + OpTypeBatch = 4 + + // Header layout + // - CRC (4 bytes) + // - Length (2 bytes) + // - Type (1 byte) + HeaderSize = 7 + + // Maximum size of a record payload + MaxRecordSize = 32 * 1024 // 32KB + + // Default WAL file size + DefaultWALFileSize = 64 * 1024 * 1024 // 64MB +) + +var ( + ErrCorruptRecord = errors.New("corrupt record") + ErrInvalidRecordType = errors.New("invalid record type") + ErrInvalidOpType = errors.New("invalid operation type") + ErrWALClosed = errors.New("WAL is closed") + ErrWALFull = errors.New("WAL file is full") +) + +// Entry represents a logical entry in the WAL +type Entry struct { + SequenceNumber uint64 + Type uint8 // OpTypePut, OpTypeDelete, etc. + Key []byte + Value []byte +} + +// WAL represents a write-ahead log +type WAL struct { + cfg *config.Config + dir string + file *os.File + writer *bufio.Writer + nextSequence uint64 + bytesWritten int64 + lastSync time.Time + batchByteSize int64 + closed bool + mu sync.Mutex +} + +// NewWAL creates a new write-ahead log +func NewWAL(cfg *config.Config, dir string) (*WAL, error) { + if cfg == nil { + return nil, errors.New("config cannot be nil") + } + + if err := os.MkdirAll(dir, 0755); err != nil { + return nil, fmt.Errorf("failed to create WAL directory: %w", err) + } + + // Create a new WAL file + filename := fmt.Sprintf("%020d.wal", time.Now().UnixNano()) + path := filepath.Join(dir, filename) + + file, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE|os.O_EXCL, 0644) + if err != nil { + return nil, fmt.Errorf("failed to create WAL file: %w", err) + } + + wal := &WAL{ + cfg: cfg, + dir: dir, + file: file, + writer: bufio.NewWriterSize(file, 64*1024), // 64KB buffer + nextSequence: 1, + lastSync: time.Now(), + } + + return wal, nil +} + +// Append adds an entry to the WAL +func (w *WAL) Append(entryType uint8, key, value []byte) (uint64, error) { + w.mu.Lock() + defer w.mu.Unlock() + + if w.closed { + return 0, ErrWALClosed + } + + if entryType != OpTypePut && entryType != OpTypeDelete && entryType != OpTypeMerge { + return 0, ErrInvalidOpType + } + + // Sequence number for this entry + seqNum := w.nextSequence + w.nextSequence++ + + // Encode the entry + // Format: type(1) + seq(8) + keylen(4) + key + vallen(4) + val + entrySize := 1 + 8 + 4 + len(key) + if entryType != OpTypeDelete { + entrySize += 4 + len(value) + } + + // Check if we need to split the record + if entrySize <= MaxRecordSize { + // Single record case + recordType := uint8(RecordTypeFull) + if err := w.writeRecord(recordType, entryType, seqNum, key, value); err != nil { + return 0, err + } + } else { + // Split into multiple records + if err := w.writeFragmentedRecord(entryType, seqNum, key, value); err != nil { + return 0, err + } + } + + // Sync the file if needed + if err := w.maybeSync(); err != nil { + return 0, err + } + + return seqNum, nil +} + +// Write a single record +func (w *WAL) writeRecord(recordType uint8, entryType uint8, seqNum uint64, key, value []byte) error { + // Calculate the record size + payloadSize := 1 + 8 + 4 + len(key) // type + seq + keylen + key + if entryType != OpTypeDelete { + payloadSize += 4 + len(value) // vallen + value + } + + if payloadSize > MaxRecordSize { + return fmt.Errorf("record too large: %d > %d", payloadSize, MaxRecordSize) + } + + // Prepare the header + header := make([]byte, HeaderSize) + binary.LittleEndian.PutUint16(header[4:6], uint16(payloadSize)) + header[6] = recordType + + // Prepare the payload + payload := make([]byte, payloadSize) + offset := 0 + + // Write entry type + payload[offset] = entryType + offset++ + + // Write sequence number + binary.LittleEndian.PutUint64(payload[offset:offset+8], seqNum) + offset += 8 + + // Write key length and key + binary.LittleEndian.PutUint32(payload[offset:offset+4], uint32(len(key))) + offset += 4 + copy(payload[offset:], key) + offset += len(key) + + // Write value length and value (if applicable) + if entryType != OpTypeDelete { + binary.LittleEndian.PutUint32(payload[offset:offset+4], uint32(len(value))) + offset += 4 + copy(payload[offset:], value) + } + + // Calculate CRC + crc := crc32.ChecksumIEEE(payload) + binary.LittleEndian.PutUint32(header[0:4], crc) + + // Write the record + if _, err := w.writer.Write(header); err != nil { + return fmt.Errorf("failed to write record header: %w", err) + } + if _, err := w.writer.Write(payload); err != nil { + return fmt.Errorf("failed to write record payload: %w", err) + } + + // Update bytes written + w.bytesWritten += int64(HeaderSize + payloadSize) + w.batchByteSize += int64(HeaderSize + payloadSize) + + return nil +} + +// writeRawRecord writes a raw record with provided data as payload +func (w *WAL) writeRawRecord(recordType uint8, data []byte) error { + if len(data) > MaxRecordSize { + return fmt.Errorf("record too large: %d > %d", len(data), MaxRecordSize) + } + + // Prepare the header + header := make([]byte, HeaderSize) + binary.LittleEndian.PutUint16(header[4:6], uint16(len(data))) + header[6] = recordType + + // Calculate CRC + crc := crc32.ChecksumIEEE(data) + binary.LittleEndian.PutUint32(header[0:4], crc) + + // Write the record + if _, err := w.writer.Write(header); err != nil { + return fmt.Errorf("failed to write record header: %w", err) + } + if _, err := w.writer.Write(data); err != nil { + return fmt.Errorf("failed to write record payload: %w", err) + } + + // Update bytes written + w.bytesWritten += int64(HeaderSize + len(data)) + w.batchByteSize += int64(HeaderSize + len(data)) + + return nil +} + +// Write a fragmented record +func (w *WAL) writeFragmentedRecord(entryType uint8, seqNum uint64, key, value []byte) error { + // First fragment contains metadata: type, sequence, key length, and as much of the key as fits + headerSize := 1 + 8 + 4 // type + seq + keylen + + // Calculate how much of the key can fit in the first fragment + maxKeyInFirst := MaxRecordSize - headerSize + keyInFirst := min(len(key), maxKeyInFirst) + + // Create the first fragment + firstFragment := make([]byte, headerSize + keyInFirst) + offset := 0 + + // Add metadata to first fragment + firstFragment[offset] = entryType + offset++ + + binary.LittleEndian.PutUint64(firstFragment[offset:offset+8], seqNum) + offset += 8 + + binary.LittleEndian.PutUint32(firstFragment[offset:offset+4], uint32(len(key))) + offset += 4 + + // Add as much of the key as fits + copy(firstFragment[offset:], key[:keyInFirst]) + + // Write the first fragment + if err := w.writeRawRecord(uint8(RecordTypeFirst), firstFragment); err != nil { + return err + } + + // Prepare the remaining data + var remaining []byte + + // Add any remaining key bytes + if keyInFirst < len(key) { + remaining = append(remaining, key[keyInFirst:]...) + } + + // Add value data if this isn't a delete operation + if entryType != OpTypeDelete { + // Add value length + valueLenBuf := make([]byte, 4) + binary.LittleEndian.PutUint32(valueLenBuf, uint32(len(value))) + remaining = append(remaining, valueLenBuf...) + + // Add value + remaining = append(remaining, value...) + } + + // Write middle fragments (all full-sized except possibly the last) + for len(remaining) > MaxRecordSize { + chunk := remaining[:MaxRecordSize] + remaining = remaining[MaxRecordSize:] + + if err := w.writeRawRecord(uint8(RecordTypeMiddle), chunk); err != nil { + return err + } + } + + // Write the last fragment if there's any remaining data + if len(remaining) > 0 { + if err := w.writeRawRecord(uint8(RecordTypeLast), remaining); err != nil { + return err + } + } + + return nil +} + +// maybeSync syncs the WAL file if needed based on configuration +func (w *WAL) maybeSync() error { + needSync := false + + switch w.cfg.WALSyncMode { + case config.SyncImmediate: + needSync = true + case config.SyncBatch: + // Sync if we've written enough bytes + if w.batchByteSize >= w.cfg.WALSyncBytes { + needSync = true + } + case config.SyncNone: + // No syncing + } + + if needSync { + // Use syncLocked since we're already holding the mutex + if err := w.syncLocked(); err != nil { + return err + } + } + + return nil +} + +// syncLocked performs the sync operation assuming the mutex is already held +func (w *WAL) syncLocked() error { + if w.closed { + return ErrWALClosed + } + + if err := w.writer.Flush(); err != nil { + return fmt.Errorf("failed to flush WAL buffer: %w", err) + } + + if err := w.file.Sync(); err != nil { + return fmt.Errorf("failed to sync WAL file: %w", err) + } + + w.lastSync = time.Now() + w.batchByteSize = 0 + + return nil +} + +// Sync flushes all buffered data to disk +func (w *WAL) Sync() error { + w.mu.Lock() + defer w.mu.Unlock() + + return w.syncLocked() +} + +// Close closes the WAL +func (w *WAL) Close() error { + w.mu.Lock() + defer w.mu.Unlock() + + if w.closed { + return nil + } + + // Use syncLocked to flush and sync + if err := w.syncLocked(); err != nil { + return err + } + + if err := w.file.Close(); err != nil { + return fmt.Errorf("failed to close WAL file: %w", err) + } + + w.closed = true + return nil +} + +func min(a, b int) int { + if a < b { + return a + } + return b +} \ No newline at end of file diff --git a/pkg/wal/wal_test.go b/pkg/wal/wal_test.go new file mode 100644 index 0000000..964434b --- /dev/null +++ b/pkg/wal/wal_test.go @@ -0,0 +1,590 @@ +package wal + +import ( + "bytes" + "fmt" + "math/rand" + "os" + "path/filepath" + "testing" + + "git.canoozie.net/jer/go-storage/pkg/config" +) + +func createTestConfig() *config.Config { + return config.NewDefaultConfig("/tmp/gostorage_test") +} + +func createTempDir(t *testing.T) string { + dir, err := os.MkdirTemp("", "wal_test") + if err != nil { + t.Fatalf("Failed to create temp directory: %v", err) + } + return dir +} + +func TestWALWrite(t *testing.T) { + dir := createTempDir(t) + defer os.RemoveAll(dir) + + cfg := createTestConfig() + wal, err := NewWAL(cfg, dir) + if err != nil { + t.Fatalf("Failed to create WAL: %v", err) + } + + // Write some entries + keys := []string{"key1", "key2", "key3"} + values := []string{"value1", "value2", "value3"} + + for i, key := range keys { + seq, err := wal.Append(OpTypePut, []byte(key), []byte(values[i])) + if err != nil { + t.Fatalf("Failed to append entry: %v", err) + } + + if seq != uint64(i+1) { + t.Errorf("Expected sequence %d, got %d", i+1, seq) + } + } + + // Close the WAL + if err := wal.Close(); err != nil { + t.Fatalf("Failed to close WAL: %v", err) + } + + // Verify entries by replaying + entries := make(map[string]string) + + err = ReplayWALDir(dir, func(entry *Entry) error { + if entry.Type == OpTypePut { + entries[string(entry.Key)] = string(entry.Value) + } else if entry.Type == OpTypeDelete { + delete(entries, string(entry.Key)) + } + return nil + }) + + if err != nil { + t.Fatalf("Failed to replay WAL: %v", err) + } + + // Verify all entries are present + for i, key := range keys { + value, ok := entries[key] + if !ok { + t.Errorf("Entry for key %q not found", key) + continue + } + + if value != values[i] { + t.Errorf("Expected value %q for key %q, got %q", values[i], key, value) + } + } +} + +func TestWALDelete(t *testing.T) { + dir := createTempDir(t) + defer os.RemoveAll(dir) + + cfg := createTestConfig() + wal, err := NewWAL(cfg, dir) + if err != nil { + t.Fatalf("Failed to create WAL: %v", err) + } + + // Write and delete + key := []byte("key1") + value := []byte("value1") + + _, err = wal.Append(OpTypePut, key, value) + if err != nil { + t.Fatalf("Failed to append put entry: %v", err) + } + + _, err = wal.Append(OpTypeDelete, key, nil) + if err != nil { + t.Fatalf("Failed to append delete entry: %v", err) + } + + // Close the WAL + if err := wal.Close(); err != nil { + t.Fatalf("Failed to close WAL: %v", err) + } + + // Verify entries by replaying + var deleted bool + + err = ReplayWALDir(dir, func(entry *Entry) error { + if entry.Type == OpTypePut && bytes.Equal(entry.Key, key) { + if deleted { + deleted = false // Key was re-added + } + } else if entry.Type == OpTypeDelete && bytes.Equal(entry.Key, key) { + deleted = true + } + return nil + }) + + if err != nil { + t.Fatalf("Failed to replay WAL: %v", err) + } + + if !deleted { + t.Errorf("Expected key to be deleted") + } +} + +func TestWALLargeEntry(t *testing.T) { + dir := createTempDir(t) + defer os.RemoveAll(dir) + + cfg := createTestConfig() + wal, err := NewWAL(cfg, dir) + if err != nil { + t.Fatalf("Failed to create WAL: %v", err) + } + + // Create a large key and value (but not too large for a single record) + key := make([]byte, 8*1024) // 8KB + value := make([]byte, 16*1024) // 16KB + + for i := range key { + key[i] = byte(i % 256) + } + + for i := range value { + value[i] = byte((i * 2) % 256) + } + + // Append the large entry + _, err = wal.Append(OpTypePut, key, value) + if err != nil { + t.Fatalf("Failed to append large entry: %v", err) + } + + // Close the WAL + if err := wal.Close(); err != nil { + t.Fatalf("Failed to close WAL: %v", err) + } + + // Verify by replaying + var foundLargeEntry bool + + err = ReplayWALDir(dir, func(entry *Entry) error { + if entry.Type == OpTypePut && len(entry.Key) == len(key) && len(entry.Value) == len(value) { + // Verify key + for i := range key { + if key[i] != entry.Key[i] { + t.Errorf("Key mismatch at position %d: expected %d, got %d", i, key[i], entry.Key[i]) + return nil + } + } + + // Verify value + for i := range value { + if value[i] != entry.Value[i] { + t.Errorf("Value mismatch at position %d: expected %d, got %d", i, value[i], entry.Value[i]) + return nil + } + } + + foundLargeEntry = true + } + return nil + }) + + if err != nil { + t.Fatalf("Failed to replay WAL: %v", err) + } + + if !foundLargeEntry { + t.Error("Large entry not found in replay") + } +} + +func TestWALBatch(t *testing.T) { + dir := createTempDir(t) + defer os.RemoveAll(dir) + + cfg := createTestConfig() + wal, err := NewWAL(cfg, dir) + if err != nil { + t.Fatalf("Failed to create WAL: %v", err) + } + + // Create a batch + batch := NewBatch() + + keys := []string{"batch1", "batch2", "batch3"} + values := []string{"value1", "value2", "value3"} + + for i, key := range keys { + batch.Put([]byte(key), []byte(values[i])) + } + + // Add a delete operation + batch.Delete([]byte("batch2")) + + // Write the batch + if err := batch.Write(wal); err != nil { + t.Fatalf("Failed to write batch: %v", err) + } + + // Close the WAL + if err := wal.Close(); err != nil { + t.Fatalf("Failed to close WAL: %v", err) + } + + // Verify by replaying + entries := make(map[string]string) + batchCount := 0 + + err = ReplayWALDir(dir, func(entry *Entry) error { + if entry.Type == OpTypeBatch { + batchCount++ + + // Decode batch + batch, err := DecodeBatch(entry) + if err != nil { + t.Errorf("Failed to decode batch: %v", err) + return nil + } + + // Apply batch operations + for _, op := range batch.Operations { + if op.Type == OpTypePut { + entries[string(op.Key)] = string(op.Value) + } else if op.Type == OpTypeDelete { + delete(entries, string(op.Key)) + } + } + } + return nil + }) + + if err != nil { + t.Fatalf("Failed to replay WAL: %v", err) + } + + // Verify batch was replayed + if batchCount != 1 { + t.Errorf("Expected 1 batch, got %d", batchCount) + } + + // Verify entries + expectedEntries := map[string]string{ + "batch1": "value1", + "batch3": "value3", + // batch2 should be deleted + } + + for key, expectedValue := range expectedEntries { + value, ok := entries[key] + if !ok { + t.Errorf("Entry for key %q not found", key) + continue + } + + if value != expectedValue { + t.Errorf("Expected value %q for key %q, got %q", expectedValue, key, value) + } + } + + // Verify batch2 is deleted + if _, ok := entries["batch2"]; ok { + t.Errorf("Key batch2 should be deleted") + } +} + +func TestWALRecovery(t *testing.T) { + dir := createTempDir(t) + defer os.RemoveAll(dir) + + cfg := createTestConfig() + + // Write some entries in the first WAL + wal1, err := NewWAL(cfg, dir) + if err != nil { + t.Fatalf("Failed to create WAL: %v", err) + } + + _, err = wal1.Append(OpTypePut, []byte("key1"), []byte("value1")) + if err != nil { + t.Fatalf("Failed to append entry: %v", err) + } + + if err := wal1.Close(); err != nil { + t.Fatalf("Failed to close WAL: %v", err) + } + + // Create a second WAL file + wal2, err := NewWAL(cfg, dir) + if err != nil { + t.Fatalf("Failed to create WAL: %v", err) + } + + _, err = wal2.Append(OpTypePut, []byte("key2"), []byte("value2")) + if err != nil { + t.Fatalf("Failed to append entry: %v", err) + } + + if err := wal2.Close(); err != nil { + t.Fatalf("Failed to close WAL: %v", err) + } + + // Verify entries by replaying all WAL files in order + entries := make(map[string]string) + + err = ReplayWALDir(dir, func(entry *Entry) error { + if entry.Type == OpTypePut { + entries[string(entry.Key)] = string(entry.Value) + } else if entry.Type == OpTypeDelete { + delete(entries, string(entry.Key)) + } + return nil + }) + + if err != nil { + t.Fatalf("Failed to replay WAL: %v", err) + } + + // Verify all entries are present + expected := map[string]string{ + "key1": "value1", + "key2": "value2", + } + + for key, expectedValue := range expected { + value, ok := entries[key] + if !ok { + t.Errorf("Entry for key %q not found", key) + continue + } + + if value != expectedValue { + t.Errorf("Expected value %q for key %q, got %q", expectedValue, key, value) + } + } +} + +func TestWALSyncModes(t *testing.T) { + testCases := []struct { + name string + syncMode config.SyncMode + }{ + {"SyncNone", config.SyncNone}, + {"SyncBatch", config.SyncBatch}, + {"SyncImmediate", config.SyncImmediate}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + dir := createTempDir(t) + defer os.RemoveAll(dir) + + // Create config with specific sync mode + cfg := createTestConfig() + cfg.WALSyncMode = tc.syncMode + + wal, err := NewWAL(cfg, dir) + if err != nil { + t.Fatalf("Failed to create WAL: %v", err) + } + + // Write some entries + for i := 0; i < 10; i++ { + key := []byte(fmt.Sprintf("key%d", i)) + value := []byte(fmt.Sprintf("value%d", i)) + + _, err := wal.Append(OpTypePut, key, value) + if err != nil { + t.Fatalf("Failed to append entry: %v", err) + } + } + + // Close the WAL + if err := wal.Close(); err != nil { + t.Fatalf("Failed to close WAL: %v", err) + } + + // Verify entries by replaying + count := 0 + err = ReplayWALDir(dir, func(entry *Entry) error { + if entry.Type == OpTypePut { + count++ + } + return nil + }) + + if err != nil { + t.Fatalf("Failed to replay WAL: %v", err) + } + + if count != 10 { + t.Errorf("Expected 10 entries, got %d", count) + } + }) + } +} + +func TestWALFragmentation(t *testing.T) { + dir := createTempDir(t) + defer os.RemoveAll(dir) + + cfg := createTestConfig() + wal, err := NewWAL(cfg, dir) + if err != nil { + t.Fatalf("Failed to create WAL: %v", err) + } + + // Create an entry that's guaranteed to be fragmented + // Header size is 1 + 8 + 4 = 13 bytes, so allocate more than MaxRecordSize - 13 for the key + keySize := MaxRecordSize - 10 + valueSize := MaxRecordSize * 2 + + key := make([]byte, keySize) // Just under MaxRecordSize to ensure key fragmentation + value := make([]byte, valueSize) // Large value to ensure value fragmentation + + // Fill with recognizable patterns + for i := range key { + key[i] = byte(i % 256) + } + + for i := range value { + value[i] = byte((i * 3) % 256) + } + + // Append the large entry - this should trigger fragmentation + _, err = wal.Append(OpTypePut, key, value) + if err != nil { + t.Fatalf("Failed to append fragmented entry: %v", err) + } + + // Close the WAL + if err := wal.Close(); err != nil { + t.Fatalf("Failed to close WAL: %v", err) + } + + // Verify by replaying + var reconstructedKey []byte + var reconstructedValue []byte + var foundPut bool + + err = ReplayWALDir(dir, func(entry *Entry) error { + if entry.Type == OpTypePut { + foundPut = true + reconstructedKey = entry.Key + reconstructedValue = entry.Value + } + return nil + }) + + if err != nil { + t.Fatalf("Failed to replay WAL: %v", err) + } + + // Check that we found the entry + if !foundPut { + t.Fatal("Did not find PUT entry in replay") + } + + // Verify key length matches + if len(reconstructedKey) != keySize { + t.Errorf("Key length mismatch: expected %d, got %d", keySize, len(reconstructedKey)) + } + + // Verify value length matches + if len(reconstructedValue) != valueSize { + t.Errorf("Value length mismatch: expected %d, got %d", valueSize, len(reconstructedValue)) + } + + // Check key content (first 10 bytes) + for i := 0; i < 10 && i < len(key); i++ { + if key[i] != reconstructedKey[i] { + t.Errorf("Key mismatch at position %d: expected %d, got %d", i, key[i], reconstructedKey[i]) + } + } + + // Check key content (last 10 bytes) + for i := 0; i < 10 && i < len(key); i++ { + idx := len(key) - 1 - i + if key[idx] != reconstructedKey[idx] { + t.Errorf("Key mismatch at position %d: expected %d, got %d", idx, key[idx], reconstructedKey[idx]) + } + } + + // Check value content (first 10 bytes) + for i := 0; i < 10 && i < len(value); i++ { + if value[i] != reconstructedValue[i] { + t.Errorf("Value mismatch at position %d: expected %d, got %d", i, value[i], reconstructedValue[i]) + } + } + + // Check value content (last 10 bytes) + for i := 0; i < 10 && i < len(value); i++ { + idx := len(value) - 1 - i + if value[idx] != reconstructedValue[idx] { + t.Errorf("Value mismatch at position %d: expected %d, got %d", idx, value[idx], reconstructedValue[idx]) + } + } + + // Verify random samples from the key and value + for i := 0; i < 10; i++ { + // Check random positions in the key + keyPos := rand.Intn(keySize) + if key[keyPos] != reconstructedKey[keyPos] { + t.Errorf("Key mismatch at random position %d: expected %d, got %d", keyPos, key[keyPos], reconstructedKey[keyPos]) + } + + // Check random positions in the value + valuePos := rand.Intn(valueSize) + if value[valuePos] != reconstructedValue[valuePos] { + t.Errorf("Value mismatch at random position %d: expected %d, got %d", valuePos, value[valuePos], reconstructedValue[valuePos]) + } + } +} + +func TestWALErrorHandling(t *testing.T) { + dir := createTempDir(t) + defer os.RemoveAll(dir) + + cfg := createTestConfig() + wal, err := NewWAL(cfg, dir) + if err != nil { + t.Fatalf("Failed to create WAL: %v", err) + } + + // Write some entries + _, err = wal.Append(OpTypePut, []byte("key1"), []byte("value1")) + if err != nil { + t.Fatalf("Failed to append entry: %v", err) + } + + // Close the WAL + if err := wal.Close(); err != nil { + t.Fatalf("Failed to close WAL: %v", err) + } + + // Try to write after close + _, err = wal.Append(OpTypePut, []byte("key2"), []byte("value2")) + if err != ErrWALClosed { + t.Errorf("Expected ErrWALClosed, got: %v", err) + } + + // Try to sync after close + err = wal.Sync() + if err != ErrWALClosed { + t.Errorf("Expected ErrWALClosed, got: %v", err) + } + + // Try to replay a non-existent file + nonExistentPath := filepath.Join(dir, "nonexistent.wal") + err = ReplayWALFile(nonExistentPath, func(entry *Entry) error { + return nil + }) + + if err == nil { + t.Error("Expected error when replaying non-existent file") + } +} \ No newline at end of file