Compare commits

..

2 Commits

Author SHA1 Message Date
e433b12930
feat: add a standard logger, and start on a replication manager to tie into wal hooks
All checks were successful
Go Tests / Run Tests (1.24.2) (pull_request) Successful in 9m37s
2025-04-20 21:05:49 -06:00
e1ea512864
feat: add a lamport clock 2025-04-20 20:20:42 -06:00
10 changed files with 1531 additions and 9 deletions

5
go.mod
View File

@ -7,4 +7,7 @@ require (
github.com/chzyer/readline v1.5.1 github.com/chzyer/readline v1.5.1
) )
require golang.org/x/sys v0.1.0 // indirect require (
github.com/google/uuid v1.6.0 // indirect
golang.org/x/sys v0.1.0 // indirect
)

2
go.sum
View File

@ -6,6 +6,8 @@ github.com/chzyer/readline v1.5.1 h1:upd/6fQk4src78LMRzh5vItIt361/o4uq553V8B5sGI
github.com/chzyer/readline v1.5.1/go.mod h1:Eh+b79XXUwfKfcPLepksvw2tcLE/Ct21YObkaSkeBlk= github.com/chzyer/readline v1.5.1/go.mod h1:Eh+b79XXUwfKfcPLepksvw2tcLE/Ct21YObkaSkeBlk=
github.com/chzyer/test v1.0.0 h1:p3BQDXSxOhOG0P9z6/hGnII4LGiEPOYBhs8asl/fC04= github.com/chzyer/test v1.0.0 h1:p3BQDXSxOhOG0P9z6/hGnII4LGiEPOYBhs8asl/fC04=
github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8= github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.1.0 h1:kunALQeHf1/185U1i0GOB/fy1IPRDDpuoOOqRReG57U= golang.org/x/sys v0.1.0 h1:kunALQeHf1/185U1i0GOB/fy1IPRDDpuoOOqRReG57U=
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=

180
pkg/common/clock/lamport.go Normal file
View File

@ -0,0 +1,180 @@
// Package clock provides logical clock implementations for distributed systems.
package clock
import (
"bytes"
"encoding/binary"
"encoding/hex"
"fmt"
"sync"
"time"
)
// NodeID represents a unique identifier for a node in the distributed system.
type NodeID [16]byte
// String returns a human-readable string representation of a NodeID.
func (id NodeID) String() string {
return hex.EncodeToString(id[:])
}
// Equal checks if two NodeIDs are equal.
func (id NodeID) Equal(other NodeID) bool {
return bytes.Equal(id[:], other[:])
}
// LamportClock implements a Lamport logical clock for event ordering in distributed systems.
// It maintains a monotonically increasing counter that is incremented on each local event.
// When receiving messages from other nodes, the counter is updated to the maximum of its
// current value and the received value, plus one.
type LamportClock struct {
counter uint64 // The logical clock counter
nodeID NodeID // This node's unique identifier
mu sync.RWMutex // Mutex to protect concurrent access
}
// Timestamp represents a Lamport timestamp with both the logical counter
// and the Node ID to break ties when logical counters are equal.
type Timestamp struct {
Counter uint64 // Logical counter value
Node NodeID // NodeID used to break ties
}
// NewLamportClock creates a new Lamport clock with the given NodeID.
func NewLamportClock(nodeID NodeID) *LamportClock {
return &LamportClock{
counter: 0,
nodeID: nodeID,
}
}
// Tick increments the local counter and returns a new timestamp.
// This should be called before generating any local event like writing to the WAL.
func (lc *LamportClock) Tick() Timestamp {
lc.mu.Lock()
defer lc.mu.Unlock()
lc.counter++
return Timestamp{
Counter: lc.counter,
Node: lc.nodeID,
}
}
// Update compares the provided timestamp with the local counter and
// updates the local counter to be at least as large as the received one.
// This should be called when processing events from other nodes.
func (lc *LamportClock) Update(ts Timestamp) Timestamp {
lc.mu.Lock()
defer lc.mu.Unlock()
if ts.Counter > lc.counter {
lc.counter = ts.Counter
}
// Increment the counter regardless to ensure causality
lc.counter++
return Timestamp{
Counter: lc.counter,
Node: lc.nodeID,
}
}
// GetCurrent returns the current timestamp without incrementing the counter.
func (lc *LamportClock) GetCurrent() Timestamp {
lc.mu.RLock()
defer lc.mu.RUnlock()
return Timestamp{
Counter: lc.counter,
Node: lc.nodeID,
}
}
// ManualSet sets the counter to a specific value if it's greater than the current value.
// This is useful for bootstrap or recovery scenarios.
func (lc *LamportClock) ManualSet(counter uint64) {
lc.mu.Lock()
defer lc.mu.Unlock()
if counter > lc.counter {
lc.counter = counter
}
}
// Compare compares two timestamps according to the Lamport ordering.
// Returns:
// -1 if ts1 < ts2
// 0 if ts1 = ts2
// +1 if ts1 > ts2
func Compare(ts1, ts2 Timestamp) int {
if ts1.Counter < ts2.Counter {
return -1
}
if ts1.Counter > ts2.Counter {
return 1
}
// Break ties using node ID comparison
return bytes.Compare(ts1.Node[:], ts2.Node[:])
}
// Less returns true if ts1 is less than ts2 according to Lamport ordering.
func Less(ts1, ts2 Timestamp) bool {
return Compare(ts1, ts2) < 0
}
// Equal returns true if the two timestamps are equal.
func Equal(ts1, ts2 Timestamp) bool {
return Compare(ts1, ts2) == 0
}
// String returns a human-readable representation of the timestamp.
func (ts Timestamp) String() string {
return fmt.Sprintf("%d@%s", ts.Counter, ts.Node.String()[:8])
}
// Bytes serializes the timestamp to a byte array for network transmission.
func (ts Timestamp) Bytes() []byte {
buf := make([]byte, 8+16) // 8 bytes for counter + 16 bytes for NodeID
binary.BigEndian.PutUint64(buf[:8], ts.Counter)
copy(buf[8:], ts.Node[:])
return buf
}
// TimestampFromBytes deserializes a timestamp from a byte array.
func TimestampFromBytes(data []byte) (Timestamp, error) {
if len(data) < 24 {
return Timestamp{}, fmt.Errorf("invalid timestamp data: expected 24 bytes, got %d", len(data))
}
ts := Timestamp{
Counter: binary.BigEndian.Uint64(data[:8]),
}
copy(ts.Node[:], data[8:24])
return ts, nil
}
// WithPhysicalTimestamp creates a hybrid logical clock timestamp that incorporates
// both the logical counter and a physical timestamp for enhanced ordering.
func WithPhysicalTimestamp(ts Timestamp) HybridTimestamp {
return HybridTimestamp{
Logical: ts,
Physical: time.Now().UnixNano(),
}
}
// HybridTimestamp combines a logical Lamport timestamp with a physical clock timestamp.
// This provides both the causal ordering benefits of Lamport clocks and the
// real-time approximation of physical clocks.
type HybridTimestamp struct {
Logical Timestamp
Physical int64 // nanoseconds since Unix epoch
}
// String returns a human-readable representation of the hybrid timestamp.
func (hts HybridTimestamp) String() string {
return fmt.Sprintf("%s@%d", hts.Logical.String(), hts.Physical)
}

View File

@ -0,0 +1,155 @@
package clock
import (
"testing"
)
func TestLamportClock_Tick(t *testing.T) {
nodeID := NodeID{1, 2, 3}
clock := NewLamportClock(nodeID)
// Initial tick should return 1
ts1 := clock.Tick()
if ts1.Counter != 1 {
t.Errorf("Expected counter to be 1, got %d", ts1.Counter)
}
// Next tick should return 2
ts2 := clock.Tick()
if ts2.Counter != 2 {
t.Errorf("Expected counter to be 2, got %d", ts2.Counter)
}
// NodeID should be preserved
if !ts1.Node.Equal(nodeID) {
t.Errorf("NodeID not preserved in timestamp")
}
}
func TestLamportClock_Update(t *testing.T) {
nodeID1 := NodeID{1, 2, 3}
nodeID2 := NodeID{4, 5, 6}
clock1 := NewLamportClock(nodeID1)
clock2 := NewLamportClock(nodeID2)
// Advance clock1 to 5
for i := 0; i < 5; i++ {
clock1.Tick()
}
// Get current timestamp from clock1
ts1 := clock1.GetCurrent()
if ts1.Counter != 5 {
t.Errorf("Expected counter to be 5, got %d", ts1.Counter)
}
// Update clock2 with timestamp from clock1
ts2 := clock2.Update(ts1)
// Clock2 should now be at least as large as clock1's timestamp + 1
if ts2.Counter <= ts1.Counter {
t.Errorf("Expected clock2 counter > %d, got %d", ts1.Counter, ts2.Counter)
}
// Clock2's NodeID should still be nodeID2
if !ts2.Node.Equal(nodeID2) {
t.Errorf("NodeID not preserved after update")
}
}
func TestLamportClock_Ordering(t *testing.T) {
nodeID1 := NodeID{1, 2, 3}
nodeID2 := NodeID{4, 5, 6}
// Create timestamps with the same counter but different NodeIDs
ts1 := Timestamp{Counter: 5, Node: nodeID1}
ts2 := Timestamp{Counter: 5, Node: nodeID2}
// Compare should break ties using NodeID
if Compare(ts1, ts2) >= 0 {
t.Errorf("Expected ts1 < ts2 when counters are equal")
}
// Create timestamps with different counters
ts3 := Timestamp{Counter: 6, Node: nodeID1}
ts4 := Timestamp{Counter: 5, Node: nodeID2}
// Compare should prioritize counter values
if Compare(ts3, ts4) <= 0 {
t.Errorf("Expected ts3 > ts4 when ts3.Counter > ts4.Counter")
}
}
func TestLamportClock_Serialization(t *testing.T) {
nodeID := NodeID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
ts := Timestamp{Counter: 42, Node: nodeID}
// Serialize
bytes := ts.Bytes()
// Deserialize
ts2, err := TimestampFromBytes(bytes)
if err != nil {
t.Fatalf("Failed to deserialize timestamp: %v", err)
}
// Compare
if ts.Counter != ts2.Counter {
t.Errorf("Counter not preserved in serialization, expected %d got %d", ts.Counter, ts2.Counter)
}
if !ts.Node.Equal(ts2.Node) {
t.Errorf("NodeID not preserved in serialization")
}
}
func TestLamportClock_Concurrent(t *testing.T) {
nodeID := NodeID{1, 2, 3}
clock := NewLamportClock(nodeID)
// Run 100 concurrent ticks
done := make(chan struct{})
for i := 0; i < 100; i++ {
go func() {
clock.Tick()
done <- struct{}{}
}()
}
// Wait for all ticks to complete
for i := 0; i < 100; i++ {
<-done
}
// Counter should be 100
ts := clock.GetCurrent()
if ts.Counter != 100 {
t.Errorf("Expected counter to be 100 after 100 concurrent ticks, got %d", ts.Counter)
}
}
func TestLamportClock_ManualSet(t *testing.T) {
nodeID := NodeID{1, 2, 3}
clock := NewLamportClock(nodeID)
// Set to a high value
clock.ManualSet(1000)
ts := clock.GetCurrent()
if ts.Counter != 1000 {
t.Errorf("Expected counter to be 1000 after ManualSet, got %d", ts.Counter)
}
// Setting to a lower value should have no effect
clock.ManualSet(500)
ts = clock.GetCurrent()
if ts.Counter != 1000 {
t.Errorf("Expected counter to remain 1000 after lower ManualSet, got %d", ts.Counter)
}
// Tick should still increment
ts = clock.Tick()
if ts.Counter != 1001 {
t.Errorf("Expected counter to be 1001 after tick, got %d", ts.Counter)
}
}

267
pkg/common/log/logger.go Normal file
View File

@ -0,0 +1,267 @@
// Package log provides a common logging interface for Kevo components.
package log
import (
"fmt"
"io"
"os"
"sync"
"time"
)
// Level represents the logging level
type Level int
const (
// LevelDebug level for detailed troubleshooting information
LevelDebug Level = iota
// LevelInfo level for general operational information
LevelInfo
// LevelWarn level for potentially harmful situations
LevelWarn
// LevelError level for error events that might still allow the application to continue
LevelError
// LevelFatal level for severe error events that will lead the application to abort
LevelFatal
)
// String returns the string representation of the log level
func (l Level) String() string {
switch l {
case LevelDebug:
return "DEBUG"
case LevelInfo:
return "INFO"
case LevelWarn:
return "WARN"
case LevelError:
return "ERROR"
case LevelFatal:
return "FATAL"
default:
return fmt.Sprintf("LEVEL(%d)", l)
}
}
// Logger interface defines the methods for logging at different levels
type Logger interface {
// Debug logs a debug-level message
Debug(msg string, args ...interface{})
// Info logs an info-level message
Info(msg string, args ...interface{})
// Warn logs a warning-level message
Warn(msg string, args ...interface{})
// Error logs an error-level message
Error(msg string, args ...interface{})
// Fatal logs a fatal-level message and then calls os.Exit(1)
Fatal(msg string, args ...interface{})
// WithFields returns a new logger with the given fields added to the context
WithFields(fields map[string]interface{}) Logger
// WithField returns a new logger with the given field added to the context
WithField(key string, value interface{}) Logger
// GetLevel returns the current logging level
GetLevel() Level
// SetLevel sets the logging level
SetLevel(level Level)
}
// StandardLogger implements the Logger interface with a standard output format
type StandardLogger struct {
mu sync.Mutex
level Level
out io.Writer
fields map[string]interface{}
}
// NewStandardLogger creates a new StandardLogger with the given options
func NewStandardLogger(options ...LoggerOption) *StandardLogger {
logger := &StandardLogger{
level: LevelInfo, // Default level
out: os.Stdout,
fields: make(map[string]interface{}),
}
// Apply options
for _, option := range options {
option(logger)
}
return logger
}
// LoggerOption is a function that configures a StandardLogger
type LoggerOption func(*StandardLogger)
// WithLevel sets the logging level
func WithLevel(level Level) LoggerOption {
return func(l *StandardLogger) {
l.level = level
}
}
// WithOutput sets the output writer
func WithOutput(out io.Writer) LoggerOption {
return func(l *StandardLogger) {
l.out = out
}
}
// WithInitialFields sets initial fields for the logger
func WithInitialFields(fields map[string]interface{}) LoggerOption {
return func(l *StandardLogger) {
for k, v := range fields {
l.fields[k] = v
}
}
}
// log logs a message at the specified level
func (l *StandardLogger) log(level Level, msg string, args ...interface{}) {
if level < l.level {
return
}
l.mu.Lock()
defer l.mu.Unlock()
// Format the message
formattedMsg := msg
if len(args) > 0 {
formattedMsg = fmt.Sprintf(msg, args...)
}
// Format timestamp
timestamp := time.Now().Format("2006-01-02 15:04:05.000")
// Format fields
fieldsStr := ""
if len(l.fields) > 0 {
for k, v := range l.fields {
fieldsStr += fmt.Sprintf(" %s=%v", k, v)
}
}
// Write the log entry
fmt.Fprintf(l.out, "[%s] [%s]%s %s\n", timestamp, level.String(), fieldsStr, formattedMsg)
// Exit if fatal
if level == LevelFatal {
os.Exit(1)
}
}
// Debug logs a debug-level message
func (l *StandardLogger) Debug(msg string, args ...interface{}) {
l.log(LevelDebug, msg, args...)
}
// Info logs an info-level message
func (l *StandardLogger) Info(msg string, args ...interface{}) {
l.log(LevelInfo, msg, args...)
}
// Warn logs a warning-level message
func (l *StandardLogger) Warn(msg string, args ...interface{}) {
l.log(LevelWarn, msg, args...)
}
// Error logs an error-level message
func (l *StandardLogger) Error(msg string, args ...interface{}) {
l.log(LevelError, msg, args...)
}
// Fatal logs a fatal-level message and then calls os.Exit(1)
func (l *StandardLogger) Fatal(msg string, args ...interface{}) {
l.log(LevelFatal, msg, args...)
}
// WithFields returns a new logger with the given fields added to the context
func (l *StandardLogger) WithFields(fields map[string]interface{}) Logger {
newLogger := &StandardLogger{
level: l.level,
out: l.out,
fields: make(map[string]interface{}, len(l.fields)+len(fields)),
}
// Copy existing fields
for k, v := range l.fields {
newLogger.fields[k] = v
}
// Add new fields
for k, v := range fields {
newLogger.fields[k] = v
}
return newLogger
}
// WithField returns a new logger with the given field added to the context
func (l *StandardLogger) WithField(key string, value interface{}) Logger {
return l.WithFields(map[string]interface{}{key: value})
}
// GetLevel returns the current logging level
func (l *StandardLogger) GetLevel() Level {
return l.level
}
// SetLevel sets the logging level
func (l *StandardLogger) SetLevel(level Level) {
l.level = level
}
// Default logger instance
var defaultLogger = NewStandardLogger()
// SetDefaultLogger sets the default logger instance
func SetDefaultLogger(logger *StandardLogger) {
defaultLogger = logger
}
// GetDefaultLogger returns the default logger instance
func GetDefaultLogger() *StandardLogger {
return defaultLogger
}
// These functions use the default logger
// Debug logs a debug-level message to the default logger
func Debug(msg string, args ...interface{}) {
defaultLogger.Debug(msg, args...)
}
// Info logs an info-level message to the default logger
func Info(msg string, args ...interface{}) {
defaultLogger.Info(msg, args...)
}
// Warn logs a warning-level message to the default logger
func Warn(msg string, args ...interface{}) {
defaultLogger.Warn(msg, args...)
}
// Error logs an error-level message to the default logger
func Error(msg string, args ...interface{}) {
defaultLogger.Error(msg, args...)
}
// Fatal logs a fatal-level message to the default logger and then calls os.Exit(1)
func Fatal(msg string, args ...interface{}) {
defaultLogger.Fatal(msg, args...)
}
// WithFields returns a new logger with the given fields added to the context
func WithFields(fields map[string]interface{}) Logger {
return defaultLogger.WithFields(fields)
}
// WithField returns a new logger with the given field added to the context
func WithField(key string, value interface{}) Logger {
return defaultLogger.WithField(key, value)
}
// SetLevel sets the logging level of the default logger
func SetLevel(level Level) {
defaultLogger.SetLevel(level)
}

View File

@ -0,0 +1,132 @@
package log
import (
"bytes"
"strings"
"testing"
)
func TestStandardLogger(t *testing.T) {
// Create a buffer to capture output
var buf bytes.Buffer
// Create a logger with the buffer as output
logger := NewStandardLogger(
WithOutput(&buf),
WithLevel(LevelDebug),
)
// Test debug level
logger.Debug("This is a debug message")
if !strings.Contains(buf.String(), "[DEBUG]") || !strings.Contains(buf.String(), "This is a debug message") {
t.Errorf("Debug logging failed, got: %s", buf.String())
}
buf.Reset()
// Test info level
logger.Info("This is an info message")
if !strings.Contains(buf.String(), "[INFO]") || !strings.Contains(buf.String(), "This is an info message") {
t.Errorf("Info logging failed, got: %s", buf.String())
}
buf.Reset()
// Test warn level
logger.Warn("This is a warning message")
if !strings.Contains(buf.String(), "[WARN]") || !strings.Contains(buf.String(), "This is a warning message") {
t.Errorf("Warn logging failed, got: %s", buf.String())
}
buf.Reset()
// Test error level
logger.Error("This is an error message")
if !strings.Contains(buf.String(), "[ERROR]") || !strings.Contains(buf.String(), "This is an error message") {
t.Errorf("Error logging failed, got: %s", buf.String())
}
buf.Reset()
// Test with fields
loggerWithFields := logger.WithFields(map[string]interface{}{
"component": "test",
"count": 123,
})
loggerWithFields.Info("Message with fields")
output := buf.String()
if !strings.Contains(output, "[INFO]") ||
!strings.Contains(output, "Message with fields") ||
!strings.Contains(output, "component=test") ||
!strings.Contains(output, "count=123") {
t.Errorf("Logging with fields failed, got: %s", output)
}
buf.Reset()
// Test with a single field
loggerWithField := logger.WithField("module", "logger")
loggerWithField.Info("Message with a field")
output = buf.String()
if !strings.Contains(output, "[INFO]") ||
!strings.Contains(output, "Message with a field") ||
!strings.Contains(output, "module=logger") {
t.Errorf("Logging with a field failed, got: %s", output)
}
buf.Reset()
// Test level filtering
logger.SetLevel(LevelError)
logger.Debug("This debug message should not appear")
logger.Info("This info message should not appear")
logger.Warn("This warning message should not appear")
logger.Error("This error message should appear")
output = buf.String()
if strings.Contains(output, "should not appear") ||
!strings.Contains(output, "This error message should appear") {
t.Errorf("Level filtering failed, got: %s", output)
}
buf.Reset()
// Test formatted messages
logger.SetLevel(LevelInfo)
logger.Info("Formatted %s with %d params", "message", 2)
if !strings.Contains(buf.String(), "Formatted message with 2 params") {
t.Errorf("Formatted message failed, got: %s", buf.String())
}
buf.Reset()
// Test GetLevel
if logger.GetLevel() != LevelInfo {
t.Errorf("GetLevel failed, expected LevelInfo, got: %v", logger.GetLevel())
}
}
func TestDefaultLogger(t *testing.T) {
// Save original default logger
originalLogger := defaultLogger
defer func() {
defaultLogger = originalLogger
}()
// Create a buffer to capture output
var buf bytes.Buffer
// Set a new default logger
SetDefaultLogger(NewStandardLogger(
WithOutput(&buf),
WithLevel(LevelInfo),
))
// Test global functions
Info("Global info message")
if !strings.Contains(buf.String(), "[INFO]") || !strings.Contains(buf.String(), "Global info message") {
t.Errorf("Global info logging failed, got: %s", buf.String())
}
buf.Reset()
// Test global with fields
WithField("global", true).Info("Global with field")
output := buf.String()
if !strings.Contains(output, "[INFO]") ||
!strings.Contains(output, "Global with field") ||
!strings.Contains(output, "global=true") {
t.Errorf("Global logging with field failed, got: %s", output)
}
buf.Reset()
}

View File

@ -10,6 +10,7 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/jeremytregunna/kevo/pkg/common/clock"
"github.com/jeremytregunna/kevo/pkg/common/iterator" "github.com/jeremytregunna/kevo/pkg/common/iterator"
"github.com/jeremytregunna/kevo/pkg/compaction" "github.com/jeremytregunna/kevo/pkg/compaction"
"github.com/jeremytregunna/kevo/pkg/config" "github.com/jeremytregunna/kevo/pkg/config"
@ -99,6 +100,11 @@ type Engine struct {
mu sync.RWMutex // Main lock for engine state mu sync.RWMutex // Main lock for engine state
flushMu sync.Mutex // Lock for flushing operations flushMu sync.Mutex // Lock for flushing operations
txLock sync.RWMutex // Lock for transaction isolation txLock sync.RWMutex // Lock for transaction isolation
// Replication fields
nodeID clock.NodeID // Unique identifier for this database instance
lamportClock *clock.LamportClock // Logical clock for event ordering
replicationMgr *ReplicationManager // Manager for replication operations
} }
// NewEngine creates a new storage engine // NewEngine creates a new storage engine
@ -141,6 +147,15 @@ func NewEngine(dataDir string) (*Engine, error) {
defer func() { wal.DisableRecoveryLogs = tempWasDisabled }() defer func() { wal.DisableRecoveryLogs = tempWasDisabled }()
} }
// Load or create the node ID and initialize a Lamport clock
nodeID, err := loadOrCreateNodeID(dataDir)
if err != nil {
return nil, fmt.Errorf("failed to initialize node ID: %w", err)
}
// Create a Lamport clock for this instance
lamportClock := clock.NewLamportClock(nodeID)
// First try to reuse an existing WAL file // First try to reuse an existing WAL file
var walLogger *wal.WAL var walLogger *wal.WAL
@ -150,12 +165,17 @@ func NewEngine(dataDir string) (*Engine, error) {
return nil, fmt.Errorf("failed to check for reusable WAL: %w", err) return nil, fmt.Errorf("failed to check for reusable WAL: %w", err)
} }
// If no suitable WAL found, create a new one // If no suitable WAL found, create a new one with our logical clock
if walLogger == nil { if walLogger == nil {
walLogger, err = wal.NewWAL(cfg, walDir) walLogger, err = wal.NewWALWithClock(cfg, walDir, lamportClock)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create WAL: %w", err) return nil, fmt.Errorf("failed to create WAL: %w", err)
} }
} else {
// If we're reusing a WAL, set the logical clock on it
if err := walLogger.SetLogicalClock(lamportClock); err != nil {
return nil, fmt.Errorf("failed to set logical clock on reused WAL: %w", err)
}
} }
// Create the MemTable pool // Create the MemTable pool
@ -172,6 +192,10 @@ func NewEngine(dataDir string) (*Engine, error) {
sstables: make([]*sstable.Reader, 0), sstables: make([]*sstable.Reader, 0),
bgFlushCh: make(chan struct{}, 1), bgFlushCh: make(chan struct{}, 1),
nextFileNum: 1, nextFileNum: 1,
// Set replication fields
nodeID: nodeID,
lamportClock: lamportClock,
} }
// Load existing SSTables // Load existing SSTables
@ -184,6 +208,11 @@ func NewEngine(dataDir string) (*Engine, error) {
return nil, fmt.Errorf("failed to recover from WAL: %w", err) return nil, fmt.Errorf("failed to recover from WAL: %w", err)
} }
// Initialize replication
if err := e.initializeReplication(); err != nil {
return nil, fmt.Errorf("failed to initialize replication: %w", err)
}
// Start background flush goroutine // Start background flush goroutine
go e.backgroundFlush() go e.backgroundFlush()
@ -573,8 +602,8 @@ func (e *Engine) rotateWAL() error {
return fmt.Errorf("failed to close WAL: %w", err) return fmt.Errorf("failed to close WAL: %w", err)
} }
// Create a new WAL // Create a new WAL with the existing logical clock
wal, err := wal.NewWAL(e.cfg, e.walDir) wal, err := wal.NewWALWithClock(e.cfg, e.walDir, e.lamportClock)
if err != nil { if err != nil {
return fmt.Errorf("failed to create new WAL: %w", err) return fmt.Errorf("failed to create new WAL: %w", err)
} }
@ -702,8 +731,8 @@ func (e *Engine) recoverFromWAL() error {
} }
} }
// Create a fresh WAL // Create a fresh WAL with the existing logical clock
newWal, err := wal.NewWAL(e.cfg, e.walDir) newWal, err := wal.NewWALWithClock(e.cfg, e.walDir, e.lamportClock)
if err != nil { if err != nil {
return fmt.Errorf("failed to create new WAL after recovery: %w", err) return fmt.Errorf("failed to create new WAL after recovery: %w", err)
} }
@ -951,6 +980,13 @@ func (e *Engine) Close() error {
return fmt.Errorf("failed to shutdown compaction: %w", err) return fmt.Errorf("failed to shutdown compaction: %w", err)
} }
// Shutdown replication manager if it exists
if e.replicationMgr != nil {
if err := e.replicationMgr.Close(); err != nil {
return fmt.Errorf("failed to shutdown replication manager: %w", err)
}
}
// Close WAL first // Close WAL first
if err := e.wal.Close(); err != nil { if err := e.wal.Close(); err != nil {
return fmt.Errorf("failed to close WAL: %w", err) return fmt.Errorf("failed to close WAL: %w", err)

346
pkg/engine/replication.go Normal file
View File

@ -0,0 +1,346 @@
package engine
import (
"fmt"
"os"
"path/filepath"
"sync"
"github.com/google/uuid"
"github.com/jeremytregunna/kevo/pkg/common/clock"
"github.com/jeremytregunna/kevo/pkg/common/log"
"github.com/jeremytregunna/kevo/pkg/wal"
)
// NodeIDFile is the name of the file that stores the database's node ID
const NodeIDFile = "NODE_ID"
// loadNodeID attempts to load the node ID from the file system
// Returns the NodeID if successful or an error if the file doesn't exist or is invalid
func loadNodeID(dbDir string) (clock.NodeID, error) {
idPath := filepath.Join(dbDir, NodeIDFile)
data, err := os.ReadFile(idPath)
if err != nil {
return clock.NodeID{}, err // Return the original error for better diagnostics
}
// Validate length
if len(data) != 16 {
return clock.NodeID{}, fmt.Errorf("invalid node ID file format (expected 16 bytes, got %d)", len(data))
}
// Convert to NodeID
var nodeID clock.NodeID
copy(nodeID[:], data)
return nodeID, nil
}
// createNodeID generates a new node ID and persists it
func createNodeID(dbDir string) (clock.NodeID, error) {
var nodeID clock.NodeID
// Generate a UUID v4
id, err := uuid.NewRandom()
if err != nil {
return clock.NodeID{}, fmt.Errorf("failed to generate UUID: %w", err)
}
// Copy the UUID bytes to our NodeID
copy(nodeID[:], id[:])
// Ensure directory exists
if err := os.MkdirAll(dbDir, 0755); err != nil {
return clock.NodeID{}, fmt.Errorf("failed to create database directory: %w", err)
}
// Write the ID to the file
idPath := filepath.Join(dbDir, NodeIDFile)
if err := os.WriteFile(idPath, nodeID[:], 0644); err != nil {
return clock.NodeID{}, fmt.Errorf("failed to write node ID file: %w", err)
}
return nodeID, nil
}
// loadOrCreateNodeID first tries to load an existing node ID, and if that fails,
// creates a new one
func loadOrCreateNodeID(dbDir string) (clock.NodeID, error) {
nodeID, err := loadNodeID(dbDir)
if err == nil {
return nodeID, nil
}
if os.IsNotExist(err) {
// File doesn't exist, create new ID
return createNodeID(dbDir)
}
// Some other error occurred while loading
return clock.NodeID{}, fmt.Errorf("failed to access node ID file: %w", err)
}
// initializeReplication sets up replication components for the engine
// This should be called during engine initialization
func (e *Engine) initializeReplication() error {
// Load or create node ID
nodeID, err := loadOrCreateNodeID(e.dataDir)
if err != nil {
return fmt.Errorf("failed to initialize node ID: %w", err)
}
// Create Lamport clock with this node ID
e.lamportClock = clock.NewLamportClock(nodeID)
e.nodeID = nodeID
// Set the clock on the WAL
if err := e.wal.SetLogicalClock(e.lamportClock); err != nil {
return fmt.Errorf("failed to set logical clock on WAL: %w", err)
}
// Create the replication manager
e.replicationMgr = NewReplicationManager(e)
// Set the replication hook on the WAL
e.wal.SetReplicationHook(e.replicationMgr)
return nil
}
// GetNodeID returns the database instance's stable node ID
func (e *Engine) GetNodeID() clock.NodeID {
return e.nodeID
}
// GetLamportClock returns the engine's Lamport clock
func (e *Engine) GetLamportClock() *clock.LamportClock {
return e.lamportClock
}
// ReplicationManager handles replication operations for the engine
type ReplicationManager struct {
engine *Engine
mu sync.RWMutex
isLeader bool
isReplica bool
replicaIDs map[string]clock.NodeID // Maps replica IDs to their NodeIDs
replicaPositions map[string]clock.Timestamp // Tracks the latest position for each replica
logChan chan *ReplicationLogEntry // Channel for log entries to be processed
stopChan chan struct{} // Channel to signal stopping
logger log.Logger // Logger interface for replication events
}
// ReplicationLogEntry represents a WAL entry that needs to be replicated
type ReplicationLogEntry struct {
entry *wal.Entry // The WAL entry
timestamp clock.Timestamp // The Lamport timestamp
batch bool // Whether this is part of a batch
batchSize int // Size of batch if part of batch
}
// NewReplicationManager creates a new replication manager for the engine
func NewReplicationManager(engine *Engine) *ReplicationManager {
logger := log.GetDefaultLogger().WithField("component", "replication")
rm := &ReplicationManager{
engine: engine,
replicaIDs: make(map[string]clock.NodeID),
replicaPositions: make(map[string]clock.Timestamp),
logChan: make(chan *ReplicationLogEntry, 1000), // Buffer for 1000 entries
stopChan: make(chan struct{}),
logger: logger,
}
// Start the replication processor goroutine
go rm.processReplicationEntries()
logger.Info("Replication manager initialized")
return rm
}
// processReplicationEntries handles replication entries in the background
func (rm *ReplicationManager) processReplicationEntries() {
for {
select {
case entry := <-rm.logChan:
rm.handleLogEntry(entry)
case <-rm.stopChan:
rm.logger.Info("Stopping replication log processor")
return
}
}
}
// handleLogEntry processes a replication log entry
func (rm *ReplicationManager) handleLogEntry(entry *ReplicationLogEntry) {
// Skip processing if we're a replica
if rm.isReplica {
return
}
// TODO: Implement actual replication to remote nodes
// For now, we'll just log the event
if entry.batch {
rm.logger.Debug("Processing batch entry for replication: timestamp=%s, batch_size=%d",
entry.timestamp.String(), entry.batchSize)
} else {
rm.logger.Debug("Processing single entry for replication: type=%d, timestamp=%s, key_size=%d",
entry.entry.Type, entry.timestamp.String(), len(entry.entry.Key))
}
}
// SetLeader sets this node as the leader for replication
func (rm *ReplicationManager) SetLeader(isLeader bool) {
rm.mu.Lock()
defer rm.mu.Unlock()
rm.isLeader = isLeader
rm.isReplica = !isLeader
if isLeader {
rm.logger.Info("Node set as replication leader")
} else {
rm.logger.Info("Node set as replica")
}
}
// IsLeader returns whether this node is the leader
func (rm *ReplicationManager) IsLeader() bool {
rm.mu.RLock()
defer rm.mu.RUnlock()
return rm.isLeader
}
// IsReplica returns whether this node is a replica
func (rm *ReplicationManager) IsReplica() bool {
rm.mu.RLock()
defer rm.mu.RUnlock()
return rm.isReplica
}
// AddReplica adds a replica node to the replication system
func (rm *ReplicationManager) AddReplica(id string, nodeID clock.NodeID) {
rm.mu.Lock()
defer rm.mu.Unlock()
rm.replicaIDs[id] = nodeID
rm.logger.Info("Added replica: %s with NodeID: %s", id, nodeID.String())
}
// RemoveReplica removes a replica node from the replication system
func (rm *ReplicationManager) RemoveReplica(id string) {
rm.mu.Lock()
defer rm.mu.Unlock()
delete(rm.replicaIDs, id)
delete(rm.replicaPositions, id)
rm.logger.Info("Removed replica: %s", id)
}
// GetReplicaIDs returns the list of replica IDs
func (rm *ReplicationManager) GetReplicaIDs() []string {
rm.mu.RLock()
defer rm.mu.RUnlock()
ids := make([]string, 0, len(rm.replicaIDs))
for id := range rm.replicaIDs {
ids = append(ids, id)
}
return ids
}
// GetReplicaPosition returns the current replication position for a replica
func (rm *ReplicationManager) GetReplicaPosition(replicaID string) (clock.Timestamp, bool) {
rm.mu.RLock()
defer rm.mu.RUnlock()
pos, exists := rm.replicaPositions[replicaID]
return pos, exists
}
// UpdateReplicaPosition updates the position for a specific replica
func (rm *ReplicationManager) UpdateReplicaPosition(replicaID string, position clock.Timestamp) {
rm.mu.Lock()
defer rm.mu.Unlock()
current, exists := rm.replicaPositions[replicaID]
if !exists || clock.Compare(position, current) > 0 {
rm.replicaPositions[replicaID] = position
rm.logger.Debug("Updated replica %s position to %s", replicaID, position.String())
}
}
// OnEntryWritten implements the wal.ReplicationHook interface
// It's called when a single entry is written to the WAL
func (rm *ReplicationManager) OnEntryWritten(entry *wal.Entry, timestamp clock.Timestamp) error {
// Skip processing if we're a replica to avoid replication loops
if rm.isReplica {
return nil
}
// Queue the entry for processing
select {
case rm.logChan <- &ReplicationLogEntry{
entry: entry,
timestamp: timestamp,
batch: false,
}:
// Successfully queued
default:
// Channel is full, log warning but don't block the write path
rm.logger.Error("Replication queue is full, dropping entry")
}
return nil
}
// OnBatchWritten implements the wal.ReplicationHook interface
// It's called when a batch of entries is written to the WAL
func (rm *ReplicationManager) OnBatchWritten(entries []*wal.Entry, startTimestamp clock.Timestamp) error {
// Skip processing if we're a replica to avoid replication loops
if rm.isReplica {
return nil
}
// Process each entry in the batch
for i, entry := range entries {
// Calculate timestamp for this entry based on batch start
entryTimestamp := clock.Timestamp{
Counter: startTimestamp.Counter + uint64(i),
Node: startTimestamp.Node,
}
// Queue the entry for processing
select {
case rm.logChan <- &ReplicationLogEntry{
entry: entry,
timestamp: entryTimestamp,
batch: true,
batchSize: len(entries),
}:
// Successfully queued
default:
// Channel is full, log warning but don't block the write path
rm.logger.Error("Replication queue is full, dropping batch entry %d/%d",
i+1, len(entries))
}
}
return nil
}
// Close shuts down the replication manager
func (rm *ReplicationManager) Close() error {
close(rm.stopChan)
rm.logger.Info("Replication manager shut down")
return nil
}
// SetReplicationHook sets a custom replication hook on the WAL
// This is mainly used for testing or when you need to override the default hook
func (e *Engine) SetReplicationHook(hook interface{}) error {
walHook, ok := hook.(wal.ReplicationHook)
if !ok {
return fmt.Errorf("invalid replication hook type: %T", hook)
}
e.wal.SetReplicationHook(walHook)
return nil
}

View File

@ -0,0 +1,251 @@
package engine
import (
"os"
"testing"
"github.com/jeremytregunna/kevo/pkg/common/clock"
"github.com/jeremytregunna/kevo/pkg/common/log"
"github.com/jeremytregunna/kevo/pkg/wal"
)
// TestReplicationHooks tests that the replication hooks are properly called
func TestReplicationHooks(t *testing.T) {
// Set log level to avoid noise in tests
log.SetLevel(log.LevelError)
// Create a temporary directory for testing
tempDir, err := os.MkdirTemp("", "replication-test")
if err != nil {
t.Fatalf("Failed to create temporary directory: %v", err)
}
defer os.RemoveAll(tempDir)
// Create a new engine
engine, err := NewEngine(tempDir)
if err != nil {
t.Fatalf("Failed to create engine: %v", err)
}
// Make sure replication manager was created
if engine.replicationMgr == nil {
t.Fatal("Replication manager was not created")
}
// Verify NodeID was assigned
if engine.nodeID == (clock.NodeID{}) {
t.Fatal("NodeID was not assigned")
}
// Verify Lamport clock was created
if engine.lamportClock == nil {
t.Fatal("Lamport clock was not created")
}
// Test adding and removing replicas
replicaID := "test-replica"
fakeNodeID := clock.NodeID{}
engine.replicationMgr.AddReplica(replicaID, fakeNodeID)
replicas := engine.replicationMgr.GetReplicaIDs()
if len(replicas) != 1 || replicas[0] != replicaID {
t.Fatalf("Expected replica ID %s, got %v", replicaID, replicas)
}
engine.replicationMgr.RemoveReplica(replicaID)
replicas = engine.replicationMgr.GetReplicaIDs()
if len(replicas) != 0 {
t.Fatalf("Expected no replicas, got %v", replicas)
}
// Test setting leader/replica status
if engine.replicationMgr.IsLeader() {
t.Fatal("Replication manager should not be leader by default")
}
engine.replicationMgr.SetLeader(true)
if !engine.replicationMgr.IsLeader() {
t.Fatal("Replication manager should be leader")
}
if engine.replicationMgr.IsReplica() {
t.Fatal("Replication manager should not be replica when it's a leader")
}
engine.replicationMgr.SetLeader(false)
if !engine.replicationMgr.IsReplica() {
t.Fatal("Replication manager should be replica when it's not a leader")
}
// Close the engine
if err := engine.Close(); err != nil {
t.Fatalf("Failed to close engine: %v", err)
}
}
// TestReplicationIntegration tests that the engine works with replication hooks
func TestReplicationIntegration(t *testing.T) {
// Set log level to avoid noise in tests
log.SetLevel(log.LevelError)
// Create a temporary directory for testing
tempDir, err := os.MkdirTemp("", "replication-integration-test")
if err != nil {
t.Fatalf("Failed to create temporary directory: %v", err)
}
defer os.RemoveAll(tempDir)
// Create a new engine
engine, err := NewEngine(tempDir)
if err != nil {
t.Fatalf("Failed to create engine: %v", err)
}
// Verify the replication manager is initialized
if engine.replicationMgr == nil {
t.Fatal("Replication manager was not initialized")
}
// Set as leader to ensure replication hooks are called
engine.replicationMgr.SetLeader(true)
// Perform some operations
testKey := []byte("test-key")
testValue := []byte("test-value")
// Put operation
if err := engine.Put(testKey, testValue); err != nil {
t.Fatalf("Failed to put: %v", err)
}
// Get operation
value, err := engine.Get(testKey)
if err != nil {
t.Fatalf("Failed to get: %v", err)
}
// Verify value
if string(value) != string(testValue) {
t.Fatalf("Expected value %q, got %q", testValue, value)
}
// Close the engine
if err := engine.Close(); err != nil {
t.Fatalf("Failed to close engine: %v", err)
}
}
// TestReplicationInterface verifies that the replication hook interface works correctly
func TestReplicationInterface(t *testing.T) {
// Set log level to avoid noise in tests
log.SetLevel(log.LevelError)
// Create a test hook
callbackCounts := struct {
singleEntries int
batchEntries int
}{}
testHook := &syncTestReplicationHook{
onEntryCallback: func(entry *wal.Entry, ts clock.Timestamp) error {
callbackCounts.singleEntries++
return nil
},
onBatchCallback: func(entries []*wal.Entry, ts clock.Timestamp) error {
callbackCounts.batchEntries += len(entries)
return nil
},
}
// Create a test entry
entry := &wal.Entry{
SequenceNumber: 1,
Type: wal.OpTypePut,
Key: []byte("key1"),
Value: []byte("value1"),
}
// Create a test batch
batch := []*wal.Entry{
{
SequenceNumber: 2,
Type: wal.OpTypePut,
Key: []byte("key2"),
Value: []byte("value2"),
},
{
SequenceNumber: 3,
Type: wal.OpTypePut,
Key: []byte("key3"),
Value: []byte("value3"),
},
}
// Create a timestamp
nodeID := clock.NodeID{}
ts := clock.Timestamp{
Counter: 1,
Node: nodeID,
}
// Call the hook methods directly
if err := testHook.OnEntryWritten(entry, ts); err != nil {
t.Fatalf("OnEntryWritten failed: %v", err)
}
if err := testHook.OnBatchWritten(batch, ts); err != nil {
t.Fatalf("OnBatchWritten failed: %v", err)
}
// Verify callbacks were called
if callbackCounts.singleEntries != 1 {
t.Errorf("Expected 1 single entry callback, got %d", callbackCounts.singleEntries)
}
if callbackCounts.batchEntries != 2 {
t.Errorf("Expected 2 batch entry callbacks, got %d", callbackCounts.batchEntries)
}
}
// Test helper: a synchronous mock replication hook for testing
type syncTestReplicationHook struct {
onEntryCallback func(*wal.Entry, clock.Timestamp) error
onBatchCallback func([]*wal.Entry, clock.Timestamp) error
}
func (h *syncTestReplicationHook) OnEntryWritten(entry *wal.Entry, ts clock.Timestamp) error {
if h.onEntryCallback != nil {
return h.onEntryCallback(entry, ts)
}
return nil
}
func (h *syncTestReplicationHook) OnBatchWritten(entries []*wal.Entry, ts clock.Timestamp) error {
if h.onBatchCallback != nil {
return h.onBatchCallback(entries, ts)
}
return nil
}
// Test helper: an asynchronous mock replication hook for testing
type testReplicationHook struct {
onEntryCallback func(*wal.Entry, clock.Timestamp) error
onBatchCallback func([]*wal.Entry, clock.Timestamp) error
}
func (h *testReplicationHook) OnEntryWritten(entry *wal.Entry, ts clock.Timestamp) error {
if h.onEntryCallback != nil {
return h.onEntryCallback(entry, ts)
}
return nil
}
func (h *testReplicationHook) OnBatchWritten(entries []*wal.Entry, ts clock.Timestamp) error {
if h.onBatchCallback != nil {
return h.onBatchCallback(entries, ts)
}
return nil
}

View File

@ -11,6 +11,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/jeremytregunna/kevo/pkg/common/clock"
"github.com/jeremytregunna/kevo/pkg/config" "github.com/jeremytregunna/kevo/pkg/config"
) )
@ -54,6 +55,22 @@ type Entry struct {
Type uint8 // OpTypePut, OpTypeDelete, etc. Type uint8 // OpTypePut, OpTypeDelete, etc.
Key []byte Key []byte
Value []byte Value []byte
// Lamport timestamp for replication ordering
Timestamp clock.Timestamp
}
// GetTimestamp returns the Lamport timestamp of this entry
func (e *Entry) GetTimestamp() clock.Timestamp {
return e.Timestamp
}
// ReplicationHook defines the interface for replication hooks
type ReplicationHook interface {
// OnEntryWritten is called when a single entry is written to the WAL
OnEntryWritten(entry *Entry, timestamp clock.Timestamp) error
// OnBatchWritten is called when a batch of entries is written to the WAL
OnBatchWritten(entries []*Entry, startTimestamp clock.Timestamp) error
} }
// Global variable to control whether to print recovery logs // Global variable to control whether to print recovery logs
@ -71,9 +88,54 @@ type WAL struct {
batchByteSize int64 batchByteSize int64
closed bool closed bool
mu sync.Mutex mu sync.Mutex
// Replication fields
logicalClock *clock.LamportClock // Logical clock for replication ordering
replicationHook ReplicationHook // Hook for replication events
replicationMu sync.Mutex // Separate mutex for replication operations
} }
// NewWAL creates a new write-ahead log // NewWALWithClock creates a new WAL with the provided logical clock
// This is the recommended constructor for systems using replication
func NewWALWithClock(cfg *config.Config, dir string, logicalClock *clock.LamportClock) (*WAL, error) {
if cfg == nil {
return nil, errors.New("config cannot be nil")
}
if logicalClock == nil {
return nil, errors.New("logicalClock cannot be nil")
}
if err := os.MkdirAll(dir, 0755); err != nil {
return nil, fmt.Errorf("failed to create WAL directory: %w", err)
}
// Create a new WAL file
filename := fmt.Sprintf("%020d.wal", time.Now().UnixNano())
path := filepath.Join(dir, filename)
file, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE|os.O_EXCL, 0644)
if err != nil {
return nil, fmt.Errorf("failed to create WAL file: %w", err)
}
wal := &WAL{
cfg: cfg,
dir: dir,
file: file,
writer: bufio.NewWriterSize(file, 64*1024), // 64KB buffer
nextSequence: 1,
lastSync: time.Now(),
// Use the provided logical clock
logicalClock: logicalClock,
}
return wal, nil
}
// NewWAL creates a new write-ahead log with a temporary logical clock
// For production systems with replication, use NewWALWithClock instead
func NewWAL(cfg *config.Config, dir string) (*WAL, error) { func NewWAL(cfg *config.Config, dir string) (*WAL, error) {
if cfg == nil { if cfg == nil {
return nil, errors.New("config cannot be nil") return nil, errors.New("config cannot be nil")
@ -174,6 +236,34 @@ func ReuseWAL(cfg *config.Config, dir string, nextSeq uint64) (*WAL, error) {
return wal, nil return wal, nil
} }
// SetReplicationHook sets the replication hook for the WAL
func (w *WAL) SetReplicationHook(hook ReplicationHook) {
w.replicationMu.Lock()
defer w.replicationMu.Unlock()
w.replicationHook = hook
}
// SetLogicalClock sets the Lamport clock for this WAL
// This should be called by the engine before using the WAL for replication
func (w *WAL) SetLogicalClock(logicalClock *clock.LamportClock) error {
if logicalClock == nil {
return errors.New("logicalClock cannot be nil")
}
w.replicationMu.Lock()
defer w.replicationMu.Unlock()
w.logicalClock = logicalClock
return nil
}
// GetLogicalClock returns the current logical clock
func (w *WAL) GetLogicalClock() *clock.LamportClock {
w.replicationMu.Lock()
defer w.replicationMu.Unlock()
return w.logicalClock
}
// Append adds an entry to the WAL // Append adds an entry to the WAL
func (w *WAL) Append(entryType uint8, key, value []byte) (uint64, error) { func (w *WAL) Append(entryType uint8, key, value []byte) (uint64, error) {
w.mu.Lock() w.mu.Lock()
@ -217,6 +307,33 @@ func (w *WAL) Append(entryType uint8, key, value []byte) (uint64, error) {
return 0, err return 0, err
} }
// Create the entry object
entry := &Entry{
SequenceNumber: seqNum,
Type: entryType,
Key: key,
Value: value,
}
// If we have a logical clock, use it to timestamp the entry
if w.logicalClock != nil {
// Generate a new timestamp for this entry
timestamp := w.logicalClock.Tick()
entry.Timestamp = timestamp
// Trigger replication hook if set
if w.replicationHook != nil {
// Call in a goroutine to avoid blocking the write path
// This is safe because the entry is fully written and synced at this point
go func() {
if err := w.replicationHook.OnEntryWritten(entry, timestamp); err != nil {
// Just log the error, don't fail the operation
fmt.Printf("Replication hook error: %v\n", err)
}
}()
}
}
return seqNum, nil return seqNum, nil
} }
@ -498,6 +615,39 @@ func (w *WAL) AppendBatch(entries []*Entry) (uint64, error) {
return 0, err return 0, err
} }
// If we have a logical clock, generate timestamps for the batch
if w.logicalClock != nil {
// Create a base timestamp for the batch
batchTimestamp := w.logicalClock.Tick()
// Only proceed with creating entries if we have a replication hook
if w.replicationHook != nil {
// Assign timestamps to each entry in the batch
entriesWithTimestamps := make([]*Entry, len(entries))
for i, entry := range entries {
// Clone the entry to avoid modifying the original
entriesWithTimestamps[i] = &Entry{
SequenceNumber: startSeqNum + uint64(i),
Type: entry.Type,
Key: entry.Key,
Value: entry.Value,
Timestamp: clock.Timestamp{
Counter: batchTimestamp.Counter + uint64(i),
Node: batchTimestamp.Node,
},
}
}
// Call the replication hook in a goroutine
go func() {
if err := w.replicationHook.OnBatchWritten(entriesWithTimestamps, batchTimestamp); err != nil {
// Just log the error, don't fail the operation
fmt.Printf("Replication hook batch error: %v\n", err)
}
}()
}
}
return startSeqNum, nil return startSeqNum, nil
} }