Compare commits

...

6 Commits

Author SHA1 Message Date
5963538bc5
feat: implement replication transport layer
All checks were successful
Go Tests / Run Tests (1.24.2) (pull_request) Successful in 9m49s
This commit implements the replication transport layer as part of Phase 2 of the replication plan.
Key components include:

- Add protocol buffer definitions for replication services
- Implement WALReplicator extension for processor management
- Create replication service server implementation
- Add replication client and server transport implementations
- Implement storage snapshot interface for bootstrap operations
- Standardize package naming across replication components
2025-04-26 13:05:54 -06:00
ed991ae00d
feat: add replication transport interfaces and protocol schema
- Add replication-specific interfaces to pkg/transport
- Create ReplicationClient and ReplicationServer interfaces
- Add replication message types for WAL entries and bootstrap
- Create Protobuf schema for replication in proto/kevo/replication.proto
- Update transport registry to support replication components
2025-04-26 12:33:38 -06:00
33ddfeeb64
fix: use Lamport clocks consistently across WAL operations
Ensure proper Lamport clock integration across batch operations and sequence number handling:
- Add Lamport clock support to Batch.Write for consistent replication timestamps
- Fix potential sequence number inconsistencies in WAL operations
- Update WAL tests to properly verify sync mode behaviors
- Fix Sequence numbering to appropriately handle Lamport clock timestamps
2025-04-26 12:24:50 -06:00
5cd1f5c5f8
feat: implement WAL applier for replication
- Add WALApplier component for applying entries on replica nodes
- Implement logical timestamp ordering with Lamport clocks
- Add support for handling out-of-order entry delivery
- Add error handling and recovery mechanisms
- Implement comprehensive testing for all applier functions
2025-04-26 12:02:53 -06:00
02febadf5d
feat: implement WAL replicator and entry serialization
- Add WAL replicator component with entry capture, buffering, and subscriptions
- Implement WAL entry serialization with checksumming
- Add batch serialization for network-efficient transfers
- Implement proper concurrency control with mutex protection
- Add utility functions for entry size estimation
- Create comprehensive test suite
2025-04-26 11:54:19 -06:00
5b2ecdd08c
fix: add the retry logic 2025-04-26 11:48:51 -06:00
30 changed files with 6713 additions and 83 deletions

View File

@ -0,0 +1,86 @@
package storage
import (
"math/rand"
"time"
"github.com/KevoDB/kevo/pkg/wal"
)
// RetryConfig defines parameters for retry operations
type RetryConfig struct {
MaxRetries int // Maximum number of retries
InitialBackoff time.Duration // Initial backoff duration
MaxBackoff time.Duration // Maximum backoff duration
}
// DefaultRetryConfig returns default retry configuration
func DefaultRetryConfig() *RetryConfig {
return &RetryConfig{
MaxRetries: 3,
InitialBackoff: 5 * time.Millisecond,
MaxBackoff: 50 * time.Millisecond,
}
}
// RetryOnWALRotating retries the operation if it fails with ErrWALRotating
func (m *Manager) RetryOnWALRotating(operation func() error) error {
config := DefaultRetryConfig()
return m.RetryWithConfig(operation, config, isWALRotating)
}
// RetryWithSequence retries the operation if it fails with ErrWALRotating
// and returns the sequence number
func (m *Manager) RetryWithSequence(operation func() (uint64, error)) (uint64, error) {
config := DefaultRetryConfig()
var seq uint64
err := m.RetryWithConfig(func() error {
var opErr error
seq, opErr = operation()
return opErr
}, config, isWALRotating)
return seq, err
}
// RetryWithConfig retries an operation with the given configuration
func (m *Manager) RetryWithConfig(operation func() error, config *RetryConfig, isRetryable func(error) bool) error {
backoff := config.InitialBackoff
for i := 0; i <= config.MaxRetries; i++ {
// Attempt the operation
err := operation()
if err == nil {
return nil
}
// Check if we should retry
if !isRetryable(err) || i == config.MaxRetries {
return err
}
// Add some jitter to the backoff
jitter := time.Duration(rand.Int63n(int64(backoff / 10)))
backoff = backoff + jitter
// Wait before retrying
time.Sleep(backoff)
// Increase backoff for next attempt, but cap it
backoff = 2 * backoff
if backoff > config.MaxBackoff {
backoff = config.MaxBackoff
}
}
// Should never get here, but just in case
return nil
}
// isWALRotating checks if the error is due to WAL rotation or closure
func isWALRotating(err error) bool {
// Both ErrWALRotating and ErrWALClosed can occur during WAL rotation
// Since WAL rotation is a normal operation, we should retry in both cases
return err == wal.ErrWALRotating || err == wal.ErrWALClosed
}

View File

@ -0,0 +1,610 @@
package service
import (
"context"
"fmt"
"io"
"sync"
"time"
"github.com/KevoDB/kevo/pkg/replication"
"github.com/KevoDB/kevo/pkg/transport"
"github.com/KevoDB/kevo/pkg/wal"
"github.com/KevoDB/kevo/proto/kevo"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
// ReplicationServiceServer implements the gRPC ReplicationService
type ReplicationServiceServer struct {
kevo.UnimplementedReplicationServiceServer
// Replication components
replicator *replication.WALReplicator
applier *replication.WALApplier
serializer *replication.EntrySerializer
highestLSN uint64
replicas map[string]*transport.ReplicaInfo
replicasMutex sync.RWMutex
// For snapshot/bootstrap
storageSnapshot replication.StorageSnapshot
}
// NewReplicationService creates a new ReplicationService
func NewReplicationService(
replicator *replication.WALReplicator,
applier *replication.WALApplier,
serializer *replication.EntrySerializer,
storageSnapshot replication.StorageSnapshot,
) *ReplicationServiceServer {
return &ReplicationServiceServer{
replicator: replicator,
applier: applier,
serializer: serializer,
replicas: make(map[string]*transport.ReplicaInfo),
storageSnapshot: storageSnapshot,
}
}
// RegisterReplica handles registration of a new replica
func (s *ReplicationServiceServer) RegisterReplica(
ctx context.Context,
req *kevo.RegisterReplicaRequest,
) (*kevo.RegisterReplicaResponse, error) {
// Validate request
if req.ReplicaId == "" {
return nil, status.Error(codes.InvalidArgument, "replica_id is required")
}
if req.Address == "" {
return nil, status.Error(codes.InvalidArgument, "address is required")
}
// Convert role enum to string
role := transport.RoleReplica
switch req.Role {
case kevo.ReplicaRole_PRIMARY:
role = transport.RolePrimary
case kevo.ReplicaRole_REPLICA:
role = transport.RoleReplica
case kevo.ReplicaRole_READ_ONLY:
role = transport.RoleReadOnly
default:
return nil, status.Error(codes.InvalidArgument, "invalid role")
}
// Register the replica
s.replicasMutex.Lock()
defer s.replicasMutex.Unlock()
// If already registered, update address and role
if replica, exists := s.replicas[req.ReplicaId]; exists {
replica.Address = req.Address
replica.Role = role
replica.LastSeen = time.Now()
replica.Status = transport.StatusConnecting
} else {
// Create new replica info
s.replicas[req.ReplicaId] = &transport.ReplicaInfo{
ID: req.ReplicaId,
Address: req.Address,
Role: role,
Status: transport.StatusConnecting,
LastSeen: time.Now(),
}
}
// Determine if bootstrap is needed (for now always suggest bootstrap)
bootstrapRequired := true
// Return current highest LSN
currentLSN := s.replicator.GetHighestTimestamp()
return &kevo.RegisterReplicaResponse{
Success: true,
CurrentLsn: currentLSN,
BootstrapRequired: bootstrapRequired,
}, nil
}
// ReplicaHeartbeat handles heartbeat requests from replicas
func (s *ReplicationServiceServer) ReplicaHeartbeat(
ctx context.Context,
req *kevo.ReplicaHeartbeatRequest,
) (*kevo.ReplicaHeartbeatResponse, error) {
// Validate request
if req.ReplicaId == "" {
return nil, status.Error(codes.InvalidArgument, "replica_id is required")
}
// Check if replica is registered
s.replicasMutex.Lock()
defer s.replicasMutex.Unlock()
replica, exists := s.replicas[req.ReplicaId]
if !exists {
return nil, status.Error(codes.NotFound, "replica not registered")
}
// Update replica status
replica.LastSeen = time.Now()
// Convert status enum to string
switch req.Status {
case kevo.ReplicaStatus_CONNECTING:
replica.Status = transport.StatusConnecting
case kevo.ReplicaStatus_SYNCING:
replica.Status = transport.StatusSyncing
case kevo.ReplicaStatus_BOOTSTRAPPING:
replica.Status = transport.StatusBootstrapping
case kevo.ReplicaStatus_READY:
replica.Status = transport.StatusReady
case kevo.ReplicaStatus_DISCONNECTED:
replica.Status = transport.StatusDisconnected
case kevo.ReplicaStatus_ERROR:
replica.Status = transport.StatusError
replica.Error = fmt.Errorf("%s", req.ErrorMessage)
default:
return nil, status.Error(codes.InvalidArgument, "invalid status")
}
// Update replica LSN
replica.CurrentLSN = req.CurrentLsn
// Calculate replication lag
primaryLSN := s.replicator.GetHighestTimestamp()
var replicationLagMs int64 = 0
if primaryLSN > req.CurrentLsn {
// Simple lag calculation based on LSN difference
// In a real system, we'd use timestamps for better accuracy
replicationLagMs = int64(primaryLSN - req.CurrentLsn)
}
replica.ReplicationLag = time.Duration(replicationLagMs) * time.Millisecond
return &kevo.ReplicaHeartbeatResponse{
Success: true,
PrimaryLsn: primaryLSN,
ReplicationLagMs: replicationLagMs,
}, nil
}
// GetReplicaStatus retrieves the status of a specific replica
func (s *ReplicationServiceServer) GetReplicaStatus(
ctx context.Context,
req *kevo.GetReplicaStatusRequest,
) (*kevo.GetReplicaStatusResponse, error) {
// Validate request
if req.ReplicaId == "" {
return nil, status.Error(codes.InvalidArgument, "replica_id is required")
}
// Get replica info
s.replicasMutex.RLock()
defer s.replicasMutex.RUnlock()
replica, exists := s.replicas[req.ReplicaId]
if !exists {
return nil, status.Error(codes.NotFound, "replica not found")
}
// Convert replica info to proto message
pbReplica := convertReplicaInfoToProto(replica)
return &kevo.GetReplicaStatusResponse{
Replica: pbReplica,
}, nil
}
// ListReplicas retrieves the status of all replicas
func (s *ReplicationServiceServer) ListReplicas(
ctx context.Context,
req *kevo.ListReplicasRequest,
) (*kevo.ListReplicasResponse, error) {
s.replicasMutex.RLock()
defer s.replicasMutex.RUnlock()
// Convert all replicas to proto messages
pbReplicas := make([]*kevo.ReplicaInfo, 0, len(s.replicas))
for _, replica := range s.replicas {
pbReplicas = append(pbReplicas, convertReplicaInfoToProto(replica))
}
return &kevo.ListReplicasResponse{
Replicas: pbReplicas,
}, nil
}
// GetWALEntries handles requests for WAL entries
func (s *ReplicationServiceServer) GetWALEntries(
ctx context.Context,
req *kevo.GetWALEntriesRequest,
) (*kevo.GetWALEntriesResponse, error) {
// Validate request
if req.ReplicaId == "" {
return nil, status.Error(codes.InvalidArgument, "replica_id is required")
}
// Check if replica is registered
s.replicasMutex.RLock()
_, exists := s.replicas[req.ReplicaId]
s.replicasMutex.RUnlock()
if !exists {
return nil, status.Error(codes.NotFound, "replica not registered")
}
// Get entries from replicator
entries, err := s.replicator.GetEntriesAfter(replication.ReplicationPosition{Timestamp: req.FromLsn})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get entries: %v", err)
}
if len(entries) == 0 {
return &kevo.GetWALEntriesResponse{
Batch: &kevo.WALEntryBatch{},
HasMore: false,
}, nil
}
// Convert entries to proto messages
pbEntries := make([]*kevo.WALEntry, 0, len(entries))
for _, entry := range entries {
pbEntries = append(pbEntries, convertWALEntryToProto(entry))
}
// Create batch
pbBatch := &kevo.WALEntryBatch{
Entries: pbEntries,
FirstLsn: entries[0].SequenceNumber,
LastLsn: entries[len(entries)-1].SequenceNumber,
Count: uint32(len(entries)),
}
// Check if there are more entries
hasMore := s.replicator.GetHighestTimestamp() > entries[len(entries)-1].SequenceNumber
return &kevo.GetWALEntriesResponse{
Batch: pbBatch,
HasMore: hasMore,
}, nil
}
// StreamWALEntries handles streaming WAL entries to a replica
func (s *ReplicationServiceServer) StreamWALEntries(
req *kevo.StreamWALEntriesRequest,
stream kevo.ReplicationService_StreamWALEntriesServer,
) error {
// Validate request
if req.ReplicaId == "" {
return status.Error(codes.InvalidArgument, "replica_id is required")
}
// Check if replica is registered
s.replicasMutex.RLock()
_, exists := s.replicas[req.ReplicaId]
s.replicasMutex.RUnlock()
if !exists {
return status.Error(codes.NotFound, "replica not registered")
}
// Process entries in batches
fromLSN := req.FromLsn
batchSize := 100 // Can be configurable
notifierCh := make(chan struct{}, 1)
// Register for notifications of new entries
var processor *entryNotifier
if req.Continuous {
processor = &entryNotifier{
notifyCh: notifierCh,
fromLSN: fromLSN,
}
s.replicator.AddProcessor(processor)
defer s.replicator.RemoveProcessor(processor)
}
// Initial send of available entries
for {
// Get batch of entries
entries, err := s.replicator.GetEntriesAfter(replication.ReplicationPosition{Timestamp: fromLSN})
if err != nil {
return err
}
if len(entries) == 0 {
// No more entries, check if continuous streaming
if !req.Continuous {
break
}
// Wait for notification of new entries
select {
case <-notifierCh:
continue
case <-stream.Context().Done():
return stream.Context().Err()
}
}
// Create batch message
pbEntries := make([]*kevo.WALEntry, 0, len(entries))
for _, entry := range entries {
pbEntries = append(pbEntries, convertWALEntryToProto(entry))
}
pbBatch := &kevo.WALEntryBatch{
Entries: pbEntries,
FirstLsn: entries[0].SequenceNumber,
LastLsn: entries[len(entries)-1].SequenceNumber,
Count: uint32(len(entries)),
}
// Send batch
if err := stream.Send(pbBatch); err != nil {
return err
}
// Update fromLSN for next batch
fromLSN = entries[len(entries)-1].SequenceNumber + 1
// If not continuous, break after sending all available entries
if !req.Continuous && len(entries) < batchSize {
break
}
}
return nil
}
// ReportAppliedEntries handles reports of entries applied by replicas
func (s *ReplicationServiceServer) ReportAppliedEntries(
ctx context.Context,
req *kevo.ReportAppliedEntriesRequest,
) (*kevo.ReportAppliedEntriesResponse, error) {
// Validate request
if req.ReplicaId == "" {
return nil, status.Error(codes.InvalidArgument, "replica_id is required")
}
// Update replica LSN
s.replicasMutex.Lock()
replica, exists := s.replicas[req.ReplicaId]
if exists {
replica.CurrentLSN = req.AppliedLsn
}
s.replicasMutex.Unlock()
if !exists {
return nil, status.Error(codes.NotFound, "replica not registered")
}
return &kevo.ReportAppliedEntriesResponse{
Success: true,
PrimaryLsn: s.replicator.GetHighestTimestamp(),
}, nil
}
// RequestBootstrap handles bootstrap requests from replicas
func (s *ReplicationServiceServer) RequestBootstrap(
req *kevo.BootstrapRequest,
stream kevo.ReplicationService_RequestBootstrapServer,
) error {
// Validate request
if req.ReplicaId == "" {
return status.Error(codes.InvalidArgument, "replica_id is required")
}
// Check if replica is registered
s.replicasMutex.RLock()
replica, exists := s.replicas[req.ReplicaId]
s.replicasMutex.RUnlock()
if !exists {
return status.Error(codes.NotFound, "replica not registered")
}
// Update replica status
s.replicasMutex.Lock()
replica.Status = transport.StatusBootstrapping
s.replicasMutex.Unlock()
// Create snapshot of current data
snapshotLSN := s.replicator.GetHighestTimestamp()
iterator, err := s.storageSnapshot.CreateSnapshotIterator()
if err != nil {
s.replicasMutex.Lock()
replica.Status = transport.StatusError
replica.Error = err
s.replicasMutex.Unlock()
return status.Errorf(codes.Internal, "failed to create snapshot: %v", err)
}
defer iterator.Close()
// Stream key-value pairs in batches
batchSize := 100 // Can be configurable
totalCount := s.storageSnapshot.KeyCount()
sentCount := 0
batch := make([]*kevo.KeyValuePair, 0, batchSize)
for {
// Get next key-value pair
key, value, err := iterator.Next()
if err == io.EOF {
break
}
if err != nil {
s.replicasMutex.Lock()
replica.Status = transport.StatusError
replica.Error = err
s.replicasMutex.Unlock()
return status.Errorf(codes.Internal, "error reading snapshot: %v", err)
}
// Add to batch
batch = append(batch, &kevo.KeyValuePair{
Key: key,
Value: value,
})
// Send batch if full
if len(batch) >= batchSize {
progress := float32(sentCount) / float32(totalCount)
if err := stream.Send(&kevo.BootstrapBatch{
Pairs: batch,
Progress: progress,
IsLast: false,
SnapshotLsn: snapshotLSN,
}); err != nil {
return err
}
// Reset batch and update count
sentCount += len(batch)
batch = batch[:0]
}
}
// Send final batch
if len(batch) > 0 {
sentCount += len(batch)
progress := float32(sentCount) / float32(totalCount)
if err := stream.Send(&kevo.BootstrapBatch{
Pairs: batch,
Progress: progress,
IsLast: true,
SnapshotLsn: snapshotLSN,
}); err != nil {
return err
}
} else if sentCount > 0 {
// Send empty final batch to mark the end
if err := stream.Send(&kevo.BootstrapBatch{
Pairs: []*kevo.KeyValuePair{},
Progress: 1.0,
IsLast: true,
SnapshotLsn: snapshotLSN,
}); err != nil {
return err
}
}
// Update replica status
s.replicasMutex.Lock()
replica.Status = transport.StatusSyncing
replica.CurrentLSN = snapshotLSN
s.replicasMutex.Unlock()
return nil
}
// Helper to convert replica info to proto message
func convertReplicaInfoToProto(replica *transport.ReplicaInfo) *kevo.ReplicaInfo {
// Convert status to proto enum
var pbStatus kevo.ReplicaStatus
switch replica.Status {
case transport.StatusConnecting:
pbStatus = kevo.ReplicaStatus_CONNECTING
case transport.StatusSyncing:
pbStatus = kevo.ReplicaStatus_SYNCING
case transport.StatusBootstrapping:
pbStatus = kevo.ReplicaStatus_BOOTSTRAPPING
case transport.StatusReady:
pbStatus = kevo.ReplicaStatus_READY
case transport.StatusDisconnected:
pbStatus = kevo.ReplicaStatus_DISCONNECTED
case transport.StatusError:
pbStatus = kevo.ReplicaStatus_ERROR
default:
pbStatus = kevo.ReplicaStatus_DISCONNECTED
}
// Convert role to proto enum
var pbRole kevo.ReplicaRole
switch replica.Role {
case transport.RolePrimary:
pbRole = kevo.ReplicaRole_PRIMARY
case transport.RoleReplica:
pbRole = kevo.ReplicaRole_REPLICA
case transport.RoleReadOnly:
pbRole = kevo.ReplicaRole_READ_ONLY
default:
pbRole = kevo.ReplicaRole_REPLICA
}
// Create proto message
pbReplica := &kevo.ReplicaInfo{
ReplicaId: replica.ID,
Address: replica.Address,
Role: pbRole,
Status: pbStatus,
LastSeenMs: replica.LastSeen.UnixMilli(),
CurrentLsn: replica.CurrentLSN,
ReplicationLagMs: replica.ReplicationLag.Milliseconds(),
}
// Add error message if any
if replica.Error != nil {
pbReplica.ErrorMessage = replica.Error.Error()
}
return pbReplica
}
// Convert WAL entry to proto message
func convertWALEntryToProto(entry *wal.Entry) *kevo.WALEntry {
return &kevo.WALEntry{
SequenceNumber: entry.SequenceNumber,
Type: uint32(entry.Type),
Key: entry.Key,
Value: entry.Value,
// We'd normally calculate a checksum here
Checksum: nil,
}
}
// entryNotifier is a helper struct that implements replication.EntryProcessor
// to notify when new entries are available
type entryNotifier struct {
notifyCh chan struct{}
fromLSN uint64
}
func (n *entryNotifier) ProcessEntry(entry *wal.Entry) error {
if entry.SequenceNumber >= n.fromLSN {
select {
case n.notifyCh <- struct{}{}:
default:
// Channel already has a notification, no need to send another
}
}
return nil
}
func (n *entryNotifier) ProcessBatch(entries []*wal.Entry) error {
if len(entries) > 0 && entries[len(entries)-1].SequenceNumber >= n.fromLSN {
select {
case n.notifyCh <- struct{}{}:
default:
// Channel already has a notification, no need to send another
}
}
return nil
}
// Define the interface for storage snapshot operations
// This would normally be implemented by the storage engine
type StorageSnapshot interface {
// CreateSnapshotIterator creates an iterator for a storage snapshot
CreateSnapshotIterator() (SnapshotIterator, error)
// KeyCount returns the approximate number of keys in storage
KeyCount() int64
}
// SnapshotIterator provides iteration over key-value pairs in storage
type SnapshotIterator interface {
// Next returns the next key-value pair
Next() (key []byte, value []byte, err error)
// Close closes the iterator
Close() error
}

View File

@ -0,0 +1,494 @@
package transport
import (
"context"
"io"
"sync"
"time"
"github.com/KevoDB/kevo/pkg/replication"
"github.com/KevoDB/kevo/pkg/transport"
"github.com/KevoDB/kevo/pkg/wal"
"github.com/KevoDB/kevo/proto/kevo"
"google.golang.org/grpc"
"google.golang.org/grpc/connectivity"
)
// ReplicationGRPCClient implements the ReplicationClient interface using gRPC
type ReplicationGRPCClient struct {
conn *grpc.ClientConn
client kevo.ReplicationServiceClient
endpoint string
options transport.TransportOptions
replicaID string
status transport.TransportStatus
applier *replication.WALApplier
serializer *replication.EntrySerializer
highestAppliedLSN uint64
currentLSN uint64
mu sync.RWMutex
}
// NewReplicationGRPCClient creates a new ReplicationGRPCClient
func NewReplicationGRPCClient(
endpoint string,
options transport.TransportOptions,
replicaID string,
applier *replication.WALApplier,
serializer *replication.EntrySerializer,
) (*ReplicationGRPCClient, error) {
return &ReplicationGRPCClient{
endpoint: endpoint,
options: options,
replicaID: replicaID,
applier: applier,
serializer: serializer,
status: transport.TransportStatus{
Connected: false,
LastConnected: time.Time{},
LastError: nil,
BytesSent: 0,
BytesReceived: 0,
RTT: 0,
},
}, nil
}
// Connect establishes a connection to the server
func (c *ReplicationGRPCClient) Connect(ctx context.Context) error {
c.mu.Lock()
defer c.mu.Unlock()
// Set up connection options
dialOptions := []grpc.DialOption{
grpc.WithBlock(),
}
// Add TLS if configured - TODO: Add TLS support once TLS helpers are implemented
if c.options.TLSEnabled {
// We'll need to implement TLS credentials loading
// For now, we'll skip TLS
c.options.TLSEnabled = false
} else {
dialOptions = append(dialOptions, grpc.WithInsecure())
}
// Set timeout for connection
dialCtx, cancel := context.WithTimeout(ctx, c.options.Timeout)
defer cancel()
// Establish connection
conn, err := grpc.DialContext(dialCtx, c.endpoint, dialOptions...)
if err != nil {
c.status.LastError = err
return err
}
c.conn = conn
c.client = kevo.NewReplicationServiceClient(conn)
c.status.Connected = true
c.status.LastConnected = time.Now()
return nil
}
// Close closes the connection
func (c *ReplicationGRPCClient) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
if c.conn != nil {
err := c.conn.Close()
c.conn = nil
c.client = nil
c.status.Connected = false
if err != nil {
c.status.LastError = err
return err
}
}
return nil
}
// IsConnected returns whether the client is connected
func (c *ReplicationGRPCClient) IsConnected() bool {
c.mu.RLock()
defer c.mu.RUnlock()
if c.conn == nil {
return false
}
state := c.conn.GetState()
return state == connectivity.Ready || state == connectivity.Idle
}
// Status returns the current status of the connection
func (c *ReplicationGRPCClient) Status() transport.TransportStatus {
c.mu.RLock()
defer c.mu.RUnlock()
return c.status
}
// Send sends a request and waits for a response
func (c *ReplicationGRPCClient) Send(ctx context.Context, request transport.Request) (transport.Response, error) {
// Implementation depends on specific replication messages
// This is a placeholder that would be completed for each message type
return nil, transport.ErrInvalidRequest
}
// Stream opens a bidirectional stream
func (c *ReplicationGRPCClient) Stream(ctx context.Context) (transport.Stream, error) {
// Not implemented for replication client
return nil, transport.ErrInvalidRequest
}
// RegisterAsReplica registers this client as a replica with the primary
func (c *ReplicationGRPCClient) RegisterAsReplica(ctx context.Context, replicaID string) error {
c.mu.Lock()
defer c.mu.Unlock()
if c.client == nil {
return transport.ErrNotConnected
}
// Create registration request
req := &kevo.RegisterReplicaRequest{
ReplicaId: replicaID,
Address: c.endpoint, // Use client endpoint as the replica address
Role: kevo.ReplicaRole_REPLICA,
}
// Call the service
resp, err := c.client.RegisterReplica(ctx, req)
if err != nil {
c.status.LastError = err
return err
}
// Update client info based on response
c.replicaID = replicaID
c.currentLSN = resp.CurrentLsn
return nil
}
// SendHeartbeat sends a heartbeat to the primary
func (c *ReplicationGRPCClient) SendHeartbeat(ctx context.Context, info *transport.ReplicaInfo) error {
c.mu.Lock()
defer c.mu.Unlock()
if c.client == nil {
return transport.ErrNotConnected
}
// Convert status to proto enum
var pbStatus kevo.ReplicaStatus
switch info.Status {
case transport.StatusConnecting:
pbStatus = kevo.ReplicaStatus_CONNECTING
case transport.StatusSyncing:
pbStatus = kevo.ReplicaStatus_SYNCING
case transport.StatusBootstrapping:
pbStatus = kevo.ReplicaStatus_BOOTSTRAPPING
case transport.StatusReady:
pbStatus = kevo.ReplicaStatus_READY
case transport.StatusDisconnected:
pbStatus = kevo.ReplicaStatus_DISCONNECTED
case transport.StatusError:
pbStatus = kevo.ReplicaStatus_ERROR
default:
pbStatus = kevo.ReplicaStatus_DISCONNECTED
}
// Create heartbeat request
req := &kevo.ReplicaHeartbeatRequest{
ReplicaId: c.replicaID,
Status: pbStatus,
CurrentLsn: c.highestAppliedLSN,
ErrorMessage: "",
}
// Add error message if any
if info.Error != nil {
req.ErrorMessage = info.Error.Error()
}
// Call the service
resp, err := c.client.ReplicaHeartbeat(ctx, req)
if err != nil {
c.status.LastError = err
return err
}
// Update client info based on response
c.currentLSN = resp.PrimaryLsn
return nil
}
// RequestWALEntries requests WAL entries from the primary starting from a specific LSN
func (c *ReplicationGRPCClient) RequestWALEntries(ctx context.Context, fromLSN uint64) ([]*wal.Entry, error) {
c.mu.Lock()
defer c.mu.Unlock()
if c.client == nil {
return nil, transport.ErrNotConnected
}
// Create request
req := &kevo.GetWALEntriesRequest{
ReplicaId: c.replicaID,
FromLsn: fromLSN,
MaxEntries: 1000, // Configurable
}
// Call the service
resp, err := c.client.GetWALEntries(ctx, req)
if err != nil {
c.status.LastError = err
return nil, err
}
// Convert proto entries to WAL entries
entries := make([]*wal.Entry, 0, len(resp.Batch.Entries))
for _, pbEntry := range resp.Batch.Entries {
entry := &wal.Entry{
SequenceNumber: pbEntry.SequenceNumber,
Type: uint8(pbEntry.Type),
Key: pbEntry.Key,
Value: pbEntry.Value,
}
entries = append(entries, entry)
}
return entries, nil
}
// RequestBootstrap requests a snapshot for bootstrap purposes
func (c *ReplicationGRPCClient) RequestBootstrap(ctx context.Context) (transport.BootstrapIterator, error) {
c.mu.Lock()
defer c.mu.Unlock()
if c.client == nil {
return nil, transport.ErrNotConnected
}
// Create request
req := &kevo.BootstrapRequest{
ReplicaId: c.replicaID,
}
// Call the service
stream, err := c.client.RequestBootstrap(ctx, req)
if err != nil {
c.status.LastError = err
return nil, err
}
// Create and return bootstrap iterator
return &GRPCBootstrapIterator{
stream: stream,
totalPairs: 0,
seenPairs: 0,
progress: 0.0,
mu: &sync.Mutex{},
applier: c.applier,
}, nil
}
// StartReplicationStream starts a stream for continuous replication
func (c *ReplicationGRPCClient) StartReplicationStream(ctx context.Context) error {
if !c.IsConnected() {
return transport.ErrNotConnected
}
// Get current highest applied LSN
c.mu.Lock()
fromLSN := c.highestAppliedLSN
c.mu.Unlock()
// Create stream request
req := &kevo.StreamWALEntriesRequest{
ReplicaId: c.replicaID,
FromLsn: fromLSN,
Continuous: true,
}
// Start streaming
stream, err := c.client.StreamWALEntries(ctx, req)
if err != nil {
return err
}
// Process stream in a goroutine
go c.processWALStream(ctx, stream)
return nil
}
// processWALStream handles the incoming WAL entry stream
func (c *ReplicationGRPCClient) processWALStream(ctx context.Context, stream kevo.ReplicationService_StreamWALEntriesClient) {
for {
// Check if context is cancelled
select {
case <-ctx.Done():
return
default:
// Continue processing
}
// Receive next batch
batch, err := stream.Recv()
if err == io.EOF {
// Stream completed normally
return
}
if err != nil {
// Stream error
c.mu.Lock()
c.status.LastError = err
c.mu.Unlock()
return
}
// Process entries in batch
entries := make([]*wal.Entry, 0, len(batch.Entries))
for _, pbEntry := range batch.Entries {
entry := &wal.Entry{
SequenceNumber: pbEntry.SequenceNumber,
Type: uint8(pbEntry.Type),
Key: pbEntry.Key,
Value: pbEntry.Value,
}
entries = append(entries, entry)
}
// Apply entries
if len(entries) > 0 {
_, err = c.applier.ApplyBatch(entries)
if err != nil {
c.mu.Lock()
c.status.LastError = err
c.mu.Unlock()
return
}
// Update highest applied LSN
c.mu.Lock()
c.highestAppliedLSN = batch.LastLsn
c.mu.Unlock()
// Report applied entries
go c.reportAppliedEntries(context.Background(), batch.LastLsn)
}
}
}
// reportAppliedEntries reports the highest applied LSN to the primary
func (c *ReplicationGRPCClient) reportAppliedEntries(ctx context.Context, appliedLSN uint64) {
c.mu.RLock()
client := c.client
replicaID := c.replicaID
c.mu.RUnlock()
if client == nil {
return
}
// Create request
req := &kevo.ReportAppliedEntriesRequest{
ReplicaId: replicaID,
AppliedLsn: appliedLSN,
}
// Call the service
_, err := client.ReportAppliedEntries(ctx, req)
if err != nil {
// Just log error, don't return it
c.mu.Lock()
c.status.LastError = err
c.mu.Unlock()
}
}
// GRPCBootstrapIterator implements the BootstrapIterator interface for gRPC
type GRPCBootstrapIterator struct {
stream kevo.ReplicationService_RequestBootstrapClient
currentBatch *kevo.BootstrapBatch
batchIndex int
totalPairs int
seenPairs int
progress float64
mu *sync.Mutex
applier *replication.WALApplier
}
// Next returns the next key-value pair
func (it *GRPCBootstrapIterator) Next() ([]byte, []byte, error) {
it.mu.Lock()
defer it.mu.Unlock()
// If we have a current batch and there are more pairs
if it.currentBatch != nil && it.batchIndex < len(it.currentBatch.Pairs) {
pair := it.currentBatch.Pairs[it.batchIndex]
it.batchIndex++
it.seenPairs++
return pair.Key, pair.Value, nil
}
// Need to get a new batch
batch, err := it.stream.Recv()
if err == io.EOF {
return nil, nil, io.EOF
}
if err != nil {
return nil, nil, err
}
// Update progress
it.currentBatch = batch
it.batchIndex = 0
it.progress = float64(batch.Progress)
// If batch is empty and it's the last one
if len(batch.Pairs) == 0 && batch.IsLast {
// Store the snapshot LSN for later use
if it.applier != nil {
it.applier.ResetHighestApplied(batch.SnapshotLsn)
}
return nil, nil, io.EOF
}
// If batch is empty but not the last one, try again
if len(batch.Pairs) == 0 {
return it.Next()
}
// Return the first pair from the new batch
pair := batch.Pairs[it.batchIndex]
it.batchIndex++
it.seenPairs++
return pair.Key, pair.Value, nil
}
// Close closes the iterator
func (it *GRPCBootstrapIterator) Close() error {
it.mu.Lock()
defer it.mu.Unlock()
// Store the snapshot LSN if we have a current batch and it's the last one
if it.currentBatch != nil && it.currentBatch.IsLast && it.applier != nil {
it.applier.ResetHighestApplied(it.currentBatch.SnapshotLsn)
}
return nil
}
// Progress returns the progress of the bootstrap operation (0.0-1.0)
func (it *GRPCBootstrapIterator) Progress() float64 {
it.mu.Lock()
defer it.mu.Unlock()
return it.progress
}

View File

@ -0,0 +1,200 @@
package transport
import (
"context"
"fmt"
"sync"
"github.com/KevoDB/kevo/pkg/grpc/service"
"github.com/KevoDB/kevo/pkg/replication"
"github.com/KevoDB/kevo/pkg/transport"
)
// ReplicationGRPCServer implements the ReplicationServer interface using gRPC
type ReplicationGRPCServer struct {
transportManager *GRPCTransportManager
replicationService *service.ReplicationServiceServer
options transport.TransportOptions
replicas map[string]*transport.ReplicaInfo
mu sync.RWMutex
}
// NewReplicationGRPCServer creates a new ReplicationGRPCServer
func NewReplicationGRPCServer(
transportManager *GRPCTransportManager,
replicator *replication.WALReplicator,
applier *replication.WALApplier,
serializer *replication.EntrySerializer,
storageSnapshot replication.StorageSnapshot,
options transport.TransportOptions,
) (*ReplicationGRPCServer, error) {
// Create replication service
replicationService := service.NewReplicationService(
replicator,
applier,
serializer,
storageSnapshot,
)
return &ReplicationGRPCServer{
transportManager: transportManager,
replicationService: replicationService,
options: options,
replicas: make(map[string]*transport.ReplicaInfo),
}, nil
}
// Start starts the server and returns immediately
func (s *ReplicationGRPCServer) Start() error {
// Register the replication service with the transport manager
if err := s.transportManager.RegisterService(s.replicationService); err != nil {
return fmt.Errorf("failed to register replication service: %w", err)
}
// Start the transport manager if it's not already started
if err := s.transportManager.Start(); err != nil {
return fmt.Errorf("failed to start transport manager: %w", err)
}
return nil
}
// Serve starts the server and blocks until it's stopped
func (s *ReplicationGRPCServer) Serve() error {
if err := s.Start(); err != nil {
return err
}
// This will block until the context is cancelled
<-context.Background().Done()
return nil
}
// Stop stops the server gracefully
func (s *ReplicationGRPCServer) Stop(ctx context.Context) error {
return s.transportManager.Stop(ctx)
}
// SetRequestHandler sets the handler for incoming requests
// Not used for the replication server as it uses a dedicated service
func (s *ReplicationGRPCServer) SetRequestHandler(handler transport.RequestHandler) {
// No-op for replication server
}
// RegisterReplica registers a new replica
func (s *ReplicationGRPCServer) RegisterReplica(replicaInfo *transport.ReplicaInfo) error {
s.mu.Lock()
defer s.mu.Unlock()
s.replicas[replicaInfo.ID] = replicaInfo
return nil
}
// UpdateReplicaStatus updates the status of a replica
func (s *ReplicationGRPCServer) UpdateReplicaStatus(replicaID string, status transport.ReplicaStatus, lsn uint64) error {
s.mu.Lock()
defer s.mu.Unlock()
replica, exists := s.replicas[replicaID]
if !exists {
return fmt.Errorf("replica not found: %s", replicaID)
}
replica.Status = status
replica.CurrentLSN = lsn
return nil
}
// GetReplicaInfo returns information about a specific replica
func (s *ReplicationGRPCServer) GetReplicaInfo(replicaID string) (*transport.ReplicaInfo, error) {
s.mu.RLock()
defer s.mu.RUnlock()
replica, exists := s.replicas[replicaID]
if !exists {
return nil, fmt.Errorf("replica not found: %s", replicaID)
}
return replica, nil
}
// ListReplicas returns information about all connected replicas
func (s *ReplicationGRPCServer) ListReplicas() ([]*transport.ReplicaInfo, error) {
s.mu.RLock()
defer s.mu.RUnlock()
result := make([]*transport.ReplicaInfo, 0, len(s.replicas))
for _, replica := range s.replicas {
result = append(result, replica)
}
return result, nil
}
// StreamWALEntriesToReplica streams WAL entries to a specific replica
func (s *ReplicationGRPCServer) StreamWALEntriesToReplica(ctx context.Context, replicaID string, fromLSN uint64) error {
// This is handled by the gRPC service directly
return nil
}
// RegisterReplicationTransport registers the gRPC replication transport with the registry
func init() {
// Register replication server factory
transport.RegisterReplicationServerTransport("grpc", func(address string, options transport.TransportOptions) (transport.ReplicationServer, error) {
// Create gRPC transport manager
grpcOptions := &GRPCTransportOptions{
ListenAddr: address,
ConnectionTimeout: options.Timeout,
DialTimeout: options.Timeout,
}
// Add TLS configuration if enabled
if options.TLSEnabled {
tlsConfig, err := LoadServerTLSConfig(options.CertFile, options.KeyFile, options.CAFile)
if err != nil {
return nil, fmt.Errorf("failed to load TLS config: %w", err)
}
grpcOptions.TLSConfig = tlsConfig
}
// Create transport manager
manager, err := NewGRPCTransportManager(grpcOptions)
if err != nil {
return nil, fmt.Errorf("failed to create gRPC transport manager: %w", err)
}
// For registration, we return a placeholder that will be properly initialized
// by the caller with the required components
return &ReplicationGRPCServer{
transportManager: manager,
options: options,
replicas: make(map[string]*transport.ReplicaInfo),
}, nil
})
// Register replication client factory
transport.RegisterReplicationClientTransport("grpc", func(endpoint string, options transport.TransportOptions) (transport.ReplicationClient, error) {
// For registration, we return a placeholder that will be properly initialized
// by the caller with the required components
return &ReplicationGRPCClient{
endpoint: endpoint,
options: options,
}, nil
})
}
// WithReplicator adds a replicator to the replication server
func (s *ReplicationGRPCServer) WithReplicator(
replicator *replication.WALReplicator,
applier *replication.WALApplier,
serializer *replication.EntrySerializer,
storageSnapshot replication.StorageSnapshot,
) *ReplicationGRPCServer {
s.replicationService = service.NewReplicationService(
replicator,
applier,
serializer,
storageSnapshot,
)
return s
}

268
pkg/replication/applier.go Normal file
View File

@ -0,0 +1,268 @@
package replication
import (
"errors"
"fmt"
"sync"
"sync/atomic"
"github.com/KevoDB/kevo/pkg/wal"
)
var (
// ErrApplierClosed indicates the applier has been closed
ErrApplierClosed = errors.New("applier is closed")
// ErrOutOfOrderEntry indicates an entry was received out of order
ErrOutOfOrderEntry = errors.New("out of order entry")
// ErrInvalidEntryType indicates an unknown or invalid entry type
ErrInvalidEntryType = errors.New("invalid entry type")
// ErrStorageError indicates a storage-related error occurred
ErrStorageError = errors.New("storage error")
)
// StorageInterface defines the minimum storage requirements for the WALApplier
type StorageInterface interface {
Put(key, value []byte) error
Delete(key []byte) error
}
// WALApplier applies WAL entries from a primary node to a replica's storage
type WALApplier struct {
// Storage is the local storage where entries will be applied
storage StorageInterface
// highestApplied is the highest Lamport timestamp of entries that have been applied
highestApplied uint64
// pendingEntries holds entries that were received out of order and are waiting for earlier entries
pendingEntries map[uint64]*wal.Entry
// appliedCount tracks the number of entries successfully applied
appliedCount uint64
// skippedCount tracks the number of entries skipped (already applied)
skippedCount uint64
// errorCount tracks the number of application errors
errorCount uint64
// closed indicates whether the applier is closed
closed int32
// concurrency management
mu sync.RWMutex
}
// NewWALApplier creates a new WAL applier
func NewWALApplier(storage StorageInterface) *WALApplier {
return &WALApplier{
storage: storage,
highestApplied: 0,
pendingEntries: make(map[uint64]*wal.Entry),
}
}
// Apply applies a single WAL entry to the local storage
// Returns whether the entry was applied and any error
func (a *WALApplier) Apply(entry *wal.Entry) (bool, error) {
if atomic.LoadInt32(&a.closed) == 1 {
return false, ErrApplierClosed
}
a.mu.Lock()
defer a.mu.Unlock()
// Check if this entry has already been applied
if entry.SequenceNumber <= a.highestApplied {
atomic.AddUint64(&a.skippedCount, 1)
return false, nil
}
// Check if this is the next entry we're expecting
if entry.SequenceNumber != a.highestApplied+1 {
// Entry is out of order, store it for later application
a.pendingEntries[entry.SequenceNumber] = entry
return false, nil
}
// Apply this entry and any subsequent pending entries
return a.applyEntryAndPending(entry)
}
// ApplyBatch applies a batch of WAL entries to the local storage
// Returns the number of entries applied and any error
func (a *WALApplier) ApplyBatch(entries []*wal.Entry) (int, error) {
if atomic.LoadInt32(&a.closed) == 1 {
return 0, ErrApplierClosed
}
if len(entries) == 0 {
return 0, nil
}
a.mu.Lock()
defer a.mu.Unlock()
// Sort entries by sequence number
sortEntriesByTimestamp(entries)
// Track how many entries we actually apply
applied := 0
// Process each entry
for _, entry := range entries {
// Skip already applied entries
if entry.SequenceNumber <= a.highestApplied {
atomic.AddUint64(&a.skippedCount, 1)
continue
}
// Check if this is the next entry we're expecting
if entry.SequenceNumber == a.highestApplied+1 {
// Apply this entry and any subsequent pending entries
wasApplied, err := a.applyEntryAndPending(entry)
if err != nil {
return applied, err
}
if wasApplied {
applied++
}
} else {
// Entry is out of order, store it for later application
a.pendingEntries[entry.SequenceNumber] = entry
}
}
return applied, nil
}
// GetHighestApplied returns the highest applied Lamport timestamp
func (a *WALApplier) GetHighestApplied() uint64 {
a.mu.RLock()
defer a.mu.RUnlock()
return a.highestApplied
}
// GetStats returns statistics about the applier
func (a *WALApplier) GetStats() map[string]uint64 {
a.mu.RLock()
pendingCount := len(a.pendingEntries)
highestApplied := a.highestApplied
a.mu.RUnlock()
return map[string]uint64{
"appliedCount": atomic.LoadUint64(&a.appliedCount),
"skippedCount": atomic.LoadUint64(&a.skippedCount),
"errorCount": atomic.LoadUint64(&a.errorCount),
"pendingCount": uint64(pendingCount),
"highestApplied": highestApplied,
}
}
// Close closes the applier
func (a *WALApplier) Close() error {
if !atomic.CompareAndSwapInt32(&a.closed, 0, 1) {
return nil // Already closed
}
a.mu.Lock()
defer a.mu.Unlock()
// Clear any pending entries
a.pendingEntries = make(map[uint64]*wal.Entry)
return nil
}
// applyEntryAndPending applies the given entry and any pending entries that become applicable
// Caller must hold the lock
func (a *WALApplier) applyEntryAndPending(entry *wal.Entry) (bool, error) {
// Apply the current entry
if err := a.applyEntryToStorage(entry); err != nil {
atomic.AddUint64(&a.errorCount, 1)
return false, err
}
// Update highest applied timestamp
a.highestApplied = entry.SequenceNumber
atomic.AddUint64(&a.appliedCount, 1)
// Check for pending entries that can now be applied
nextSeq := a.highestApplied + 1
for {
nextEntry, exists := a.pendingEntries[nextSeq]
if !exists {
break
}
// Apply this pending entry
if err := a.applyEntryToStorage(nextEntry); err != nil {
atomic.AddUint64(&a.errorCount, 1)
return true, err
}
// Update highest applied and remove from pending
a.highestApplied = nextSeq
delete(a.pendingEntries, nextSeq)
atomic.AddUint64(&a.appliedCount, 1)
// Look for the next one
nextSeq++
}
return true, nil
}
// applyEntryToStorage applies a single WAL entry to the storage engine
// Caller must hold the lock
func (a *WALApplier) applyEntryToStorage(entry *wal.Entry) error {
switch entry.Type {
case wal.OpTypePut:
if err := a.storage.Put(entry.Key, entry.Value); err != nil {
return fmt.Errorf("%w: %v", ErrStorageError, err)
}
case wal.OpTypeDelete:
if err := a.storage.Delete(entry.Key); err != nil {
return fmt.Errorf("%w: %v", ErrStorageError, err)
}
case wal.OpTypeBatch:
// In the WAL the batch entry itself doesn't contain the operations
// Actual batch operations should be received as separate entries
// with sequential sequence numbers
return nil
default:
return fmt.Errorf("%w: type %d", ErrInvalidEntryType, entry.Type)
}
return nil
}
// PendingEntryCount returns the number of pending entries
func (a *WALApplier) PendingEntryCount() int {
a.mu.RLock()
defer a.mu.RUnlock()
return len(a.pendingEntries)
}
// ResetHighestApplied allows manually setting a new highest applied value
// This should only be used during initialization or recovery
func (a *WALApplier) ResetHighestApplied(value uint64) {
a.mu.Lock()
defer a.mu.Unlock()
a.highestApplied = value
}
// HasEntry checks if a specific entry timestamp has been applied or is pending
func (a *WALApplier) HasEntry(timestamp uint64) bool {
a.mu.RLock()
defer a.mu.RUnlock()
if timestamp <= a.highestApplied {
return true
}
_, exists := a.pendingEntries[timestamp]
return exists
}

View File

@ -0,0 +1,526 @@
package replication
import (
"errors"
"sync"
"testing"
"github.com/KevoDB/kevo/pkg/common/iterator"
"github.com/KevoDB/kevo/pkg/wal"
)
// MockStorage implements a simple mock storage for testing
type MockStorage struct {
mu sync.Mutex
data map[string][]byte
putFail bool
deleteFail bool
putCount int
deleteCount int
lastPutKey []byte
lastPutValue []byte
lastDeleteKey []byte
}
func NewMockStorage() *MockStorage {
return &MockStorage{
data: make(map[string][]byte),
}
}
func (m *MockStorage) Put(key, value []byte) error {
m.mu.Lock()
defer m.mu.Unlock()
if m.putFail {
return errors.New("simulated put failure")
}
m.putCount++
m.lastPutKey = append([]byte{}, key...)
m.lastPutValue = append([]byte{}, value...)
m.data[string(key)] = append([]byte{}, value...)
return nil
}
func (m *MockStorage) Get(key []byte) ([]byte, error) {
m.mu.Lock()
defer m.mu.Unlock()
value, ok := m.data[string(key)]
if !ok {
return nil, errors.New("key not found")
}
return value, nil
}
func (m *MockStorage) Delete(key []byte) error {
m.mu.Lock()
defer m.mu.Unlock()
if m.deleteFail {
return errors.New("simulated delete failure")
}
m.deleteCount++
m.lastDeleteKey = append([]byte{}, key...)
delete(m.data, string(key))
return nil
}
// Stub implementations for the rest of the interface
func (m *MockStorage) Close() error { return nil }
func (m *MockStorage) IsDeleted(key []byte) (bool, error) { return false, nil }
func (m *MockStorage) GetIterator() (iterator.Iterator, error) { return nil, nil }
func (m *MockStorage) GetRangeIterator(startKey, endKey []byte) (iterator.Iterator, error) { return nil, nil }
func (m *MockStorage) ApplyBatch(entries []*wal.Entry) error { return nil }
func (m *MockStorage) FlushMemTables() error { return nil }
func (m *MockStorage) GetMemTableSize() uint64 { return 0 }
func (m *MockStorage) IsFlushNeeded() bool { return false }
func (m *MockStorage) GetSSTables() []string { return nil }
func (m *MockStorage) ReloadSSTables() error { return nil }
func (m *MockStorage) RotateWAL() error { return nil }
func (m *MockStorage) GetStorageStats() map[string]interface{} { return nil }
func TestWALApplierBasic(t *testing.T) {
storage := NewMockStorage()
applier := NewWALApplier(storage)
defer applier.Close()
// Create test entries
entries := []*wal.Entry{
{
SequenceNumber: 1,
Type: wal.OpTypePut,
Key: []byte("key1"),
Value: []byte("value1"),
},
{
SequenceNumber: 2,
Type: wal.OpTypePut,
Key: []byte("key2"),
Value: []byte("value2"),
},
{
SequenceNumber: 3,
Type: wal.OpTypeDelete,
Key: []byte("key1"),
},
}
// Apply entries one by one
for i, entry := range entries {
applied, err := applier.Apply(entry)
if err != nil {
t.Fatalf("Error applying entry %d: %v", i, err)
}
if !applied {
t.Errorf("Entry %d should have been applied", i)
}
}
// Check state
if got := applier.GetHighestApplied(); got != 3 {
t.Errorf("Expected highest applied 3, got %d", got)
}
// Check storage state
if value, _ := storage.Get([]byte("key2")); string(value) != "value2" {
t.Errorf("Expected key2=value2 in storage, got %q", value)
}
// key1 should be deleted
if _, err := storage.Get([]byte("key1")); err == nil {
t.Errorf("Expected key1 to be deleted")
}
// Check stats
stats := applier.GetStats()
if stats["appliedCount"] != 3 {
t.Errorf("Expected appliedCount=3, got %d", stats["appliedCount"])
}
if stats["pendingCount"] != 0 {
t.Errorf("Expected pendingCount=0, got %d", stats["pendingCount"])
}
}
func TestWALApplierOutOfOrder(t *testing.T) {
storage := NewMockStorage()
applier := NewWALApplier(storage)
defer applier.Close()
// Apply entries out of order
entries := []*wal.Entry{
{
SequenceNumber: 2,
Type: wal.OpTypePut,
Key: []byte("key2"),
Value: []byte("value2"),
},
{
SequenceNumber: 3,
Type: wal.OpTypePut,
Key: []byte("key3"),
Value: []byte("value3"),
},
{
SequenceNumber: 1,
Type: wal.OpTypePut,
Key: []byte("key1"),
Value: []byte("value1"),
},
}
// Apply entry with sequence 2 - should be stored as pending
applied, err := applier.Apply(entries[0])
if err != nil {
t.Fatalf("Error applying entry: %v", err)
}
if applied {
t.Errorf("Entry with seq 2 should not have been applied yet")
}
// Apply entry with sequence 3 - should be stored as pending
applied, err = applier.Apply(entries[1])
if err != nil {
t.Fatalf("Error applying entry: %v", err)
}
if applied {
t.Errorf("Entry with seq 3 should not have been applied yet")
}
// Check pending count
if pending := applier.PendingEntryCount(); pending != 2 {
t.Errorf("Expected 2 pending entries, got %d", pending)
}
// Now apply entry with sequence 1 - should trigger all entries to be applied
applied, err = applier.Apply(entries[2])
if err != nil {
t.Fatalf("Error applying entry: %v", err)
}
if !applied {
t.Errorf("Entry with seq 1 should have been applied")
}
// Check state - all entries should be applied now
if got := applier.GetHighestApplied(); got != 3 {
t.Errorf("Expected highest applied 3, got %d", got)
}
// Pending count should be 0
if pending := applier.PendingEntryCount(); pending != 0 {
t.Errorf("Expected 0 pending entries, got %d", pending)
}
// Check storage contains all values
values := []struct {
key string
value string
}{
{"key1", "value1"},
{"key2", "value2"},
{"key3", "value3"},
}
for _, v := range values {
if val, err := storage.Get([]byte(v.key)); err != nil || string(val) != v.value {
t.Errorf("Expected %s=%s in storage, got %s, err=%v", v.key, v.value, val, err)
}
}
}
func TestWALApplierBatch(t *testing.T) {
storage := NewMockStorage()
applier := NewWALApplier(storage)
defer applier.Close()
// Create a batch of entries
batch := []*wal.Entry{
{
SequenceNumber: 3,
Type: wal.OpTypePut,
Key: []byte("key3"),
Value: []byte("value3"),
},
{
SequenceNumber: 1,
Type: wal.OpTypePut,
Key: []byte("key1"),
Value: []byte("value1"),
},
{
SequenceNumber: 2,
Type: wal.OpTypePut,
Key: []byte("key2"),
Value: []byte("value2"),
},
}
// Apply batch - entries should be sorted by sequence number
applied, err := applier.ApplyBatch(batch)
if err != nil {
t.Fatalf("Error applying batch: %v", err)
}
// All 3 entries should be applied
if applied != 3 {
t.Errorf("Expected 3 entries applied, got %d", applied)
}
// Check highest applied
if got := applier.GetHighestApplied(); got != 3 {
t.Errorf("Expected highest applied 3, got %d", got)
}
// Check all values in storage
values := []struct {
key string
value string
}{
{"key1", "value1"},
{"key2", "value2"},
{"key3", "value3"},
}
for _, v := range values {
if val, err := storage.Get([]byte(v.key)); err != nil || string(val) != v.value {
t.Errorf("Expected %s=%s in storage, got %s, err=%v", v.key, v.value, val, err)
}
}
}
func TestWALApplierAlreadyApplied(t *testing.T) {
storage := NewMockStorage()
applier := NewWALApplier(storage)
defer applier.Close()
// Apply an entry
entry := &wal.Entry{
SequenceNumber: 1,
Type: wal.OpTypePut,
Key: []byte("key1"),
Value: []byte("value1"),
}
applied, err := applier.Apply(entry)
if err != nil {
t.Fatalf("Error applying entry: %v", err)
}
if !applied {
t.Errorf("Entry should have been applied")
}
// Try to apply the same entry again
applied, err = applier.Apply(entry)
if err != nil {
t.Fatalf("Error applying entry: %v", err)
}
if applied {
t.Errorf("Entry should not have been applied a second time")
}
// Check stats
stats := applier.GetStats()
if stats["appliedCount"] != 1 {
t.Errorf("Expected appliedCount=1, got %d", stats["appliedCount"])
}
if stats["skippedCount"] != 1 {
t.Errorf("Expected skippedCount=1, got %d", stats["skippedCount"])
}
}
func TestWALApplierError(t *testing.T) {
storage := NewMockStorage()
storage.putFail = true
applier := NewWALApplier(storage)
defer applier.Close()
entry := &wal.Entry{
SequenceNumber: 1,
Type: wal.OpTypePut,
Key: []byte("key1"),
Value: []byte("value1"),
}
// Apply should return an error
_, err := applier.Apply(entry)
if err == nil {
t.Errorf("Expected error from Apply, got nil")
}
// Check error count
stats := applier.GetStats()
if stats["errorCount"] != 1 {
t.Errorf("Expected errorCount=1, got %d", stats["errorCount"])
}
// Fix storage and try again
storage.putFail = false
// Apply should succeed
applied, err := applier.Apply(entry)
if err != nil {
t.Fatalf("Error applying entry: %v", err)
}
if !applied {
t.Errorf("Entry should have been applied")
}
}
func TestWALApplierInvalidType(t *testing.T) {
storage := NewMockStorage()
applier := NewWALApplier(storage)
defer applier.Close()
entry := &wal.Entry{
SequenceNumber: 1,
Type: 99, // Invalid type
Key: []byte("key1"),
Value: []byte("value1"),
}
// Apply should return an error
_, err := applier.Apply(entry)
if err == nil || !errors.Is(err, ErrInvalidEntryType) {
t.Errorf("Expected invalid entry type error, got %v", err)
}
}
func TestWALApplierClose(t *testing.T) {
storage := NewMockStorage()
applier := NewWALApplier(storage)
// Apply an entry
entry := &wal.Entry{
SequenceNumber: 1,
Type: wal.OpTypePut,
Key: []byte("key1"),
Value: []byte("value1"),
}
applied, err := applier.Apply(entry)
if err != nil {
t.Fatalf("Error applying entry: %v", err)
}
if !applied {
t.Errorf("Entry should have been applied")
}
// Close the applier
if err := applier.Close(); err != nil {
t.Fatalf("Error closing applier: %v", err)
}
// Try to apply another entry
_, err = applier.Apply(&wal.Entry{
SequenceNumber: 2,
Type: wal.OpTypePut,
Key: []byte("key2"),
Value: []byte("value2"),
})
if err == nil || !errors.Is(err, ErrApplierClosed) {
t.Errorf("Expected applier closed error, got %v", err)
}
}
func TestWALApplierResetHighest(t *testing.T) {
storage := NewMockStorage()
applier := NewWALApplier(storage)
defer applier.Close()
// Manually set the highest applied to 10
applier.ResetHighestApplied(10)
// Check value
if got := applier.GetHighestApplied(); got != 10 {
t.Errorf("Expected highest applied 10, got %d", got)
}
// Try to apply an entry with sequence 10
applied, err := applier.Apply(&wal.Entry{
SequenceNumber: 10,
Type: wal.OpTypePut,
Key: []byte("key10"),
Value: []byte("value10"),
})
if err != nil {
t.Fatalf("Error applying entry: %v", err)
}
if applied {
t.Errorf("Entry with seq 10 should have been skipped")
}
// Apply an entry with sequence 11
applied, err = applier.Apply(&wal.Entry{
SequenceNumber: 11,
Type: wal.OpTypePut,
Key: []byte("key11"),
Value: []byte("value11"),
})
if err != nil {
t.Fatalf("Error applying entry: %v", err)
}
if !applied {
t.Errorf("Entry with seq 11 should have been applied")
}
// Check new highest
if got := applier.GetHighestApplied(); got != 11 {
t.Errorf("Expected highest applied 11, got %d", got)
}
}
func TestWALApplierHasEntry(t *testing.T) {
storage := NewMockStorage()
applier := NewWALApplier(storage)
defer applier.Close()
// Apply an entry with sequence 1
applied, err := applier.Apply(&wal.Entry{
SequenceNumber: 1,
Type: wal.OpTypePut,
Key: []byte("key1"),
Value: []byte("value1"),
})
if err != nil {
t.Fatalf("Error applying entry: %v", err)
}
if !applied {
t.Errorf("Entry should have been applied")
}
// Add a pending entry with sequence 3
_, err = applier.Apply(&wal.Entry{
SequenceNumber: 3,
Type: wal.OpTypePut,
Key: []byte("key3"),
Value: []byte("value3"),
})
if err != nil {
t.Fatalf("Error applying entry: %v", err)
}
// Check has entry
testCases := []struct {
timestamp uint64
expected bool
}{
{0, true},
{1, true},
{2, false},
{3, true},
{4, false},
}
for _, tc := range testCases {
if got := applier.HasEntry(tc.timestamp); got != tc.expected {
t.Errorf("HasEntry(%d) = %v, want %v", tc.timestamp, got, tc.expected)
}
}
}

View File

@ -42,4 +42,4 @@ func (c *LamportClock) Current() uint64 {
c.mu.Lock()
defer c.mu.Unlock()
return c.counter
}
}

View File

@ -7,12 +7,12 @@ import (
func TestLamportClockTick(t *testing.T) {
clock := NewLamportClock()
// Initial tick should return 1
if ts := clock.Tick(); ts != 1 {
t.Errorf("First tick should return 1, got %d", ts)
}
// Second tick should return 2
if ts := clock.Tick(); ts != 2 {
t.Errorf("Second tick should return 2, got %d", ts)
@ -21,25 +21,25 @@ func TestLamportClockTick(t *testing.T) {
func TestLamportClockUpdate(t *testing.T) {
clock := NewLamportClock()
// Update with lower value should increment
ts := clock.Update(0)
if ts != 1 {
t.Errorf("Update with lower value should return 1, got %d", ts)
}
// Update with same value should increment
ts = clock.Update(1)
if ts != 2 {
t.Errorf("Update with same value should return 2, got %d", ts)
}
// Update with higher value should use that value and increment
ts = clock.Update(10)
if ts != 11 {
t.Errorf("Update with higher value should return 11, got %d", ts)
}
// Subsequent tick should continue from updated value
ts = clock.Tick()
if ts != 12 {
@ -49,18 +49,18 @@ func TestLamportClockUpdate(t *testing.T) {
func TestLamportClockCurrent(t *testing.T) {
clock := NewLamportClock()
// Initial current should be 0
if ts := clock.Current(); ts != 0 {
t.Errorf("Initial current should be 0, got %d", ts)
}
// After tick, current should reflect new value
clock.Tick()
if ts := clock.Current(); ts != 1 {
t.Errorf("Current after tick should be 1, got %d", ts)
}
// Current should not increment the clock
if ts := clock.Current(); ts != 1 {
t.Errorf("Multiple calls to Current should return same value, got %d", ts)
@ -71,7 +71,7 @@ func TestLamportClockConcurrency(t *testing.T) {
clock := NewLamportClock()
iterations := 1000
var wg sync.WaitGroup
// Run multiple goroutines calling Tick concurrently
wg.Add(iterations)
for i := 0; i < iterations; i++ {
@ -81,10 +81,10 @@ func TestLamportClockConcurrency(t *testing.T) {
}()
}
wg.Wait()
// After iterations concurrent ticks, value should be iterations
if ts := clock.Current(); ts != uint64(iterations) {
t.Errorf("After %d concurrent ticks, expected value %d, got %d",
t.Errorf("After %d concurrent ticks, expected value %d, got %d",
iterations, iterations, ts)
}
}
}

View File

@ -0,0 +1,373 @@
package replication
import (
"errors"
"sync"
"sync/atomic"
"github.com/KevoDB/kevo/pkg/wal"
)
var (
// ErrReplicatorClosed indicates the replicator has been closed and no longer accepts entries
ErrReplicatorClosed = errors.New("replicator is closed")
// ErrReplicatorFull indicates the replicator's entry buffer is full
ErrReplicatorFull = errors.New("replicator entry buffer is full")
// ErrInvalidPosition indicates an invalid replication position was provided
ErrInvalidPosition = errors.New("invalid replication position")
)
// EntryProcessor is an interface for components that process WAL entries for replication
type EntryProcessor interface {
// ProcessEntry processes a single WAL entry
ProcessEntry(entry *wal.Entry) error
// ProcessBatch processes a batch of WAL entries
ProcessBatch(entries []*wal.Entry) error
}
// ReplicationPosition represents a position in the replication stream
type ReplicationPosition struct {
// Timestamp is the Lamport timestamp of the position
Timestamp uint64
}
// WALReplicator captures WAL entries and makes them available for replication
type WALReplicator struct {
// Entries is a map of timestamp -> entry for all captured entries
entries map[uint64]*wal.Entry
// Batches is a map of batch start timestamp -> batch entries
batches map[uint64][]*wal.Entry
// EntryChannel is a channel of captured entries for subscribers
entryChannel chan *wal.Entry
// BatchChannel is a channel of captured batches for subscribers
batchChannel chan []*wal.Entry
// Highest timestamp seen so far
highestTimestamp uint64
// MaxBufferedEntries is the maximum number of entries to buffer
maxBufferedEntries int
// Concurrency control
mu sync.RWMutex
// Closed indicates if the replicator is closed
closed int32
// EntryProcessors are components that process entries as they're captured
processors []EntryProcessor
}
// NewWALReplicator creates a new WAL replicator
func NewWALReplicator(maxBufferedEntries int) *WALReplicator {
if maxBufferedEntries <= 0 {
maxBufferedEntries = 10000 // Default to 10,000 entries
}
return &WALReplicator{
entries: make(map[uint64]*wal.Entry),
batches: make(map[uint64][]*wal.Entry),
entryChannel: make(chan *wal.Entry, 1000),
batchChannel: make(chan []*wal.Entry, 100),
maxBufferedEntries: maxBufferedEntries,
processors: make([]EntryProcessor, 0),
}
}
// OnEntryWritten implements the wal.ReplicationHook interface
func (r *WALReplicator) OnEntryWritten(entry *wal.Entry) {
if atomic.LoadInt32(&r.closed) == 1 {
return
}
r.mu.Lock()
// Update highest timestamp
if entry.SequenceNumber > r.highestTimestamp {
r.highestTimestamp = entry.SequenceNumber
}
// Store the entry (make a copy to avoid potential mutation)
entryCopy := &wal.Entry{
SequenceNumber: entry.SequenceNumber,
Type: entry.Type,
Key: append([]byte{}, entry.Key...),
}
if entry.Value != nil {
entryCopy.Value = append([]byte{}, entry.Value...)
}
r.entries[entryCopy.SequenceNumber] = entryCopy
// Cleanup old entries if we exceed the buffer size
if len(r.entries) > r.maxBufferedEntries {
r.cleanupOldestEntries(r.maxBufferedEntries / 10) // Remove ~10% of entries
}
r.mu.Unlock()
// Send to channel (non-blocking)
select {
case r.entryChannel <- entryCopy:
// Successfully sent
default:
// Channel full, skip sending but entry is still stored
}
// Process the entry
r.processEntry(entryCopy)
}
// OnBatchWritten implements the wal.ReplicationHook interface
func (r *WALReplicator) OnBatchWritten(entries []*wal.Entry) {
if atomic.LoadInt32(&r.closed) == 1 || len(entries) == 0 {
return
}
r.mu.Lock()
// Make copies to avoid potential mutation
entriesCopy := make([]*wal.Entry, len(entries))
batchTimestamp := entries[0].SequenceNumber
for i, entry := range entries {
entriesCopy[i] = &wal.Entry{
SequenceNumber: entry.SequenceNumber,
Type: entry.Type,
Key: append([]byte{}, entry.Key...),
}
if entry.Value != nil {
entriesCopy[i].Value = append([]byte{}, entry.Value...)
}
// Store individual entry
r.entries[entriesCopy[i].SequenceNumber] = entriesCopy[i]
// Update highest timestamp
if entry.SequenceNumber > r.highestTimestamp {
r.highestTimestamp = entry.SequenceNumber
}
}
// Store the batch
r.batches[batchTimestamp] = entriesCopy
// Cleanup old entries if we exceed the buffer size
if len(r.entries) > r.maxBufferedEntries {
r.cleanupOldestEntries(r.maxBufferedEntries / 10)
}
// Cleanup old batches if we have too many
if len(r.batches) > r.maxBufferedEntries/10 {
r.cleanupOldestBatches(r.maxBufferedEntries / 100)
}
r.mu.Unlock()
// Send to batch channel (non-blocking)
select {
case r.batchChannel <- entriesCopy:
// Successfully sent
default:
// Channel full, skip sending but entries are still stored
}
// Process the batch
r.processBatch(entriesCopy)
}
// GetHighestTimestamp returns the highest timestamp seen so far
func (r *WALReplicator) GetHighestTimestamp() uint64 {
r.mu.RLock()
defer r.mu.RUnlock()
return r.highestTimestamp
}
// GetEntriesAfter returns all entries with timestamps greater than the given position
func (r *WALReplicator) GetEntriesAfter(position ReplicationPosition) ([]*wal.Entry, error) {
if atomic.LoadInt32(&r.closed) == 1 {
return nil, ErrReplicatorClosed
}
r.mu.RLock()
defer r.mu.RUnlock()
// Create a result slice with appropriate capacity
result := make([]*wal.Entry, 0, min(100, len(r.entries)))
// Find all entries with timestamps greater than the position
for timestamp, entry := range r.entries {
if timestamp > position.Timestamp {
result = append(result, entry)
}
}
// Sort the entries by timestamp
sortEntriesByTimestamp(result)
return result, nil
}
// GetEntryCount returns the number of entries currently stored
func (r *WALReplicator) GetEntryCount() int {
r.mu.RLock()
defer r.mu.RUnlock()
return len(r.entries)
}
// GetBatchCount returns the number of batches currently stored
func (r *WALReplicator) GetBatchCount() int {
r.mu.RLock()
defer r.mu.RUnlock()
return len(r.batches)
}
// SubscribeToEntries returns a channel that receives entries as they're captured
func (r *WALReplicator) SubscribeToEntries() <-chan *wal.Entry {
return r.entryChannel
}
// SubscribeToBatches returns a channel that receives batches as they're captured
func (r *WALReplicator) SubscribeToBatches() <-chan []*wal.Entry {
return r.batchChannel
}
// AddProcessor adds an EntryProcessor to receive entries as they're captured
func (r *WALReplicator) AddProcessor(processor EntryProcessor) {
if atomic.LoadInt32(&r.closed) == 1 {
return
}
r.mu.Lock()
defer r.mu.Unlock()
r.processors = append(r.processors, processor)
}
// Close closes the replicator and its channels
func (r *WALReplicator) Close() error {
// Set closed flag
if !atomic.CompareAndSwapInt32(&r.closed, 0, 1) {
return nil // Already closed
}
// Close channels
close(r.entryChannel)
close(r.batchChannel)
// Clear entries and batches
r.mu.Lock()
defer r.mu.Unlock()
r.entries = make(map[uint64]*wal.Entry)
r.batches = make(map[uint64][]*wal.Entry)
r.processors = nil
return nil
}
// cleanupOldestEntries removes the oldest entries from the buffer
func (r *WALReplicator) cleanupOldestEntries(count int) {
// Find the oldest timestamps
oldestTimestamps := findOldestTimestamps(r.entries, count)
// Remove the oldest entries
for _, ts := range oldestTimestamps {
delete(r.entries, ts)
}
}
// cleanupOldestBatches removes the oldest batches from the buffer
func (r *WALReplicator) cleanupOldestBatches(count int) {
// Find the oldest timestamps
oldestTimestamps := findOldestTimestamps(r.batches, count)
// Remove the oldest batches
for _, ts := range oldestTimestamps {
delete(r.batches, ts)
}
}
// processEntry sends the entry to all registered processors
func (r *WALReplicator) processEntry(entry *wal.Entry) {
r.mu.RLock()
processors := r.processors
r.mu.RUnlock()
for _, processor := range processors {
_ = processor.ProcessEntry(entry) // Ignore errors for now
}
}
// processBatch sends the batch to all registered processors
func (r *WALReplicator) processBatch(entries []*wal.Entry) {
r.mu.RLock()
processors := r.processors
r.mu.RUnlock()
for _, processor := range processors {
_ = processor.ProcessBatch(entries) // Ignore errors for now
}
}
// findOldestTimestamps finds the n oldest timestamps in a map
func findOldestTimestamps[T any](m map[uint64]T, n int) []uint64 {
if len(m) <= n {
// If we don't have enough entries, return all timestamps
result := make([]uint64, 0, len(m))
for ts := range m {
result = append(result, ts)
}
return result
}
// Find the n smallest timestamps
result := make([]uint64, 0, n)
for ts := range m {
if len(result) < n {
// Add to result if we don't have enough yet
result = append(result, ts)
} else {
// Find the largest timestamp in our result
largestIdx := 0
for i, t := range result {
if t > result[largestIdx] {
largestIdx = i
}
}
// Replace the largest with this one if it's smaller
if ts < result[largestIdx] {
result[largestIdx] = ts
}
}
}
return result
}
// sortEntriesByTimestamp sorts a slice of entries by their timestamps
func sortEntriesByTimestamp(entries []*wal.Entry) {
// Simple insertion sort for small slices
for i := 1; i < len(entries); i++ {
j := i
for j > 0 && entries[j-1].SequenceNumber > entries[j].SequenceNumber {
entries[j], entries[j-1] = entries[j-1], entries[j]
j--
}
}
}
// min returns the smaller of two integers
func min(a, b int) int {
if a < b {
return a
}
return b
}

View File

@ -0,0 +1,46 @@
package replication
// No imports needed
// processorIndex finds the index of a processor in the processors slice
// Returns -1 if not found
func (r *WALReplicator) processorIndex(target EntryProcessor) int {
r.mu.RLock()
defer r.mu.RUnlock()
for i, p := range r.processors {
if p == target {
return i
}
}
return -1
}
// RemoveProcessor removes an EntryProcessor from the replicator
func (r *WALReplicator) RemoveProcessor(processor EntryProcessor) {
if processor == nil {
return
}
r.mu.Lock()
defer r.mu.Unlock()
// Find the processor in the slice
idx := -1
for i, p := range r.processors {
if p == processor {
idx = i
break
}
}
// If found, remove it
if idx >= 0 {
// Remove the element by replacing it with the last element and truncating
lastIdx := len(r.processors) - 1
if idx < lastIdx {
r.processors[idx] = r.processors[lastIdx]
}
r.processors = r.processors[:lastIdx]
}
}

View File

@ -0,0 +1,401 @@
package replication
import (
"sync"
"testing"
"time"
"github.com/KevoDB/kevo/pkg/wal"
)
// MockEntryProcessor implements the EntryProcessor interface for testing
type MockEntryProcessor struct {
mu sync.Mutex
processedEntries []*wal.Entry
processedBatches [][]*wal.Entry
entriesProcessed int
batchesProcessed int
failProcessEntry bool
failProcessBatch bool
}
func (m *MockEntryProcessor) ProcessEntry(entry *wal.Entry) error {
m.mu.Lock()
defer m.mu.Unlock()
m.processedEntries = append(m.processedEntries, entry)
m.entriesProcessed++
if m.failProcessEntry {
return ErrReplicatorClosed // Just use an existing error
}
return nil
}
func (m *MockEntryProcessor) ProcessBatch(entries []*wal.Entry) error {
m.mu.Lock()
defer m.mu.Unlock()
m.processedBatches = append(m.processedBatches, entries)
m.batchesProcessed++
if m.failProcessBatch {
return ErrReplicatorClosed
}
return nil
}
func (m *MockEntryProcessor) GetStats() (int, int) {
m.mu.Lock()
defer m.mu.Unlock()
return m.entriesProcessed, m.batchesProcessed
}
func TestWALReplicatorBasic(t *testing.T) {
replicator := NewWALReplicator(1000)
defer replicator.Close()
// Create some test entries
entry1 := &wal.Entry{
SequenceNumber: 1,
Type: wal.OpTypePut,
Key: []byte("key1"),
Value: []byte("value1"),
}
entry2 := &wal.Entry{
SequenceNumber: 2,
Type: wal.OpTypePut,
Key: []byte("key2"),
Value: []byte("value2"),
}
// Process some entries
replicator.OnEntryWritten(entry1)
replicator.OnEntryWritten(entry2)
// Check entry count
if count := replicator.GetEntryCount(); count != 2 {
t.Errorf("Expected 2 entries, got %d", count)
}
// Check highest timestamp
if ts := replicator.GetHighestTimestamp(); ts != 2 {
t.Errorf("Expected highest timestamp 2, got %d", ts)
}
// Get entries after timestamp 0
entries, err := replicator.GetEntriesAfter(ReplicationPosition{Timestamp: 0})
if err != nil {
t.Fatalf("Error getting entries: %v", err)
}
if len(entries) != 2 {
t.Fatalf("Expected 2 entries after timestamp 0, got %d", len(entries))
}
// Check entries are sorted by timestamp
if entries[0].SequenceNumber != 1 || entries[1].SequenceNumber != 2 {
t.Errorf("Entries not sorted by timestamp")
}
// Get entries after timestamp 1
entries, err = replicator.GetEntriesAfter(ReplicationPosition{Timestamp: 1})
if err != nil {
t.Fatalf("Error getting entries: %v", err)
}
if len(entries) != 1 {
t.Fatalf("Expected 1 entry after timestamp 1, got %d", len(entries))
}
if entries[0].SequenceNumber != 2 {
t.Errorf("Expected entry with timestamp 2, got %d", entries[0].SequenceNumber)
}
}
func TestWALReplicatorBatches(t *testing.T) {
replicator := NewWALReplicator(1000)
defer replicator.Close()
// Create a batch of entries
entries := []*wal.Entry{
{
SequenceNumber: 10,
Type: wal.OpTypePut,
Key: []byte("key1"),
Value: []byte("value1"),
},
{
SequenceNumber: 11,
Type: wal.OpTypePut,
Key: []byte("key2"),
Value: []byte("value2"),
},
}
// Process the batch
replicator.OnBatchWritten(entries)
// Check entry count
if count := replicator.GetEntryCount(); count != 2 {
t.Errorf("Expected 2 entries, got %d", count)
}
// Check batch count
if count := replicator.GetBatchCount(); count != 1 {
t.Errorf("Expected 1 batch, got %d", count)
}
// Check highest timestamp
if ts := replicator.GetHighestTimestamp(); ts != 11 {
t.Errorf("Expected highest timestamp 11, got %d", ts)
}
// Get entries after timestamp 9
result, err := replicator.GetEntriesAfter(ReplicationPosition{Timestamp: 9})
if err != nil {
t.Fatalf("Error getting entries: %v", err)
}
if len(result) != 2 {
t.Fatalf("Expected 2 entries after timestamp 9, got %d", len(result))
}
}
func TestWALReplicatorProcessors(t *testing.T) {
replicator := NewWALReplicator(1000)
defer replicator.Close()
// Create a processor
processor := &MockEntryProcessor{}
// Add the processor
replicator.AddProcessor(processor)
// Create an entry and a batch
entry := &wal.Entry{
SequenceNumber: 1,
Type: wal.OpTypePut,
Key: []byte("key1"),
Value: []byte("value1"),
}
batch := []*wal.Entry{
{
SequenceNumber: 10,
Type: wal.OpTypePut,
Key: []byte("key10"),
Value: []byte("value10"),
},
{
SequenceNumber: 11,
Type: wal.OpTypePut,
Key: []byte("key11"),
Value: []byte("value11"),
},
}
// Process the entry and batch
replicator.OnEntryWritten(entry)
replicator.OnBatchWritten(batch)
// Check processor stats
entriesProcessed, batchesProcessed := processor.GetStats()
if entriesProcessed != 1 {
t.Errorf("Expected 1 entry processed, got %d", entriesProcessed)
}
if batchesProcessed != 1 {
t.Errorf("Expected 1 batch processed, got %d", batchesProcessed)
}
}
func TestWALReplicatorSubscribe(t *testing.T) {
replicator := NewWALReplicator(1000)
defer replicator.Close()
// Subscribe to entries and batches
entryChannel := replicator.SubscribeToEntries()
batchChannel := replicator.SubscribeToBatches()
// Create an entry and a batch
entry := &wal.Entry{
SequenceNumber: 1,
Type: wal.OpTypePut,
Key: []byte("key1"),
Value: []byte("value1"),
}
batch := []*wal.Entry{
{
SequenceNumber: 10,
Type: wal.OpTypePut,
Key: []byte("key10"),
Value: []byte("value10"),
},
{
SequenceNumber: 11,
Type: wal.OpTypePut,
Key: []byte("key11"),
Value: []byte("value11"),
},
}
// Create channels to receive the results
entryReceived := make(chan *wal.Entry, 1)
batchReceived := make(chan []*wal.Entry, 1)
// Start goroutines to receive from the channels
go func() {
select {
case e := <-entryChannel:
entryReceived <- e
case <-time.After(time.Second):
close(entryReceived)
}
}()
go func() {
select {
case b := <-batchChannel:
batchReceived <- b
case <-time.After(time.Second):
close(batchReceived)
}
}()
// Process the entry and batch
replicator.OnEntryWritten(entry)
replicator.OnBatchWritten(batch)
// Check that we received the entry
select {
case receivedEntry := <-entryReceived:
if receivedEntry.SequenceNumber != 1 {
t.Errorf("Expected entry with timestamp 1, got %d", receivedEntry.SequenceNumber)
}
case <-time.After(time.Second):
t.Errorf("Timeout waiting for entry")
}
// Check that we received the batch
select {
case receivedBatch := <-batchReceived:
if len(receivedBatch) != 2 {
t.Errorf("Expected batch with 2 entries, got %d", len(receivedBatch))
}
case <-time.After(time.Second):
t.Errorf("Timeout waiting for batch")
}
}
func TestWALReplicatorCleanup(t *testing.T) {
// Create a replicator with a small buffer
replicator := NewWALReplicator(10)
defer replicator.Close()
// Add more entries than the buffer can hold
for i := 0; i < 20; i++ {
entry := &wal.Entry{
SequenceNumber: uint64(i),
Type: wal.OpTypePut,
Key: []byte("key"),
Value: []byte("value"),
}
replicator.OnEntryWritten(entry)
}
// Check that some entries were cleaned up
count := replicator.GetEntryCount()
if count > 20 {
t.Errorf("Expected fewer than 20 entries after cleanup, got %d", count)
}
// The most recent entries should still be there
entries, err := replicator.GetEntriesAfter(ReplicationPosition{Timestamp: 15})
if err != nil {
t.Fatalf("Error getting entries: %v", err)
}
if len(entries) == 0 {
t.Errorf("Expected some entries after timestamp 15")
}
}
func TestWALReplicatorClose(t *testing.T) {
replicator := NewWALReplicator(1000)
// Add some entries
entry := &wal.Entry{
SequenceNumber: 1,
Type: wal.OpTypePut,
Key: []byte("key"),
Value: []byte("value"),
}
replicator.OnEntryWritten(entry)
// Close the replicator
if err := replicator.Close(); err != nil {
t.Fatalf("Error closing replicator: %v", err)
}
// Check that we can't add more entries
replicator.OnEntryWritten(entry)
// Entry count should still be 0 after closure and cleanup
if count := replicator.GetEntryCount(); count != 0 {
t.Errorf("Expected 0 entries after close, got %d", count)
}
// Try to get entries (should return an error)
_, err := replicator.GetEntriesAfter(ReplicationPosition{Timestamp: 0})
if err != ErrReplicatorClosed {
t.Errorf("Expected ErrReplicatorClosed, got %v", err)
}
}
func TestFindOldestTimestamps(t *testing.T) {
// Create a map with some timestamps
m := map[uint64]string{
1: "one",
2: "two",
3: "three",
4: "four",
5: "five",
}
// Find the 2 oldest timestamps
result := findOldestTimestamps(m, 2)
// Check the result length
if len(result) != 2 {
t.Fatalf("Expected 2 timestamps, got %d", len(result))
}
// Check that the result contains the 2 smallest timestamps
for _, ts := range result {
if ts != 1 && ts != 2 {
t.Errorf("Expected timestamp 1 or 2, got %d", ts)
}
}
}
func TestSortEntriesByTimestamp(t *testing.T) {
// Create some entries with unsorted timestamps
entries := []*wal.Entry{
{SequenceNumber: 3},
{SequenceNumber: 1},
{SequenceNumber: 2},
}
// Sort the entries
sortEntriesByTimestamp(entries)
// Check that the entries are sorted
for i := 0; i < len(entries)-1; i++ {
if entries[i].SequenceNumber > entries[i+1].SequenceNumber {
t.Errorf("Entries not sorted at index %d: %d > %d",
i, entries[i].SequenceNumber, entries[i+1].SequenceNumber)
}
}
}

View File

@ -0,0 +1,358 @@
package replication
import (
"encoding/binary"
"errors"
"hash/crc32"
"github.com/KevoDB/kevo/pkg/wal"
)
var (
// ErrInvalidChecksum indicates a checksum validation failure during deserialization
ErrInvalidChecksum = errors.New("invalid checksum")
// ErrInvalidFormat indicates an invalid format of serialized data
ErrInvalidFormat = errors.New("invalid entry format")
// ErrBufferTooSmall indicates the provided buffer is too small for serialization
ErrBufferTooSmall = errors.New("buffer too small")
)
const (
// Entry serialization constants
entryHeaderSize = 17 // checksum(4) + timestamp(8) + type(1) + keylen(4)
// Additional 4 bytes for value length when not a delete operation
)
// EntrySerializer handles serialization and deserialization of WAL entries
type EntrySerializer struct {
// ChecksumEnabled controls whether checksums are calculated/verified
ChecksumEnabled bool
}
// NewEntrySerializer creates a new entry serializer
func NewEntrySerializer() *EntrySerializer {
return &EntrySerializer{
ChecksumEnabled: true,
}
}
// SerializeEntry converts a WAL entry to a byte slice
func (s *EntrySerializer) SerializeEntry(entry *wal.Entry) []byte {
// Calculate total size needed
totalSize := entryHeaderSize + len(entry.Key)
if entry.Value != nil {
totalSize += 4 + len(entry.Value) // vallen(4) + value
}
// Allocate buffer
data := make([]byte, totalSize)
offset := 4 // Skip first 4 bytes for checksum
// Write timestamp
binary.LittleEndian.PutUint64(data[offset:offset+8], entry.SequenceNumber)
offset += 8
// Write entry type
data[offset] = entry.Type
offset++
// Write key length and key
binary.LittleEndian.PutUint32(data[offset:offset+4], uint32(len(entry.Key)))
offset += 4
copy(data[offset:], entry.Key)
offset += len(entry.Key)
// Write value length and value (if present)
if entry.Value != nil {
binary.LittleEndian.PutUint32(data[offset:offset+4], uint32(len(entry.Value)))
offset += 4
copy(data[offset:], entry.Value)
}
// Calculate and store checksum if enabled
if s.ChecksumEnabled {
checksum := crc32.ChecksumIEEE(data[4:])
binary.LittleEndian.PutUint32(data[0:4], checksum)
}
return data
}
// SerializeEntryToBuffer serializes a WAL entry to an existing buffer
// Returns the number of bytes written or an error if the buffer is too small
func (s *EntrySerializer) SerializeEntryToBuffer(entry *wal.Entry, buffer []byte) (int, error) {
// Calculate total size needed
totalSize := entryHeaderSize + len(entry.Key)
if entry.Value != nil {
totalSize += 4 + len(entry.Value) // vallen(4) + value
}
// Check if buffer is large enough
if len(buffer) < totalSize {
return 0, ErrBufferTooSmall
}
// Write to buffer
offset := 4 // Skip first 4 bytes for checksum
// Write timestamp
binary.LittleEndian.PutUint64(buffer[offset:offset+8], entry.SequenceNumber)
offset += 8
// Write entry type
buffer[offset] = entry.Type
offset++
// Write key length and key
binary.LittleEndian.PutUint32(buffer[offset:offset+4], uint32(len(entry.Key)))
offset += 4
copy(buffer[offset:], entry.Key)
offset += len(entry.Key)
// Write value length and value (if present)
if entry.Value != nil {
binary.LittleEndian.PutUint32(buffer[offset:offset+4], uint32(len(entry.Value)))
offset += 4
copy(buffer[offset:], entry.Value)
offset += len(entry.Value)
}
// Calculate and store checksum if enabled
if s.ChecksumEnabled {
checksum := crc32.ChecksumIEEE(buffer[4:offset])
binary.LittleEndian.PutUint32(buffer[0:4], checksum)
}
return offset, nil
}
// DeserializeEntry converts a byte slice back to a WAL entry
func (s *EntrySerializer) DeserializeEntry(data []byte) (*wal.Entry, error) {
// Validate minimum size
if len(data) < entryHeaderSize {
return nil, ErrInvalidFormat
}
// Verify checksum if enabled
if s.ChecksumEnabled {
storedChecksum := binary.LittleEndian.Uint32(data[0:4])
calculatedChecksum := crc32.ChecksumIEEE(data[4:])
if storedChecksum != calculatedChecksum {
return nil, ErrInvalidChecksum
}
}
offset := 4 // Skip checksum
// Read timestamp
timestamp := binary.LittleEndian.Uint64(data[offset : offset+8])
offset += 8
// Read entry type
entryType := data[offset]
offset++
// Read key length and key
keyLen := binary.LittleEndian.Uint32(data[offset : offset+4])
offset += 4
// Validate key length
if offset+int(keyLen) > len(data) {
return nil, ErrInvalidFormat
}
key := make([]byte, keyLen)
copy(key, data[offset:offset+int(keyLen)])
offset += int(keyLen)
// Read value length and value if present
var value []byte
if offset < len(data) {
// Only read value if there's more data
if offset+4 > len(data) {
return nil, ErrInvalidFormat
}
valueLen := binary.LittleEndian.Uint32(data[offset : offset+4])
offset += 4
// Validate value length
if offset+int(valueLen) > len(data) {
return nil, ErrInvalidFormat
}
value = make([]byte, valueLen)
copy(value, data[offset:offset+int(valueLen)])
}
// Create and return the entry
return &wal.Entry{
SequenceNumber: timestamp,
Type: entryType,
Key: key,
Value: value,
}, nil
}
// BatchSerializer handles serialization of WAL entry batches
type BatchSerializer struct {
entrySerializer *EntrySerializer
}
// NewBatchSerializer creates a new batch serializer
func NewBatchSerializer() *BatchSerializer {
return &BatchSerializer{
entrySerializer: NewEntrySerializer(),
}
}
// SerializeBatch converts a batch of WAL entries to a byte slice
func (s *BatchSerializer) SerializeBatch(entries []*wal.Entry) []byte {
if len(entries) == 0 {
// Empty batch - just return header with count 0
result := make([]byte, 12) // checksum(4) + count(4) + timestamp(4)
binary.LittleEndian.PutUint32(result[4:8], 0)
// Calculate and store checksum
checksum := crc32.ChecksumIEEE(result[4:])
binary.LittleEndian.PutUint32(result[0:4], checksum)
return result
}
// First pass: calculate total size needed
var totalSize int = 12 // header: checksum(4) + count(4) + base timestamp(4)
for _, entry := range entries {
// For each entry: size(4) + serialized entry data
entrySize := entryHeaderSize + len(entry.Key)
if entry.Value != nil {
entrySize += 4 + len(entry.Value)
}
totalSize += 4 + entrySize
}
// Allocate buffer
result := make([]byte, totalSize)
offset := 4 // Skip checksum for now
// Write entry count
binary.LittleEndian.PutUint32(result[offset:offset+4], uint32(len(entries)))
offset += 4
// Write base timestamp (from first entry)
binary.LittleEndian.PutUint32(result[offset:offset+4], uint32(entries[0].SequenceNumber))
offset += 4
// Write each entry
for _, entry := range entries {
// Reserve space for entry size
sizeOffset := offset
offset += 4
// Serialize entry directly into the buffer
entrySize, err := s.entrySerializer.SerializeEntryToBuffer(entry, result[offset:])
if err != nil {
// This shouldn't happen since we pre-calculated the size,
// but handle it gracefully just in case
panic("buffer too small for entry serialization")
}
offset += entrySize
// Write the actual entry size
binary.LittleEndian.PutUint32(result[sizeOffset:sizeOffset+4], uint32(entrySize))
}
// Calculate and store checksum
checksum := crc32.ChecksumIEEE(result[4:offset])
binary.LittleEndian.PutUint32(result[0:4], checksum)
return result
}
// DeserializeBatch converts a byte slice back to a batch of WAL entries
func (s *BatchSerializer) DeserializeBatch(data []byte) ([]*wal.Entry, error) {
// Validate minimum size for batch header
if len(data) < 12 {
return nil, ErrInvalidFormat
}
// Verify checksum
storedChecksum := binary.LittleEndian.Uint32(data[0:4])
calculatedChecksum := crc32.ChecksumIEEE(data[4:])
if storedChecksum != calculatedChecksum {
return nil, ErrInvalidChecksum
}
offset := 4 // Skip checksum
// Read entry count
count := binary.LittleEndian.Uint32(data[offset:offset+4])
offset += 4
// Read base timestamp (we don't use this currently, but read past it)
offset += 4 // Skip base timestamp
// Early return for empty batch
if count == 0 {
return []*wal.Entry{}, nil
}
// Deserialize each entry
entries := make([]*wal.Entry, count)
for i := uint32(0); i < count; i++ {
// Validate we have enough data for entry size
if offset+4 > len(data) {
return nil, ErrInvalidFormat
}
// Read entry size
entrySize := binary.LittleEndian.Uint32(data[offset:offset+4])
offset += 4
// Validate entry size
if offset+int(entrySize) > len(data) {
return nil, ErrInvalidFormat
}
// Deserialize entry
entry, err := s.entrySerializer.DeserializeEntry(data[offset:offset+int(entrySize)])
if err != nil {
return nil, err
}
entries[i] = entry
offset += int(entrySize)
}
return entries, nil
}
// EstimateEntrySize estimates the serialized size of a WAL entry without actually serializing it
func EstimateEntrySize(entry *wal.Entry) int {
size := entryHeaderSize + len(entry.Key)
if entry.Value != nil {
size += 4 + len(entry.Value)
}
return size
}
// EstimateBatchSize estimates the serialized size of a batch of WAL entries
func EstimateBatchSize(entries []*wal.Entry) int {
if len(entries) == 0 {
return 12 // Empty batch header
}
size := 12 // Batch header: checksum(4) + count(4) + base timestamp(4)
for _, entry := range entries {
entrySize := EstimateEntrySize(entry)
size += 4 + entrySize // size field(4) + entry data
}
return size
}

View File

@ -0,0 +1,420 @@
package replication
import (
"bytes"
"encoding/binary"
"hash/crc32"
"testing"
"github.com/KevoDB/kevo/pkg/wal"
)
func TestEntrySerializer(t *testing.T) {
// Create a serializer
serializer := NewEntrySerializer()
// Test different entry types
testCases := []struct {
name string
entry *wal.Entry
}{
{
name: "Put operation",
entry: &wal.Entry{
SequenceNumber: 123,
Type: wal.OpTypePut,
Key: []byte("test-key"),
Value: []byte("test-value"),
},
},
{
name: "Delete operation",
entry: &wal.Entry{
SequenceNumber: 456,
Type: wal.OpTypeDelete,
Key: []byte("deleted-key"),
Value: nil,
},
},
{
name: "Large entry",
entry: &wal.Entry{
SequenceNumber: 789,
Type: wal.OpTypePut,
Key: bytes.Repeat([]byte("K"), 1000),
Value: bytes.Repeat([]byte("V"), 1000),
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Serialize the entry
data := serializer.SerializeEntry(tc.entry)
// Deserialize back
result, err := serializer.DeserializeEntry(data)
if err != nil {
t.Fatalf("Error deserializing entry: %v", err)
}
// Compare entries
if result.SequenceNumber != tc.entry.SequenceNumber {
t.Errorf("Expected sequence number %d, got %d",
tc.entry.SequenceNumber, result.SequenceNumber)
}
if result.Type != tc.entry.Type {
t.Errorf("Expected type %d, got %d", tc.entry.Type, result.Type)
}
if !bytes.Equal(result.Key, tc.entry.Key) {
t.Errorf("Expected key %q, got %q", tc.entry.Key, result.Key)
}
if !bytes.Equal(result.Value, tc.entry.Value) {
t.Errorf("Expected value %q, got %q", tc.entry.Value, result.Value)
}
})
}
}
func TestEntrySerializerChecksum(t *testing.T) {
// Create a serializer with checksums enabled
serializer := NewEntrySerializer()
serializer.ChecksumEnabled = true
// Create a test entry
entry := &wal.Entry{
SequenceNumber: 123,
Type: wal.OpTypePut,
Key: []byte("test-key"),
Value: []byte("test-value"),
}
// Serialize the entry
data := serializer.SerializeEntry(entry)
// Corrupt the data
data[10]++
// Try to deserialize - should fail with checksum error
_, err := serializer.DeserializeEntry(data)
if err != ErrInvalidChecksum {
t.Errorf("Expected checksum error, got %v", err)
}
// Now disable checksum verification and try again
serializer.ChecksumEnabled = false
result, err := serializer.DeserializeEntry(data)
if err != nil {
t.Errorf("Expected no error with checksums disabled, got %v", err)
}
if result == nil {
t.Fatal("Expected entry to be returned with checksums disabled")
}
}
func TestEntrySerializerInvalidFormat(t *testing.T) {
serializer := NewEntrySerializer()
// Test with empty data
_, err := serializer.DeserializeEntry([]byte{})
if err != ErrInvalidFormat {
t.Errorf("Expected format error for empty data, got %v", err)
}
// Test with insufficient data
_, err = serializer.DeserializeEntry(make([]byte, 10))
if err != ErrInvalidFormat {
t.Errorf("Expected format error for insufficient data, got %v", err)
}
// Test with invalid key length
data := make([]byte, entryHeaderSize+4)
offset := 4
binary.LittleEndian.PutUint64(data[offset:offset+8], 123) // timestamp
offset += 8
data[offset] = wal.OpTypePut // type
offset++
binary.LittleEndian.PutUint32(data[offset:offset+4], 1000) // key length (too large)
// Calculate a valid checksum for this data
checksum := crc32.ChecksumIEEE(data[4:])
binary.LittleEndian.PutUint32(data[0:4], checksum)
_, err = serializer.DeserializeEntry(data)
if err != ErrInvalidFormat {
t.Errorf("Expected format error for invalid key length, got %v", err)
}
}
func TestBatchSerializer(t *testing.T) {
// Create batch serializer
serializer := NewBatchSerializer()
// Test batch with multiple entries
entries := []*wal.Entry{
{
SequenceNumber: 101,
Type: wal.OpTypePut,
Key: []byte("key1"),
Value: []byte("value1"),
},
{
SequenceNumber: 102,
Type: wal.OpTypeDelete,
Key: []byte("key2"),
Value: nil,
},
{
SequenceNumber: 103,
Type: wal.OpTypePut,
Key: []byte("key3"),
Value: []byte("value3"),
},
}
// Serialize batch
data := serializer.SerializeBatch(entries)
// Deserialize batch
result, err := serializer.DeserializeBatch(data)
if err != nil {
t.Fatalf("Error deserializing batch: %v", err)
}
// Verify batch
if len(result) != len(entries) {
t.Fatalf("Expected %d entries, got %d", len(entries), len(result))
}
for i, entry := range entries {
if result[i].SequenceNumber != entry.SequenceNumber {
t.Errorf("Entry %d: Expected sequence number %d, got %d",
i, entry.SequenceNumber, result[i].SequenceNumber)
}
if result[i].Type != entry.Type {
t.Errorf("Entry %d: Expected type %d, got %d",
i, entry.Type, result[i].Type)
}
if !bytes.Equal(result[i].Key, entry.Key) {
t.Errorf("Entry %d: Expected key %q, got %q",
i, entry.Key, result[i].Key)
}
if !bytes.Equal(result[i].Value, entry.Value) {
t.Errorf("Entry %d: Expected value %q, got %q",
i, entry.Value, result[i].Value)
}
}
}
func TestEmptyBatchSerialization(t *testing.T) {
// Create batch serializer
serializer := NewBatchSerializer()
// Test empty batch
entries := []*wal.Entry{}
// Serialize batch
data := serializer.SerializeBatch(entries)
// Deserialize batch
result, err := serializer.DeserializeBatch(data)
if err != nil {
t.Fatalf("Error deserializing empty batch: %v", err)
}
// Verify result is empty
if len(result) != 0 {
t.Errorf("Expected empty batch, got %d entries", len(result))
}
}
func TestBatchSerializerChecksum(t *testing.T) {
// Create batch serializer
serializer := NewBatchSerializer()
// Test batch with single entry
entries := []*wal.Entry{
{
SequenceNumber: 101,
Type: wal.OpTypePut,
Key: []byte("key1"),
Value: []byte("value1"),
},
}
// Serialize batch
data := serializer.SerializeBatch(entries)
// Corrupt data
data[8]++
// Attempt to deserialize - should fail
_, err := serializer.DeserializeBatch(data)
if err != ErrInvalidChecksum {
t.Errorf("Expected checksum error for corrupted batch, got %v", err)
}
}
func TestBatchSerializerInvalidFormat(t *testing.T) {
serializer := NewBatchSerializer()
// Test with empty data
_, err := serializer.DeserializeBatch([]byte{})
if err != ErrInvalidFormat {
t.Errorf("Expected format error for empty data, got %v", err)
}
// Test with insufficient data
_, err = serializer.DeserializeBatch(make([]byte, 10))
if err != ErrInvalidFormat {
t.Errorf("Expected format error for insufficient data, got %v", err)
}
}
func TestEstimateEntrySize(t *testing.T) {
// Test entries with different sizes
testCases := []struct {
name string
entry *wal.Entry
expected int
}{
{
name: "Basic put entry",
entry: &wal.Entry{
SequenceNumber: 101,
Type: wal.OpTypePut,
Key: []byte("key"),
Value: []byte("value"),
},
expected: entryHeaderSize + 3 + 4 + 5, // header + key_len + value_len + value
},
{
name: "Delete entry (no value)",
entry: &wal.Entry{
SequenceNumber: 102,
Type: wal.OpTypeDelete,
Key: []byte("delete-key"),
Value: nil,
},
expected: entryHeaderSize + 10, // header + key_len
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
size := EstimateEntrySize(tc.entry)
if size != tc.expected {
t.Errorf("Expected size %d, got %d", tc.expected, size)
}
// Verify estimate matches actual size
serializer := NewEntrySerializer()
data := serializer.SerializeEntry(tc.entry)
if len(data) != size {
t.Errorf("Estimated size %d doesn't match actual size %d",
size, len(data))
}
})
}
}
func TestEstimateBatchSize(t *testing.T) {
// Test batches with different contents
testCases := []struct {
name string
entries []*wal.Entry
expected int
}{
{
name: "Empty batch",
entries: []*wal.Entry{},
expected: 12, // Just batch header
},
{
name: "Batch with one entry",
entries: []*wal.Entry{
{
SequenceNumber: 101,
Type: wal.OpTypePut,
Key: []byte("key1"),
Value: []byte("value1"),
},
},
expected: 12 + 4 + entryHeaderSize + 4 + 4 + 6, // batch header + entry size field + entry header + key + value size + value
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
size := EstimateBatchSize(tc.entries)
if size != tc.expected {
t.Errorf("Expected size %d, got %d", tc.expected, size)
}
// Verify estimate matches actual size
serializer := NewBatchSerializer()
data := serializer.SerializeBatch(tc.entries)
if len(data) != size {
t.Errorf("Estimated size %d doesn't match actual size %d",
size, len(data))
}
})
}
}
func TestSerializeToBuffer(t *testing.T) {
serializer := NewEntrySerializer()
// Create a test entry
entry := &wal.Entry{
SequenceNumber: 101,
Type: wal.OpTypePut,
Key: []byte("key1"),
Value: []byte("value1"),
}
// Estimate the size
estimatedSize := EstimateEntrySize(entry)
// Create a buffer of the estimated size
buffer := make([]byte, estimatedSize)
// Serialize to buffer
n, err := serializer.SerializeEntryToBuffer(entry, buffer)
if err != nil {
t.Fatalf("Error serializing to buffer: %v", err)
}
// Check bytes written
if n != estimatedSize {
t.Errorf("Expected %d bytes written, got %d", estimatedSize, n)
}
// Verify by deserializing
result, err := serializer.DeserializeEntry(buffer)
if err != nil {
t.Fatalf("Error deserializing from buffer: %v", err)
}
// Check result
if result.SequenceNumber != entry.SequenceNumber {
t.Errorf("Expected sequence number %d, got %d",
entry.SequenceNumber, result.SequenceNumber)
}
if !bytes.Equal(result.Key, entry.Key) {
t.Errorf("Expected key %q, got %q", entry.Key, result.Key)
}
if !bytes.Equal(result.Value, entry.Value) {
t.Errorf("Expected value %q, got %q", entry.Value, result.Value)
}
// Test with too small buffer
smallBuffer := make([]byte, estimatedSize - 1)
_, err = serializer.SerializeEntryToBuffer(entry, smallBuffer)
if err != ErrBufferTooSmall {
t.Errorf("Expected buffer too small error, got %v", err)
}
}

View File

@ -0,0 +1,87 @@
package replication
import (
"io"
)
// StorageSnapshot provides an interface for taking a snapshot of the storage
// for replication bootstrap purposes.
type StorageSnapshot interface {
// CreateSnapshotIterator creates an iterator for a storage snapshot
CreateSnapshotIterator() (SnapshotIterator, error)
// KeyCount returns the approximate number of keys in storage
KeyCount() int64
}
// SnapshotIterator provides iteration over key-value pairs in storage
type SnapshotIterator interface {
// Next returns the next key-value pair
// Returns io.EOF when there are no more items
Next() (key []byte, value []byte, err error)
// Close closes the iterator
Close() error
}
// StorageSnapshotProvider is implemented by storage engines that support snapshots
type StorageSnapshotProvider interface {
// CreateSnapshot creates a snapshot of the current storage state
CreateSnapshot() (StorageSnapshot, error)
}
// MemoryStorageSnapshot is a simple in-memory implementation of StorageSnapshot
// Useful for testing or small datasets
type MemoryStorageSnapshot struct {
Pairs []KeyValuePair
position int
}
// KeyValuePair represents a key-value pair in storage
type KeyValuePair struct {
Key []byte
Value []byte
}
// CreateSnapshotIterator creates an iterator for a memory storage snapshot
func (m *MemoryStorageSnapshot) CreateSnapshotIterator() (SnapshotIterator, error) {
return &MemorySnapshotIterator{
snapshot: m,
position: 0,
}, nil
}
// KeyCount returns the number of keys in the snapshot
func (m *MemoryStorageSnapshot) KeyCount() int64 {
return int64(len(m.Pairs))
}
// MemorySnapshotIterator is an iterator for MemoryStorageSnapshot
type MemorySnapshotIterator struct {
snapshot *MemoryStorageSnapshot
position int
}
// Next returns the next key-value pair
func (it *MemorySnapshotIterator) Next() ([]byte, []byte, error) {
if it.position >= len(it.snapshot.Pairs) {
return nil, nil, io.EOF
}
pair := it.snapshot.Pairs[it.position]
it.position++
return pair.Key, pair.Value, nil
}
// Close closes the iterator
func (it *MemorySnapshotIterator) Close() error {
return nil
}
// NewMemoryStorageSnapshot creates a new in-memory storage snapshot
func NewMemoryStorageSnapshot(pairs []KeyValuePair) *MemoryStorageSnapshot {
return &MemoryStorageSnapshot{
Pairs: pairs,
}
}

View File

@ -138,12 +138,24 @@ type Registry interface {
// RegisterServer adds a new server implementation to the registry
RegisterServer(name string, factory ServerFactory)
// RegisterReplicationClient adds a new replication client implementation
RegisterReplicationClient(name string, factory ReplicationClientFactory)
// RegisterReplicationServer adds a new replication server implementation
RegisterReplicationServer(name string, factory ReplicationServerFactory)
// CreateClient instantiates a client by name
CreateClient(name, endpoint string, options TransportOptions) (Client, error)
// CreateServer instantiates a server by name
CreateServer(name, address string, options TransportOptions) (Server, error)
// CreateReplicationClient instantiates a replication client by name
CreateReplicationClient(name, endpoint string, options TransportOptions) (ReplicationClient, error)
// CreateReplicationServer instantiates a replication server by name
CreateReplicationServer(name, address string, options TransportOptions) (ReplicationServer, error)
// ListTransports returns all available transport names
ListTransports() []string
}

View File

@ -7,9 +7,11 @@ import (
// registry implements the Registry interface
type registry struct {
mu sync.RWMutex
clientFactories map[string]ClientFactory
serverFactories map[string]ServerFactory
mu sync.RWMutex
clientFactories map[string]ClientFactory
serverFactories map[string]ServerFactory
replicationClientFactories map[string]ReplicationClientFactory
replicationServerFactories map[string]ReplicationServerFactory
}
// NewRegistry creates a new transport registry
@ -17,6 +19,8 @@ func NewRegistry() Registry {
return &registry{
clientFactories: make(map[string]ClientFactory),
serverFactories: make(map[string]ServerFactory),
replicationClientFactories: make(map[string]ReplicationClientFactory),
replicationServerFactories: make(map[string]ReplicationServerFactory),
}
}
@ -63,6 +67,46 @@ func (r *registry) CreateServer(name, address string, options TransportOptions)
return factory(address, options)
}
// RegisterReplicationClient adds a new replication client implementation to the registry
func (r *registry) RegisterReplicationClient(name string, factory ReplicationClientFactory) {
r.mu.Lock()
defer r.mu.Unlock()
r.replicationClientFactories[name] = factory
}
// RegisterReplicationServer adds a new replication server implementation to the registry
func (r *registry) RegisterReplicationServer(name string, factory ReplicationServerFactory) {
r.mu.Lock()
defer r.mu.Unlock()
r.replicationServerFactories[name] = factory
}
// CreateReplicationClient instantiates a replication client by name
func (r *registry) CreateReplicationClient(name, endpoint string, options TransportOptions) (ReplicationClient, error) {
r.mu.RLock()
factory, exists := r.replicationClientFactories[name]
r.mu.RUnlock()
if !exists {
return nil, fmt.Errorf("replication client %q not registered", name)
}
return factory(endpoint, options)
}
// CreateReplicationServer instantiates a replication server by name
func (r *registry) CreateReplicationServer(name, address string, options TransportOptions) (ReplicationServer, error) {
r.mu.RLock()
factory, exists := r.replicationServerFactories[name]
r.mu.RUnlock()
if !exists {
return nil, fmt.Errorf("replication server %q not registered", name)
}
return factory(address, options)
}
// ListTransports returns all available transport names
func (r *registry) ListTransports() []string {
r.mu.RLock()
@ -76,6 +120,12 @@ func (r *registry) ListTransports() []string {
for name := range r.serverFactories {
names[name] = struct{}{}
}
for name := range r.replicationClientFactories {
names[name] = struct{}{}
}
for name := range r.replicationServerFactories {
names[name] = struct{}{}
}
// Convert to slice
result := make([]string, 0, len(names))
@ -98,6 +148,16 @@ func RegisterServerTransport(name string, factory ServerFactory) {
DefaultRegistry.RegisterServer(name, factory)
}
// RegisterReplicationClientTransport registers a replication client with the default registry
func RegisterReplicationClientTransport(name string, factory ReplicationClientFactory) {
DefaultRegistry.RegisterReplicationClient(name, factory)
}
// RegisterReplicationServerTransport registers a replication server with the default registry
func RegisterReplicationServerTransport(name string, factory ReplicationServerFactory) {
DefaultRegistry.RegisterReplicationServer(name, factory)
}
// GetClient creates a client using the default registry
func GetClient(name, endpoint string, options TransportOptions) (Client, error) {
return DefaultRegistry.CreateClient(name, endpoint, options)
@ -108,6 +168,16 @@ func GetServer(name, address string, options TransportOptions) (Server, error) {
return DefaultRegistry.CreateServer(name, address, options)
}
// GetReplicationClient creates a replication client using the default registry
func GetReplicationClient(name, endpoint string, options TransportOptions) (ReplicationClient, error) {
return DefaultRegistry.CreateReplicationClient(name, endpoint, options)
}
// GetReplicationServer creates a replication server using the default registry
func GetReplicationServer(name, address string, options TransportOptions) (ReplicationServer, error) {
return DefaultRegistry.CreateReplicationServer(name, address, options)
}
// AvailableTransports lists all available transports in the default registry
func AvailableTransports() []string {
return DefaultRegistry.ListTransports()

View File

@ -0,0 +1,183 @@
package transport
import (
"context"
"time"
"github.com/KevoDB/kevo/pkg/wal"
)
// Standard constants for replication message types
const (
// Request types
TypeReplicaRegister = "replica_register"
TypeReplicaHeartbeat = "replica_heartbeat"
TypeReplicaWALSync = "replica_wal_sync"
TypeReplicaBootstrap = "replica_bootstrap"
TypeReplicaStatus = "replica_status"
// Response types
TypeReplicaACK = "replica_ack"
TypeReplicaWALEntries = "replica_wal_entries"
TypeReplicaBootstrapData = "replica_bootstrap_data"
TypeReplicaStatusData = "replica_status_data"
)
// ReplicaRole defines the role of a node in replication
type ReplicaRole string
// Replica roles
const (
RolePrimary ReplicaRole = "primary"
RoleReplica ReplicaRole = "replica"
RoleReadOnly ReplicaRole = "readonly"
)
// ReplicaStatus defines the current status of a replica
type ReplicaStatus string
// Replica statuses
const (
StatusConnecting ReplicaStatus = "connecting"
StatusSyncing ReplicaStatus = "syncing"
StatusBootstrapping ReplicaStatus = "bootstrapping"
StatusReady ReplicaStatus = "ready"
StatusDisconnected ReplicaStatus = "disconnected"
StatusError ReplicaStatus = "error"
)
// ReplicaInfo contains information about a replica
type ReplicaInfo struct {
ID string
Address string
Role ReplicaRole
Status ReplicaStatus
LastSeen time.Time
CurrentLSN uint64 // Lamport Sequence Number
ReplicationLag time.Duration
Error error
}
// ReplicationStreamDirection defines the direction of a replication stream
type ReplicationStreamDirection int
const (
DirectionPrimaryToReplica ReplicationStreamDirection = iota
DirectionReplicaToPrimary
DirectionBidirectional
)
// ReplicationConnection provides methods specific to replication connections
type ReplicationConnection interface {
Connection
// GetReplicaInfo returns information about the remote replica
GetReplicaInfo() (*ReplicaInfo, error)
// SendWALEntries sends WAL entries to the replica
SendWALEntries(ctx context.Context, entries []*wal.Entry) error
// ReceiveWALEntries receives WAL entries from the replica
ReceiveWALEntries(ctx context.Context) ([]*wal.Entry, error)
// StartReplicationStream starts a stream for WAL entries
StartReplicationStream(ctx context.Context, direction ReplicationStreamDirection) (ReplicationStream, error)
}
// ReplicationStream provides a bidirectional stream of WAL entries
type ReplicationStream interface {
// SendEntries sends WAL entries through the stream
SendEntries(entries []*wal.Entry) error
// ReceiveEntries receives WAL entries from the stream
ReceiveEntries() ([]*wal.Entry, error)
// Close closes the replication stream
Close() error
// SetHighWatermark updates the highest applied Lamport sequence number
SetHighWatermark(lsn uint64) error
// GetHighWatermark returns the highest applied Lamport sequence number
GetHighWatermark() (uint64, error)
}
// ReplicationClient extends the Client interface with replication-specific methods
type ReplicationClient interface {
Client
// RegisterAsReplica registers this client as a replica with the primary
RegisterAsReplica(ctx context.Context, replicaID string) error
// SendHeartbeat sends a heartbeat to the primary
SendHeartbeat(ctx context.Context, status *ReplicaInfo) error
// RequestWALEntries requests WAL entries from the primary starting from a specific LSN
RequestWALEntries(ctx context.Context, fromLSN uint64) ([]*wal.Entry, error)
// RequestBootstrap requests a snapshot for bootstrap purposes
RequestBootstrap(ctx context.Context) (BootstrapIterator, error)
}
// ReplicationServer extends the Server interface with replication-specific methods
type ReplicationServer interface {
Server
// RegisterReplica registers a new replica
RegisterReplica(replicaInfo *ReplicaInfo) error
// UpdateReplicaStatus updates the status of a replica
UpdateReplicaStatus(replicaID string, status ReplicaStatus, lsn uint64) error
// GetReplicaInfo returns information about a specific replica
GetReplicaInfo(replicaID string) (*ReplicaInfo, error)
// ListReplicas returns information about all connected replicas
ListReplicas() ([]*ReplicaInfo, error)
// StreamWALEntriesToReplica streams WAL entries to a specific replica
StreamWALEntriesToReplica(ctx context.Context, replicaID string, fromLSN uint64) error
}
// BootstrapIterator provides an iterator over key-value pairs for bootstrapping a replica
type BootstrapIterator interface {
// Next returns the next key-value pair
Next() (key []byte, value []byte, err error)
// Close closes the iterator
Close() error
// Progress returns the progress of the bootstrap operation (0.0-1.0)
Progress() float64
}
// ReplicationRequestHandler processes replication-specific requests
type ReplicationRequestHandler interface {
// HandleReplicaRegister handles replica registration requests
HandleReplicaRegister(ctx context.Context, replicaID string, address string) error
// HandleReplicaHeartbeat handles heartbeat requests
HandleReplicaHeartbeat(ctx context.Context, status *ReplicaInfo) error
// HandleWALRequest handles requests for WAL entries
HandleWALRequest(ctx context.Context, fromLSN uint64) ([]*wal.Entry, error)
// HandleBootstrapRequest handles bootstrap requests
HandleBootstrapRequest(ctx context.Context) (BootstrapIterator, error)
}
// ReplicationClientFactory creates a new replication client
type ReplicationClientFactory func(endpoint string, options TransportOptions) (ReplicationClient, error)
// ReplicationServerFactory creates a new replication server
type ReplicationServerFactory func(address string, options TransportOptions) (ReplicationServer, error)
// RegisterReplicationClient registers a replication client implementation
func RegisterReplicationClient(name string, factory ReplicationClientFactory) {
// This would be implemented to register with the transport registry
}
// RegisterReplicationServer registers a replication server implementation
func RegisterReplicationServer(name string, factory ReplicationServerFactory) {
// This would be implemented to register with the transport registry
}

View File

@ -0,0 +1,401 @@
package transport
import (
"context"
"testing"
"time"
"github.com/KevoDB/kevo/pkg/wal"
)
// MockReplicationClient implements ReplicationClient for testing
type MockReplicationClient struct {
connected bool
registeredAsReplica bool
heartbeatSent bool
walEntriesRequested bool
bootstrapRequested bool
replicaID string
walEntries []*wal.Entry
bootstrapIterator BootstrapIterator
status TransportStatus
}
func NewMockReplicationClient() *MockReplicationClient {
return &MockReplicationClient{
connected: false,
registeredAsReplica: false,
heartbeatSent: false,
walEntriesRequested: false,
bootstrapRequested: false,
status: TransportStatus{
Connected: false,
LastConnected: time.Time{},
LastError: nil,
BytesSent: 0,
BytesReceived: 0,
RTT: 0,
},
}
}
func (c *MockReplicationClient) Connect(ctx context.Context) error {
c.connected = true
c.status.Connected = true
c.status.LastConnected = time.Now()
return nil
}
func (c *MockReplicationClient) Close() error {
c.connected = false
c.status.Connected = false
return nil
}
func (c *MockReplicationClient) IsConnected() bool {
return c.connected
}
func (c *MockReplicationClient) Status() TransportStatus {
return c.status
}
func (c *MockReplicationClient) Send(ctx context.Context, request Request) (Response, error) {
return nil, ErrInvalidRequest
}
func (c *MockReplicationClient) Stream(ctx context.Context) (Stream, error) {
return nil, ErrInvalidRequest
}
func (c *MockReplicationClient) RegisterAsReplica(ctx context.Context, replicaID string) error {
c.registeredAsReplica = true
c.replicaID = replicaID
return nil
}
func (c *MockReplicationClient) SendHeartbeat(ctx context.Context, status *ReplicaInfo) error {
c.heartbeatSent = true
return nil
}
func (c *MockReplicationClient) RequestWALEntries(ctx context.Context, fromLSN uint64) ([]*wal.Entry, error) {
c.walEntriesRequested = true
return c.walEntries, nil
}
func (c *MockReplicationClient) RequestBootstrap(ctx context.Context) (BootstrapIterator, error) {
c.bootstrapRequested = true
return c.bootstrapIterator, nil
}
// MockBootstrapIterator implements BootstrapIterator for testing
type MockBootstrapIterator struct {
pairs []struct{ key, value []byte }
position int
progress float64
closed bool
}
func NewMockBootstrapIterator() *MockBootstrapIterator {
return &MockBootstrapIterator{
pairs: []struct{ key, value []byte }{
{[]byte("key1"), []byte("value1")},
{[]byte("key2"), []byte("value2")},
{[]byte("key3"), []byte("value3")},
},
position: 0,
progress: 0.0,
closed: false,
}
}
func (it *MockBootstrapIterator) Next() ([]byte, []byte, error) {
if it.position >= len(it.pairs) {
return nil, nil, nil
}
pair := it.pairs[it.position]
it.position++
it.progress = float64(it.position) / float64(len(it.pairs))
return pair.key, pair.value, nil
}
func (it *MockBootstrapIterator) Close() error {
it.closed = true
return nil
}
func (it *MockBootstrapIterator) Progress() float64 {
return it.progress
}
// Tests
func TestReplicationClientInterface(t *testing.T) {
// Create a mock client
client := NewMockReplicationClient()
// Test Connect
ctx := context.Background()
err := client.Connect(ctx)
if err != nil {
t.Errorf("Connect failed: %v", err)
}
// Test IsConnected
if !client.IsConnected() {
t.Errorf("Expected client to be connected")
}
// Test Status
status := client.Status()
if !status.Connected {
t.Errorf("Expected status.Connected to be true")
}
// Test RegisterAsReplica
err = client.RegisterAsReplica(ctx, "replica1")
if err != nil {
t.Errorf("RegisterAsReplica failed: %v", err)
}
if !client.registeredAsReplica {
t.Errorf("Expected client to be registered as replica")
}
if client.replicaID != "replica1" {
t.Errorf("Expected replicaID to be 'replica1', got '%s'", client.replicaID)
}
// Test SendHeartbeat
replicaInfo := &ReplicaInfo{
ID: "replica1",
Address: "localhost:50051",
Role: RoleReplica,
Status: StatusReady,
LastSeen: time.Now(),
CurrentLSN: 100,
ReplicationLag: 0,
}
err = client.SendHeartbeat(ctx, replicaInfo)
if err != nil {
t.Errorf("SendHeartbeat failed: %v", err)
}
if !client.heartbeatSent {
t.Errorf("Expected heartbeat to be sent")
}
// Test RequestWALEntries
client.walEntries = []*wal.Entry{
{SequenceNumber: 101, Type: 1, Key: []byte("key1"), Value: []byte("value1")},
{SequenceNumber: 102, Type: 1, Key: []byte("key2"), Value: []byte("value2")},
}
entries, err := client.RequestWALEntries(ctx, 100)
if err != nil {
t.Errorf("RequestWALEntries failed: %v", err)
}
if !client.walEntriesRequested {
t.Errorf("Expected WAL entries to be requested")
}
if len(entries) != 2 {
t.Errorf("Expected 2 entries, got %d", len(entries))
}
// Test RequestBootstrap
client.bootstrapIterator = NewMockBootstrapIterator()
iterator, err := client.RequestBootstrap(ctx)
if err != nil {
t.Errorf("RequestBootstrap failed: %v", err)
}
if !client.bootstrapRequested {
t.Errorf("Expected bootstrap to be requested")
}
// Test iterator
key, value, err := iterator.Next()
if err != nil {
t.Errorf("Iterator.Next failed: %v", err)
}
if string(key) != "key1" || string(value) != "value1" {
t.Errorf("Expected key1/value1, got %s/%s", string(key), string(value))
}
progress := iterator.Progress()
if progress != 1.0/3.0 {
t.Errorf("Expected progress to be 1/3, got %f", progress)
}
// Test Close
err = client.Close()
if err != nil {
t.Errorf("Close failed: %v", err)
}
if client.IsConnected() {
t.Errorf("Expected client to be disconnected")
}
// Test iterator Close
err = iterator.Close()
if err != nil {
t.Errorf("Iterator.Close failed: %v", err)
}
mockIter := iterator.(*MockBootstrapIterator)
if !mockIter.closed {
t.Errorf("Expected iterator to be closed")
}
}
// MockReplicationServer implements ReplicationServer for testing
type MockReplicationServer struct {
started bool
stopped bool
replicas map[string]*ReplicaInfo
streamingReplicas map[string]bool
}
func NewMockReplicationServer() *MockReplicationServer {
return &MockReplicationServer{
started: false,
stopped: false,
replicas: make(map[string]*ReplicaInfo),
streamingReplicas: make(map[string]bool),
}
}
func (s *MockReplicationServer) Start() error {
s.started = true
return nil
}
func (s *MockReplicationServer) Serve() error {
s.started = true
return nil
}
func (s *MockReplicationServer) Stop(ctx context.Context) error {
s.stopped = true
return nil
}
func (s *MockReplicationServer) SetRequestHandler(handler RequestHandler) {
// No-op for testing
}
func (s *MockReplicationServer) RegisterReplica(replicaInfo *ReplicaInfo) error {
s.replicas[replicaInfo.ID] = replicaInfo
return nil
}
func (s *MockReplicationServer) UpdateReplicaStatus(replicaID string, status ReplicaStatus, lsn uint64) error {
replica, exists := s.replicas[replicaID]
if !exists {
return ErrInvalidRequest
}
replica.Status = status
replica.CurrentLSN = lsn
return nil
}
func (s *MockReplicationServer) GetReplicaInfo(replicaID string) (*ReplicaInfo, error) {
replica, exists := s.replicas[replicaID]
if !exists {
return nil, ErrInvalidRequest
}
return replica, nil
}
func (s *MockReplicationServer) ListReplicas() ([]*ReplicaInfo, error) {
result := make([]*ReplicaInfo, 0, len(s.replicas))
for _, replica := range s.replicas {
result = append(result, replica)
}
return result, nil
}
func (s *MockReplicationServer) StreamWALEntriesToReplica(ctx context.Context, replicaID string, fromLSN uint64) error {
_, exists := s.replicas[replicaID]
if !exists {
return ErrInvalidRequest
}
s.streamingReplicas[replicaID] = true
return nil
}
func TestReplicationServerInterface(t *testing.T) {
// Create a mock server
server := NewMockReplicationServer()
// Test Start
err := server.Start()
if err != nil {
t.Errorf("Start failed: %v", err)
}
if !server.started {
t.Errorf("Expected server to be started")
}
// Test RegisterReplica
replica1 := &ReplicaInfo{
ID: "replica1",
Address: "localhost:50051",
Role: RoleReplica,
Status: StatusConnecting,
LastSeen: time.Now(),
CurrentLSN: 0,
ReplicationLag: 0,
}
err = server.RegisterReplica(replica1)
if err != nil {
t.Errorf("RegisterReplica failed: %v", err)
}
// Test UpdateReplicaStatus
err = server.UpdateReplicaStatus("replica1", StatusReady, 100)
if err != nil {
t.Errorf("UpdateReplicaStatus failed: %v", err)
}
// Test GetReplicaInfo
replica, err := server.GetReplicaInfo("replica1")
if err != nil {
t.Errorf("GetReplicaInfo failed: %v", err)
}
if replica.Status != StatusReady {
t.Errorf("Expected status to be StatusReady, got %v", replica.Status)
}
if replica.CurrentLSN != 100 {
t.Errorf("Expected LSN to be 100, got %d", replica.CurrentLSN)
}
// Test ListReplicas
replicas, err := server.ListReplicas()
if err != nil {
t.Errorf("ListReplicas failed: %v", err)
}
if len(replicas) != 1 {
t.Errorf("Expected 1 replica, got %d", len(replicas))
}
// Test StreamWALEntriesToReplica
ctx := context.Background()
err = server.StreamWALEntriesToReplica(ctx, "replica1", 0)
if err != nil {
t.Errorf("StreamWALEntriesToReplica failed: %v", err)
}
if !server.streamingReplicas["replica1"] {
t.Errorf("Expected replica1 to be streaming")
}
// Test Stop
err = server.Stop(ctx)
if err != nil {
t.Errorf("Stop failed: %v", err)
}
if !server.stopped {
t.Errorf("Expected server to be stopped")
}
}

View File

@ -138,13 +138,25 @@ func (b *Batch) Write(w *WAL) error {
return ErrWALClosed
}
// Set the sequence number
b.Seq = w.nextSequence
// Set the sequence number - use Lamport clock if available
var seqNum uint64
if w.clock != nil {
// Generate Lamport timestamp for the batch
seqNum = w.clock.Tick()
// Keep the nextSequence in sync with the highest timestamp
if seqNum >= w.nextSequence {
w.nextSequence = seqNum + 1
}
} else {
// Use traditional sequence number
seqNum = w.nextSequence
// Increment sequence for future operations
w.nextSequence += uint64(len(b.Operations))
}
b.Seq = seqNum
binary.LittleEndian.PutUint64(data[4:12], b.Seq)
// Increment sequence for future operations
w.nextSequence += uint64(len(b.Operations))
// Write as a batch entry
if err := w.writeRecord(uint8(RecordTypeFull), OpTypeBatch, b.Seq, data, nil); err != nil {
return err

View File

@ -47,7 +47,11 @@ func TestBatchEncoding(t *testing.T) {
defer os.RemoveAll(dir)
cfg := createTestConfig()
wal, err := NewWAL(cfg, dir)
// Create a mock Lamport clock for the test
clock := &MockLamportClock{counter: 0}
wal, err := NewWALWithReplication(cfg, dir, clock, nil)
if err != nil {
t.Fatalf("Failed to create WAL: %v", err)
}

View File

@ -5,12 +5,12 @@ package wal
type LamportClock interface {
// Tick increments the clock and returns the new timestamp value
Tick() uint64
// Update updates the clock based on a received timestamp,
// ensuring the local clock is at least as large as the received timestamp,
// then increments and returns the new value
Update(received uint64) uint64
// Current returns the current timestamp without incrementing the clock
Current() uint64
}
@ -23,4 +23,4 @@ type ReplicationHook interface {
// OnBatchWritten is called when a batch of WAL entries is written
OnBatchWritten(entries []*Entry)
}
}

View File

@ -53,7 +53,7 @@ type MockReplicationHook struct {
func (m *MockReplicationHook) OnEntryWritten(entry *Entry) {
m.mu.Lock()
defer m.mu.Unlock()
// Make a deep copy of the entry to ensure tests are not affected by later modifications
entryCopy := &Entry{
SequenceNumber: entry.SequenceNumber,
@ -63,7 +63,7 @@ func (m *MockReplicationHook) OnEntryWritten(entry *Entry) {
if entry.Value != nil {
entryCopy.Value = append([]byte{}, entry.Value...)
}
m.entries = append(m.entries, entryCopy)
m.entriesReceived++
}
@ -71,7 +71,7 @@ func (m *MockReplicationHook) OnEntryWritten(entry *Entry) {
func (m *MockReplicationHook) OnBatchWritten(entries []*Entry) {
m.mu.Lock()
defer m.mu.Unlock()
// Make a deep copy of all entries
entriesCopy := make([]*Entry, len(entries))
for i, entry := range entries {
@ -84,7 +84,7 @@ func (m *MockReplicationHook) OnBatchWritten(entries []*Entry) {
entriesCopy[i].Value = append([]byte{}, entry.Value...)
}
}
m.batchEntries = append(m.batchEntries, entriesCopy)
m.batchesReceived++
}
@ -117,10 +117,10 @@ func TestWALReplicationHook(t *testing.T) {
// Create a mock replication hook
hook := &MockReplicationHook{}
// Create a Lamport clock
clock := NewMockLamportClock()
// Create a WAL with the replication hook
cfg := config.NewDefaultConfig(dir)
wal, err := NewWALWithReplication(cfg, dir, clock, hook)
@ -132,18 +132,18 @@ func TestWALReplicationHook(t *testing.T) {
// Test single entry writes
key1 := []byte("key1")
value1 := []byte("value1")
seq1, err := wal.Append(OpTypePut, key1, value1)
if err != nil {
t.Fatalf("Failed to append to WAL: %v", err)
}
// Test that the hook received the entry
entries := hook.GetEntries()
if len(entries) != 1 {
t.Fatalf("Expected 1 entry, got %d", len(entries))
}
entry := entries[0]
if entry.SequenceNumber != seq1 {
t.Errorf("Expected sequence number %d, got %d", seq1, entry.SequenceNumber)
@ -163,28 +163,28 @@ func TestWALReplicationHook(t *testing.T) {
value2 := []byte("value2")
key3 := []byte("key3")
value3 := []byte("value3")
batchEntries := []*Entry{
{Type: OpTypePut, Key: key2, Value: value2},
{Type: OpTypePut, Key: key3, Value: value3},
}
batchSeq, err := wal.AppendBatch(batchEntries)
if err != nil {
t.Fatalf("Failed to append batch to WAL: %v", err)
}
// Test that the hook received the batch
batches := hook.GetBatchEntries()
if len(batches) != 1 {
t.Fatalf("Expected 1 batch, got %d", len(batches))
}
batch := batches[0]
if len(batch) != 2 {
t.Fatalf("Expected 2 entries in batch, got %d", len(batch))
}
// Check first entry in batch
if batch[0].SequenceNumber != batchSeq {
t.Errorf("Expected sequence number %d, got %d", batchSeq, batch[0].SequenceNumber)
@ -198,7 +198,7 @@ func TestWALReplicationHook(t *testing.T) {
if string(batch[0].Value) != string(value2) {
t.Errorf("Expected value %q, got %q", value2, batch[0].Value)
}
// Check second entry in batch
if batch[1].SequenceNumber != batchSeq+1 {
t.Errorf("Expected sequence number %d, got %d", batchSeq+1, batch[1].SequenceNumber)
@ -212,7 +212,7 @@ func TestWALReplicationHook(t *testing.T) {
if string(batch[1].Value) != string(value3) {
t.Errorf("Expected value %q, got %q", value3, batch[1].Value)
}
// Check call counts
entriesReceived, batchesReceived := hook.GetStats()
if entriesReceived != 1 {
@ -233,7 +233,7 @@ func TestWALWithLamportClock(t *testing.T) {
// Create a Lamport clock
clock := NewMockLamportClock()
// Create a WAL with the Lamport clock but no hook
cfg := config.NewDefaultConfig(dir)
wal, err := NewWALWithReplication(cfg, dir, clock, nil)
@ -246,7 +246,7 @@ func TestWALWithLamportClock(t *testing.T) {
for i := 0; i < 5; i++ {
clock.Tick()
}
// Current clock value should be 5
if clock.Current() != 5 {
t.Fatalf("Expected clock value 5, got %d", clock.Current())
@ -255,38 +255,38 @@ func TestWALWithLamportClock(t *testing.T) {
// Test that the WAL uses the Lamport clock for sequence numbers
key1 := []byte("key1")
value1 := []byte("value1")
seq1, err := wal.Append(OpTypePut, key1, value1)
if err != nil {
t.Fatalf("Failed to append to WAL: %v", err)
}
// Sequence number should be 6 (previous 5 + 1 for this operation)
if seq1 != 6 {
t.Errorf("Expected sequence number 6, got %d", seq1)
}
// Clock should have incremented
if clock.Current() != 6 {
t.Errorf("Expected clock value 6, got %d", clock.Current())
}
// Test with a batch
entries := []*Entry{
{Type: OpTypePut, Key: []byte("key2"), Value: []byte("value2")},
{Type: OpTypePut, Key: []byte("key3"), Value: []byte("value3")},
}
batchSeq, err := wal.AppendBatch(entries)
if err != nil {
t.Fatalf("Failed to append batch to WAL: %v", err)
}
// Batch sequence should be 7
if batchSeq != 7 {
t.Errorf("Expected batch sequence number 7, got %d", batchSeq)
}
// Clock should have incremented again
if clock.Current() != 7 {
t.Errorf("Expected clock value 7, got %d", clock.Current())
@ -312,45 +312,45 @@ func TestWALHookAfterCreation(t *testing.T) {
// Write an entry before adding a hook
key1 := []byte("key1")
value1 := []byte("value1")
_, err = wal.Append(OpTypePut, key1, value1)
if err != nil {
t.Fatalf("Failed to append to WAL: %v", err)
}
// Create and add a hook after the fact
hook := &MockReplicationHook{}
wal.SetReplicationHook(hook)
// Create and add a Lamport clock after the fact
clock := NewMockLamportClock()
wal.SetLamportClock(clock)
// Write another entry, this should trigger the hook
key2 := []byte("key2")
value2 := []byte("value2")
seq2, err := wal.Append(OpTypePut, key2, value2)
if err != nil {
t.Fatalf("Failed to append to WAL: %v", err)
}
// Verify hook received the entry
entries := hook.GetEntries()
if len(entries) != 1 {
t.Fatalf("Expected 1 entry in hook, got %d", len(entries))
}
if entries[0].SequenceNumber != seq2 {
t.Errorf("Expected sequence number %d, got %d", seq2, entries[0].SequenceNumber)
}
if string(entries[0].Key) != string(key2) {
t.Errorf("Expected key %q, got %q", key2, entries[0].Key)
}
// Verify the clock was used
if seq2 != 1 { // First tick of the clock
t.Errorf("Expected sequence from clock to be 1, got %d", seq2)
}
}
}

View File

@ -81,9 +81,9 @@ type WAL struct {
status int32 // Using atomic int32 for status flags
closed int32 // Atomic flag indicating if WAL is closed
mu sync.Mutex
// Replication support
clock LamportClock // Lamport clock for logical timestamps
clock LamportClock // Lamport clock for logical timestamps
replicationHook ReplicationHook // Hook for replication events
}
@ -226,11 +226,15 @@ func (w *WAL) Append(entryType uint8, key, value []byte) (uint64, error) {
if w.clock != nil {
// Generate Lamport timestamp (reusing SequenceNumber field)
seqNum = w.clock.Tick()
// Keep the nextSequence in sync with the highest used sequence number
if seqNum >= w.nextSequence {
w.nextSequence = seqNum + 1
}
} else {
// Use traditional sequence number
seqNum = w.nextSequence
w.nextSequence++
}
w.nextSequence = seqNum + 1
// Encode the entry
// Format: type(1) + seq(8) + keylen(4) + key + vallen(4) + val
@ -257,11 +261,11 @@ func (w *WAL) Append(entryType uint8, key, value []byte) (uint64, error) {
if err := w.maybeSync(); err != nil {
return 0, err
}
// Notify replication hook if available
if w.replicationHook != nil {
entry := &Entry{
SequenceNumber: seqNum, // This now represents the Lamport timestamp
SequenceNumber: seqNum, // This now represents the Lamport timestamp
Type: entryType,
Key: key,
Value: value,
@ -511,6 +515,10 @@ func (w *WAL) AppendBatch(entries []*Entry) (uint64, error) {
if w.clock != nil {
// Generate Lamport timestamp for the batch
startSeqNum = w.clock.Tick()
// Keep the nextSequence in sync with the highest timestamp
if startSeqNum >= w.nextSequence {
w.nextSequence = startSeqNum + 1
}
} else {
// Use traditional sequence number
startSeqNum = w.nextSequence
@ -543,7 +551,7 @@ func (w *WAL) AppendBatch(entries []*Entry) (uint64, error) {
for i, entry := range entries {
// Assign sequential sequence numbers to each entry
seqNum := startSeqNum + uint64(i)
// Save sequence number in the entry for replication
entry.SequenceNumber = seqNum
entriesForReplication[i] = entry
@ -569,7 +577,7 @@ func (w *WAL) AppendBatch(entries []*Entry) (uint64, error) {
if err := w.maybeSync(); err != nil {
return 0, err
}
// Notify replication hook if available
if w.replicationHook != nil {
w.replicationHook.OnBatchWritten(entriesForReplication)
@ -599,7 +607,6 @@ func (w *WAL) Close() error {
if err := w.file.Close(); err != nil {
return fmt.Errorf("failed to close WAL file: %w", err)
}
atomic.StoreInt32(&w.status, WALStatusClosed)
return nil
}
@ -629,7 +636,7 @@ func (w *WAL) UpdateNextSequence(nextSeq uint64) {
func (w *WAL) SetReplicationHook(hook ReplicationHook) {
w.mu.Lock()
defer w.mu.Unlock()
w.replicationHook = hook
}
@ -637,7 +644,7 @@ func (w *WAL) SetReplicationHook(hook ReplicationHook) {
func (w *WAL) SetLamportClock(clock LamportClock) {
w.mu.Lock()
defer w.mu.Unlock()
w.clock = clock
}

View File

@ -12,7 +12,10 @@ import (
)
func createTestConfig() *config.Config {
return config.NewDefaultConfig("/tmp/gostorage_test")
cfg := config.NewDefaultConfig("/tmp/gostorage_test")
// Force immediate sync for tests
cfg.WALSyncMode = config.SyncImmediate
return cfg
}
func createTempDir(t *testing.T) string {
@ -370,12 +373,13 @@ func TestWALRecovery(t *testing.T) {
func TestWALSyncModes(t *testing.T) {
testCases := []struct {
name string
syncMode config.SyncMode
name string
syncMode config.SyncMode
expectedEntries int // Expected number of entries after crash (without explicit sync)
}{
{"SyncNone", config.SyncNone},
{"SyncBatch", config.SyncBatch},
{"SyncImmediate", config.SyncImmediate},
{"SyncNone", config.SyncNone, 0}, // No entries should be recovered without explicit sync
{"SyncBatch", config.SyncBatch, 0}, // No entries should be recovered if batch threshold not reached
{"SyncImmediate", config.SyncImmediate, 10}, // All entries should be recovered
}
for _, tc := range testCases {
@ -386,6 +390,10 @@ func TestWALSyncModes(t *testing.T) {
// Create config with specific sync mode
cfg := createTestConfig()
cfg.WALSyncMode = tc.syncMode
// Set a high sync threshold for batch mode to ensure it doesn't auto-sync
if tc.syncMode == config.SyncBatch {
cfg.WALSyncBytes = 100 * 1024 * 1024 // 100MB, high enough to not trigger
}
wal, err := NewWAL(cfg, dir)
if err != nil {
@ -403,6 +411,8 @@ func TestWALSyncModes(t *testing.T) {
}
}
// Skip explicit sync to simulate a crash
// Close the WAL
if err := wal.Close(); err != nil {
t.Fatalf("Failed to close WAL: %v", err)
@ -421,8 +431,54 @@ func TestWALSyncModes(t *testing.T) {
t.Fatalf("Failed to replay WAL: %v", err)
}
if count != 10 {
t.Errorf("Expected 10 entries, got %d", count)
// Check that the number of recovered entries matches expectations for this sync mode
if count != tc.expectedEntries {
t.Errorf("Expected %d entries for %s mode, got %d", tc.expectedEntries, tc.name, count)
}
// Now test with explicit sync - all entries should be recoverable
wal, err = NewWAL(cfg, dir)
if err != nil {
t.Fatalf("Failed to create WAL: %v", err)
}
// Write some more entries
for i := 0; i < 10; i++ {
key := []byte(fmt.Sprintf("explicit_key%d", i))
value := []byte(fmt.Sprintf("explicit_value%d", i))
_, err := wal.Append(OpTypePut, key, value)
if err != nil {
t.Fatalf("Failed to append entry: %v", err)
}
}
// Explicitly sync
if err := wal.Sync(); err != nil {
t.Fatalf("Failed to sync WAL: %v", err)
}
// Close the WAL
if err := wal.Close(); err != nil {
t.Fatalf("Failed to close WAL: %v", err)
}
// Verify entries by replaying
explicitCount := 0
_, err = ReplayWALDir(dir, func(entry *Entry) error {
if entry.Type == OpTypePut && bytes.HasPrefix(entry.Key, []byte("explicit_")) {
explicitCount++
}
return nil
})
if err != nil {
t.Fatalf("Failed to replay WAL after explicit sync: %v", err)
}
// After explicit sync, all 10 new entries should be recovered regardless of mode
if explicitCount != 10 {
t.Errorf("Expected 10 entries after explicit sync, got %d", explicitCount)
}
})
}

1464
proto/kevo/replication.pb.go Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,152 @@
syntax = "proto3";
package kevo;
option go_package = "github.com/KevoDB/kevo/proto/kevo";
service ReplicationService {
// Replica Registration and Status
rpc RegisterReplica(RegisterReplicaRequest) returns (RegisterReplicaResponse);
rpc ReplicaHeartbeat(ReplicaHeartbeatRequest) returns (ReplicaHeartbeatResponse);
rpc GetReplicaStatus(GetReplicaStatusRequest) returns (GetReplicaStatusResponse);
rpc ListReplicas(ListReplicasRequest) returns (ListReplicasResponse);
// WAL Replication
rpc GetWALEntries(GetWALEntriesRequest) returns (GetWALEntriesResponse);
rpc StreamWALEntries(StreamWALEntriesRequest) returns (stream WALEntryBatch);
rpc ReportAppliedEntries(ReportAppliedEntriesRequest) returns (ReportAppliedEntriesResponse);
// Bootstrap Operations
rpc RequestBootstrap(BootstrapRequest) returns (stream BootstrapBatch);
}
// Replication status enum
enum ReplicaRole {
PRIMARY = 0;
REPLICA = 1;
READ_ONLY = 2;
}
enum ReplicaStatus {
CONNECTING = 0;
SYNCING = 1;
BOOTSTRAPPING = 2;
READY = 3;
DISCONNECTED = 4;
ERROR = 5;
}
// Replica Registration messages
message RegisterReplicaRequest {
string replica_id = 1;
string address = 2;
ReplicaRole role = 3;
}
message RegisterReplicaResponse {
bool success = 1;
string error_message = 2;
uint64 current_lsn = 3; // Current Lamport Sequence Number on primary
bool bootstrap_required = 4;
}
// Heartbeat messages
message ReplicaHeartbeatRequest {
string replica_id = 1;
ReplicaStatus status = 2;
uint64 current_lsn = 3; // Current Lamport Sequence Number on replica
string error_message = 4; // If status is ERROR
}
message ReplicaHeartbeatResponse {
bool success = 1;
uint64 primary_lsn = 2; // Current Lamport Sequence Number on primary
int64 replication_lag_ms = 3; // Estimated lag in milliseconds
}
// Status messages
message GetReplicaStatusRequest {
string replica_id = 1;
}
message ReplicaInfo {
string replica_id = 1;
string address = 2;
ReplicaRole role = 3;
ReplicaStatus status = 4;
int64 last_seen_ms = 5; // Timestamp of last heartbeat in milliseconds since epoch
uint64 current_lsn = 6; // Current Lamport Sequence Number
int64 replication_lag_ms = 7; // Estimated lag in milliseconds
string error_message = 8; // If status is ERROR
}
message GetReplicaStatusResponse {
ReplicaInfo replica = 1;
}
message ListReplicasRequest {}
message ListReplicasResponse {
repeated ReplicaInfo replicas = 1;
}
// WAL Replication messages
message WALEntry {
uint64 sequence_number = 1; // Lamport Sequence Number
uint32 type = 2; // Entry type (put, delete, etc.)
bytes key = 3;
bytes value = 4; // Optional, depending on type
bytes checksum = 5; // Checksum for data integrity
}
message WALEntryBatch {
repeated WALEntry entries = 1;
uint64 first_lsn = 2; // LSN of the first entry in the batch
uint64 last_lsn = 3; // LSN of the last entry in the batch
uint32 count = 4; // Number of entries in the batch
bytes checksum = 5; // Checksum of the entire batch
}
message GetWALEntriesRequest {
string replica_id = 1;
uint64 from_lsn = 2; // Request entries starting from this LSN
uint32 max_entries = 3; // Maximum number of entries to return (0 for no limit)
}
message GetWALEntriesResponse {
WALEntryBatch batch = 1;
bool has_more = 2; // True if there are more entries available
}
message StreamWALEntriesRequest {
string replica_id = 1;
uint64 from_lsn = 2; // Request entries starting from this LSN
bool continuous = 3; // If true, keep streaming as new entries arrive
}
message ReportAppliedEntriesRequest {
string replica_id = 1;
uint64 applied_lsn = 2; // Highest LSN successfully applied on the replica
}
message ReportAppliedEntriesResponse {
bool success = 1;
uint64 primary_lsn = 2; // Current LSN on primary
}
// Bootstrap messages
message BootstrapRequest {
string replica_id = 1;
}
message KeyValuePair {
bytes key = 1;
bytes value = 2;
}
message BootstrapBatch {
repeated KeyValuePair pairs = 1;
float progress = 2; // Progress from 0.0 to 1.0
bool is_last = 3; // True if this is the last batch
uint64 snapshot_lsn = 4; // LSN at which this snapshot was taken
}

View File

@ -0,0 +1,400 @@
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions:
// - protoc-gen-go-grpc v1.5.1
// - protoc v3.20.3
// source: proto/kevo/replication.proto
package kevo
import (
context "context"
grpc "google.golang.org/grpc"
codes "google.golang.org/grpc/codes"
status "google.golang.org/grpc/status"
)
// This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against.
// Requires gRPC-Go v1.64.0 or later.
const _ = grpc.SupportPackageIsVersion9
const (
ReplicationService_RegisterReplica_FullMethodName = "/kevo.ReplicationService/RegisterReplica"
ReplicationService_ReplicaHeartbeat_FullMethodName = "/kevo.ReplicationService/ReplicaHeartbeat"
ReplicationService_GetReplicaStatus_FullMethodName = "/kevo.ReplicationService/GetReplicaStatus"
ReplicationService_ListReplicas_FullMethodName = "/kevo.ReplicationService/ListReplicas"
ReplicationService_GetWALEntries_FullMethodName = "/kevo.ReplicationService/GetWALEntries"
ReplicationService_StreamWALEntries_FullMethodName = "/kevo.ReplicationService/StreamWALEntries"
ReplicationService_ReportAppliedEntries_FullMethodName = "/kevo.ReplicationService/ReportAppliedEntries"
ReplicationService_RequestBootstrap_FullMethodName = "/kevo.ReplicationService/RequestBootstrap"
)
// ReplicationServiceClient is the client API for ReplicationService service.
//
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
type ReplicationServiceClient interface {
// Replica Registration and Status
RegisterReplica(ctx context.Context, in *RegisterReplicaRequest, opts ...grpc.CallOption) (*RegisterReplicaResponse, error)
ReplicaHeartbeat(ctx context.Context, in *ReplicaHeartbeatRequest, opts ...grpc.CallOption) (*ReplicaHeartbeatResponse, error)
GetReplicaStatus(ctx context.Context, in *GetReplicaStatusRequest, opts ...grpc.CallOption) (*GetReplicaStatusResponse, error)
ListReplicas(ctx context.Context, in *ListReplicasRequest, opts ...grpc.CallOption) (*ListReplicasResponse, error)
// WAL Replication
GetWALEntries(ctx context.Context, in *GetWALEntriesRequest, opts ...grpc.CallOption) (*GetWALEntriesResponse, error)
StreamWALEntries(ctx context.Context, in *StreamWALEntriesRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[WALEntryBatch], error)
ReportAppliedEntries(ctx context.Context, in *ReportAppliedEntriesRequest, opts ...grpc.CallOption) (*ReportAppliedEntriesResponse, error)
// Bootstrap Operations
RequestBootstrap(ctx context.Context, in *BootstrapRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[BootstrapBatch], error)
}
type replicationServiceClient struct {
cc grpc.ClientConnInterface
}
func NewReplicationServiceClient(cc grpc.ClientConnInterface) ReplicationServiceClient {
return &replicationServiceClient{cc}
}
func (c *replicationServiceClient) RegisterReplica(ctx context.Context, in *RegisterReplicaRequest, opts ...grpc.CallOption) (*RegisterReplicaResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(RegisterReplicaResponse)
err := c.cc.Invoke(ctx, ReplicationService_RegisterReplica_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *replicationServiceClient) ReplicaHeartbeat(ctx context.Context, in *ReplicaHeartbeatRequest, opts ...grpc.CallOption) (*ReplicaHeartbeatResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(ReplicaHeartbeatResponse)
err := c.cc.Invoke(ctx, ReplicationService_ReplicaHeartbeat_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *replicationServiceClient) GetReplicaStatus(ctx context.Context, in *GetReplicaStatusRequest, opts ...grpc.CallOption) (*GetReplicaStatusResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(GetReplicaStatusResponse)
err := c.cc.Invoke(ctx, ReplicationService_GetReplicaStatus_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *replicationServiceClient) ListReplicas(ctx context.Context, in *ListReplicasRequest, opts ...grpc.CallOption) (*ListReplicasResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(ListReplicasResponse)
err := c.cc.Invoke(ctx, ReplicationService_ListReplicas_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *replicationServiceClient) GetWALEntries(ctx context.Context, in *GetWALEntriesRequest, opts ...grpc.CallOption) (*GetWALEntriesResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(GetWALEntriesResponse)
err := c.cc.Invoke(ctx, ReplicationService_GetWALEntries_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *replicationServiceClient) StreamWALEntries(ctx context.Context, in *StreamWALEntriesRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[WALEntryBatch], error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
stream, err := c.cc.NewStream(ctx, &ReplicationService_ServiceDesc.Streams[0], ReplicationService_StreamWALEntries_FullMethodName, cOpts...)
if err != nil {
return nil, err
}
x := &grpc.GenericClientStream[StreamWALEntriesRequest, WALEntryBatch]{ClientStream: stream}
if err := x.ClientStream.SendMsg(in); err != nil {
return nil, err
}
if err := x.ClientStream.CloseSend(); err != nil {
return nil, err
}
return x, nil
}
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
type ReplicationService_StreamWALEntriesClient = grpc.ServerStreamingClient[WALEntryBatch]
func (c *replicationServiceClient) ReportAppliedEntries(ctx context.Context, in *ReportAppliedEntriesRequest, opts ...grpc.CallOption) (*ReportAppliedEntriesResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(ReportAppliedEntriesResponse)
err := c.cc.Invoke(ctx, ReplicationService_ReportAppliedEntries_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *replicationServiceClient) RequestBootstrap(ctx context.Context, in *BootstrapRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[BootstrapBatch], error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
stream, err := c.cc.NewStream(ctx, &ReplicationService_ServiceDesc.Streams[1], ReplicationService_RequestBootstrap_FullMethodName, cOpts...)
if err != nil {
return nil, err
}
x := &grpc.GenericClientStream[BootstrapRequest, BootstrapBatch]{ClientStream: stream}
if err := x.ClientStream.SendMsg(in); err != nil {
return nil, err
}
if err := x.ClientStream.CloseSend(); err != nil {
return nil, err
}
return x, nil
}
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
type ReplicationService_RequestBootstrapClient = grpc.ServerStreamingClient[BootstrapBatch]
// ReplicationServiceServer is the server API for ReplicationService service.
// All implementations must embed UnimplementedReplicationServiceServer
// for forward compatibility.
type ReplicationServiceServer interface {
// Replica Registration and Status
RegisterReplica(context.Context, *RegisterReplicaRequest) (*RegisterReplicaResponse, error)
ReplicaHeartbeat(context.Context, *ReplicaHeartbeatRequest) (*ReplicaHeartbeatResponse, error)
GetReplicaStatus(context.Context, *GetReplicaStatusRequest) (*GetReplicaStatusResponse, error)
ListReplicas(context.Context, *ListReplicasRequest) (*ListReplicasResponse, error)
// WAL Replication
GetWALEntries(context.Context, *GetWALEntriesRequest) (*GetWALEntriesResponse, error)
StreamWALEntries(*StreamWALEntriesRequest, grpc.ServerStreamingServer[WALEntryBatch]) error
ReportAppliedEntries(context.Context, *ReportAppliedEntriesRequest) (*ReportAppliedEntriesResponse, error)
// Bootstrap Operations
RequestBootstrap(*BootstrapRequest, grpc.ServerStreamingServer[BootstrapBatch]) error
mustEmbedUnimplementedReplicationServiceServer()
}
// UnimplementedReplicationServiceServer must be embedded to have
// forward compatible implementations.
//
// NOTE: this should be embedded by value instead of pointer to avoid a nil
// pointer dereference when methods are called.
type UnimplementedReplicationServiceServer struct{}
func (UnimplementedReplicationServiceServer) RegisterReplica(context.Context, *RegisterReplicaRequest) (*RegisterReplicaResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method RegisterReplica not implemented")
}
func (UnimplementedReplicationServiceServer) ReplicaHeartbeat(context.Context, *ReplicaHeartbeatRequest) (*ReplicaHeartbeatResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method ReplicaHeartbeat not implemented")
}
func (UnimplementedReplicationServiceServer) GetReplicaStatus(context.Context, *GetReplicaStatusRequest) (*GetReplicaStatusResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method GetReplicaStatus not implemented")
}
func (UnimplementedReplicationServiceServer) ListReplicas(context.Context, *ListReplicasRequest) (*ListReplicasResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method ListReplicas not implemented")
}
func (UnimplementedReplicationServiceServer) GetWALEntries(context.Context, *GetWALEntriesRequest) (*GetWALEntriesResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method GetWALEntries not implemented")
}
func (UnimplementedReplicationServiceServer) StreamWALEntries(*StreamWALEntriesRequest, grpc.ServerStreamingServer[WALEntryBatch]) error {
return status.Errorf(codes.Unimplemented, "method StreamWALEntries not implemented")
}
func (UnimplementedReplicationServiceServer) ReportAppliedEntries(context.Context, *ReportAppliedEntriesRequest) (*ReportAppliedEntriesResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method ReportAppliedEntries not implemented")
}
func (UnimplementedReplicationServiceServer) RequestBootstrap(*BootstrapRequest, grpc.ServerStreamingServer[BootstrapBatch]) error {
return status.Errorf(codes.Unimplemented, "method RequestBootstrap not implemented")
}
func (UnimplementedReplicationServiceServer) mustEmbedUnimplementedReplicationServiceServer() {}
func (UnimplementedReplicationServiceServer) testEmbeddedByValue() {}
// UnsafeReplicationServiceServer may be embedded to opt out of forward compatibility for this service.
// Use of this interface is not recommended, as added methods to ReplicationServiceServer will
// result in compilation errors.
type UnsafeReplicationServiceServer interface {
mustEmbedUnimplementedReplicationServiceServer()
}
func RegisterReplicationServiceServer(s grpc.ServiceRegistrar, srv ReplicationServiceServer) {
// If the following call pancis, it indicates UnimplementedReplicationServiceServer was
// embedded by pointer and is nil. This will cause panics if an
// unimplemented method is ever invoked, so we test this at initialization
// time to prevent it from happening at runtime later due to I/O.
if t, ok := srv.(interface{ testEmbeddedByValue() }); ok {
t.testEmbeddedByValue()
}
s.RegisterService(&ReplicationService_ServiceDesc, srv)
}
func _ReplicationService_RegisterReplica_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(RegisterReplicaRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(ReplicationServiceServer).RegisterReplica(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: ReplicationService_RegisterReplica_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(ReplicationServiceServer).RegisterReplica(ctx, req.(*RegisterReplicaRequest))
}
return interceptor(ctx, in, info, handler)
}
func _ReplicationService_ReplicaHeartbeat_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(ReplicaHeartbeatRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(ReplicationServiceServer).ReplicaHeartbeat(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: ReplicationService_ReplicaHeartbeat_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(ReplicationServiceServer).ReplicaHeartbeat(ctx, req.(*ReplicaHeartbeatRequest))
}
return interceptor(ctx, in, info, handler)
}
func _ReplicationService_GetReplicaStatus_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(GetReplicaStatusRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(ReplicationServiceServer).GetReplicaStatus(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: ReplicationService_GetReplicaStatus_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(ReplicationServiceServer).GetReplicaStatus(ctx, req.(*GetReplicaStatusRequest))
}
return interceptor(ctx, in, info, handler)
}
func _ReplicationService_ListReplicas_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(ListReplicasRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(ReplicationServiceServer).ListReplicas(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: ReplicationService_ListReplicas_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(ReplicationServiceServer).ListReplicas(ctx, req.(*ListReplicasRequest))
}
return interceptor(ctx, in, info, handler)
}
func _ReplicationService_GetWALEntries_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(GetWALEntriesRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(ReplicationServiceServer).GetWALEntries(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: ReplicationService_GetWALEntries_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(ReplicationServiceServer).GetWALEntries(ctx, req.(*GetWALEntriesRequest))
}
return interceptor(ctx, in, info, handler)
}
func _ReplicationService_StreamWALEntries_Handler(srv interface{}, stream grpc.ServerStream) error {
m := new(StreamWALEntriesRequest)
if err := stream.RecvMsg(m); err != nil {
return err
}
return srv.(ReplicationServiceServer).StreamWALEntries(m, &grpc.GenericServerStream[StreamWALEntriesRequest, WALEntryBatch]{ServerStream: stream})
}
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
type ReplicationService_StreamWALEntriesServer = grpc.ServerStreamingServer[WALEntryBatch]
func _ReplicationService_ReportAppliedEntries_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(ReportAppliedEntriesRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(ReplicationServiceServer).ReportAppliedEntries(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: ReplicationService_ReportAppliedEntries_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(ReplicationServiceServer).ReportAppliedEntries(ctx, req.(*ReportAppliedEntriesRequest))
}
return interceptor(ctx, in, info, handler)
}
func _ReplicationService_RequestBootstrap_Handler(srv interface{}, stream grpc.ServerStream) error {
m := new(BootstrapRequest)
if err := stream.RecvMsg(m); err != nil {
return err
}
return srv.(ReplicationServiceServer).RequestBootstrap(m, &grpc.GenericServerStream[BootstrapRequest, BootstrapBatch]{ServerStream: stream})
}
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
type ReplicationService_RequestBootstrapServer = grpc.ServerStreamingServer[BootstrapBatch]
// ReplicationService_ServiceDesc is the grpc.ServiceDesc for ReplicationService service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
var ReplicationService_ServiceDesc = grpc.ServiceDesc{
ServiceName: "kevo.ReplicationService",
HandlerType: (*ReplicationServiceServer)(nil),
Methods: []grpc.MethodDesc{
{
MethodName: "RegisterReplica",
Handler: _ReplicationService_RegisterReplica_Handler,
},
{
MethodName: "ReplicaHeartbeat",
Handler: _ReplicationService_ReplicaHeartbeat_Handler,
},
{
MethodName: "GetReplicaStatus",
Handler: _ReplicationService_GetReplicaStatus_Handler,
},
{
MethodName: "ListReplicas",
Handler: _ReplicationService_ListReplicas_Handler,
},
{
MethodName: "GetWALEntries",
Handler: _ReplicationService_GetWALEntries_Handler,
},
{
MethodName: "ReportAppliedEntries",
Handler: _ReplicationService_ReportAppliedEntries_Handler,
},
},
Streams: []grpc.StreamDesc{
{
StreamName: "StreamWALEntries",
Handler: _ReplicationService_StreamWALEntries_Handler,
ServerStreams: true,
},
{
StreamName: "RequestBootstrap",
Handler: _ReplicationService_RequestBootstrap_Handler,
ServerStreams: true,
},
},
Metadata: "proto/kevo/replication.proto",
}

View File

@ -4,7 +4,7 @@
// protoc v3.20.3
// source: proto/kevo/service.proto
package proto
package kevo
import (
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
@ -1911,7 +1911,7 @@ const file_proto_kevo_service_proto_rawDesc = "" +
"\bTxDelete\x12\x15.kevo.TxDeleteRequest\x1a\x16.kevo.TxDeleteResponse\x125\n" +
"\x06TxScan\x12\x13.kevo.TxScanRequest\x1a\x14.kevo.TxScanResponse0\x01\x129\n" +
"\bGetStats\x12\x15.kevo.GetStatsRequest\x1a\x16.kevo.GetStatsResponse\x126\n" +
"\aCompact\x12\x14.kevo.CompactRequest\x1a\x15.kevo.CompactResponseB5Z3github.com/jeremytregunna/kevo/pkg/grpc/proto;protob\x06proto3"
"\aCompact\x12\x14.kevo.CompactRequest\x1a\x15.kevo.CompactResponseB#Z!github.com/KevoDB/kevo/proto/kevob\x06proto3"
var (
file_proto_kevo_service_proto_rawDescOnce sync.Once

View File

@ -2,7 +2,7 @@ syntax = "proto3";
package kevo;
option go_package = "github.com/jeremytregunna/kevo/pkg/grpc/proto;proto";
option go_package = "github.com/KevoDB/kevo/proto/kevo";
service KevoService {
// Key-Value Operations

View File

@ -4,7 +4,7 @@
// - protoc v3.20.3
// source: proto/kevo/service.proto
package proto
package kevo
import (
context "context"