Compare commits
2 Commits
87285d931e
...
e433b12930
Author | SHA1 | Date | |
---|---|---|---|
e433b12930 | |||
e1ea512864 |
5
go.mod
5
go.mod
@ -7,4 +7,7 @@ require (
|
||||
github.com/chzyer/readline v1.5.1
|
||||
)
|
||||
|
||||
require golang.org/x/sys v0.1.0 // indirect
|
||||
require (
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
golang.org/x/sys v0.1.0 // indirect
|
||||
)
|
||||
|
2
go.sum
2
go.sum
@ -6,6 +6,8 @@ github.com/chzyer/readline v1.5.1 h1:upd/6fQk4src78LMRzh5vItIt361/o4uq553V8B5sGI
|
||||
github.com/chzyer/readline v1.5.1/go.mod h1:Eh+b79XXUwfKfcPLepksvw2tcLE/Ct21YObkaSkeBlk=
|
||||
github.com/chzyer/test v1.0.0 h1:p3BQDXSxOhOG0P9z6/hGnII4LGiEPOYBhs8asl/fC04=
|
||||
github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.1.0 h1:kunALQeHf1/185U1i0GOB/fy1IPRDDpuoOOqRReG57U=
|
||||
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
|
180
pkg/common/clock/lamport.go
Normal file
180
pkg/common/clock/lamport.go
Normal file
@ -0,0 +1,180 @@
|
||||
// Package clock provides logical clock implementations for distributed systems.
|
||||
package clock
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// NodeID represents a unique identifier for a node in the distributed system.
|
||||
type NodeID [16]byte
|
||||
|
||||
// String returns a human-readable string representation of a NodeID.
|
||||
func (id NodeID) String() string {
|
||||
return hex.EncodeToString(id[:])
|
||||
}
|
||||
|
||||
// Equal checks if two NodeIDs are equal.
|
||||
func (id NodeID) Equal(other NodeID) bool {
|
||||
return bytes.Equal(id[:], other[:])
|
||||
}
|
||||
|
||||
// LamportClock implements a Lamport logical clock for event ordering in distributed systems.
|
||||
// It maintains a monotonically increasing counter that is incremented on each local event.
|
||||
// When receiving messages from other nodes, the counter is updated to the maximum of its
|
||||
// current value and the received value, plus one.
|
||||
type LamportClock struct {
|
||||
counter uint64 // The logical clock counter
|
||||
nodeID NodeID // This node's unique identifier
|
||||
mu sync.RWMutex // Mutex to protect concurrent access
|
||||
}
|
||||
|
||||
// Timestamp represents a Lamport timestamp with both the logical counter
|
||||
// and the Node ID to break ties when logical counters are equal.
|
||||
type Timestamp struct {
|
||||
Counter uint64 // Logical counter value
|
||||
Node NodeID // NodeID used to break ties
|
||||
}
|
||||
|
||||
// NewLamportClock creates a new Lamport clock with the given NodeID.
|
||||
func NewLamportClock(nodeID NodeID) *LamportClock {
|
||||
return &LamportClock{
|
||||
counter: 0,
|
||||
nodeID: nodeID,
|
||||
}
|
||||
}
|
||||
|
||||
// Tick increments the local counter and returns a new timestamp.
|
||||
// This should be called before generating any local event like writing to the WAL.
|
||||
func (lc *LamportClock) Tick() Timestamp {
|
||||
lc.mu.Lock()
|
||||
defer lc.mu.Unlock()
|
||||
|
||||
lc.counter++
|
||||
return Timestamp{
|
||||
Counter: lc.counter,
|
||||
Node: lc.nodeID,
|
||||
}
|
||||
}
|
||||
|
||||
// Update compares the provided timestamp with the local counter and
|
||||
// updates the local counter to be at least as large as the received one.
|
||||
// This should be called when processing events from other nodes.
|
||||
func (lc *LamportClock) Update(ts Timestamp) Timestamp {
|
||||
lc.mu.Lock()
|
||||
defer lc.mu.Unlock()
|
||||
|
||||
if ts.Counter > lc.counter {
|
||||
lc.counter = ts.Counter
|
||||
}
|
||||
|
||||
// Increment the counter regardless to ensure causality
|
||||
lc.counter++
|
||||
|
||||
return Timestamp{
|
||||
Counter: lc.counter,
|
||||
Node: lc.nodeID,
|
||||
}
|
||||
}
|
||||
|
||||
// GetCurrent returns the current timestamp without incrementing the counter.
|
||||
func (lc *LamportClock) GetCurrent() Timestamp {
|
||||
lc.mu.RLock()
|
||||
defer lc.mu.RUnlock()
|
||||
|
||||
return Timestamp{
|
||||
Counter: lc.counter,
|
||||
Node: lc.nodeID,
|
||||
}
|
||||
}
|
||||
|
||||
// ManualSet sets the counter to a specific value if it's greater than the current value.
|
||||
// This is useful for bootstrap or recovery scenarios.
|
||||
func (lc *LamportClock) ManualSet(counter uint64) {
|
||||
lc.mu.Lock()
|
||||
defer lc.mu.Unlock()
|
||||
|
||||
if counter > lc.counter {
|
||||
lc.counter = counter
|
||||
}
|
||||
}
|
||||
|
||||
// Compare compares two timestamps according to the Lamport ordering.
|
||||
// Returns:
|
||||
// -1 if ts1 < ts2
|
||||
// 0 if ts1 = ts2
|
||||
// +1 if ts1 > ts2
|
||||
func Compare(ts1, ts2 Timestamp) int {
|
||||
if ts1.Counter < ts2.Counter {
|
||||
return -1
|
||||
}
|
||||
if ts1.Counter > ts2.Counter {
|
||||
return 1
|
||||
}
|
||||
|
||||
// Break ties using node ID comparison
|
||||
return bytes.Compare(ts1.Node[:], ts2.Node[:])
|
||||
}
|
||||
|
||||
// Less returns true if ts1 is less than ts2 according to Lamport ordering.
|
||||
func Less(ts1, ts2 Timestamp) bool {
|
||||
return Compare(ts1, ts2) < 0
|
||||
}
|
||||
|
||||
// Equal returns true if the two timestamps are equal.
|
||||
func Equal(ts1, ts2 Timestamp) bool {
|
||||
return Compare(ts1, ts2) == 0
|
||||
}
|
||||
|
||||
// String returns a human-readable representation of the timestamp.
|
||||
func (ts Timestamp) String() string {
|
||||
return fmt.Sprintf("%d@%s", ts.Counter, ts.Node.String()[:8])
|
||||
}
|
||||
|
||||
// Bytes serializes the timestamp to a byte array for network transmission.
|
||||
func (ts Timestamp) Bytes() []byte {
|
||||
buf := make([]byte, 8+16) // 8 bytes for counter + 16 bytes for NodeID
|
||||
binary.BigEndian.PutUint64(buf[:8], ts.Counter)
|
||||
copy(buf[8:], ts.Node[:])
|
||||
return buf
|
||||
}
|
||||
|
||||
// TimestampFromBytes deserializes a timestamp from a byte array.
|
||||
func TimestampFromBytes(data []byte) (Timestamp, error) {
|
||||
if len(data) < 24 {
|
||||
return Timestamp{}, fmt.Errorf("invalid timestamp data: expected 24 bytes, got %d", len(data))
|
||||
}
|
||||
|
||||
ts := Timestamp{
|
||||
Counter: binary.BigEndian.Uint64(data[:8]),
|
||||
}
|
||||
copy(ts.Node[:], data[8:24])
|
||||
|
||||
return ts, nil
|
||||
}
|
||||
|
||||
// WithPhysicalTimestamp creates a hybrid logical clock timestamp that incorporates
|
||||
// both the logical counter and a physical timestamp for enhanced ordering.
|
||||
func WithPhysicalTimestamp(ts Timestamp) HybridTimestamp {
|
||||
return HybridTimestamp{
|
||||
Logical: ts,
|
||||
Physical: time.Now().UnixNano(),
|
||||
}
|
||||
}
|
||||
|
||||
// HybridTimestamp combines a logical Lamport timestamp with a physical clock timestamp.
|
||||
// This provides both the causal ordering benefits of Lamport clocks and the
|
||||
// real-time approximation of physical clocks.
|
||||
type HybridTimestamp struct {
|
||||
Logical Timestamp
|
||||
Physical int64 // nanoseconds since Unix epoch
|
||||
}
|
||||
|
||||
// String returns a human-readable representation of the hybrid timestamp.
|
||||
func (hts HybridTimestamp) String() string {
|
||||
return fmt.Sprintf("%s@%d", hts.Logical.String(), hts.Physical)
|
||||
}
|
155
pkg/common/clock/lamport_test.go
Normal file
155
pkg/common/clock/lamport_test.go
Normal file
@ -0,0 +1,155 @@
|
||||
package clock
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLamportClock_Tick(t *testing.T) {
|
||||
nodeID := NodeID{1, 2, 3}
|
||||
clock := NewLamportClock(nodeID)
|
||||
|
||||
// Initial tick should return 1
|
||||
ts1 := clock.Tick()
|
||||
if ts1.Counter != 1 {
|
||||
t.Errorf("Expected counter to be 1, got %d", ts1.Counter)
|
||||
}
|
||||
|
||||
// Next tick should return 2
|
||||
ts2 := clock.Tick()
|
||||
if ts2.Counter != 2 {
|
||||
t.Errorf("Expected counter to be 2, got %d", ts2.Counter)
|
||||
}
|
||||
|
||||
// NodeID should be preserved
|
||||
if !ts1.Node.Equal(nodeID) {
|
||||
t.Errorf("NodeID not preserved in timestamp")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLamportClock_Update(t *testing.T) {
|
||||
nodeID1 := NodeID{1, 2, 3}
|
||||
nodeID2 := NodeID{4, 5, 6}
|
||||
|
||||
clock1 := NewLamportClock(nodeID1)
|
||||
clock2 := NewLamportClock(nodeID2)
|
||||
|
||||
// Advance clock1 to 5
|
||||
for i := 0; i < 5; i++ {
|
||||
clock1.Tick()
|
||||
}
|
||||
|
||||
// Get current timestamp from clock1
|
||||
ts1 := clock1.GetCurrent()
|
||||
if ts1.Counter != 5 {
|
||||
t.Errorf("Expected counter to be 5, got %d", ts1.Counter)
|
||||
}
|
||||
|
||||
// Update clock2 with timestamp from clock1
|
||||
ts2 := clock2.Update(ts1)
|
||||
|
||||
// Clock2 should now be at least as large as clock1's timestamp + 1
|
||||
if ts2.Counter <= ts1.Counter {
|
||||
t.Errorf("Expected clock2 counter > %d, got %d", ts1.Counter, ts2.Counter)
|
||||
}
|
||||
|
||||
// Clock2's NodeID should still be nodeID2
|
||||
if !ts2.Node.Equal(nodeID2) {
|
||||
t.Errorf("NodeID not preserved after update")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLamportClock_Ordering(t *testing.T) {
|
||||
nodeID1 := NodeID{1, 2, 3}
|
||||
nodeID2 := NodeID{4, 5, 6}
|
||||
|
||||
// Create timestamps with the same counter but different NodeIDs
|
||||
ts1 := Timestamp{Counter: 5, Node: nodeID1}
|
||||
ts2 := Timestamp{Counter: 5, Node: nodeID2}
|
||||
|
||||
// Compare should break ties using NodeID
|
||||
if Compare(ts1, ts2) >= 0 {
|
||||
t.Errorf("Expected ts1 < ts2 when counters are equal")
|
||||
}
|
||||
|
||||
// Create timestamps with different counters
|
||||
ts3 := Timestamp{Counter: 6, Node: nodeID1}
|
||||
ts4 := Timestamp{Counter: 5, Node: nodeID2}
|
||||
|
||||
// Compare should prioritize counter values
|
||||
if Compare(ts3, ts4) <= 0 {
|
||||
t.Errorf("Expected ts3 > ts4 when ts3.Counter > ts4.Counter")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLamportClock_Serialization(t *testing.T) {
|
||||
nodeID := NodeID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
|
||||
ts := Timestamp{Counter: 42, Node: nodeID}
|
||||
|
||||
// Serialize
|
||||
bytes := ts.Bytes()
|
||||
|
||||
// Deserialize
|
||||
ts2, err := TimestampFromBytes(bytes)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to deserialize timestamp: %v", err)
|
||||
}
|
||||
|
||||
// Compare
|
||||
if ts.Counter != ts2.Counter {
|
||||
t.Errorf("Counter not preserved in serialization, expected %d got %d", ts.Counter, ts2.Counter)
|
||||
}
|
||||
|
||||
if !ts.Node.Equal(ts2.Node) {
|
||||
t.Errorf("NodeID not preserved in serialization")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLamportClock_Concurrent(t *testing.T) {
|
||||
nodeID := NodeID{1, 2, 3}
|
||||
clock := NewLamportClock(nodeID)
|
||||
|
||||
// Run 100 concurrent ticks
|
||||
done := make(chan struct{})
|
||||
for i := 0; i < 100; i++ {
|
||||
go func() {
|
||||
clock.Tick()
|
||||
done <- struct{}{}
|
||||
}()
|
||||
}
|
||||
|
||||
// Wait for all ticks to complete
|
||||
for i := 0; i < 100; i++ {
|
||||
<-done
|
||||
}
|
||||
|
||||
// Counter should be 100
|
||||
ts := clock.GetCurrent()
|
||||
if ts.Counter != 100 {
|
||||
t.Errorf("Expected counter to be 100 after 100 concurrent ticks, got %d", ts.Counter)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLamportClock_ManualSet(t *testing.T) {
|
||||
nodeID := NodeID{1, 2, 3}
|
||||
clock := NewLamportClock(nodeID)
|
||||
|
||||
// Set to a high value
|
||||
clock.ManualSet(1000)
|
||||
ts := clock.GetCurrent()
|
||||
if ts.Counter != 1000 {
|
||||
t.Errorf("Expected counter to be 1000 after ManualSet, got %d", ts.Counter)
|
||||
}
|
||||
|
||||
// Setting to a lower value should have no effect
|
||||
clock.ManualSet(500)
|
||||
ts = clock.GetCurrent()
|
||||
if ts.Counter != 1000 {
|
||||
t.Errorf("Expected counter to remain 1000 after lower ManualSet, got %d", ts.Counter)
|
||||
}
|
||||
|
||||
// Tick should still increment
|
||||
ts = clock.Tick()
|
||||
if ts.Counter != 1001 {
|
||||
t.Errorf("Expected counter to be 1001 after tick, got %d", ts.Counter)
|
||||
}
|
||||
}
|
267
pkg/common/log/logger.go
Normal file
267
pkg/common/log/logger.go
Normal file
@ -0,0 +1,267 @@
|
||||
// Package log provides a common logging interface for Kevo components.
|
||||
package log
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Level represents the logging level
|
||||
type Level int
|
||||
|
||||
const (
|
||||
// LevelDebug level for detailed troubleshooting information
|
||||
LevelDebug Level = iota
|
||||
// LevelInfo level for general operational information
|
||||
LevelInfo
|
||||
// LevelWarn level for potentially harmful situations
|
||||
LevelWarn
|
||||
// LevelError level for error events that might still allow the application to continue
|
||||
LevelError
|
||||
// LevelFatal level for severe error events that will lead the application to abort
|
||||
LevelFatal
|
||||
)
|
||||
|
||||
// String returns the string representation of the log level
|
||||
func (l Level) String() string {
|
||||
switch l {
|
||||
case LevelDebug:
|
||||
return "DEBUG"
|
||||
case LevelInfo:
|
||||
return "INFO"
|
||||
case LevelWarn:
|
||||
return "WARN"
|
||||
case LevelError:
|
||||
return "ERROR"
|
||||
case LevelFatal:
|
||||
return "FATAL"
|
||||
default:
|
||||
return fmt.Sprintf("LEVEL(%d)", l)
|
||||
}
|
||||
}
|
||||
|
||||
// Logger interface defines the methods for logging at different levels
|
||||
type Logger interface {
|
||||
// Debug logs a debug-level message
|
||||
Debug(msg string, args ...interface{})
|
||||
// Info logs an info-level message
|
||||
Info(msg string, args ...interface{})
|
||||
// Warn logs a warning-level message
|
||||
Warn(msg string, args ...interface{})
|
||||
// Error logs an error-level message
|
||||
Error(msg string, args ...interface{})
|
||||
// Fatal logs a fatal-level message and then calls os.Exit(1)
|
||||
Fatal(msg string, args ...interface{})
|
||||
// WithFields returns a new logger with the given fields added to the context
|
||||
WithFields(fields map[string]interface{}) Logger
|
||||
// WithField returns a new logger with the given field added to the context
|
||||
WithField(key string, value interface{}) Logger
|
||||
// GetLevel returns the current logging level
|
||||
GetLevel() Level
|
||||
// SetLevel sets the logging level
|
||||
SetLevel(level Level)
|
||||
}
|
||||
|
||||
// StandardLogger implements the Logger interface with a standard output format
|
||||
type StandardLogger struct {
|
||||
mu sync.Mutex
|
||||
level Level
|
||||
out io.Writer
|
||||
fields map[string]interface{}
|
||||
}
|
||||
|
||||
// NewStandardLogger creates a new StandardLogger with the given options
|
||||
func NewStandardLogger(options ...LoggerOption) *StandardLogger {
|
||||
logger := &StandardLogger{
|
||||
level: LevelInfo, // Default level
|
||||
out: os.Stdout,
|
||||
fields: make(map[string]interface{}),
|
||||
}
|
||||
|
||||
// Apply options
|
||||
for _, option := range options {
|
||||
option(logger)
|
||||
}
|
||||
|
||||
return logger
|
||||
}
|
||||
|
||||
// LoggerOption is a function that configures a StandardLogger
|
||||
type LoggerOption func(*StandardLogger)
|
||||
|
||||
// WithLevel sets the logging level
|
||||
func WithLevel(level Level) LoggerOption {
|
||||
return func(l *StandardLogger) {
|
||||
l.level = level
|
||||
}
|
||||
}
|
||||
|
||||
// WithOutput sets the output writer
|
||||
func WithOutput(out io.Writer) LoggerOption {
|
||||
return func(l *StandardLogger) {
|
||||
l.out = out
|
||||
}
|
||||
}
|
||||
|
||||
// WithInitialFields sets initial fields for the logger
|
||||
func WithInitialFields(fields map[string]interface{}) LoggerOption {
|
||||
return func(l *StandardLogger) {
|
||||
for k, v := range fields {
|
||||
l.fields[k] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// log logs a message at the specified level
|
||||
func (l *StandardLogger) log(level Level, msg string, args ...interface{}) {
|
||||
if level < l.level {
|
||||
return
|
||||
}
|
||||
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
// Format the message
|
||||
formattedMsg := msg
|
||||
if len(args) > 0 {
|
||||
formattedMsg = fmt.Sprintf(msg, args...)
|
||||
}
|
||||
|
||||
// Format timestamp
|
||||
timestamp := time.Now().Format("2006-01-02 15:04:05.000")
|
||||
|
||||
// Format fields
|
||||
fieldsStr := ""
|
||||
if len(l.fields) > 0 {
|
||||
for k, v := range l.fields {
|
||||
fieldsStr += fmt.Sprintf(" %s=%v", k, v)
|
||||
}
|
||||
}
|
||||
|
||||
// Write the log entry
|
||||
fmt.Fprintf(l.out, "[%s] [%s]%s %s\n", timestamp, level.String(), fieldsStr, formattedMsg)
|
||||
|
||||
// Exit if fatal
|
||||
if level == LevelFatal {
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
// Debug logs a debug-level message
|
||||
func (l *StandardLogger) Debug(msg string, args ...interface{}) {
|
||||
l.log(LevelDebug, msg, args...)
|
||||
}
|
||||
|
||||
// Info logs an info-level message
|
||||
func (l *StandardLogger) Info(msg string, args ...interface{}) {
|
||||
l.log(LevelInfo, msg, args...)
|
||||
}
|
||||
|
||||
// Warn logs a warning-level message
|
||||
func (l *StandardLogger) Warn(msg string, args ...interface{}) {
|
||||
l.log(LevelWarn, msg, args...)
|
||||
}
|
||||
|
||||
// Error logs an error-level message
|
||||
func (l *StandardLogger) Error(msg string, args ...interface{}) {
|
||||
l.log(LevelError, msg, args...)
|
||||
}
|
||||
|
||||
// Fatal logs a fatal-level message and then calls os.Exit(1)
|
||||
func (l *StandardLogger) Fatal(msg string, args ...interface{}) {
|
||||
l.log(LevelFatal, msg, args...)
|
||||
}
|
||||
|
||||
// WithFields returns a new logger with the given fields added to the context
|
||||
func (l *StandardLogger) WithFields(fields map[string]interface{}) Logger {
|
||||
newLogger := &StandardLogger{
|
||||
level: l.level,
|
||||
out: l.out,
|
||||
fields: make(map[string]interface{}, len(l.fields)+len(fields)),
|
||||
}
|
||||
|
||||
// Copy existing fields
|
||||
for k, v := range l.fields {
|
||||
newLogger.fields[k] = v
|
||||
}
|
||||
|
||||
// Add new fields
|
||||
for k, v := range fields {
|
||||
newLogger.fields[k] = v
|
||||
}
|
||||
|
||||
return newLogger
|
||||
}
|
||||
|
||||
// WithField returns a new logger with the given field added to the context
|
||||
func (l *StandardLogger) WithField(key string, value interface{}) Logger {
|
||||
return l.WithFields(map[string]interface{}{key: value})
|
||||
}
|
||||
|
||||
// GetLevel returns the current logging level
|
||||
func (l *StandardLogger) GetLevel() Level {
|
||||
return l.level
|
||||
}
|
||||
|
||||
// SetLevel sets the logging level
|
||||
func (l *StandardLogger) SetLevel(level Level) {
|
||||
l.level = level
|
||||
}
|
||||
|
||||
// Default logger instance
|
||||
var defaultLogger = NewStandardLogger()
|
||||
|
||||
// SetDefaultLogger sets the default logger instance
|
||||
func SetDefaultLogger(logger *StandardLogger) {
|
||||
defaultLogger = logger
|
||||
}
|
||||
|
||||
// GetDefaultLogger returns the default logger instance
|
||||
func GetDefaultLogger() *StandardLogger {
|
||||
return defaultLogger
|
||||
}
|
||||
|
||||
// These functions use the default logger
|
||||
|
||||
// Debug logs a debug-level message to the default logger
|
||||
func Debug(msg string, args ...interface{}) {
|
||||
defaultLogger.Debug(msg, args...)
|
||||
}
|
||||
|
||||
// Info logs an info-level message to the default logger
|
||||
func Info(msg string, args ...interface{}) {
|
||||
defaultLogger.Info(msg, args...)
|
||||
}
|
||||
|
||||
// Warn logs a warning-level message to the default logger
|
||||
func Warn(msg string, args ...interface{}) {
|
||||
defaultLogger.Warn(msg, args...)
|
||||
}
|
||||
|
||||
// Error logs an error-level message to the default logger
|
||||
func Error(msg string, args ...interface{}) {
|
||||
defaultLogger.Error(msg, args...)
|
||||
}
|
||||
|
||||
// Fatal logs a fatal-level message to the default logger and then calls os.Exit(1)
|
||||
func Fatal(msg string, args ...interface{}) {
|
||||
defaultLogger.Fatal(msg, args...)
|
||||
}
|
||||
|
||||
// WithFields returns a new logger with the given fields added to the context
|
||||
func WithFields(fields map[string]interface{}) Logger {
|
||||
return defaultLogger.WithFields(fields)
|
||||
}
|
||||
|
||||
// WithField returns a new logger with the given field added to the context
|
||||
func WithField(key string, value interface{}) Logger {
|
||||
return defaultLogger.WithField(key, value)
|
||||
}
|
||||
|
||||
// SetLevel sets the logging level of the default logger
|
||||
func SetLevel(level Level) {
|
||||
defaultLogger.SetLevel(level)
|
||||
}
|
132
pkg/common/log/logger_test.go
Normal file
132
pkg/common/log/logger_test.go
Normal file
@ -0,0 +1,132 @@
|
||||
package log
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestStandardLogger(t *testing.T) {
|
||||
// Create a buffer to capture output
|
||||
var buf bytes.Buffer
|
||||
|
||||
// Create a logger with the buffer as output
|
||||
logger := NewStandardLogger(
|
||||
WithOutput(&buf),
|
||||
WithLevel(LevelDebug),
|
||||
)
|
||||
|
||||
// Test debug level
|
||||
logger.Debug("This is a debug message")
|
||||
if !strings.Contains(buf.String(), "[DEBUG]") || !strings.Contains(buf.String(), "This is a debug message") {
|
||||
t.Errorf("Debug logging failed, got: %s", buf.String())
|
||||
}
|
||||
buf.Reset()
|
||||
|
||||
// Test info level
|
||||
logger.Info("This is an info message")
|
||||
if !strings.Contains(buf.String(), "[INFO]") || !strings.Contains(buf.String(), "This is an info message") {
|
||||
t.Errorf("Info logging failed, got: %s", buf.String())
|
||||
}
|
||||
buf.Reset()
|
||||
|
||||
// Test warn level
|
||||
logger.Warn("This is a warning message")
|
||||
if !strings.Contains(buf.String(), "[WARN]") || !strings.Contains(buf.String(), "This is a warning message") {
|
||||
t.Errorf("Warn logging failed, got: %s", buf.String())
|
||||
}
|
||||
buf.Reset()
|
||||
|
||||
// Test error level
|
||||
logger.Error("This is an error message")
|
||||
if !strings.Contains(buf.String(), "[ERROR]") || !strings.Contains(buf.String(), "This is an error message") {
|
||||
t.Errorf("Error logging failed, got: %s", buf.String())
|
||||
}
|
||||
buf.Reset()
|
||||
|
||||
// Test with fields
|
||||
loggerWithFields := logger.WithFields(map[string]interface{}{
|
||||
"component": "test",
|
||||
"count": 123,
|
||||
})
|
||||
loggerWithFields.Info("Message with fields")
|
||||
output := buf.String()
|
||||
if !strings.Contains(output, "[INFO]") ||
|
||||
!strings.Contains(output, "Message with fields") ||
|
||||
!strings.Contains(output, "component=test") ||
|
||||
!strings.Contains(output, "count=123") {
|
||||
t.Errorf("Logging with fields failed, got: %s", output)
|
||||
}
|
||||
buf.Reset()
|
||||
|
||||
// Test with a single field
|
||||
loggerWithField := logger.WithField("module", "logger")
|
||||
loggerWithField.Info("Message with a field")
|
||||
output = buf.String()
|
||||
if !strings.Contains(output, "[INFO]") ||
|
||||
!strings.Contains(output, "Message with a field") ||
|
||||
!strings.Contains(output, "module=logger") {
|
||||
t.Errorf("Logging with a field failed, got: %s", output)
|
||||
}
|
||||
buf.Reset()
|
||||
|
||||
// Test level filtering
|
||||
logger.SetLevel(LevelError)
|
||||
logger.Debug("This debug message should not appear")
|
||||
logger.Info("This info message should not appear")
|
||||
logger.Warn("This warning message should not appear")
|
||||
logger.Error("This error message should appear")
|
||||
output = buf.String()
|
||||
if strings.Contains(output, "should not appear") ||
|
||||
!strings.Contains(output, "This error message should appear") {
|
||||
t.Errorf("Level filtering failed, got: %s", output)
|
||||
}
|
||||
buf.Reset()
|
||||
|
||||
// Test formatted messages
|
||||
logger.SetLevel(LevelInfo)
|
||||
logger.Info("Formatted %s with %d params", "message", 2)
|
||||
if !strings.Contains(buf.String(), "Formatted message with 2 params") {
|
||||
t.Errorf("Formatted message failed, got: %s", buf.String())
|
||||
}
|
||||
buf.Reset()
|
||||
|
||||
// Test GetLevel
|
||||
if logger.GetLevel() != LevelInfo {
|
||||
t.Errorf("GetLevel failed, expected LevelInfo, got: %v", logger.GetLevel())
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultLogger(t *testing.T) {
|
||||
// Save original default logger
|
||||
originalLogger := defaultLogger
|
||||
defer func() {
|
||||
defaultLogger = originalLogger
|
||||
}()
|
||||
|
||||
// Create a buffer to capture output
|
||||
var buf bytes.Buffer
|
||||
|
||||
// Set a new default logger
|
||||
SetDefaultLogger(NewStandardLogger(
|
||||
WithOutput(&buf),
|
||||
WithLevel(LevelInfo),
|
||||
))
|
||||
|
||||
// Test global functions
|
||||
Info("Global info message")
|
||||
if !strings.Contains(buf.String(), "[INFO]") || !strings.Contains(buf.String(), "Global info message") {
|
||||
t.Errorf("Global info logging failed, got: %s", buf.String())
|
||||
}
|
||||
buf.Reset()
|
||||
|
||||
// Test global with fields
|
||||
WithField("global", true).Info("Global with field")
|
||||
output := buf.String()
|
||||
if !strings.Contains(output, "[INFO]") ||
|
||||
!strings.Contains(output, "Global with field") ||
|
||||
!strings.Contains(output, "global=true") {
|
||||
t.Errorf("Global logging with field failed, got: %s", output)
|
||||
}
|
||||
buf.Reset()
|
||||
}
|
@ -10,6 +10,7 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/jeremytregunna/kevo/pkg/common/clock"
|
||||
"github.com/jeremytregunna/kevo/pkg/common/iterator"
|
||||
"github.com/jeremytregunna/kevo/pkg/compaction"
|
||||
"github.com/jeremytregunna/kevo/pkg/config"
|
||||
@ -99,6 +100,11 @@ type Engine struct {
|
||||
mu sync.RWMutex // Main lock for engine state
|
||||
flushMu sync.Mutex // Lock for flushing operations
|
||||
txLock sync.RWMutex // Lock for transaction isolation
|
||||
|
||||
// Replication fields
|
||||
nodeID clock.NodeID // Unique identifier for this database instance
|
||||
lamportClock *clock.LamportClock // Logical clock for event ordering
|
||||
replicationMgr *ReplicationManager // Manager for replication operations
|
||||
}
|
||||
|
||||
// NewEngine creates a new storage engine
|
||||
@ -141,6 +147,15 @@ func NewEngine(dataDir string) (*Engine, error) {
|
||||
defer func() { wal.DisableRecoveryLogs = tempWasDisabled }()
|
||||
}
|
||||
|
||||
// Load or create the node ID and initialize a Lamport clock
|
||||
nodeID, err := loadOrCreateNodeID(dataDir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize node ID: %w", err)
|
||||
}
|
||||
|
||||
// Create a Lamport clock for this instance
|
||||
lamportClock := clock.NewLamportClock(nodeID)
|
||||
|
||||
// First try to reuse an existing WAL file
|
||||
var walLogger *wal.WAL
|
||||
|
||||
@ -150,12 +165,17 @@ func NewEngine(dataDir string) (*Engine, error) {
|
||||
return nil, fmt.Errorf("failed to check for reusable WAL: %w", err)
|
||||
}
|
||||
|
||||
// If no suitable WAL found, create a new one
|
||||
// If no suitable WAL found, create a new one with our logical clock
|
||||
if walLogger == nil {
|
||||
walLogger, err = wal.NewWAL(cfg, walDir)
|
||||
walLogger, err = wal.NewWALWithClock(cfg, walDir, lamportClock)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create WAL: %w", err)
|
||||
}
|
||||
} else {
|
||||
// If we're reusing a WAL, set the logical clock on it
|
||||
if err := walLogger.SetLogicalClock(lamportClock); err != nil {
|
||||
return nil, fmt.Errorf("failed to set logical clock on reused WAL: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Create the MemTable pool
|
||||
@ -172,6 +192,10 @@ func NewEngine(dataDir string) (*Engine, error) {
|
||||
sstables: make([]*sstable.Reader, 0),
|
||||
bgFlushCh: make(chan struct{}, 1),
|
||||
nextFileNum: 1,
|
||||
|
||||
// Set replication fields
|
||||
nodeID: nodeID,
|
||||
lamportClock: lamportClock,
|
||||
}
|
||||
|
||||
// Load existing SSTables
|
||||
@ -184,6 +208,11 @@ func NewEngine(dataDir string) (*Engine, error) {
|
||||
return nil, fmt.Errorf("failed to recover from WAL: %w", err)
|
||||
}
|
||||
|
||||
// Initialize replication
|
||||
if err := e.initializeReplication(); err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize replication: %w", err)
|
||||
}
|
||||
|
||||
// Start background flush goroutine
|
||||
go e.backgroundFlush()
|
||||
|
||||
@ -573,8 +602,8 @@ func (e *Engine) rotateWAL() error {
|
||||
return fmt.Errorf("failed to close WAL: %w", err)
|
||||
}
|
||||
|
||||
// Create a new WAL
|
||||
wal, err := wal.NewWAL(e.cfg, e.walDir)
|
||||
// Create a new WAL with the existing logical clock
|
||||
wal, err := wal.NewWALWithClock(e.cfg, e.walDir, e.lamportClock)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create new WAL: %w", err)
|
||||
}
|
||||
@ -702,8 +731,8 @@ func (e *Engine) recoverFromWAL() error {
|
||||
}
|
||||
}
|
||||
|
||||
// Create a fresh WAL
|
||||
newWal, err := wal.NewWAL(e.cfg, e.walDir)
|
||||
// Create a fresh WAL with the existing logical clock
|
||||
newWal, err := wal.NewWALWithClock(e.cfg, e.walDir, e.lamportClock)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create new WAL after recovery: %w", err)
|
||||
}
|
||||
@ -951,6 +980,13 @@ func (e *Engine) Close() error {
|
||||
return fmt.Errorf("failed to shutdown compaction: %w", err)
|
||||
}
|
||||
|
||||
// Shutdown replication manager if it exists
|
||||
if e.replicationMgr != nil {
|
||||
if err := e.replicationMgr.Close(); err != nil {
|
||||
return fmt.Errorf("failed to shutdown replication manager: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Close WAL first
|
||||
if err := e.wal.Close(); err != nil {
|
||||
return fmt.Errorf("failed to close WAL: %w", err)
|
||||
|
346
pkg/engine/replication.go
Normal file
346
pkg/engine/replication.go
Normal file
@ -0,0 +1,346 @@
|
||||
package engine
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/jeremytregunna/kevo/pkg/common/clock"
|
||||
"github.com/jeremytregunna/kevo/pkg/common/log"
|
||||
"github.com/jeremytregunna/kevo/pkg/wal"
|
||||
)
|
||||
|
||||
// NodeIDFile is the name of the file that stores the database's node ID
|
||||
const NodeIDFile = "NODE_ID"
|
||||
|
||||
// loadNodeID attempts to load the node ID from the file system
|
||||
// Returns the NodeID if successful or an error if the file doesn't exist or is invalid
|
||||
func loadNodeID(dbDir string) (clock.NodeID, error) {
|
||||
idPath := filepath.Join(dbDir, NodeIDFile)
|
||||
|
||||
data, err := os.ReadFile(idPath)
|
||||
if err != nil {
|
||||
return clock.NodeID{}, err // Return the original error for better diagnostics
|
||||
}
|
||||
|
||||
// Validate length
|
||||
if len(data) != 16 {
|
||||
return clock.NodeID{}, fmt.Errorf("invalid node ID file format (expected 16 bytes, got %d)", len(data))
|
||||
}
|
||||
|
||||
// Convert to NodeID
|
||||
var nodeID clock.NodeID
|
||||
copy(nodeID[:], data)
|
||||
return nodeID, nil
|
||||
}
|
||||
|
||||
// createNodeID generates a new node ID and persists it
|
||||
func createNodeID(dbDir string) (clock.NodeID, error) {
|
||||
var nodeID clock.NodeID
|
||||
|
||||
// Generate a UUID v4
|
||||
id, err := uuid.NewRandom()
|
||||
if err != nil {
|
||||
return clock.NodeID{}, fmt.Errorf("failed to generate UUID: %w", err)
|
||||
}
|
||||
|
||||
// Copy the UUID bytes to our NodeID
|
||||
copy(nodeID[:], id[:])
|
||||
|
||||
// Ensure directory exists
|
||||
if err := os.MkdirAll(dbDir, 0755); err != nil {
|
||||
return clock.NodeID{}, fmt.Errorf("failed to create database directory: %w", err)
|
||||
}
|
||||
|
||||
// Write the ID to the file
|
||||
idPath := filepath.Join(dbDir, NodeIDFile)
|
||||
if err := os.WriteFile(idPath, nodeID[:], 0644); err != nil {
|
||||
return clock.NodeID{}, fmt.Errorf("failed to write node ID file: %w", err)
|
||||
}
|
||||
|
||||
return nodeID, nil
|
||||
}
|
||||
|
||||
// loadOrCreateNodeID first tries to load an existing node ID, and if that fails,
|
||||
// creates a new one
|
||||
func loadOrCreateNodeID(dbDir string) (clock.NodeID, error) {
|
||||
nodeID, err := loadNodeID(dbDir)
|
||||
if err == nil {
|
||||
return nodeID, nil
|
||||
}
|
||||
|
||||
if os.IsNotExist(err) {
|
||||
// File doesn't exist, create new ID
|
||||
return createNodeID(dbDir)
|
||||
}
|
||||
|
||||
// Some other error occurred while loading
|
||||
return clock.NodeID{}, fmt.Errorf("failed to access node ID file: %w", err)
|
||||
}
|
||||
|
||||
// initializeReplication sets up replication components for the engine
|
||||
// This should be called during engine initialization
|
||||
func (e *Engine) initializeReplication() error {
|
||||
// Load or create node ID
|
||||
nodeID, err := loadOrCreateNodeID(e.dataDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize node ID: %w", err)
|
||||
}
|
||||
|
||||
// Create Lamport clock with this node ID
|
||||
e.lamportClock = clock.NewLamportClock(nodeID)
|
||||
e.nodeID = nodeID
|
||||
|
||||
// Set the clock on the WAL
|
||||
if err := e.wal.SetLogicalClock(e.lamportClock); err != nil {
|
||||
return fmt.Errorf("failed to set logical clock on WAL: %w", err)
|
||||
}
|
||||
|
||||
// Create the replication manager
|
||||
e.replicationMgr = NewReplicationManager(e)
|
||||
|
||||
// Set the replication hook on the WAL
|
||||
e.wal.SetReplicationHook(e.replicationMgr)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetNodeID returns the database instance's stable node ID
|
||||
func (e *Engine) GetNodeID() clock.NodeID {
|
||||
return e.nodeID
|
||||
}
|
||||
|
||||
// GetLamportClock returns the engine's Lamport clock
|
||||
func (e *Engine) GetLamportClock() *clock.LamportClock {
|
||||
return e.lamportClock
|
||||
}
|
||||
|
||||
// ReplicationManager handles replication operations for the engine
|
||||
type ReplicationManager struct {
|
||||
engine *Engine
|
||||
mu sync.RWMutex
|
||||
isLeader bool
|
||||
isReplica bool
|
||||
replicaIDs map[string]clock.NodeID // Maps replica IDs to their NodeIDs
|
||||
replicaPositions map[string]clock.Timestamp // Tracks the latest position for each replica
|
||||
logChan chan *ReplicationLogEntry // Channel for log entries to be processed
|
||||
stopChan chan struct{} // Channel to signal stopping
|
||||
logger log.Logger // Logger interface for replication events
|
||||
}
|
||||
|
||||
// ReplicationLogEntry represents a WAL entry that needs to be replicated
|
||||
type ReplicationLogEntry struct {
|
||||
entry *wal.Entry // The WAL entry
|
||||
timestamp clock.Timestamp // The Lamport timestamp
|
||||
batch bool // Whether this is part of a batch
|
||||
batchSize int // Size of batch if part of batch
|
||||
}
|
||||
|
||||
// NewReplicationManager creates a new replication manager for the engine
|
||||
func NewReplicationManager(engine *Engine) *ReplicationManager {
|
||||
logger := log.GetDefaultLogger().WithField("component", "replication")
|
||||
|
||||
rm := &ReplicationManager{
|
||||
engine: engine,
|
||||
replicaIDs: make(map[string]clock.NodeID),
|
||||
replicaPositions: make(map[string]clock.Timestamp),
|
||||
logChan: make(chan *ReplicationLogEntry, 1000), // Buffer for 1000 entries
|
||||
stopChan: make(chan struct{}),
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
// Start the replication processor goroutine
|
||||
go rm.processReplicationEntries()
|
||||
|
||||
logger.Info("Replication manager initialized")
|
||||
return rm
|
||||
}
|
||||
|
||||
// processReplicationEntries handles replication entries in the background
|
||||
func (rm *ReplicationManager) processReplicationEntries() {
|
||||
for {
|
||||
select {
|
||||
case entry := <-rm.logChan:
|
||||
rm.handleLogEntry(entry)
|
||||
case <-rm.stopChan:
|
||||
rm.logger.Info("Stopping replication log processor")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleLogEntry processes a replication log entry
|
||||
func (rm *ReplicationManager) handleLogEntry(entry *ReplicationLogEntry) {
|
||||
// Skip processing if we're a replica
|
||||
if rm.isReplica {
|
||||
return
|
||||
}
|
||||
|
||||
// TODO: Implement actual replication to remote nodes
|
||||
// For now, we'll just log the event
|
||||
if entry.batch {
|
||||
rm.logger.Debug("Processing batch entry for replication: timestamp=%s, batch_size=%d",
|
||||
entry.timestamp.String(), entry.batchSize)
|
||||
} else {
|
||||
rm.logger.Debug("Processing single entry for replication: type=%d, timestamp=%s, key_size=%d",
|
||||
entry.entry.Type, entry.timestamp.String(), len(entry.entry.Key))
|
||||
}
|
||||
}
|
||||
|
||||
// SetLeader sets this node as the leader for replication
|
||||
func (rm *ReplicationManager) SetLeader(isLeader bool) {
|
||||
rm.mu.Lock()
|
||||
defer rm.mu.Unlock()
|
||||
rm.isLeader = isLeader
|
||||
rm.isReplica = !isLeader
|
||||
|
||||
if isLeader {
|
||||
rm.logger.Info("Node set as replication leader")
|
||||
} else {
|
||||
rm.logger.Info("Node set as replica")
|
||||
}
|
||||
}
|
||||
|
||||
// IsLeader returns whether this node is the leader
|
||||
func (rm *ReplicationManager) IsLeader() bool {
|
||||
rm.mu.RLock()
|
||||
defer rm.mu.RUnlock()
|
||||
return rm.isLeader
|
||||
}
|
||||
|
||||
// IsReplica returns whether this node is a replica
|
||||
func (rm *ReplicationManager) IsReplica() bool {
|
||||
rm.mu.RLock()
|
||||
defer rm.mu.RUnlock()
|
||||
return rm.isReplica
|
||||
}
|
||||
|
||||
// AddReplica adds a replica node to the replication system
|
||||
func (rm *ReplicationManager) AddReplica(id string, nodeID clock.NodeID) {
|
||||
rm.mu.Lock()
|
||||
defer rm.mu.Unlock()
|
||||
rm.replicaIDs[id] = nodeID
|
||||
rm.logger.Info("Added replica: %s with NodeID: %s", id, nodeID.String())
|
||||
}
|
||||
|
||||
// RemoveReplica removes a replica node from the replication system
|
||||
func (rm *ReplicationManager) RemoveReplica(id string) {
|
||||
rm.mu.Lock()
|
||||
defer rm.mu.Unlock()
|
||||
delete(rm.replicaIDs, id)
|
||||
delete(rm.replicaPositions, id)
|
||||
rm.logger.Info("Removed replica: %s", id)
|
||||
}
|
||||
|
||||
// GetReplicaIDs returns the list of replica IDs
|
||||
func (rm *ReplicationManager) GetReplicaIDs() []string {
|
||||
rm.mu.RLock()
|
||||
defer rm.mu.RUnlock()
|
||||
|
||||
ids := make([]string, 0, len(rm.replicaIDs))
|
||||
for id := range rm.replicaIDs {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
// GetReplicaPosition returns the current replication position for a replica
|
||||
func (rm *ReplicationManager) GetReplicaPosition(replicaID string) (clock.Timestamp, bool) {
|
||||
rm.mu.RLock()
|
||||
defer rm.mu.RUnlock()
|
||||
|
||||
pos, exists := rm.replicaPositions[replicaID]
|
||||
return pos, exists
|
||||
}
|
||||
|
||||
// UpdateReplicaPosition updates the position for a specific replica
|
||||
func (rm *ReplicationManager) UpdateReplicaPosition(replicaID string, position clock.Timestamp) {
|
||||
rm.mu.Lock()
|
||||
defer rm.mu.Unlock()
|
||||
|
||||
current, exists := rm.replicaPositions[replicaID]
|
||||
if !exists || clock.Compare(position, current) > 0 {
|
||||
rm.replicaPositions[replicaID] = position
|
||||
rm.logger.Debug("Updated replica %s position to %s", replicaID, position.String())
|
||||
}
|
||||
}
|
||||
|
||||
// OnEntryWritten implements the wal.ReplicationHook interface
|
||||
// It's called when a single entry is written to the WAL
|
||||
func (rm *ReplicationManager) OnEntryWritten(entry *wal.Entry, timestamp clock.Timestamp) error {
|
||||
// Skip processing if we're a replica to avoid replication loops
|
||||
if rm.isReplica {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Queue the entry for processing
|
||||
select {
|
||||
case rm.logChan <- &ReplicationLogEntry{
|
||||
entry: entry,
|
||||
timestamp: timestamp,
|
||||
batch: false,
|
||||
}:
|
||||
// Successfully queued
|
||||
default:
|
||||
// Channel is full, log warning but don't block the write path
|
||||
rm.logger.Error("Replication queue is full, dropping entry")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// OnBatchWritten implements the wal.ReplicationHook interface
|
||||
// It's called when a batch of entries is written to the WAL
|
||||
func (rm *ReplicationManager) OnBatchWritten(entries []*wal.Entry, startTimestamp clock.Timestamp) error {
|
||||
// Skip processing if we're a replica to avoid replication loops
|
||||
if rm.isReplica {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Process each entry in the batch
|
||||
for i, entry := range entries {
|
||||
// Calculate timestamp for this entry based on batch start
|
||||
entryTimestamp := clock.Timestamp{
|
||||
Counter: startTimestamp.Counter + uint64(i),
|
||||
Node: startTimestamp.Node,
|
||||
}
|
||||
|
||||
// Queue the entry for processing
|
||||
select {
|
||||
case rm.logChan <- &ReplicationLogEntry{
|
||||
entry: entry,
|
||||
timestamp: entryTimestamp,
|
||||
batch: true,
|
||||
batchSize: len(entries),
|
||||
}:
|
||||
// Successfully queued
|
||||
default:
|
||||
// Channel is full, log warning but don't block the write path
|
||||
rm.logger.Error("Replication queue is full, dropping batch entry %d/%d",
|
||||
i+1, len(entries))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close shuts down the replication manager
|
||||
func (rm *ReplicationManager) Close() error {
|
||||
close(rm.stopChan)
|
||||
rm.logger.Info("Replication manager shut down")
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetReplicationHook sets a custom replication hook on the WAL
|
||||
// This is mainly used for testing or when you need to override the default hook
|
||||
func (e *Engine) SetReplicationHook(hook interface{}) error {
|
||||
walHook, ok := hook.(wal.ReplicationHook)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid replication hook type: %T", hook)
|
||||
}
|
||||
|
||||
e.wal.SetReplicationHook(walHook)
|
||||
return nil
|
||||
}
|
251
pkg/engine/replication_test.go
Normal file
251
pkg/engine/replication_test.go
Normal file
@ -0,0 +1,251 @@
|
||||
package engine
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/jeremytregunna/kevo/pkg/common/clock"
|
||||
"github.com/jeremytregunna/kevo/pkg/common/log"
|
||||
"github.com/jeremytregunna/kevo/pkg/wal"
|
||||
)
|
||||
|
||||
// TestReplicationHooks tests that the replication hooks are properly called
|
||||
func TestReplicationHooks(t *testing.T) {
|
||||
// Set log level to avoid noise in tests
|
||||
log.SetLevel(log.LevelError)
|
||||
|
||||
// Create a temporary directory for testing
|
||||
tempDir, err := os.MkdirTemp("", "replication-test")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temporary directory: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
// Create a new engine
|
||||
engine, err := NewEngine(tempDir)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create engine: %v", err)
|
||||
}
|
||||
|
||||
// Make sure replication manager was created
|
||||
if engine.replicationMgr == nil {
|
||||
t.Fatal("Replication manager was not created")
|
||||
}
|
||||
|
||||
// Verify NodeID was assigned
|
||||
if engine.nodeID == (clock.NodeID{}) {
|
||||
t.Fatal("NodeID was not assigned")
|
||||
}
|
||||
|
||||
// Verify Lamport clock was created
|
||||
if engine.lamportClock == nil {
|
||||
t.Fatal("Lamport clock was not created")
|
||||
}
|
||||
|
||||
// Test adding and removing replicas
|
||||
replicaID := "test-replica"
|
||||
fakeNodeID := clock.NodeID{}
|
||||
engine.replicationMgr.AddReplica(replicaID, fakeNodeID)
|
||||
|
||||
replicas := engine.replicationMgr.GetReplicaIDs()
|
||||
if len(replicas) != 1 || replicas[0] != replicaID {
|
||||
t.Fatalf("Expected replica ID %s, got %v", replicaID, replicas)
|
||||
}
|
||||
|
||||
engine.replicationMgr.RemoveReplica(replicaID)
|
||||
|
||||
replicas = engine.replicationMgr.GetReplicaIDs()
|
||||
if len(replicas) != 0 {
|
||||
t.Fatalf("Expected no replicas, got %v", replicas)
|
||||
}
|
||||
|
||||
// Test setting leader/replica status
|
||||
if engine.replicationMgr.IsLeader() {
|
||||
t.Fatal("Replication manager should not be leader by default")
|
||||
}
|
||||
|
||||
engine.replicationMgr.SetLeader(true)
|
||||
|
||||
if !engine.replicationMgr.IsLeader() {
|
||||
t.Fatal("Replication manager should be leader")
|
||||
}
|
||||
|
||||
if engine.replicationMgr.IsReplica() {
|
||||
t.Fatal("Replication manager should not be replica when it's a leader")
|
||||
}
|
||||
|
||||
engine.replicationMgr.SetLeader(false)
|
||||
|
||||
if !engine.replicationMgr.IsReplica() {
|
||||
t.Fatal("Replication manager should be replica when it's not a leader")
|
||||
}
|
||||
|
||||
// Close the engine
|
||||
if err := engine.Close(); err != nil {
|
||||
t.Fatalf("Failed to close engine: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestReplicationIntegration tests that the engine works with replication hooks
|
||||
func TestReplicationIntegration(t *testing.T) {
|
||||
// Set log level to avoid noise in tests
|
||||
log.SetLevel(log.LevelError)
|
||||
|
||||
// Create a temporary directory for testing
|
||||
tempDir, err := os.MkdirTemp("", "replication-integration-test")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temporary directory: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
// Create a new engine
|
||||
engine, err := NewEngine(tempDir)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create engine: %v", err)
|
||||
}
|
||||
|
||||
// Verify the replication manager is initialized
|
||||
if engine.replicationMgr == nil {
|
||||
t.Fatal("Replication manager was not initialized")
|
||||
}
|
||||
|
||||
// Set as leader to ensure replication hooks are called
|
||||
engine.replicationMgr.SetLeader(true)
|
||||
|
||||
// Perform some operations
|
||||
testKey := []byte("test-key")
|
||||
testValue := []byte("test-value")
|
||||
|
||||
// Put operation
|
||||
if err := engine.Put(testKey, testValue); err != nil {
|
||||
t.Fatalf("Failed to put: %v", err)
|
||||
}
|
||||
|
||||
// Get operation
|
||||
value, err := engine.Get(testKey)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get: %v", err)
|
||||
}
|
||||
|
||||
// Verify value
|
||||
if string(value) != string(testValue) {
|
||||
t.Fatalf("Expected value %q, got %q", testValue, value)
|
||||
}
|
||||
|
||||
// Close the engine
|
||||
if err := engine.Close(); err != nil {
|
||||
t.Fatalf("Failed to close engine: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestReplicationInterface verifies that the replication hook interface works correctly
|
||||
func TestReplicationInterface(t *testing.T) {
|
||||
// Set log level to avoid noise in tests
|
||||
log.SetLevel(log.LevelError)
|
||||
|
||||
// Create a test hook
|
||||
callbackCounts := struct {
|
||||
singleEntries int
|
||||
batchEntries int
|
||||
}{}
|
||||
|
||||
testHook := &syncTestReplicationHook{
|
||||
onEntryCallback: func(entry *wal.Entry, ts clock.Timestamp) error {
|
||||
callbackCounts.singleEntries++
|
||||
return nil
|
||||
},
|
||||
onBatchCallback: func(entries []*wal.Entry, ts clock.Timestamp) error {
|
||||
callbackCounts.batchEntries += len(entries)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
// Create a test entry
|
||||
entry := &wal.Entry{
|
||||
SequenceNumber: 1,
|
||||
Type: wal.OpTypePut,
|
||||
Key: []byte("key1"),
|
||||
Value: []byte("value1"),
|
||||
}
|
||||
|
||||
// Create a test batch
|
||||
batch := []*wal.Entry{
|
||||
{
|
||||
SequenceNumber: 2,
|
||||
Type: wal.OpTypePut,
|
||||
Key: []byte("key2"),
|
||||
Value: []byte("value2"),
|
||||
},
|
||||
{
|
||||
SequenceNumber: 3,
|
||||
Type: wal.OpTypePut,
|
||||
Key: []byte("key3"),
|
||||
Value: []byte("value3"),
|
||||
},
|
||||
}
|
||||
|
||||
// Create a timestamp
|
||||
nodeID := clock.NodeID{}
|
||||
ts := clock.Timestamp{
|
||||
Counter: 1,
|
||||
Node: nodeID,
|
||||
}
|
||||
|
||||
// Call the hook methods directly
|
||||
if err := testHook.OnEntryWritten(entry, ts); err != nil {
|
||||
t.Fatalf("OnEntryWritten failed: %v", err)
|
||||
}
|
||||
|
||||
if err := testHook.OnBatchWritten(batch, ts); err != nil {
|
||||
t.Fatalf("OnBatchWritten failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify callbacks were called
|
||||
if callbackCounts.singleEntries != 1 {
|
||||
t.Errorf("Expected 1 single entry callback, got %d", callbackCounts.singleEntries)
|
||||
}
|
||||
|
||||
if callbackCounts.batchEntries != 2 {
|
||||
t.Errorf("Expected 2 batch entry callbacks, got %d", callbackCounts.batchEntries)
|
||||
}
|
||||
}
|
||||
|
||||
// Test helper: a synchronous mock replication hook for testing
|
||||
type syncTestReplicationHook struct {
|
||||
onEntryCallback func(*wal.Entry, clock.Timestamp) error
|
||||
onBatchCallback func([]*wal.Entry, clock.Timestamp) error
|
||||
}
|
||||
|
||||
func (h *syncTestReplicationHook) OnEntryWritten(entry *wal.Entry, ts clock.Timestamp) error {
|
||||
if h.onEntryCallback != nil {
|
||||
return h.onEntryCallback(entry, ts)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *syncTestReplicationHook) OnBatchWritten(entries []*wal.Entry, ts clock.Timestamp) error {
|
||||
if h.onBatchCallback != nil {
|
||||
return h.onBatchCallback(entries, ts)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Test helper: an asynchronous mock replication hook for testing
|
||||
type testReplicationHook struct {
|
||||
onEntryCallback func(*wal.Entry, clock.Timestamp) error
|
||||
onBatchCallback func([]*wal.Entry, clock.Timestamp) error
|
||||
}
|
||||
|
||||
func (h *testReplicationHook) OnEntryWritten(entry *wal.Entry, ts clock.Timestamp) error {
|
||||
if h.onEntryCallback != nil {
|
||||
return h.onEntryCallback(entry, ts)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *testReplicationHook) OnBatchWritten(entries []*wal.Entry, ts clock.Timestamp) error {
|
||||
if h.onBatchCallback != nil {
|
||||
return h.onBatchCallback(entries, ts)
|
||||
}
|
||||
return nil
|
||||
}
|
154
pkg/wal/wal.go
154
pkg/wal/wal.go
@ -11,6 +11,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/jeremytregunna/kevo/pkg/common/clock"
|
||||
"github.com/jeremytregunna/kevo/pkg/config"
|
||||
)
|
||||
|
||||
@ -54,6 +55,22 @@ type Entry struct {
|
||||
Type uint8 // OpTypePut, OpTypeDelete, etc.
|
||||
Key []byte
|
||||
Value []byte
|
||||
// Lamport timestamp for replication ordering
|
||||
Timestamp clock.Timestamp
|
||||
}
|
||||
|
||||
// GetTimestamp returns the Lamport timestamp of this entry
|
||||
func (e *Entry) GetTimestamp() clock.Timestamp {
|
||||
return e.Timestamp
|
||||
}
|
||||
|
||||
// ReplicationHook defines the interface for replication hooks
|
||||
type ReplicationHook interface {
|
||||
// OnEntryWritten is called when a single entry is written to the WAL
|
||||
OnEntryWritten(entry *Entry, timestamp clock.Timestamp) error
|
||||
|
||||
// OnBatchWritten is called when a batch of entries is written to the WAL
|
||||
OnBatchWritten(entries []*Entry, startTimestamp clock.Timestamp) error
|
||||
}
|
||||
|
||||
// Global variable to control whether to print recovery logs
|
||||
@ -71,9 +88,54 @@ type WAL struct {
|
||||
batchByteSize int64
|
||||
closed bool
|
||||
mu sync.Mutex
|
||||
|
||||
// Replication fields
|
||||
logicalClock *clock.LamportClock // Logical clock for replication ordering
|
||||
replicationHook ReplicationHook // Hook for replication events
|
||||
replicationMu sync.Mutex // Separate mutex for replication operations
|
||||
}
|
||||
|
||||
// NewWAL creates a new write-ahead log
|
||||
// NewWALWithClock creates a new WAL with the provided logical clock
|
||||
// This is the recommended constructor for systems using replication
|
||||
func NewWALWithClock(cfg *config.Config, dir string, logicalClock *clock.LamportClock) (*WAL, error) {
|
||||
if cfg == nil {
|
||||
return nil, errors.New("config cannot be nil")
|
||||
}
|
||||
|
||||
if logicalClock == nil {
|
||||
return nil, errors.New("logicalClock cannot be nil")
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
return nil, fmt.Errorf("failed to create WAL directory: %w", err)
|
||||
}
|
||||
|
||||
// Create a new WAL file
|
||||
filename := fmt.Sprintf("%020d.wal", time.Now().UnixNano())
|
||||
path := filepath.Join(dir, filename)
|
||||
|
||||
file, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE|os.O_EXCL, 0644)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create WAL file: %w", err)
|
||||
}
|
||||
|
||||
wal := &WAL{
|
||||
cfg: cfg,
|
||||
dir: dir,
|
||||
file: file,
|
||||
writer: bufio.NewWriterSize(file, 64*1024), // 64KB buffer
|
||||
nextSequence: 1,
|
||||
lastSync: time.Now(),
|
||||
|
||||
// Use the provided logical clock
|
||||
logicalClock: logicalClock,
|
||||
}
|
||||
|
||||
return wal, nil
|
||||
}
|
||||
|
||||
// NewWAL creates a new write-ahead log with a temporary logical clock
|
||||
// For production systems with replication, use NewWALWithClock instead
|
||||
func NewWAL(cfg *config.Config, dir string) (*WAL, error) {
|
||||
if cfg == nil {
|
||||
return nil, errors.New("config cannot be nil")
|
||||
@ -92,7 +154,7 @@ func NewWAL(cfg *config.Config, dir string) (*WAL, error) {
|
||||
return nil, fmt.Errorf("failed to create WAL file: %w", err)
|
||||
}
|
||||
|
||||
wal := &WAL{
|
||||
wal := &WAL{
|
||||
cfg: cfg,
|
||||
dir: dir,
|
||||
file: file,
|
||||
@ -174,6 +236,34 @@ func ReuseWAL(cfg *config.Config, dir string, nextSeq uint64) (*WAL, error) {
|
||||
return wal, nil
|
||||
}
|
||||
|
||||
// SetReplicationHook sets the replication hook for the WAL
|
||||
func (w *WAL) SetReplicationHook(hook ReplicationHook) {
|
||||
w.replicationMu.Lock()
|
||||
defer w.replicationMu.Unlock()
|
||||
w.replicationHook = hook
|
||||
}
|
||||
|
||||
// SetLogicalClock sets the Lamport clock for this WAL
|
||||
// This should be called by the engine before using the WAL for replication
|
||||
func (w *WAL) SetLogicalClock(logicalClock *clock.LamportClock) error {
|
||||
if logicalClock == nil {
|
||||
return errors.New("logicalClock cannot be nil")
|
||||
}
|
||||
|
||||
w.replicationMu.Lock()
|
||||
defer w.replicationMu.Unlock()
|
||||
|
||||
w.logicalClock = logicalClock
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetLogicalClock returns the current logical clock
|
||||
func (w *WAL) GetLogicalClock() *clock.LamportClock {
|
||||
w.replicationMu.Lock()
|
||||
defer w.replicationMu.Unlock()
|
||||
return w.logicalClock
|
||||
}
|
||||
|
||||
// Append adds an entry to the WAL
|
||||
func (w *WAL) Append(entryType uint8, key, value []byte) (uint64, error) {
|
||||
w.mu.Lock()
|
||||
@ -217,6 +307,33 @@ func (w *WAL) Append(entryType uint8, key, value []byte) (uint64, error) {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// Create the entry object
|
||||
entry := &Entry{
|
||||
SequenceNumber: seqNum,
|
||||
Type: entryType,
|
||||
Key: key,
|
||||
Value: value,
|
||||
}
|
||||
|
||||
// If we have a logical clock, use it to timestamp the entry
|
||||
if w.logicalClock != nil {
|
||||
// Generate a new timestamp for this entry
|
||||
timestamp := w.logicalClock.Tick()
|
||||
entry.Timestamp = timestamp
|
||||
|
||||
// Trigger replication hook if set
|
||||
if w.replicationHook != nil {
|
||||
// Call in a goroutine to avoid blocking the write path
|
||||
// This is safe because the entry is fully written and synced at this point
|
||||
go func() {
|
||||
if err := w.replicationHook.OnEntryWritten(entry, timestamp); err != nil {
|
||||
// Just log the error, don't fail the operation
|
||||
fmt.Printf("Replication hook error: %v\n", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
return seqNum, nil
|
||||
}
|
||||
|
||||
@ -498,6 +615,39 @@ func (w *WAL) AppendBatch(entries []*Entry) (uint64, error) {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// If we have a logical clock, generate timestamps for the batch
|
||||
if w.logicalClock != nil {
|
||||
// Create a base timestamp for the batch
|
||||
batchTimestamp := w.logicalClock.Tick()
|
||||
|
||||
// Only proceed with creating entries if we have a replication hook
|
||||
if w.replicationHook != nil {
|
||||
// Assign timestamps to each entry in the batch
|
||||
entriesWithTimestamps := make([]*Entry, len(entries))
|
||||
for i, entry := range entries {
|
||||
// Clone the entry to avoid modifying the original
|
||||
entriesWithTimestamps[i] = &Entry{
|
||||
SequenceNumber: startSeqNum + uint64(i),
|
||||
Type: entry.Type,
|
||||
Key: entry.Key,
|
||||
Value: entry.Value,
|
||||
Timestamp: clock.Timestamp{
|
||||
Counter: batchTimestamp.Counter + uint64(i),
|
||||
Node: batchTimestamp.Node,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Call the replication hook in a goroutine
|
||||
go func() {
|
||||
if err := w.replicationHook.OnBatchWritten(entriesWithTimestamps, batchTimestamp); err != nil {
|
||||
// Just log the error, don't fail the operation
|
||||
fmt.Printf("Replication hook batch error: %v\n", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
return startSeqNum, nil
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user