Compare commits
6 Commits
c0bfd835f7
...
5963538bc5
Author | SHA1 | Date | |
---|---|---|---|
5963538bc5 | |||
ed991ae00d | |||
33ddfeeb64 | |||
5cd1f5c5f8 | |||
02febadf5d | |||
5b2ecdd08c |
86
pkg/engine/storage/retry.go
Normal file
86
pkg/engine/storage/retry.go
Normal file
@ -0,0 +1,86 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"time"
|
||||
|
||||
"github.com/KevoDB/kevo/pkg/wal"
|
||||
)
|
||||
|
||||
// RetryConfig defines parameters for retry operations
|
||||
type RetryConfig struct {
|
||||
MaxRetries int // Maximum number of retries
|
||||
InitialBackoff time.Duration // Initial backoff duration
|
||||
MaxBackoff time.Duration // Maximum backoff duration
|
||||
}
|
||||
|
||||
// DefaultRetryConfig returns default retry configuration
|
||||
func DefaultRetryConfig() *RetryConfig {
|
||||
return &RetryConfig{
|
||||
MaxRetries: 3,
|
||||
InitialBackoff: 5 * time.Millisecond,
|
||||
MaxBackoff: 50 * time.Millisecond,
|
||||
}
|
||||
}
|
||||
|
||||
// RetryOnWALRotating retries the operation if it fails with ErrWALRotating
|
||||
func (m *Manager) RetryOnWALRotating(operation func() error) error {
|
||||
config := DefaultRetryConfig()
|
||||
return m.RetryWithConfig(operation, config, isWALRotating)
|
||||
}
|
||||
|
||||
// RetryWithSequence retries the operation if it fails with ErrWALRotating
|
||||
// and returns the sequence number
|
||||
func (m *Manager) RetryWithSequence(operation func() (uint64, error)) (uint64, error) {
|
||||
config := DefaultRetryConfig()
|
||||
var seq uint64
|
||||
|
||||
err := m.RetryWithConfig(func() error {
|
||||
var opErr error
|
||||
seq, opErr = operation()
|
||||
return opErr
|
||||
}, config, isWALRotating)
|
||||
|
||||
return seq, err
|
||||
}
|
||||
|
||||
// RetryWithConfig retries an operation with the given configuration
|
||||
func (m *Manager) RetryWithConfig(operation func() error, config *RetryConfig, isRetryable func(error) bool) error {
|
||||
backoff := config.InitialBackoff
|
||||
|
||||
for i := 0; i <= config.MaxRetries; i++ {
|
||||
// Attempt the operation
|
||||
err := operation()
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if we should retry
|
||||
if !isRetryable(err) || i == config.MaxRetries {
|
||||
return err
|
||||
}
|
||||
|
||||
// Add some jitter to the backoff
|
||||
jitter := time.Duration(rand.Int63n(int64(backoff / 10)))
|
||||
backoff = backoff + jitter
|
||||
|
||||
// Wait before retrying
|
||||
time.Sleep(backoff)
|
||||
|
||||
// Increase backoff for next attempt, but cap it
|
||||
backoff = 2 * backoff
|
||||
if backoff > config.MaxBackoff {
|
||||
backoff = config.MaxBackoff
|
||||
}
|
||||
}
|
||||
|
||||
// Should never get here, but just in case
|
||||
return nil
|
||||
}
|
||||
|
||||
// isWALRotating checks if the error is due to WAL rotation or closure
|
||||
func isWALRotating(err error) bool {
|
||||
// Both ErrWALRotating and ErrWALClosed can occur during WAL rotation
|
||||
// Since WAL rotation is a normal operation, we should retry in both cases
|
||||
return err == wal.ErrWALRotating || err == wal.ErrWALClosed
|
||||
}
|
610
pkg/grpc/service/replication_service.go
Normal file
610
pkg/grpc/service/replication_service.go
Normal file
@ -0,0 +1,610 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/KevoDB/kevo/pkg/replication"
|
||||
"github.com/KevoDB/kevo/pkg/transport"
|
||||
"github.com/KevoDB/kevo/pkg/wal"
|
||||
"github.com/KevoDB/kevo/proto/kevo"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
// ReplicationServiceServer implements the gRPC ReplicationService
|
||||
type ReplicationServiceServer struct {
|
||||
kevo.UnimplementedReplicationServiceServer
|
||||
|
||||
// Replication components
|
||||
replicator *replication.WALReplicator
|
||||
applier *replication.WALApplier
|
||||
serializer *replication.EntrySerializer
|
||||
highestLSN uint64
|
||||
replicas map[string]*transport.ReplicaInfo
|
||||
replicasMutex sync.RWMutex
|
||||
|
||||
// For snapshot/bootstrap
|
||||
storageSnapshot replication.StorageSnapshot
|
||||
}
|
||||
|
||||
// NewReplicationService creates a new ReplicationService
|
||||
func NewReplicationService(
|
||||
replicator *replication.WALReplicator,
|
||||
applier *replication.WALApplier,
|
||||
serializer *replication.EntrySerializer,
|
||||
storageSnapshot replication.StorageSnapshot,
|
||||
) *ReplicationServiceServer {
|
||||
return &ReplicationServiceServer{
|
||||
replicator: replicator,
|
||||
applier: applier,
|
||||
serializer: serializer,
|
||||
replicas: make(map[string]*transport.ReplicaInfo),
|
||||
storageSnapshot: storageSnapshot,
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterReplica handles registration of a new replica
|
||||
func (s *ReplicationServiceServer) RegisterReplica(
|
||||
ctx context.Context,
|
||||
req *kevo.RegisterReplicaRequest,
|
||||
) (*kevo.RegisterReplicaResponse, error) {
|
||||
// Validate request
|
||||
if req.ReplicaId == "" {
|
||||
return nil, status.Error(codes.InvalidArgument, "replica_id is required")
|
||||
}
|
||||
if req.Address == "" {
|
||||
return nil, status.Error(codes.InvalidArgument, "address is required")
|
||||
}
|
||||
|
||||
// Convert role enum to string
|
||||
role := transport.RoleReplica
|
||||
switch req.Role {
|
||||
case kevo.ReplicaRole_PRIMARY:
|
||||
role = transport.RolePrimary
|
||||
case kevo.ReplicaRole_REPLICA:
|
||||
role = transport.RoleReplica
|
||||
case kevo.ReplicaRole_READ_ONLY:
|
||||
role = transport.RoleReadOnly
|
||||
default:
|
||||
return nil, status.Error(codes.InvalidArgument, "invalid role")
|
||||
}
|
||||
|
||||
// Register the replica
|
||||
s.replicasMutex.Lock()
|
||||
defer s.replicasMutex.Unlock()
|
||||
|
||||
// If already registered, update address and role
|
||||
if replica, exists := s.replicas[req.ReplicaId]; exists {
|
||||
replica.Address = req.Address
|
||||
replica.Role = role
|
||||
replica.LastSeen = time.Now()
|
||||
replica.Status = transport.StatusConnecting
|
||||
} else {
|
||||
// Create new replica info
|
||||
s.replicas[req.ReplicaId] = &transport.ReplicaInfo{
|
||||
ID: req.ReplicaId,
|
||||
Address: req.Address,
|
||||
Role: role,
|
||||
Status: transport.StatusConnecting,
|
||||
LastSeen: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// Determine if bootstrap is needed (for now always suggest bootstrap)
|
||||
bootstrapRequired := true
|
||||
|
||||
// Return current highest LSN
|
||||
currentLSN := s.replicator.GetHighestTimestamp()
|
||||
|
||||
return &kevo.RegisterReplicaResponse{
|
||||
Success: true,
|
||||
CurrentLsn: currentLSN,
|
||||
BootstrapRequired: bootstrapRequired,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ReplicaHeartbeat handles heartbeat requests from replicas
|
||||
func (s *ReplicationServiceServer) ReplicaHeartbeat(
|
||||
ctx context.Context,
|
||||
req *kevo.ReplicaHeartbeatRequest,
|
||||
) (*kevo.ReplicaHeartbeatResponse, error) {
|
||||
// Validate request
|
||||
if req.ReplicaId == "" {
|
||||
return nil, status.Error(codes.InvalidArgument, "replica_id is required")
|
||||
}
|
||||
|
||||
// Check if replica is registered
|
||||
s.replicasMutex.Lock()
|
||||
defer s.replicasMutex.Unlock()
|
||||
|
||||
replica, exists := s.replicas[req.ReplicaId]
|
||||
if !exists {
|
||||
return nil, status.Error(codes.NotFound, "replica not registered")
|
||||
}
|
||||
|
||||
// Update replica status
|
||||
replica.LastSeen = time.Now()
|
||||
|
||||
// Convert status enum to string
|
||||
switch req.Status {
|
||||
case kevo.ReplicaStatus_CONNECTING:
|
||||
replica.Status = transport.StatusConnecting
|
||||
case kevo.ReplicaStatus_SYNCING:
|
||||
replica.Status = transport.StatusSyncing
|
||||
case kevo.ReplicaStatus_BOOTSTRAPPING:
|
||||
replica.Status = transport.StatusBootstrapping
|
||||
case kevo.ReplicaStatus_READY:
|
||||
replica.Status = transport.StatusReady
|
||||
case kevo.ReplicaStatus_DISCONNECTED:
|
||||
replica.Status = transport.StatusDisconnected
|
||||
case kevo.ReplicaStatus_ERROR:
|
||||
replica.Status = transport.StatusError
|
||||
replica.Error = fmt.Errorf("%s", req.ErrorMessage)
|
||||
default:
|
||||
return nil, status.Error(codes.InvalidArgument, "invalid status")
|
||||
}
|
||||
|
||||
// Update replica LSN
|
||||
replica.CurrentLSN = req.CurrentLsn
|
||||
|
||||
// Calculate replication lag
|
||||
primaryLSN := s.replicator.GetHighestTimestamp()
|
||||
var replicationLagMs int64 = 0
|
||||
if primaryLSN > req.CurrentLsn {
|
||||
// Simple lag calculation based on LSN difference
|
||||
// In a real system, we'd use timestamps for better accuracy
|
||||
replicationLagMs = int64(primaryLSN - req.CurrentLsn)
|
||||
}
|
||||
|
||||
replica.ReplicationLag = time.Duration(replicationLagMs) * time.Millisecond
|
||||
|
||||
return &kevo.ReplicaHeartbeatResponse{
|
||||
Success: true,
|
||||
PrimaryLsn: primaryLSN,
|
||||
ReplicationLagMs: replicationLagMs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetReplicaStatus retrieves the status of a specific replica
|
||||
func (s *ReplicationServiceServer) GetReplicaStatus(
|
||||
ctx context.Context,
|
||||
req *kevo.GetReplicaStatusRequest,
|
||||
) (*kevo.GetReplicaStatusResponse, error) {
|
||||
// Validate request
|
||||
if req.ReplicaId == "" {
|
||||
return nil, status.Error(codes.InvalidArgument, "replica_id is required")
|
||||
}
|
||||
|
||||
// Get replica info
|
||||
s.replicasMutex.RLock()
|
||||
defer s.replicasMutex.RUnlock()
|
||||
|
||||
replica, exists := s.replicas[req.ReplicaId]
|
||||
if !exists {
|
||||
return nil, status.Error(codes.NotFound, "replica not found")
|
||||
}
|
||||
|
||||
// Convert replica info to proto message
|
||||
pbReplica := convertReplicaInfoToProto(replica)
|
||||
|
||||
return &kevo.GetReplicaStatusResponse{
|
||||
Replica: pbReplica,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ListReplicas retrieves the status of all replicas
|
||||
func (s *ReplicationServiceServer) ListReplicas(
|
||||
ctx context.Context,
|
||||
req *kevo.ListReplicasRequest,
|
||||
) (*kevo.ListReplicasResponse, error) {
|
||||
s.replicasMutex.RLock()
|
||||
defer s.replicasMutex.RUnlock()
|
||||
|
||||
// Convert all replicas to proto messages
|
||||
pbReplicas := make([]*kevo.ReplicaInfo, 0, len(s.replicas))
|
||||
for _, replica := range s.replicas {
|
||||
pbReplicas = append(pbReplicas, convertReplicaInfoToProto(replica))
|
||||
}
|
||||
|
||||
return &kevo.ListReplicasResponse{
|
||||
Replicas: pbReplicas,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetWALEntries handles requests for WAL entries
|
||||
func (s *ReplicationServiceServer) GetWALEntries(
|
||||
ctx context.Context,
|
||||
req *kevo.GetWALEntriesRequest,
|
||||
) (*kevo.GetWALEntriesResponse, error) {
|
||||
// Validate request
|
||||
if req.ReplicaId == "" {
|
||||
return nil, status.Error(codes.InvalidArgument, "replica_id is required")
|
||||
}
|
||||
|
||||
// Check if replica is registered
|
||||
s.replicasMutex.RLock()
|
||||
_, exists := s.replicas[req.ReplicaId]
|
||||
s.replicasMutex.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return nil, status.Error(codes.NotFound, "replica not registered")
|
||||
}
|
||||
|
||||
// Get entries from replicator
|
||||
entries, err := s.replicator.GetEntriesAfter(replication.ReplicationPosition{Timestamp: req.FromLsn})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get entries: %v", err)
|
||||
}
|
||||
if len(entries) == 0 {
|
||||
return &kevo.GetWALEntriesResponse{
|
||||
Batch: &kevo.WALEntryBatch{},
|
||||
HasMore: false,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Convert entries to proto messages
|
||||
pbEntries := make([]*kevo.WALEntry, 0, len(entries))
|
||||
for _, entry := range entries {
|
||||
pbEntries = append(pbEntries, convertWALEntryToProto(entry))
|
||||
}
|
||||
|
||||
// Create batch
|
||||
pbBatch := &kevo.WALEntryBatch{
|
||||
Entries: pbEntries,
|
||||
FirstLsn: entries[0].SequenceNumber,
|
||||
LastLsn: entries[len(entries)-1].SequenceNumber,
|
||||
Count: uint32(len(entries)),
|
||||
}
|
||||
|
||||
// Check if there are more entries
|
||||
hasMore := s.replicator.GetHighestTimestamp() > entries[len(entries)-1].SequenceNumber
|
||||
|
||||
return &kevo.GetWALEntriesResponse{
|
||||
Batch: pbBatch,
|
||||
HasMore: hasMore,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// StreamWALEntries handles streaming WAL entries to a replica
|
||||
func (s *ReplicationServiceServer) StreamWALEntries(
|
||||
req *kevo.StreamWALEntriesRequest,
|
||||
stream kevo.ReplicationService_StreamWALEntriesServer,
|
||||
) error {
|
||||
// Validate request
|
||||
if req.ReplicaId == "" {
|
||||
return status.Error(codes.InvalidArgument, "replica_id is required")
|
||||
}
|
||||
|
||||
// Check if replica is registered
|
||||
s.replicasMutex.RLock()
|
||||
_, exists := s.replicas[req.ReplicaId]
|
||||
s.replicasMutex.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return status.Error(codes.NotFound, "replica not registered")
|
||||
}
|
||||
|
||||
// Process entries in batches
|
||||
fromLSN := req.FromLsn
|
||||
batchSize := 100 // Can be configurable
|
||||
notifierCh := make(chan struct{}, 1)
|
||||
|
||||
// Register for notifications of new entries
|
||||
var processor *entryNotifier
|
||||
if req.Continuous {
|
||||
processor = &entryNotifier{
|
||||
notifyCh: notifierCh,
|
||||
fromLSN: fromLSN,
|
||||
}
|
||||
s.replicator.AddProcessor(processor)
|
||||
defer s.replicator.RemoveProcessor(processor)
|
||||
}
|
||||
|
||||
// Initial send of available entries
|
||||
for {
|
||||
// Get batch of entries
|
||||
entries, err := s.replicator.GetEntriesAfter(replication.ReplicationPosition{Timestamp: fromLSN})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(entries) == 0 {
|
||||
// No more entries, check if continuous streaming
|
||||
if !req.Continuous {
|
||||
break
|
||||
}
|
||||
// Wait for notification of new entries
|
||||
select {
|
||||
case <-notifierCh:
|
||||
continue
|
||||
case <-stream.Context().Done():
|
||||
return stream.Context().Err()
|
||||
}
|
||||
}
|
||||
|
||||
// Create batch message
|
||||
pbEntries := make([]*kevo.WALEntry, 0, len(entries))
|
||||
for _, entry := range entries {
|
||||
pbEntries = append(pbEntries, convertWALEntryToProto(entry))
|
||||
}
|
||||
|
||||
pbBatch := &kevo.WALEntryBatch{
|
||||
Entries: pbEntries,
|
||||
FirstLsn: entries[0].SequenceNumber,
|
||||
LastLsn: entries[len(entries)-1].SequenceNumber,
|
||||
Count: uint32(len(entries)),
|
||||
}
|
||||
|
||||
// Send batch
|
||||
if err := stream.Send(pbBatch); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Update fromLSN for next batch
|
||||
fromLSN = entries[len(entries)-1].SequenceNumber + 1
|
||||
|
||||
// If not continuous, break after sending all available entries
|
||||
if !req.Continuous && len(entries) < batchSize {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReportAppliedEntries handles reports of entries applied by replicas
|
||||
func (s *ReplicationServiceServer) ReportAppliedEntries(
|
||||
ctx context.Context,
|
||||
req *kevo.ReportAppliedEntriesRequest,
|
||||
) (*kevo.ReportAppliedEntriesResponse, error) {
|
||||
// Validate request
|
||||
if req.ReplicaId == "" {
|
||||
return nil, status.Error(codes.InvalidArgument, "replica_id is required")
|
||||
}
|
||||
|
||||
// Update replica LSN
|
||||
s.replicasMutex.Lock()
|
||||
replica, exists := s.replicas[req.ReplicaId]
|
||||
if exists {
|
||||
replica.CurrentLSN = req.AppliedLsn
|
||||
}
|
||||
s.replicasMutex.Unlock()
|
||||
|
||||
if !exists {
|
||||
return nil, status.Error(codes.NotFound, "replica not registered")
|
||||
}
|
||||
|
||||
return &kevo.ReportAppliedEntriesResponse{
|
||||
Success: true,
|
||||
PrimaryLsn: s.replicator.GetHighestTimestamp(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RequestBootstrap handles bootstrap requests from replicas
|
||||
func (s *ReplicationServiceServer) RequestBootstrap(
|
||||
req *kevo.BootstrapRequest,
|
||||
stream kevo.ReplicationService_RequestBootstrapServer,
|
||||
) error {
|
||||
// Validate request
|
||||
if req.ReplicaId == "" {
|
||||
return status.Error(codes.InvalidArgument, "replica_id is required")
|
||||
}
|
||||
|
||||
// Check if replica is registered
|
||||
s.replicasMutex.RLock()
|
||||
replica, exists := s.replicas[req.ReplicaId]
|
||||
s.replicasMutex.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return status.Error(codes.NotFound, "replica not registered")
|
||||
}
|
||||
|
||||
// Update replica status
|
||||
s.replicasMutex.Lock()
|
||||
replica.Status = transport.StatusBootstrapping
|
||||
s.replicasMutex.Unlock()
|
||||
|
||||
// Create snapshot of current data
|
||||
snapshotLSN := s.replicator.GetHighestTimestamp()
|
||||
iterator, err := s.storageSnapshot.CreateSnapshotIterator()
|
||||
if err != nil {
|
||||
s.replicasMutex.Lock()
|
||||
replica.Status = transport.StatusError
|
||||
replica.Error = err
|
||||
s.replicasMutex.Unlock()
|
||||
return status.Errorf(codes.Internal, "failed to create snapshot: %v", err)
|
||||
}
|
||||
defer iterator.Close()
|
||||
|
||||
// Stream key-value pairs in batches
|
||||
batchSize := 100 // Can be configurable
|
||||
totalCount := s.storageSnapshot.KeyCount()
|
||||
sentCount := 0
|
||||
batch := make([]*kevo.KeyValuePair, 0, batchSize)
|
||||
|
||||
for {
|
||||
// Get next key-value pair
|
||||
key, value, err := iterator.Next()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
s.replicasMutex.Lock()
|
||||
replica.Status = transport.StatusError
|
||||
replica.Error = err
|
||||
s.replicasMutex.Unlock()
|
||||
return status.Errorf(codes.Internal, "error reading snapshot: %v", err)
|
||||
}
|
||||
|
||||
// Add to batch
|
||||
batch = append(batch, &kevo.KeyValuePair{
|
||||
Key: key,
|
||||
Value: value,
|
||||
})
|
||||
|
||||
// Send batch if full
|
||||
if len(batch) >= batchSize {
|
||||
progress := float32(sentCount) / float32(totalCount)
|
||||
if err := stream.Send(&kevo.BootstrapBatch{
|
||||
Pairs: batch,
|
||||
Progress: progress,
|
||||
IsLast: false,
|
||||
SnapshotLsn: snapshotLSN,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Reset batch and update count
|
||||
sentCount += len(batch)
|
||||
batch = batch[:0]
|
||||
}
|
||||
}
|
||||
|
||||
// Send final batch
|
||||
if len(batch) > 0 {
|
||||
sentCount += len(batch)
|
||||
progress := float32(sentCount) / float32(totalCount)
|
||||
if err := stream.Send(&kevo.BootstrapBatch{
|
||||
Pairs: batch,
|
||||
Progress: progress,
|
||||
IsLast: true,
|
||||
SnapshotLsn: snapshotLSN,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
} else if sentCount > 0 {
|
||||
// Send empty final batch to mark the end
|
||||
if err := stream.Send(&kevo.BootstrapBatch{
|
||||
Pairs: []*kevo.KeyValuePair{},
|
||||
Progress: 1.0,
|
||||
IsLast: true,
|
||||
SnapshotLsn: snapshotLSN,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Update replica status
|
||||
s.replicasMutex.Lock()
|
||||
replica.Status = transport.StatusSyncing
|
||||
replica.CurrentLSN = snapshotLSN
|
||||
s.replicasMutex.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Helper to convert replica info to proto message
|
||||
func convertReplicaInfoToProto(replica *transport.ReplicaInfo) *kevo.ReplicaInfo {
|
||||
// Convert status to proto enum
|
||||
var pbStatus kevo.ReplicaStatus
|
||||
switch replica.Status {
|
||||
case transport.StatusConnecting:
|
||||
pbStatus = kevo.ReplicaStatus_CONNECTING
|
||||
case transport.StatusSyncing:
|
||||
pbStatus = kevo.ReplicaStatus_SYNCING
|
||||
case transport.StatusBootstrapping:
|
||||
pbStatus = kevo.ReplicaStatus_BOOTSTRAPPING
|
||||
case transport.StatusReady:
|
||||
pbStatus = kevo.ReplicaStatus_READY
|
||||
case transport.StatusDisconnected:
|
||||
pbStatus = kevo.ReplicaStatus_DISCONNECTED
|
||||
case transport.StatusError:
|
||||
pbStatus = kevo.ReplicaStatus_ERROR
|
||||
default:
|
||||
pbStatus = kevo.ReplicaStatus_DISCONNECTED
|
||||
}
|
||||
|
||||
// Convert role to proto enum
|
||||
var pbRole kevo.ReplicaRole
|
||||
switch replica.Role {
|
||||
case transport.RolePrimary:
|
||||
pbRole = kevo.ReplicaRole_PRIMARY
|
||||
case transport.RoleReplica:
|
||||
pbRole = kevo.ReplicaRole_REPLICA
|
||||
case transport.RoleReadOnly:
|
||||
pbRole = kevo.ReplicaRole_READ_ONLY
|
||||
default:
|
||||
pbRole = kevo.ReplicaRole_REPLICA
|
||||
}
|
||||
|
||||
// Create proto message
|
||||
pbReplica := &kevo.ReplicaInfo{
|
||||
ReplicaId: replica.ID,
|
||||
Address: replica.Address,
|
||||
Role: pbRole,
|
||||
Status: pbStatus,
|
||||
LastSeenMs: replica.LastSeen.UnixMilli(),
|
||||
CurrentLsn: replica.CurrentLSN,
|
||||
ReplicationLagMs: replica.ReplicationLag.Milliseconds(),
|
||||
}
|
||||
|
||||
// Add error message if any
|
||||
if replica.Error != nil {
|
||||
pbReplica.ErrorMessage = replica.Error.Error()
|
||||
}
|
||||
|
||||
return pbReplica
|
||||
}
|
||||
|
||||
// Convert WAL entry to proto message
|
||||
func convertWALEntryToProto(entry *wal.Entry) *kevo.WALEntry {
|
||||
return &kevo.WALEntry{
|
||||
SequenceNumber: entry.SequenceNumber,
|
||||
Type: uint32(entry.Type),
|
||||
Key: entry.Key,
|
||||
Value: entry.Value,
|
||||
// We'd normally calculate a checksum here
|
||||
Checksum: nil,
|
||||
}
|
||||
}
|
||||
|
||||
// entryNotifier is a helper struct that implements replication.EntryProcessor
|
||||
// to notify when new entries are available
|
||||
type entryNotifier struct {
|
||||
notifyCh chan struct{}
|
||||
fromLSN uint64
|
||||
}
|
||||
|
||||
func (n *entryNotifier) ProcessEntry(entry *wal.Entry) error {
|
||||
if entry.SequenceNumber >= n.fromLSN {
|
||||
select {
|
||||
case n.notifyCh <- struct{}{}:
|
||||
default:
|
||||
// Channel already has a notification, no need to send another
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *entryNotifier) ProcessBatch(entries []*wal.Entry) error {
|
||||
if len(entries) > 0 && entries[len(entries)-1].SequenceNumber >= n.fromLSN {
|
||||
select {
|
||||
case n.notifyCh <- struct{}{}:
|
||||
default:
|
||||
// Channel already has a notification, no need to send another
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Define the interface for storage snapshot operations
|
||||
// This would normally be implemented by the storage engine
|
||||
type StorageSnapshot interface {
|
||||
// CreateSnapshotIterator creates an iterator for a storage snapshot
|
||||
CreateSnapshotIterator() (SnapshotIterator, error)
|
||||
|
||||
// KeyCount returns the approximate number of keys in storage
|
||||
KeyCount() int64
|
||||
}
|
||||
|
||||
// SnapshotIterator provides iteration over key-value pairs in storage
|
||||
type SnapshotIterator interface {
|
||||
// Next returns the next key-value pair
|
||||
Next() (key []byte, value []byte, err error)
|
||||
|
||||
// Close closes the iterator
|
||||
Close() error
|
||||
}
|
494
pkg/grpc/transport/replication_client.go
Normal file
494
pkg/grpc/transport/replication_client.go
Normal file
@ -0,0 +1,494 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/KevoDB/kevo/pkg/replication"
|
||||
"github.com/KevoDB/kevo/pkg/transport"
|
||||
"github.com/KevoDB/kevo/pkg/wal"
|
||||
"github.com/KevoDB/kevo/proto/kevo"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/connectivity"
|
||||
)
|
||||
|
||||
// ReplicationGRPCClient implements the ReplicationClient interface using gRPC
|
||||
type ReplicationGRPCClient struct {
|
||||
conn *grpc.ClientConn
|
||||
client kevo.ReplicationServiceClient
|
||||
endpoint string
|
||||
options transport.TransportOptions
|
||||
replicaID string
|
||||
status transport.TransportStatus
|
||||
applier *replication.WALApplier
|
||||
serializer *replication.EntrySerializer
|
||||
highestAppliedLSN uint64
|
||||
currentLSN uint64
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewReplicationGRPCClient creates a new ReplicationGRPCClient
|
||||
func NewReplicationGRPCClient(
|
||||
endpoint string,
|
||||
options transport.TransportOptions,
|
||||
replicaID string,
|
||||
applier *replication.WALApplier,
|
||||
serializer *replication.EntrySerializer,
|
||||
) (*ReplicationGRPCClient, error) {
|
||||
return &ReplicationGRPCClient{
|
||||
endpoint: endpoint,
|
||||
options: options,
|
||||
replicaID: replicaID,
|
||||
applier: applier,
|
||||
serializer: serializer,
|
||||
status: transport.TransportStatus{
|
||||
Connected: false,
|
||||
LastConnected: time.Time{},
|
||||
LastError: nil,
|
||||
BytesSent: 0,
|
||||
BytesReceived: 0,
|
||||
RTT: 0,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Connect establishes a connection to the server
|
||||
func (c *ReplicationGRPCClient) Connect(ctx context.Context) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
// Set up connection options
|
||||
dialOptions := []grpc.DialOption{
|
||||
grpc.WithBlock(),
|
||||
}
|
||||
|
||||
// Add TLS if configured - TODO: Add TLS support once TLS helpers are implemented
|
||||
if c.options.TLSEnabled {
|
||||
// We'll need to implement TLS credentials loading
|
||||
// For now, we'll skip TLS
|
||||
c.options.TLSEnabled = false
|
||||
} else {
|
||||
dialOptions = append(dialOptions, grpc.WithInsecure())
|
||||
}
|
||||
|
||||
// Set timeout for connection
|
||||
dialCtx, cancel := context.WithTimeout(ctx, c.options.Timeout)
|
||||
defer cancel()
|
||||
|
||||
// Establish connection
|
||||
conn, err := grpc.DialContext(dialCtx, c.endpoint, dialOptions...)
|
||||
if err != nil {
|
||||
c.status.LastError = err
|
||||
return err
|
||||
}
|
||||
|
||||
c.conn = conn
|
||||
c.client = kevo.NewReplicationServiceClient(conn)
|
||||
c.status.Connected = true
|
||||
c.status.LastConnected = time.Now()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the connection
|
||||
func (c *ReplicationGRPCClient) Close() error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.conn != nil {
|
||||
err := c.conn.Close()
|
||||
c.conn = nil
|
||||
c.client = nil
|
||||
c.status.Connected = false
|
||||
if err != nil {
|
||||
c.status.LastError = err
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsConnected returns whether the client is connected
|
||||
func (c *ReplicationGRPCClient) IsConnected() bool {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
if c.conn == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
state := c.conn.GetState()
|
||||
return state == connectivity.Ready || state == connectivity.Idle
|
||||
}
|
||||
|
||||
// Status returns the current status of the connection
|
||||
func (c *ReplicationGRPCClient) Status() transport.TransportStatus {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.status
|
||||
}
|
||||
|
||||
// Send sends a request and waits for a response
|
||||
func (c *ReplicationGRPCClient) Send(ctx context.Context, request transport.Request) (transport.Response, error) {
|
||||
// Implementation depends on specific replication messages
|
||||
// This is a placeholder that would be completed for each message type
|
||||
return nil, transport.ErrInvalidRequest
|
||||
}
|
||||
|
||||
// Stream opens a bidirectional stream
|
||||
func (c *ReplicationGRPCClient) Stream(ctx context.Context) (transport.Stream, error) {
|
||||
// Not implemented for replication client
|
||||
return nil, transport.ErrInvalidRequest
|
||||
}
|
||||
|
||||
// RegisterAsReplica registers this client as a replica with the primary
|
||||
func (c *ReplicationGRPCClient) RegisterAsReplica(ctx context.Context, replicaID string) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.client == nil {
|
||||
return transport.ErrNotConnected
|
||||
}
|
||||
|
||||
// Create registration request
|
||||
req := &kevo.RegisterReplicaRequest{
|
||||
ReplicaId: replicaID,
|
||||
Address: c.endpoint, // Use client endpoint as the replica address
|
||||
Role: kevo.ReplicaRole_REPLICA,
|
||||
}
|
||||
|
||||
// Call the service
|
||||
resp, err := c.client.RegisterReplica(ctx, req)
|
||||
if err != nil {
|
||||
c.status.LastError = err
|
||||
return err
|
||||
}
|
||||
|
||||
// Update client info based on response
|
||||
c.replicaID = replicaID
|
||||
c.currentLSN = resp.CurrentLsn
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SendHeartbeat sends a heartbeat to the primary
|
||||
func (c *ReplicationGRPCClient) SendHeartbeat(ctx context.Context, info *transport.ReplicaInfo) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.client == nil {
|
||||
return transport.ErrNotConnected
|
||||
}
|
||||
|
||||
// Convert status to proto enum
|
||||
var pbStatus kevo.ReplicaStatus
|
||||
switch info.Status {
|
||||
case transport.StatusConnecting:
|
||||
pbStatus = kevo.ReplicaStatus_CONNECTING
|
||||
case transport.StatusSyncing:
|
||||
pbStatus = kevo.ReplicaStatus_SYNCING
|
||||
case transport.StatusBootstrapping:
|
||||
pbStatus = kevo.ReplicaStatus_BOOTSTRAPPING
|
||||
case transport.StatusReady:
|
||||
pbStatus = kevo.ReplicaStatus_READY
|
||||
case transport.StatusDisconnected:
|
||||
pbStatus = kevo.ReplicaStatus_DISCONNECTED
|
||||
case transport.StatusError:
|
||||
pbStatus = kevo.ReplicaStatus_ERROR
|
||||
default:
|
||||
pbStatus = kevo.ReplicaStatus_DISCONNECTED
|
||||
}
|
||||
|
||||
// Create heartbeat request
|
||||
req := &kevo.ReplicaHeartbeatRequest{
|
||||
ReplicaId: c.replicaID,
|
||||
Status: pbStatus,
|
||||
CurrentLsn: c.highestAppliedLSN,
|
||||
ErrorMessage: "",
|
||||
}
|
||||
|
||||
// Add error message if any
|
||||
if info.Error != nil {
|
||||
req.ErrorMessage = info.Error.Error()
|
||||
}
|
||||
|
||||
// Call the service
|
||||
resp, err := c.client.ReplicaHeartbeat(ctx, req)
|
||||
if err != nil {
|
||||
c.status.LastError = err
|
||||
return err
|
||||
}
|
||||
|
||||
// Update client info based on response
|
||||
c.currentLSN = resp.PrimaryLsn
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RequestWALEntries requests WAL entries from the primary starting from a specific LSN
|
||||
func (c *ReplicationGRPCClient) RequestWALEntries(ctx context.Context, fromLSN uint64) ([]*wal.Entry, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.client == nil {
|
||||
return nil, transport.ErrNotConnected
|
||||
}
|
||||
|
||||
// Create request
|
||||
req := &kevo.GetWALEntriesRequest{
|
||||
ReplicaId: c.replicaID,
|
||||
FromLsn: fromLSN,
|
||||
MaxEntries: 1000, // Configurable
|
||||
}
|
||||
|
||||
// Call the service
|
||||
resp, err := c.client.GetWALEntries(ctx, req)
|
||||
if err != nil {
|
||||
c.status.LastError = err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Convert proto entries to WAL entries
|
||||
entries := make([]*wal.Entry, 0, len(resp.Batch.Entries))
|
||||
for _, pbEntry := range resp.Batch.Entries {
|
||||
entry := &wal.Entry{
|
||||
SequenceNumber: pbEntry.SequenceNumber,
|
||||
Type: uint8(pbEntry.Type),
|
||||
Key: pbEntry.Key,
|
||||
Value: pbEntry.Value,
|
||||
}
|
||||
entries = append(entries, entry)
|
||||
}
|
||||
|
||||
return entries, nil
|
||||
}
|
||||
|
||||
// RequestBootstrap requests a snapshot for bootstrap purposes
|
||||
func (c *ReplicationGRPCClient) RequestBootstrap(ctx context.Context) (transport.BootstrapIterator, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.client == nil {
|
||||
return nil, transport.ErrNotConnected
|
||||
}
|
||||
|
||||
// Create request
|
||||
req := &kevo.BootstrapRequest{
|
||||
ReplicaId: c.replicaID,
|
||||
}
|
||||
|
||||
// Call the service
|
||||
stream, err := c.client.RequestBootstrap(ctx, req)
|
||||
if err != nil {
|
||||
c.status.LastError = err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create and return bootstrap iterator
|
||||
return &GRPCBootstrapIterator{
|
||||
stream: stream,
|
||||
totalPairs: 0,
|
||||
seenPairs: 0,
|
||||
progress: 0.0,
|
||||
mu: &sync.Mutex{},
|
||||
applier: c.applier,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// StartReplicationStream starts a stream for continuous replication
|
||||
func (c *ReplicationGRPCClient) StartReplicationStream(ctx context.Context) error {
|
||||
if !c.IsConnected() {
|
||||
return transport.ErrNotConnected
|
||||
}
|
||||
|
||||
// Get current highest applied LSN
|
||||
c.mu.Lock()
|
||||
fromLSN := c.highestAppliedLSN
|
||||
c.mu.Unlock()
|
||||
|
||||
// Create stream request
|
||||
req := &kevo.StreamWALEntriesRequest{
|
||||
ReplicaId: c.replicaID,
|
||||
FromLsn: fromLSN,
|
||||
Continuous: true,
|
||||
}
|
||||
|
||||
// Start streaming
|
||||
stream, err := c.client.StreamWALEntries(ctx, req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Process stream in a goroutine
|
||||
go c.processWALStream(ctx, stream)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// processWALStream handles the incoming WAL entry stream
|
||||
func (c *ReplicationGRPCClient) processWALStream(ctx context.Context, stream kevo.ReplicationService_StreamWALEntriesClient) {
|
||||
for {
|
||||
// Check if context is cancelled
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
// Continue processing
|
||||
}
|
||||
|
||||
// Receive next batch
|
||||
batch, err := stream.Recv()
|
||||
if err == io.EOF {
|
||||
// Stream completed normally
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
// Stream error
|
||||
c.mu.Lock()
|
||||
c.status.LastError = err
|
||||
c.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
// Process entries in batch
|
||||
entries := make([]*wal.Entry, 0, len(batch.Entries))
|
||||
for _, pbEntry := range batch.Entries {
|
||||
entry := &wal.Entry{
|
||||
SequenceNumber: pbEntry.SequenceNumber,
|
||||
Type: uint8(pbEntry.Type),
|
||||
Key: pbEntry.Key,
|
||||
Value: pbEntry.Value,
|
||||
}
|
||||
entries = append(entries, entry)
|
||||
}
|
||||
|
||||
// Apply entries
|
||||
if len(entries) > 0 {
|
||||
_, err = c.applier.ApplyBatch(entries)
|
||||
if err != nil {
|
||||
c.mu.Lock()
|
||||
c.status.LastError = err
|
||||
c.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
// Update highest applied LSN
|
||||
c.mu.Lock()
|
||||
c.highestAppliedLSN = batch.LastLsn
|
||||
c.mu.Unlock()
|
||||
|
||||
// Report applied entries
|
||||
go c.reportAppliedEntries(context.Background(), batch.LastLsn)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// reportAppliedEntries reports the highest applied LSN to the primary
|
||||
func (c *ReplicationGRPCClient) reportAppliedEntries(ctx context.Context, appliedLSN uint64) {
|
||||
c.mu.RLock()
|
||||
client := c.client
|
||||
replicaID := c.replicaID
|
||||
c.mu.RUnlock()
|
||||
|
||||
if client == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Create request
|
||||
req := &kevo.ReportAppliedEntriesRequest{
|
||||
ReplicaId: replicaID,
|
||||
AppliedLsn: appliedLSN,
|
||||
}
|
||||
|
||||
// Call the service
|
||||
_, err := client.ReportAppliedEntries(ctx, req)
|
||||
if err != nil {
|
||||
// Just log error, don't return it
|
||||
c.mu.Lock()
|
||||
c.status.LastError = err
|
||||
c.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// GRPCBootstrapIterator implements the BootstrapIterator interface for gRPC
|
||||
type GRPCBootstrapIterator struct {
|
||||
stream kevo.ReplicationService_RequestBootstrapClient
|
||||
currentBatch *kevo.BootstrapBatch
|
||||
batchIndex int
|
||||
totalPairs int
|
||||
seenPairs int
|
||||
progress float64
|
||||
mu *sync.Mutex
|
||||
applier *replication.WALApplier
|
||||
}
|
||||
|
||||
// Next returns the next key-value pair
|
||||
func (it *GRPCBootstrapIterator) Next() ([]byte, []byte, error) {
|
||||
it.mu.Lock()
|
||||
defer it.mu.Unlock()
|
||||
|
||||
// If we have a current batch and there are more pairs
|
||||
if it.currentBatch != nil && it.batchIndex < len(it.currentBatch.Pairs) {
|
||||
pair := it.currentBatch.Pairs[it.batchIndex]
|
||||
it.batchIndex++
|
||||
it.seenPairs++
|
||||
return pair.Key, pair.Value, nil
|
||||
}
|
||||
|
||||
// Need to get a new batch
|
||||
batch, err := it.stream.Recv()
|
||||
if err == io.EOF {
|
||||
return nil, nil, io.EOF
|
||||
}
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Update progress
|
||||
it.currentBatch = batch
|
||||
it.batchIndex = 0
|
||||
it.progress = float64(batch.Progress)
|
||||
|
||||
// If batch is empty and it's the last one
|
||||
if len(batch.Pairs) == 0 && batch.IsLast {
|
||||
// Store the snapshot LSN for later use
|
||||
if it.applier != nil {
|
||||
it.applier.ResetHighestApplied(batch.SnapshotLsn)
|
||||
}
|
||||
return nil, nil, io.EOF
|
||||
}
|
||||
|
||||
// If batch is empty but not the last one, try again
|
||||
if len(batch.Pairs) == 0 {
|
||||
return it.Next()
|
||||
}
|
||||
|
||||
// Return the first pair from the new batch
|
||||
pair := batch.Pairs[it.batchIndex]
|
||||
it.batchIndex++
|
||||
it.seenPairs++
|
||||
return pair.Key, pair.Value, nil
|
||||
}
|
||||
|
||||
// Close closes the iterator
|
||||
func (it *GRPCBootstrapIterator) Close() error {
|
||||
it.mu.Lock()
|
||||
defer it.mu.Unlock()
|
||||
|
||||
// Store the snapshot LSN if we have a current batch and it's the last one
|
||||
if it.currentBatch != nil && it.currentBatch.IsLast && it.applier != nil {
|
||||
it.applier.ResetHighestApplied(it.currentBatch.SnapshotLsn)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Progress returns the progress of the bootstrap operation (0.0-1.0)
|
||||
func (it *GRPCBootstrapIterator) Progress() float64 {
|
||||
it.mu.Lock()
|
||||
defer it.mu.Unlock()
|
||||
return it.progress
|
||||
}
|
200
pkg/grpc/transport/replication_server.go
Normal file
200
pkg/grpc/transport/replication_server.go
Normal file
@ -0,0 +1,200 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/KevoDB/kevo/pkg/grpc/service"
|
||||
"github.com/KevoDB/kevo/pkg/replication"
|
||||
"github.com/KevoDB/kevo/pkg/transport"
|
||||
)
|
||||
|
||||
// ReplicationGRPCServer implements the ReplicationServer interface using gRPC
|
||||
type ReplicationGRPCServer struct {
|
||||
transportManager *GRPCTransportManager
|
||||
replicationService *service.ReplicationServiceServer
|
||||
options transport.TransportOptions
|
||||
replicas map[string]*transport.ReplicaInfo
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewReplicationGRPCServer creates a new ReplicationGRPCServer
|
||||
func NewReplicationGRPCServer(
|
||||
transportManager *GRPCTransportManager,
|
||||
replicator *replication.WALReplicator,
|
||||
applier *replication.WALApplier,
|
||||
serializer *replication.EntrySerializer,
|
||||
storageSnapshot replication.StorageSnapshot,
|
||||
options transport.TransportOptions,
|
||||
) (*ReplicationGRPCServer, error) {
|
||||
// Create replication service
|
||||
replicationService := service.NewReplicationService(
|
||||
replicator,
|
||||
applier,
|
||||
serializer,
|
||||
storageSnapshot,
|
||||
)
|
||||
|
||||
return &ReplicationGRPCServer{
|
||||
transportManager: transportManager,
|
||||
replicationService: replicationService,
|
||||
options: options,
|
||||
replicas: make(map[string]*transport.ReplicaInfo),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Start starts the server and returns immediately
|
||||
func (s *ReplicationGRPCServer) Start() error {
|
||||
// Register the replication service with the transport manager
|
||||
if err := s.transportManager.RegisterService(s.replicationService); err != nil {
|
||||
return fmt.Errorf("failed to register replication service: %w", err)
|
||||
}
|
||||
|
||||
// Start the transport manager if it's not already started
|
||||
if err := s.transportManager.Start(); err != nil {
|
||||
return fmt.Errorf("failed to start transport manager: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Serve starts the server and blocks until it's stopped
|
||||
func (s *ReplicationGRPCServer) Serve() error {
|
||||
if err := s.Start(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// This will block until the context is cancelled
|
||||
<-context.Background().Done()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop stops the server gracefully
|
||||
func (s *ReplicationGRPCServer) Stop(ctx context.Context) error {
|
||||
return s.transportManager.Stop(ctx)
|
||||
}
|
||||
|
||||
// SetRequestHandler sets the handler for incoming requests
|
||||
// Not used for the replication server as it uses a dedicated service
|
||||
func (s *ReplicationGRPCServer) SetRequestHandler(handler transport.RequestHandler) {
|
||||
// No-op for replication server
|
||||
}
|
||||
|
||||
// RegisterReplica registers a new replica
|
||||
func (s *ReplicationGRPCServer) RegisterReplica(replicaInfo *transport.ReplicaInfo) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.replicas[replicaInfo.ID] = replicaInfo
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateReplicaStatus updates the status of a replica
|
||||
func (s *ReplicationGRPCServer) UpdateReplicaStatus(replicaID string, status transport.ReplicaStatus, lsn uint64) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
replica, exists := s.replicas[replicaID]
|
||||
if !exists {
|
||||
return fmt.Errorf("replica not found: %s", replicaID)
|
||||
}
|
||||
|
||||
replica.Status = status
|
||||
replica.CurrentLSN = lsn
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetReplicaInfo returns information about a specific replica
|
||||
func (s *ReplicationGRPCServer) GetReplicaInfo(replicaID string) (*transport.ReplicaInfo, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
replica, exists := s.replicas[replicaID]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("replica not found: %s", replicaID)
|
||||
}
|
||||
|
||||
return replica, nil
|
||||
}
|
||||
|
||||
// ListReplicas returns information about all connected replicas
|
||||
func (s *ReplicationGRPCServer) ListReplicas() ([]*transport.ReplicaInfo, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
result := make([]*transport.ReplicaInfo, 0, len(s.replicas))
|
||||
for _, replica := range s.replicas {
|
||||
result = append(result, replica)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// StreamWALEntriesToReplica streams WAL entries to a specific replica
|
||||
func (s *ReplicationGRPCServer) StreamWALEntriesToReplica(ctx context.Context, replicaID string, fromLSN uint64) error {
|
||||
// This is handled by the gRPC service directly
|
||||
return nil
|
||||
}
|
||||
|
||||
// RegisterReplicationTransport registers the gRPC replication transport with the registry
|
||||
func init() {
|
||||
// Register replication server factory
|
||||
transport.RegisterReplicationServerTransport("grpc", func(address string, options transport.TransportOptions) (transport.ReplicationServer, error) {
|
||||
// Create gRPC transport manager
|
||||
grpcOptions := &GRPCTransportOptions{
|
||||
ListenAddr: address,
|
||||
ConnectionTimeout: options.Timeout,
|
||||
DialTimeout: options.Timeout,
|
||||
}
|
||||
|
||||
// Add TLS configuration if enabled
|
||||
if options.TLSEnabled {
|
||||
tlsConfig, err := LoadServerTLSConfig(options.CertFile, options.KeyFile, options.CAFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load TLS config: %w", err)
|
||||
}
|
||||
grpcOptions.TLSConfig = tlsConfig
|
||||
}
|
||||
|
||||
// Create transport manager
|
||||
manager, err := NewGRPCTransportManager(grpcOptions)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create gRPC transport manager: %w", err)
|
||||
}
|
||||
|
||||
// For registration, we return a placeholder that will be properly initialized
|
||||
// by the caller with the required components
|
||||
return &ReplicationGRPCServer{
|
||||
transportManager: manager,
|
||||
options: options,
|
||||
replicas: make(map[string]*transport.ReplicaInfo),
|
||||
}, nil
|
||||
})
|
||||
|
||||
// Register replication client factory
|
||||
transport.RegisterReplicationClientTransport("grpc", func(endpoint string, options transport.TransportOptions) (transport.ReplicationClient, error) {
|
||||
// For registration, we return a placeholder that will be properly initialized
|
||||
// by the caller with the required components
|
||||
return &ReplicationGRPCClient{
|
||||
endpoint: endpoint,
|
||||
options: options,
|
||||
}, nil
|
||||
})
|
||||
}
|
||||
|
||||
// WithReplicator adds a replicator to the replication server
|
||||
func (s *ReplicationGRPCServer) WithReplicator(
|
||||
replicator *replication.WALReplicator,
|
||||
applier *replication.WALApplier,
|
||||
serializer *replication.EntrySerializer,
|
||||
storageSnapshot replication.StorageSnapshot,
|
||||
) *ReplicationGRPCServer {
|
||||
s.replicationService = service.NewReplicationService(
|
||||
replicator,
|
||||
applier,
|
||||
serializer,
|
||||
storageSnapshot,
|
||||
)
|
||||
return s
|
||||
}
|
268
pkg/replication/applier.go
Normal file
268
pkg/replication/applier.go
Normal file
@ -0,0 +1,268 @@
|
||||
package replication
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/KevoDB/kevo/pkg/wal"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrApplierClosed indicates the applier has been closed
|
||||
ErrApplierClosed = errors.New("applier is closed")
|
||||
|
||||
// ErrOutOfOrderEntry indicates an entry was received out of order
|
||||
ErrOutOfOrderEntry = errors.New("out of order entry")
|
||||
|
||||
// ErrInvalidEntryType indicates an unknown or invalid entry type
|
||||
ErrInvalidEntryType = errors.New("invalid entry type")
|
||||
|
||||
// ErrStorageError indicates a storage-related error occurred
|
||||
ErrStorageError = errors.New("storage error")
|
||||
)
|
||||
|
||||
// StorageInterface defines the minimum storage requirements for the WALApplier
|
||||
type StorageInterface interface {
|
||||
Put(key, value []byte) error
|
||||
Delete(key []byte) error
|
||||
}
|
||||
|
||||
// WALApplier applies WAL entries from a primary node to a replica's storage
|
||||
type WALApplier struct {
|
||||
// Storage is the local storage where entries will be applied
|
||||
storage StorageInterface
|
||||
|
||||
// highestApplied is the highest Lamport timestamp of entries that have been applied
|
||||
highestApplied uint64
|
||||
|
||||
// pendingEntries holds entries that were received out of order and are waiting for earlier entries
|
||||
pendingEntries map[uint64]*wal.Entry
|
||||
|
||||
// appliedCount tracks the number of entries successfully applied
|
||||
appliedCount uint64
|
||||
|
||||
// skippedCount tracks the number of entries skipped (already applied)
|
||||
skippedCount uint64
|
||||
|
||||
// errorCount tracks the number of application errors
|
||||
errorCount uint64
|
||||
|
||||
// closed indicates whether the applier is closed
|
||||
closed int32
|
||||
|
||||
// concurrency management
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewWALApplier creates a new WAL applier
|
||||
func NewWALApplier(storage StorageInterface) *WALApplier {
|
||||
return &WALApplier{
|
||||
storage: storage,
|
||||
highestApplied: 0,
|
||||
pendingEntries: make(map[uint64]*wal.Entry),
|
||||
}
|
||||
}
|
||||
|
||||
// Apply applies a single WAL entry to the local storage
|
||||
// Returns whether the entry was applied and any error
|
||||
func (a *WALApplier) Apply(entry *wal.Entry) (bool, error) {
|
||||
if atomic.LoadInt32(&a.closed) == 1 {
|
||||
return false, ErrApplierClosed
|
||||
}
|
||||
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
|
||||
// Check if this entry has already been applied
|
||||
if entry.SequenceNumber <= a.highestApplied {
|
||||
atomic.AddUint64(&a.skippedCount, 1)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Check if this is the next entry we're expecting
|
||||
if entry.SequenceNumber != a.highestApplied+1 {
|
||||
// Entry is out of order, store it for later application
|
||||
a.pendingEntries[entry.SequenceNumber] = entry
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Apply this entry and any subsequent pending entries
|
||||
return a.applyEntryAndPending(entry)
|
||||
}
|
||||
|
||||
// ApplyBatch applies a batch of WAL entries to the local storage
|
||||
// Returns the number of entries applied and any error
|
||||
func (a *WALApplier) ApplyBatch(entries []*wal.Entry) (int, error) {
|
||||
if atomic.LoadInt32(&a.closed) == 1 {
|
||||
return 0, ErrApplierClosed
|
||||
}
|
||||
|
||||
if len(entries) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
|
||||
// Sort entries by sequence number
|
||||
sortEntriesByTimestamp(entries)
|
||||
|
||||
// Track how many entries we actually apply
|
||||
applied := 0
|
||||
|
||||
// Process each entry
|
||||
for _, entry := range entries {
|
||||
// Skip already applied entries
|
||||
if entry.SequenceNumber <= a.highestApplied {
|
||||
atomic.AddUint64(&a.skippedCount, 1)
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if this is the next entry we're expecting
|
||||
if entry.SequenceNumber == a.highestApplied+1 {
|
||||
// Apply this entry and any subsequent pending entries
|
||||
wasApplied, err := a.applyEntryAndPending(entry)
|
||||
if err != nil {
|
||||
return applied, err
|
||||
}
|
||||
if wasApplied {
|
||||
applied++
|
||||
}
|
||||
} else {
|
||||
// Entry is out of order, store it for later application
|
||||
a.pendingEntries[entry.SequenceNumber] = entry
|
||||
}
|
||||
}
|
||||
|
||||
return applied, nil
|
||||
}
|
||||
|
||||
// GetHighestApplied returns the highest applied Lamport timestamp
|
||||
func (a *WALApplier) GetHighestApplied() uint64 {
|
||||
a.mu.RLock()
|
||||
defer a.mu.RUnlock()
|
||||
return a.highestApplied
|
||||
}
|
||||
|
||||
// GetStats returns statistics about the applier
|
||||
func (a *WALApplier) GetStats() map[string]uint64 {
|
||||
a.mu.RLock()
|
||||
pendingCount := len(a.pendingEntries)
|
||||
highestApplied := a.highestApplied
|
||||
a.mu.RUnlock()
|
||||
|
||||
return map[string]uint64{
|
||||
"appliedCount": atomic.LoadUint64(&a.appliedCount),
|
||||
"skippedCount": atomic.LoadUint64(&a.skippedCount),
|
||||
"errorCount": atomic.LoadUint64(&a.errorCount),
|
||||
"pendingCount": uint64(pendingCount),
|
||||
"highestApplied": highestApplied,
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes the applier
|
||||
func (a *WALApplier) Close() error {
|
||||
if !atomic.CompareAndSwapInt32(&a.closed, 0, 1) {
|
||||
return nil // Already closed
|
||||
}
|
||||
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
|
||||
// Clear any pending entries
|
||||
a.pendingEntries = make(map[uint64]*wal.Entry)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// applyEntryAndPending applies the given entry and any pending entries that become applicable
|
||||
// Caller must hold the lock
|
||||
func (a *WALApplier) applyEntryAndPending(entry *wal.Entry) (bool, error) {
|
||||
// Apply the current entry
|
||||
if err := a.applyEntryToStorage(entry); err != nil {
|
||||
atomic.AddUint64(&a.errorCount, 1)
|
||||
return false, err
|
||||
}
|
||||
|
||||
// Update highest applied timestamp
|
||||
a.highestApplied = entry.SequenceNumber
|
||||
atomic.AddUint64(&a.appliedCount, 1)
|
||||
|
||||
// Check for pending entries that can now be applied
|
||||
nextSeq := a.highestApplied + 1
|
||||
for {
|
||||
nextEntry, exists := a.pendingEntries[nextSeq]
|
||||
if !exists {
|
||||
break
|
||||
}
|
||||
|
||||
// Apply this pending entry
|
||||
if err := a.applyEntryToStorage(nextEntry); err != nil {
|
||||
atomic.AddUint64(&a.errorCount, 1)
|
||||
return true, err
|
||||
}
|
||||
|
||||
// Update highest applied and remove from pending
|
||||
a.highestApplied = nextSeq
|
||||
delete(a.pendingEntries, nextSeq)
|
||||
atomic.AddUint64(&a.appliedCount, 1)
|
||||
|
||||
// Look for the next one
|
||||
nextSeq++
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// applyEntryToStorage applies a single WAL entry to the storage engine
|
||||
// Caller must hold the lock
|
||||
func (a *WALApplier) applyEntryToStorage(entry *wal.Entry) error {
|
||||
switch entry.Type {
|
||||
case wal.OpTypePut:
|
||||
if err := a.storage.Put(entry.Key, entry.Value); err != nil {
|
||||
return fmt.Errorf("%w: %v", ErrStorageError, err)
|
||||
}
|
||||
case wal.OpTypeDelete:
|
||||
if err := a.storage.Delete(entry.Key); err != nil {
|
||||
return fmt.Errorf("%w: %v", ErrStorageError, err)
|
||||
}
|
||||
case wal.OpTypeBatch:
|
||||
// In the WAL the batch entry itself doesn't contain the operations
|
||||
// Actual batch operations should be received as separate entries
|
||||
// with sequential sequence numbers
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("%w: type %d", ErrInvalidEntryType, entry.Type)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// PendingEntryCount returns the number of pending entries
|
||||
func (a *WALApplier) PendingEntryCount() int {
|
||||
a.mu.RLock()
|
||||
defer a.mu.RUnlock()
|
||||
return len(a.pendingEntries)
|
||||
}
|
||||
|
||||
// ResetHighestApplied allows manually setting a new highest applied value
|
||||
// This should only be used during initialization or recovery
|
||||
func (a *WALApplier) ResetHighestApplied(value uint64) {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
a.highestApplied = value
|
||||
}
|
||||
|
||||
// HasEntry checks if a specific entry timestamp has been applied or is pending
|
||||
func (a *WALApplier) HasEntry(timestamp uint64) bool {
|
||||
a.mu.RLock()
|
||||
defer a.mu.RUnlock()
|
||||
|
||||
if timestamp <= a.highestApplied {
|
||||
return true
|
||||
}
|
||||
|
||||
_, exists := a.pendingEntries[timestamp]
|
||||
return exists
|
||||
}
|
526
pkg/replication/applier_test.go
Normal file
526
pkg/replication/applier_test.go
Normal file
@ -0,0 +1,526 @@
|
||||
package replication
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/KevoDB/kevo/pkg/common/iterator"
|
||||
"github.com/KevoDB/kevo/pkg/wal"
|
||||
)
|
||||
|
||||
// MockStorage implements a simple mock storage for testing
|
||||
type MockStorage struct {
|
||||
mu sync.Mutex
|
||||
data map[string][]byte
|
||||
putFail bool
|
||||
deleteFail bool
|
||||
putCount int
|
||||
deleteCount int
|
||||
lastPutKey []byte
|
||||
lastPutValue []byte
|
||||
lastDeleteKey []byte
|
||||
}
|
||||
|
||||
func NewMockStorage() *MockStorage {
|
||||
return &MockStorage{
|
||||
data: make(map[string][]byte),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MockStorage) Put(key, value []byte) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.putFail {
|
||||
return errors.New("simulated put failure")
|
||||
}
|
||||
|
||||
m.putCount++
|
||||
m.lastPutKey = append([]byte{}, key...)
|
||||
m.lastPutValue = append([]byte{}, value...)
|
||||
m.data[string(key)] = append([]byte{}, value...)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockStorage) Get(key []byte) ([]byte, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
value, ok := m.data[string(key)]
|
||||
if !ok {
|
||||
return nil, errors.New("key not found")
|
||||
}
|
||||
return value, nil
|
||||
}
|
||||
|
||||
func (m *MockStorage) Delete(key []byte) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.deleteFail {
|
||||
return errors.New("simulated delete failure")
|
||||
}
|
||||
|
||||
m.deleteCount++
|
||||
m.lastDeleteKey = append([]byte{}, key...)
|
||||
delete(m.data, string(key))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stub implementations for the rest of the interface
|
||||
func (m *MockStorage) Close() error { return nil }
|
||||
func (m *MockStorage) IsDeleted(key []byte) (bool, error) { return false, nil }
|
||||
func (m *MockStorage) GetIterator() (iterator.Iterator, error) { return nil, nil }
|
||||
func (m *MockStorage) GetRangeIterator(startKey, endKey []byte) (iterator.Iterator, error) { return nil, nil }
|
||||
func (m *MockStorage) ApplyBatch(entries []*wal.Entry) error { return nil }
|
||||
func (m *MockStorage) FlushMemTables() error { return nil }
|
||||
func (m *MockStorage) GetMemTableSize() uint64 { return 0 }
|
||||
func (m *MockStorage) IsFlushNeeded() bool { return false }
|
||||
func (m *MockStorage) GetSSTables() []string { return nil }
|
||||
func (m *MockStorage) ReloadSSTables() error { return nil }
|
||||
func (m *MockStorage) RotateWAL() error { return nil }
|
||||
func (m *MockStorage) GetStorageStats() map[string]interface{} { return nil }
|
||||
|
||||
func TestWALApplierBasic(t *testing.T) {
|
||||
storage := NewMockStorage()
|
||||
applier := NewWALApplier(storage)
|
||||
defer applier.Close()
|
||||
|
||||
// Create test entries
|
||||
entries := []*wal.Entry{
|
||||
{
|
||||
SequenceNumber: 1,
|
||||
Type: wal.OpTypePut,
|
||||
Key: []byte("key1"),
|
||||
Value: []byte("value1"),
|
||||
},
|
||||
{
|
||||
SequenceNumber: 2,
|
||||
Type: wal.OpTypePut,
|
||||
Key: []byte("key2"),
|
||||
Value: []byte("value2"),
|
||||
},
|
||||
{
|
||||
SequenceNumber: 3,
|
||||
Type: wal.OpTypeDelete,
|
||||
Key: []byte("key1"),
|
||||
},
|
||||
}
|
||||
|
||||
// Apply entries one by one
|
||||
for i, entry := range entries {
|
||||
applied, err := applier.Apply(entry)
|
||||
if err != nil {
|
||||
t.Fatalf("Error applying entry %d: %v", i, err)
|
||||
}
|
||||
if !applied {
|
||||
t.Errorf("Entry %d should have been applied", i)
|
||||
}
|
||||
}
|
||||
|
||||
// Check state
|
||||
if got := applier.GetHighestApplied(); got != 3 {
|
||||
t.Errorf("Expected highest applied 3, got %d", got)
|
||||
}
|
||||
|
||||
// Check storage state
|
||||
if value, _ := storage.Get([]byte("key2")); string(value) != "value2" {
|
||||
t.Errorf("Expected key2=value2 in storage, got %q", value)
|
||||
}
|
||||
|
||||
// key1 should be deleted
|
||||
if _, err := storage.Get([]byte("key1")); err == nil {
|
||||
t.Errorf("Expected key1 to be deleted")
|
||||
}
|
||||
|
||||
// Check stats
|
||||
stats := applier.GetStats()
|
||||
if stats["appliedCount"] != 3 {
|
||||
t.Errorf("Expected appliedCount=3, got %d", stats["appliedCount"])
|
||||
}
|
||||
if stats["pendingCount"] != 0 {
|
||||
t.Errorf("Expected pendingCount=0, got %d", stats["pendingCount"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestWALApplierOutOfOrder(t *testing.T) {
|
||||
storage := NewMockStorage()
|
||||
applier := NewWALApplier(storage)
|
||||
defer applier.Close()
|
||||
|
||||
// Apply entries out of order
|
||||
entries := []*wal.Entry{
|
||||
{
|
||||
SequenceNumber: 2,
|
||||
Type: wal.OpTypePut,
|
||||
Key: []byte("key2"),
|
||||
Value: []byte("value2"),
|
||||
},
|
||||
{
|
||||
SequenceNumber: 3,
|
||||
Type: wal.OpTypePut,
|
||||
Key: []byte("key3"),
|
||||
Value: []byte("value3"),
|
||||
},
|
||||
{
|
||||
SequenceNumber: 1,
|
||||
Type: wal.OpTypePut,
|
||||
Key: []byte("key1"),
|
||||
Value: []byte("value1"),
|
||||
},
|
||||
}
|
||||
|
||||
// Apply entry with sequence 2 - should be stored as pending
|
||||
applied, err := applier.Apply(entries[0])
|
||||
if err != nil {
|
||||
t.Fatalf("Error applying entry: %v", err)
|
||||
}
|
||||
if applied {
|
||||
t.Errorf("Entry with seq 2 should not have been applied yet")
|
||||
}
|
||||
|
||||
// Apply entry with sequence 3 - should be stored as pending
|
||||
applied, err = applier.Apply(entries[1])
|
||||
if err != nil {
|
||||
t.Fatalf("Error applying entry: %v", err)
|
||||
}
|
||||
if applied {
|
||||
t.Errorf("Entry with seq 3 should not have been applied yet")
|
||||
}
|
||||
|
||||
// Check pending count
|
||||
if pending := applier.PendingEntryCount(); pending != 2 {
|
||||
t.Errorf("Expected 2 pending entries, got %d", pending)
|
||||
}
|
||||
|
||||
// Now apply entry with sequence 1 - should trigger all entries to be applied
|
||||
applied, err = applier.Apply(entries[2])
|
||||
if err != nil {
|
||||
t.Fatalf("Error applying entry: %v", err)
|
||||
}
|
||||
if !applied {
|
||||
t.Errorf("Entry with seq 1 should have been applied")
|
||||
}
|
||||
|
||||
// Check state - all entries should be applied now
|
||||
if got := applier.GetHighestApplied(); got != 3 {
|
||||
t.Errorf("Expected highest applied 3, got %d", got)
|
||||
}
|
||||
|
||||
// Pending count should be 0
|
||||
if pending := applier.PendingEntryCount(); pending != 0 {
|
||||
t.Errorf("Expected 0 pending entries, got %d", pending)
|
||||
}
|
||||
|
||||
// Check storage contains all values
|
||||
values := []struct {
|
||||
key string
|
||||
value string
|
||||
}{
|
||||
{"key1", "value1"},
|
||||
{"key2", "value2"},
|
||||
{"key3", "value3"},
|
||||
}
|
||||
|
||||
for _, v := range values {
|
||||
if val, err := storage.Get([]byte(v.key)); err != nil || string(val) != v.value {
|
||||
t.Errorf("Expected %s=%s in storage, got %s, err=%v", v.key, v.value, val, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestWALApplierBatch(t *testing.T) {
|
||||
storage := NewMockStorage()
|
||||
applier := NewWALApplier(storage)
|
||||
defer applier.Close()
|
||||
|
||||
// Create a batch of entries
|
||||
batch := []*wal.Entry{
|
||||
{
|
||||
SequenceNumber: 3,
|
||||
Type: wal.OpTypePut,
|
||||
Key: []byte("key3"),
|
||||
Value: []byte("value3"),
|
||||
},
|
||||
{
|
||||
SequenceNumber: 1,
|
||||
Type: wal.OpTypePut,
|
||||
Key: []byte("key1"),
|
||||
Value: []byte("value1"),
|
||||
},
|
||||
{
|
||||
SequenceNumber: 2,
|
||||
Type: wal.OpTypePut,
|
||||
Key: []byte("key2"),
|
||||
Value: []byte("value2"),
|
||||
},
|
||||
}
|
||||
|
||||
// Apply batch - entries should be sorted by sequence number
|
||||
applied, err := applier.ApplyBatch(batch)
|
||||
if err != nil {
|
||||
t.Fatalf("Error applying batch: %v", err)
|
||||
}
|
||||
|
||||
// All 3 entries should be applied
|
||||
if applied != 3 {
|
||||
t.Errorf("Expected 3 entries applied, got %d", applied)
|
||||
}
|
||||
|
||||
// Check highest applied
|
||||
if got := applier.GetHighestApplied(); got != 3 {
|
||||
t.Errorf("Expected highest applied 3, got %d", got)
|
||||
}
|
||||
|
||||
// Check all values in storage
|
||||
values := []struct {
|
||||
key string
|
||||
value string
|
||||
}{
|
||||
{"key1", "value1"},
|
||||
{"key2", "value2"},
|
||||
{"key3", "value3"},
|
||||
}
|
||||
|
||||
for _, v := range values {
|
||||
if val, err := storage.Get([]byte(v.key)); err != nil || string(val) != v.value {
|
||||
t.Errorf("Expected %s=%s in storage, got %s, err=%v", v.key, v.value, val, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestWALApplierAlreadyApplied(t *testing.T) {
|
||||
storage := NewMockStorage()
|
||||
applier := NewWALApplier(storage)
|
||||
defer applier.Close()
|
||||
|
||||
// Apply an entry
|
||||
entry := &wal.Entry{
|
||||
SequenceNumber: 1,
|
||||
Type: wal.OpTypePut,
|
||||
Key: []byte("key1"),
|
||||
Value: []byte("value1"),
|
||||
}
|
||||
|
||||
applied, err := applier.Apply(entry)
|
||||
if err != nil {
|
||||
t.Fatalf("Error applying entry: %v", err)
|
||||
}
|
||||
if !applied {
|
||||
t.Errorf("Entry should have been applied")
|
||||
}
|
||||
|
||||
// Try to apply the same entry again
|
||||
applied, err = applier.Apply(entry)
|
||||
if err != nil {
|
||||
t.Fatalf("Error applying entry: %v", err)
|
||||
}
|
||||
if applied {
|
||||
t.Errorf("Entry should not have been applied a second time")
|
||||
}
|
||||
|
||||
// Check stats
|
||||
stats := applier.GetStats()
|
||||
if stats["appliedCount"] != 1 {
|
||||
t.Errorf("Expected appliedCount=1, got %d", stats["appliedCount"])
|
||||
}
|
||||
if stats["skippedCount"] != 1 {
|
||||
t.Errorf("Expected skippedCount=1, got %d", stats["skippedCount"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestWALApplierError(t *testing.T) {
|
||||
storage := NewMockStorage()
|
||||
storage.putFail = true
|
||||
applier := NewWALApplier(storage)
|
||||
defer applier.Close()
|
||||
|
||||
entry := &wal.Entry{
|
||||
SequenceNumber: 1,
|
||||
Type: wal.OpTypePut,
|
||||
Key: []byte("key1"),
|
||||
Value: []byte("value1"),
|
||||
}
|
||||
|
||||
// Apply should return an error
|
||||
_, err := applier.Apply(entry)
|
||||
if err == nil {
|
||||
t.Errorf("Expected error from Apply, got nil")
|
||||
}
|
||||
|
||||
// Check error count
|
||||
stats := applier.GetStats()
|
||||
if stats["errorCount"] != 1 {
|
||||
t.Errorf("Expected errorCount=1, got %d", stats["errorCount"])
|
||||
}
|
||||
|
||||
// Fix storage and try again
|
||||
storage.putFail = false
|
||||
|
||||
// Apply should succeed
|
||||
applied, err := applier.Apply(entry)
|
||||
if err != nil {
|
||||
t.Fatalf("Error applying entry: %v", err)
|
||||
}
|
||||
if !applied {
|
||||
t.Errorf("Entry should have been applied")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWALApplierInvalidType(t *testing.T) {
|
||||
storage := NewMockStorage()
|
||||
applier := NewWALApplier(storage)
|
||||
defer applier.Close()
|
||||
|
||||
entry := &wal.Entry{
|
||||
SequenceNumber: 1,
|
||||
Type: 99, // Invalid type
|
||||
Key: []byte("key1"),
|
||||
Value: []byte("value1"),
|
||||
}
|
||||
|
||||
// Apply should return an error
|
||||
_, err := applier.Apply(entry)
|
||||
if err == nil || !errors.Is(err, ErrInvalidEntryType) {
|
||||
t.Errorf("Expected invalid entry type error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWALApplierClose(t *testing.T) {
|
||||
storage := NewMockStorage()
|
||||
applier := NewWALApplier(storage)
|
||||
|
||||
// Apply an entry
|
||||
entry := &wal.Entry{
|
||||
SequenceNumber: 1,
|
||||
Type: wal.OpTypePut,
|
||||
Key: []byte("key1"),
|
||||
Value: []byte("value1"),
|
||||
}
|
||||
|
||||
applied, err := applier.Apply(entry)
|
||||
if err != nil {
|
||||
t.Fatalf("Error applying entry: %v", err)
|
||||
}
|
||||
if !applied {
|
||||
t.Errorf("Entry should have been applied")
|
||||
}
|
||||
|
||||
// Close the applier
|
||||
if err := applier.Close(); err != nil {
|
||||
t.Fatalf("Error closing applier: %v", err)
|
||||
}
|
||||
|
||||
// Try to apply another entry
|
||||
_, err = applier.Apply(&wal.Entry{
|
||||
SequenceNumber: 2,
|
||||
Type: wal.OpTypePut,
|
||||
Key: []byte("key2"),
|
||||
Value: []byte("value2"),
|
||||
})
|
||||
|
||||
if err == nil || !errors.Is(err, ErrApplierClosed) {
|
||||
t.Errorf("Expected applier closed error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWALApplierResetHighest(t *testing.T) {
|
||||
storage := NewMockStorage()
|
||||
applier := NewWALApplier(storage)
|
||||
defer applier.Close()
|
||||
|
||||
// Manually set the highest applied to 10
|
||||
applier.ResetHighestApplied(10)
|
||||
|
||||
// Check value
|
||||
if got := applier.GetHighestApplied(); got != 10 {
|
||||
t.Errorf("Expected highest applied 10, got %d", got)
|
||||
}
|
||||
|
||||
// Try to apply an entry with sequence 10
|
||||
applied, err := applier.Apply(&wal.Entry{
|
||||
SequenceNumber: 10,
|
||||
Type: wal.OpTypePut,
|
||||
Key: []byte("key10"),
|
||||
Value: []byte("value10"),
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Error applying entry: %v", err)
|
||||
}
|
||||
if applied {
|
||||
t.Errorf("Entry with seq 10 should have been skipped")
|
||||
}
|
||||
|
||||
// Apply an entry with sequence 11
|
||||
applied, err = applier.Apply(&wal.Entry{
|
||||
SequenceNumber: 11,
|
||||
Type: wal.OpTypePut,
|
||||
Key: []byte("key11"),
|
||||
Value: []byte("value11"),
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Error applying entry: %v", err)
|
||||
}
|
||||
if !applied {
|
||||
t.Errorf("Entry with seq 11 should have been applied")
|
||||
}
|
||||
|
||||
// Check new highest
|
||||
if got := applier.GetHighestApplied(); got != 11 {
|
||||
t.Errorf("Expected highest applied 11, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWALApplierHasEntry(t *testing.T) {
|
||||
storage := NewMockStorage()
|
||||
applier := NewWALApplier(storage)
|
||||
defer applier.Close()
|
||||
|
||||
// Apply an entry with sequence 1
|
||||
applied, err := applier.Apply(&wal.Entry{
|
||||
SequenceNumber: 1,
|
||||
Type: wal.OpTypePut,
|
||||
Key: []byte("key1"),
|
||||
Value: []byte("value1"),
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Error applying entry: %v", err)
|
||||
}
|
||||
if !applied {
|
||||
t.Errorf("Entry should have been applied")
|
||||
}
|
||||
|
||||
// Add a pending entry with sequence 3
|
||||
_, err = applier.Apply(&wal.Entry{
|
||||
SequenceNumber: 3,
|
||||
Type: wal.OpTypePut,
|
||||
Key: []byte("key3"),
|
||||
Value: []byte("value3"),
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Error applying entry: %v", err)
|
||||
}
|
||||
|
||||
// Check has entry
|
||||
testCases := []struct {
|
||||
timestamp uint64
|
||||
expected bool
|
||||
}{
|
||||
{0, true},
|
||||
{1, true},
|
||||
{2, false},
|
||||
{3, true},
|
||||
{4, false},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
if got := applier.HasEntry(tc.timestamp); got != tc.expected {
|
||||
t.Errorf("HasEntry(%d) = %v, want %v", tc.timestamp, got, tc.expected)
|
||||
}
|
||||
}
|
||||
}
|
@ -42,4 +42,4 @@ func (c *LamportClock) Current() uint64 {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return c.counter
|
||||
}
|
||||
}
|
||||
|
@ -7,12 +7,12 @@ import (
|
||||
|
||||
func TestLamportClockTick(t *testing.T) {
|
||||
clock := NewLamportClock()
|
||||
|
||||
|
||||
// Initial tick should return 1
|
||||
if ts := clock.Tick(); ts != 1 {
|
||||
t.Errorf("First tick should return 1, got %d", ts)
|
||||
}
|
||||
|
||||
|
||||
// Second tick should return 2
|
||||
if ts := clock.Tick(); ts != 2 {
|
||||
t.Errorf("Second tick should return 2, got %d", ts)
|
||||
@ -21,25 +21,25 @@ func TestLamportClockTick(t *testing.T) {
|
||||
|
||||
func TestLamportClockUpdate(t *testing.T) {
|
||||
clock := NewLamportClock()
|
||||
|
||||
|
||||
// Update with lower value should increment
|
||||
ts := clock.Update(0)
|
||||
if ts != 1 {
|
||||
t.Errorf("Update with lower value should return 1, got %d", ts)
|
||||
}
|
||||
|
||||
|
||||
// Update with same value should increment
|
||||
ts = clock.Update(1)
|
||||
if ts != 2 {
|
||||
t.Errorf("Update with same value should return 2, got %d", ts)
|
||||
}
|
||||
|
||||
|
||||
// Update with higher value should use that value and increment
|
||||
ts = clock.Update(10)
|
||||
if ts != 11 {
|
||||
t.Errorf("Update with higher value should return 11, got %d", ts)
|
||||
}
|
||||
|
||||
|
||||
// Subsequent tick should continue from updated value
|
||||
ts = clock.Tick()
|
||||
if ts != 12 {
|
||||
@ -49,18 +49,18 @@ func TestLamportClockUpdate(t *testing.T) {
|
||||
|
||||
func TestLamportClockCurrent(t *testing.T) {
|
||||
clock := NewLamportClock()
|
||||
|
||||
|
||||
// Initial current should be 0
|
||||
if ts := clock.Current(); ts != 0 {
|
||||
t.Errorf("Initial current should be 0, got %d", ts)
|
||||
}
|
||||
|
||||
|
||||
// After tick, current should reflect new value
|
||||
clock.Tick()
|
||||
if ts := clock.Current(); ts != 1 {
|
||||
t.Errorf("Current after tick should be 1, got %d", ts)
|
||||
}
|
||||
|
||||
|
||||
// Current should not increment the clock
|
||||
if ts := clock.Current(); ts != 1 {
|
||||
t.Errorf("Multiple calls to Current should return same value, got %d", ts)
|
||||
@ -71,7 +71,7 @@ func TestLamportClockConcurrency(t *testing.T) {
|
||||
clock := NewLamportClock()
|
||||
iterations := 1000
|
||||
var wg sync.WaitGroup
|
||||
|
||||
|
||||
// Run multiple goroutines calling Tick concurrently
|
||||
wg.Add(iterations)
|
||||
for i := 0; i < iterations; i++ {
|
||||
@ -81,10 +81,10 @@ func TestLamportClockConcurrency(t *testing.T) {
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
|
||||
// After iterations concurrent ticks, value should be iterations
|
||||
if ts := clock.Current(); ts != uint64(iterations) {
|
||||
t.Errorf("After %d concurrent ticks, expected value %d, got %d",
|
||||
t.Errorf("After %d concurrent ticks, expected value %d, got %d",
|
||||
iterations, iterations, ts)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
373
pkg/replication/replicator.go
Normal file
373
pkg/replication/replicator.go
Normal file
@ -0,0 +1,373 @@
|
||||
package replication
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/KevoDB/kevo/pkg/wal"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrReplicatorClosed indicates the replicator has been closed and no longer accepts entries
|
||||
ErrReplicatorClosed = errors.New("replicator is closed")
|
||||
|
||||
// ErrReplicatorFull indicates the replicator's entry buffer is full
|
||||
ErrReplicatorFull = errors.New("replicator entry buffer is full")
|
||||
|
||||
// ErrInvalidPosition indicates an invalid replication position was provided
|
||||
ErrInvalidPosition = errors.New("invalid replication position")
|
||||
)
|
||||
|
||||
// EntryProcessor is an interface for components that process WAL entries for replication
|
||||
type EntryProcessor interface {
|
||||
// ProcessEntry processes a single WAL entry
|
||||
ProcessEntry(entry *wal.Entry) error
|
||||
|
||||
// ProcessBatch processes a batch of WAL entries
|
||||
ProcessBatch(entries []*wal.Entry) error
|
||||
}
|
||||
|
||||
// ReplicationPosition represents a position in the replication stream
|
||||
type ReplicationPosition struct {
|
||||
// Timestamp is the Lamport timestamp of the position
|
||||
Timestamp uint64
|
||||
}
|
||||
|
||||
// WALReplicator captures WAL entries and makes them available for replication
|
||||
type WALReplicator struct {
|
||||
// Entries is a map of timestamp -> entry for all captured entries
|
||||
entries map[uint64]*wal.Entry
|
||||
|
||||
// Batches is a map of batch start timestamp -> batch entries
|
||||
batches map[uint64][]*wal.Entry
|
||||
|
||||
// EntryChannel is a channel of captured entries for subscribers
|
||||
entryChannel chan *wal.Entry
|
||||
|
||||
// BatchChannel is a channel of captured batches for subscribers
|
||||
batchChannel chan []*wal.Entry
|
||||
|
||||
// Highest timestamp seen so far
|
||||
highestTimestamp uint64
|
||||
|
||||
// MaxBufferedEntries is the maximum number of entries to buffer
|
||||
maxBufferedEntries int
|
||||
|
||||
// Concurrency control
|
||||
mu sync.RWMutex
|
||||
|
||||
// Closed indicates if the replicator is closed
|
||||
closed int32
|
||||
|
||||
// EntryProcessors are components that process entries as they're captured
|
||||
processors []EntryProcessor
|
||||
}
|
||||
|
||||
// NewWALReplicator creates a new WAL replicator
|
||||
func NewWALReplicator(maxBufferedEntries int) *WALReplicator {
|
||||
if maxBufferedEntries <= 0 {
|
||||
maxBufferedEntries = 10000 // Default to 10,000 entries
|
||||
}
|
||||
|
||||
return &WALReplicator{
|
||||
entries: make(map[uint64]*wal.Entry),
|
||||
batches: make(map[uint64][]*wal.Entry),
|
||||
entryChannel: make(chan *wal.Entry, 1000),
|
||||
batchChannel: make(chan []*wal.Entry, 100),
|
||||
maxBufferedEntries: maxBufferedEntries,
|
||||
processors: make([]EntryProcessor, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// OnEntryWritten implements the wal.ReplicationHook interface
|
||||
func (r *WALReplicator) OnEntryWritten(entry *wal.Entry) {
|
||||
if atomic.LoadInt32(&r.closed) == 1 {
|
||||
return
|
||||
}
|
||||
|
||||
r.mu.Lock()
|
||||
|
||||
// Update highest timestamp
|
||||
if entry.SequenceNumber > r.highestTimestamp {
|
||||
r.highestTimestamp = entry.SequenceNumber
|
||||
}
|
||||
|
||||
// Store the entry (make a copy to avoid potential mutation)
|
||||
entryCopy := &wal.Entry{
|
||||
SequenceNumber: entry.SequenceNumber,
|
||||
Type: entry.Type,
|
||||
Key: append([]byte{}, entry.Key...),
|
||||
}
|
||||
if entry.Value != nil {
|
||||
entryCopy.Value = append([]byte{}, entry.Value...)
|
||||
}
|
||||
|
||||
r.entries[entryCopy.SequenceNumber] = entryCopy
|
||||
|
||||
// Cleanup old entries if we exceed the buffer size
|
||||
if len(r.entries) > r.maxBufferedEntries {
|
||||
r.cleanupOldestEntries(r.maxBufferedEntries / 10) // Remove ~10% of entries
|
||||
}
|
||||
|
||||
r.mu.Unlock()
|
||||
|
||||
// Send to channel (non-blocking)
|
||||
select {
|
||||
case r.entryChannel <- entryCopy:
|
||||
// Successfully sent
|
||||
default:
|
||||
// Channel full, skip sending but entry is still stored
|
||||
}
|
||||
|
||||
// Process the entry
|
||||
r.processEntry(entryCopy)
|
||||
}
|
||||
|
||||
// OnBatchWritten implements the wal.ReplicationHook interface
|
||||
func (r *WALReplicator) OnBatchWritten(entries []*wal.Entry) {
|
||||
if atomic.LoadInt32(&r.closed) == 1 || len(entries) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
r.mu.Lock()
|
||||
|
||||
// Make copies to avoid potential mutation
|
||||
entriesCopy := make([]*wal.Entry, len(entries))
|
||||
batchTimestamp := entries[0].SequenceNumber
|
||||
|
||||
for i, entry := range entries {
|
||||
entriesCopy[i] = &wal.Entry{
|
||||
SequenceNumber: entry.SequenceNumber,
|
||||
Type: entry.Type,
|
||||
Key: append([]byte{}, entry.Key...),
|
||||
}
|
||||
if entry.Value != nil {
|
||||
entriesCopy[i].Value = append([]byte{}, entry.Value...)
|
||||
}
|
||||
|
||||
// Store individual entry
|
||||
r.entries[entriesCopy[i].SequenceNumber] = entriesCopy[i]
|
||||
|
||||
// Update highest timestamp
|
||||
if entry.SequenceNumber > r.highestTimestamp {
|
||||
r.highestTimestamp = entry.SequenceNumber
|
||||
}
|
||||
}
|
||||
|
||||
// Store the batch
|
||||
r.batches[batchTimestamp] = entriesCopy
|
||||
|
||||
// Cleanup old entries if we exceed the buffer size
|
||||
if len(r.entries) > r.maxBufferedEntries {
|
||||
r.cleanupOldestEntries(r.maxBufferedEntries / 10)
|
||||
}
|
||||
|
||||
// Cleanup old batches if we have too many
|
||||
if len(r.batches) > r.maxBufferedEntries/10 {
|
||||
r.cleanupOldestBatches(r.maxBufferedEntries / 100)
|
||||
}
|
||||
|
||||
r.mu.Unlock()
|
||||
|
||||
// Send to batch channel (non-blocking)
|
||||
select {
|
||||
case r.batchChannel <- entriesCopy:
|
||||
// Successfully sent
|
||||
default:
|
||||
// Channel full, skip sending but entries are still stored
|
||||
}
|
||||
|
||||
// Process the batch
|
||||
r.processBatch(entriesCopy)
|
||||
}
|
||||
|
||||
// GetHighestTimestamp returns the highest timestamp seen so far
|
||||
func (r *WALReplicator) GetHighestTimestamp() uint64 {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
return r.highestTimestamp
|
||||
}
|
||||
|
||||
// GetEntriesAfter returns all entries with timestamps greater than the given position
|
||||
func (r *WALReplicator) GetEntriesAfter(position ReplicationPosition) ([]*wal.Entry, error) {
|
||||
if atomic.LoadInt32(&r.closed) == 1 {
|
||||
return nil, ErrReplicatorClosed
|
||||
}
|
||||
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
// Create a result slice with appropriate capacity
|
||||
result := make([]*wal.Entry, 0, min(100, len(r.entries)))
|
||||
|
||||
// Find all entries with timestamps greater than the position
|
||||
for timestamp, entry := range r.entries {
|
||||
if timestamp > position.Timestamp {
|
||||
result = append(result, entry)
|
||||
}
|
||||
}
|
||||
|
||||
// Sort the entries by timestamp
|
||||
sortEntriesByTimestamp(result)
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetEntryCount returns the number of entries currently stored
|
||||
func (r *WALReplicator) GetEntryCount() int {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
return len(r.entries)
|
||||
}
|
||||
|
||||
// GetBatchCount returns the number of batches currently stored
|
||||
func (r *WALReplicator) GetBatchCount() int {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
return len(r.batches)
|
||||
}
|
||||
|
||||
// SubscribeToEntries returns a channel that receives entries as they're captured
|
||||
func (r *WALReplicator) SubscribeToEntries() <-chan *wal.Entry {
|
||||
return r.entryChannel
|
||||
}
|
||||
|
||||
// SubscribeToBatches returns a channel that receives batches as they're captured
|
||||
func (r *WALReplicator) SubscribeToBatches() <-chan []*wal.Entry {
|
||||
return r.batchChannel
|
||||
}
|
||||
|
||||
// AddProcessor adds an EntryProcessor to receive entries as they're captured
|
||||
func (r *WALReplicator) AddProcessor(processor EntryProcessor) {
|
||||
if atomic.LoadInt32(&r.closed) == 1 {
|
||||
return
|
||||
}
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
r.processors = append(r.processors, processor)
|
||||
}
|
||||
|
||||
// Close closes the replicator and its channels
|
||||
func (r *WALReplicator) Close() error {
|
||||
// Set closed flag
|
||||
if !atomic.CompareAndSwapInt32(&r.closed, 0, 1) {
|
||||
return nil // Already closed
|
||||
}
|
||||
|
||||
// Close channels
|
||||
close(r.entryChannel)
|
||||
close(r.batchChannel)
|
||||
|
||||
// Clear entries and batches
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
r.entries = make(map[uint64]*wal.Entry)
|
||||
r.batches = make(map[uint64][]*wal.Entry)
|
||||
r.processors = nil
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// cleanupOldestEntries removes the oldest entries from the buffer
|
||||
func (r *WALReplicator) cleanupOldestEntries(count int) {
|
||||
// Find the oldest timestamps
|
||||
oldestTimestamps := findOldestTimestamps(r.entries, count)
|
||||
|
||||
// Remove the oldest entries
|
||||
for _, ts := range oldestTimestamps {
|
||||
delete(r.entries, ts)
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupOldestBatches removes the oldest batches from the buffer
|
||||
func (r *WALReplicator) cleanupOldestBatches(count int) {
|
||||
// Find the oldest timestamps
|
||||
oldestTimestamps := findOldestTimestamps(r.batches, count)
|
||||
|
||||
// Remove the oldest batches
|
||||
for _, ts := range oldestTimestamps {
|
||||
delete(r.batches, ts)
|
||||
}
|
||||
}
|
||||
|
||||
// processEntry sends the entry to all registered processors
|
||||
func (r *WALReplicator) processEntry(entry *wal.Entry) {
|
||||
r.mu.RLock()
|
||||
processors := r.processors
|
||||
r.mu.RUnlock()
|
||||
|
||||
for _, processor := range processors {
|
||||
_ = processor.ProcessEntry(entry) // Ignore errors for now
|
||||
}
|
||||
}
|
||||
|
||||
// processBatch sends the batch to all registered processors
|
||||
func (r *WALReplicator) processBatch(entries []*wal.Entry) {
|
||||
r.mu.RLock()
|
||||
processors := r.processors
|
||||
r.mu.RUnlock()
|
||||
|
||||
for _, processor := range processors {
|
||||
_ = processor.ProcessBatch(entries) // Ignore errors for now
|
||||
}
|
||||
}
|
||||
|
||||
// findOldestTimestamps finds the n oldest timestamps in a map
|
||||
func findOldestTimestamps[T any](m map[uint64]T, n int) []uint64 {
|
||||
if len(m) <= n {
|
||||
// If we don't have enough entries, return all timestamps
|
||||
result := make([]uint64, 0, len(m))
|
||||
for ts := range m {
|
||||
result = append(result, ts)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// Find the n smallest timestamps
|
||||
result := make([]uint64, 0, n)
|
||||
for ts := range m {
|
||||
if len(result) < n {
|
||||
// Add to result if we don't have enough yet
|
||||
result = append(result, ts)
|
||||
} else {
|
||||
// Find the largest timestamp in our result
|
||||
largestIdx := 0
|
||||
for i, t := range result {
|
||||
if t > result[largestIdx] {
|
||||
largestIdx = i
|
||||
}
|
||||
}
|
||||
|
||||
// Replace the largest with this one if it's smaller
|
||||
if ts < result[largestIdx] {
|
||||
result[largestIdx] = ts
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// sortEntriesByTimestamp sorts a slice of entries by their timestamps
|
||||
func sortEntriesByTimestamp(entries []*wal.Entry) {
|
||||
// Simple insertion sort for small slices
|
||||
for i := 1; i < len(entries); i++ {
|
||||
j := i
|
||||
for j > 0 && entries[j-1].SequenceNumber > entries[j].SequenceNumber {
|
||||
entries[j], entries[j-1] = entries[j-1], entries[j]
|
||||
j--
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// min returns the smaller of two integers
|
||||
func min(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
46
pkg/replication/replicator_ext.go
Normal file
46
pkg/replication/replicator_ext.go
Normal file
@ -0,0 +1,46 @@
|
||||
package replication
|
||||
|
||||
// No imports needed
|
||||
|
||||
// processorIndex finds the index of a processor in the processors slice
|
||||
// Returns -1 if not found
|
||||
func (r *WALReplicator) processorIndex(target EntryProcessor) int {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
for i, p := range r.processors {
|
||||
if p == target {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// RemoveProcessor removes an EntryProcessor from the replicator
|
||||
func (r *WALReplicator) RemoveProcessor(processor EntryProcessor) {
|
||||
if processor == nil {
|
||||
return
|
||||
}
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
// Find the processor in the slice
|
||||
idx := -1
|
||||
for i, p := range r.processors {
|
||||
if p == processor {
|
||||
idx = i
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// If found, remove it
|
||||
if idx >= 0 {
|
||||
// Remove the element by replacing it with the last element and truncating
|
||||
lastIdx := len(r.processors) - 1
|
||||
if idx < lastIdx {
|
||||
r.processors[idx] = r.processors[lastIdx]
|
||||
}
|
||||
r.processors = r.processors[:lastIdx]
|
||||
}
|
||||
}
|
401
pkg/replication/replicator_test.go
Normal file
401
pkg/replication/replicator_test.go
Normal file
@ -0,0 +1,401 @@
|
||||
package replication
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/KevoDB/kevo/pkg/wal"
|
||||
)
|
||||
|
||||
// MockEntryProcessor implements the EntryProcessor interface for testing
|
||||
type MockEntryProcessor struct {
|
||||
mu sync.Mutex
|
||||
processedEntries []*wal.Entry
|
||||
processedBatches [][]*wal.Entry
|
||||
entriesProcessed int
|
||||
batchesProcessed int
|
||||
failProcessEntry bool
|
||||
failProcessBatch bool
|
||||
}
|
||||
|
||||
func (m *MockEntryProcessor) ProcessEntry(entry *wal.Entry) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
m.processedEntries = append(m.processedEntries, entry)
|
||||
m.entriesProcessed++
|
||||
|
||||
if m.failProcessEntry {
|
||||
return ErrReplicatorClosed // Just use an existing error
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockEntryProcessor) ProcessBatch(entries []*wal.Entry) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
m.processedBatches = append(m.processedBatches, entries)
|
||||
m.batchesProcessed++
|
||||
|
||||
if m.failProcessBatch {
|
||||
return ErrReplicatorClosed
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockEntryProcessor) GetStats() (int, int) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return m.entriesProcessed, m.batchesProcessed
|
||||
}
|
||||
|
||||
func TestWALReplicatorBasic(t *testing.T) {
|
||||
replicator := NewWALReplicator(1000)
|
||||
defer replicator.Close()
|
||||
|
||||
// Create some test entries
|
||||
entry1 := &wal.Entry{
|
||||
SequenceNumber: 1,
|
||||
Type: wal.OpTypePut,
|
||||
Key: []byte("key1"),
|
||||
Value: []byte("value1"),
|
||||
}
|
||||
|
||||
entry2 := &wal.Entry{
|
||||
SequenceNumber: 2,
|
||||
Type: wal.OpTypePut,
|
||||
Key: []byte("key2"),
|
||||
Value: []byte("value2"),
|
||||
}
|
||||
|
||||
// Process some entries
|
||||
replicator.OnEntryWritten(entry1)
|
||||
replicator.OnEntryWritten(entry2)
|
||||
|
||||
// Check entry count
|
||||
if count := replicator.GetEntryCount(); count != 2 {
|
||||
t.Errorf("Expected 2 entries, got %d", count)
|
||||
}
|
||||
|
||||
// Check highest timestamp
|
||||
if ts := replicator.GetHighestTimestamp(); ts != 2 {
|
||||
t.Errorf("Expected highest timestamp 2, got %d", ts)
|
||||
}
|
||||
|
||||
// Get entries after timestamp 0
|
||||
entries, err := replicator.GetEntriesAfter(ReplicationPosition{Timestamp: 0})
|
||||
if err != nil {
|
||||
t.Fatalf("Error getting entries: %v", err)
|
||||
}
|
||||
|
||||
if len(entries) != 2 {
|
||||
t.Fatalf("Expected 2 entries after timestamp 0, got %d", len(entries))
|
||||
}
|
||||
|
||||
// Check entries are sorted by timestamp
|
||||
if entries[0].SequenceNumber != 1 || entries[1].SequenceNumber != 2 {
|
||||
t.Errorf("Entries not sorted by timestamp")
|
||||
}
|
||||
|
||||
// Get entries after timestamp 1
|
||||
entries, err = replicator.GetEntriesAfter(ReplicationPosition{Timestamp: 1})
|
||||
if err != nil {
|
||||
t.Fatalf("Error getting entries: %v", err)
|
||||
}
|
||||
|
||||
if len(entries) != 1 {
|
||||
t.Fatalf("Expected 1 entry after timestamp 1, got %d", len(entries))
|
||||
}
|
||||
|
||||
if entries[0].SequenceNumber != 2 {
|
||||
t.Errorf("Expected entry with timestamp 2, got %d", entries[0].SequenceNumber)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWALReplicatorBatches(t *testing.T) {
|
||||
replicator := NewWALReplicator(1000)
|
||||
defer replicator.Close()
|
||||
|
||||
// Create a batch of entries
|
||||
entries := []*wal.Entry{
|
||||
{
|
||||
SequenceNumber: 10,
|
||||
Type: wal.OpTypePut,
|
||||
Key: []byte("key1"),
|
||||
Value: []byte("value1"),
|
||||
},
|
||||
{
|
||||
SequenceNumber: 11,
|
||||
Type: wal.OpTypePut,
|
||||
Key: []byte("key2"),
|
||||
Value: []byte("value2"),
|
||||
},
|
||||
}
|
||||
|
||||
// Process the batch
|
||||
replicator.OnBatchWritten(entries)
|
||||
|
||||
// Check entry count
|
||||
if count := replicator.GetEntryCount(); count != 2 {
|
||||
t.Errorf("Expected 2 entries, got %d", count)
|
||||
}
|
||||
|
||||
// Check batch count
|
||||
if count := replicator.GetBatchCount(); count != 1 {
|
||||
t.Errorf("Expected 1 batch, got %d", count)
|
||||
}
|
||||
|
||||
// Check highest timestamp
|
||||
if ts := replicator.GetHighestTimestamp(); ts != 11 {
|
||||
t.Errorf("Expected highest timestamp 11, got %d", ts)
|
||||
}
|
||||
|
||||
// Get entries after timestamp 9
|
||||
result, err := replicator.GetEntriesAfter(ReplicationPosition{Timestamp: 9})
|
||||
if err != nil {
|
||||
t.Fatalf("Error getting entries: %v", err)
|
||||
}
|
||||
|
||||
if len(result) != 2 {
|
||||
t.Fatalf("Expected 2 entries after timestamp 9, got %d", len(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestWALReplicatorProcessors(t *testing.T) {
|
||||
replicator := NewWALReplicator(1000)
|
||||
defer replicator.Close()
|
||||
|
||||
// Create a processor
|
||||
processor := &MockEntryProcessor{}
|
||||
|
||||
// Add the processor
|
||||
replicator.AddProcessor(processor)
|
||||
|
||||
// Create an entry and a batch
|
||||
entry := &wal.Entry{
|
||||
SequenceNumber: 1,
|
||||
Type: wal.OpTypePut,
|
||||
Key: []byte("key1"),
|
||||
Value: []byte("value1"),
|
||||
}
|
||||
|
||||
batch := []*wal.Entry{
|
||||
{
|
||||
SequenceNumber: 10,
|
||||
Type: wal.OpTypePut,
|
||||
Key: []byte("key10"),
|
||||
Value: []byte("value10"),
|
||||
},
|
||||
{
|
||||
SequenceNumber: 11,
|
||||
Type: wal.OpTypePut,
|
||||
Key: []byte("key11"),
|
||||
Value: []byte("value11"),
|
||||
},
|
||||
}
|
||||
|
||||
// Process the entry and batch
|
||||
replicator.OnEntryWritten(entry)
|
||||
replicator.OnBatchWritten(batch)
|
||||
|
||||
// Check processor stats
|
||||
entriesProcessed, batchesProcessed := processor.GetStats()
|
||||
if entriesProcessed != 1 {
|
||||
t.Errorf("Expected 1 entry processed, got %d", entriesProcessed)
|
||||
}
|
||||
if batchesProcessed != 1 {
|
||||
t.Errorf("Expected 1 batch processed, got %d", batchesProcessed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWALReplicatorSubscribe(t *testing.T) {
|
||||
replicator := NewWALReplicator(1000)
|
||||
defer replicator.Close()
|
||||
|
||||
// Subscribe to entries and batches
|
||||
entryChannel := replicator.SubscribeToEntries()
|
||||
batchChannel := replicator.SubscribeToBatches()
|
||||
|
||||
// Create an entry and a batch
|
||||
entry := &wal.Entry{
|
||||
SequenceNumber: 1,
|
||||
Type: wal.OpTypePut,
|
||||
Key: []byte("key1"),
|
||||
Value: []byte("value1"),
|
||||
}
|
||||
|
||||
batch := []*wal.Entry{
|
||||
{
|
||||
SequenceNumber: 10,
|
||||
Type: wal.OpTypePut,
|
||||
Key: []byte("key10"),
|
||||
Value: []byte("value10"),
|
||||
},
|
||||
{
|
||||
SequenceNumber: 11,
|
||||
Type: wal.OpTypePut,
|
||||
Key: []byte("key11"),
|
||||
Value: []byte("value11"),
|
||||
},
|
||||
}
|
||||
|
||||
// Create channels to receive the results
|
||||
entryReceived := make(chan *wal.Entry, 1)
|
||||
batchReceived := make(chan []*wal.Entry, 1)
|
||||
|
||||
// Start goroutines to receive from the channels
|
||||
go func() {
|
||||
select {
|
||||
case e := <-entryChannel:
|
||||
entryReceived <- e
|
||||
case <-time.After(time.Second):
|
||||
close(entryReceived)
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
select {
|
||||
case b := <-batchChannel:
|
||||
batchReceived <- b
|
||||
case <-time.After(time.Second):
|
||||
close(batchReceived)
|
||||
}
|
||||
}()
|
||||
|
||||
// Process the entry and batch
|
||||
replicator.OnEntryWritten(entry)
|
||||
replicator.OnBatchWritten(batch)
|
||||
|
||||
// Check that we received the entry
|
||||
select {
|
||||
case receivedEntry := <-entryReceived:
|
||||
if receivedEntry.SequenceNumber != 1 {
|
||||
t.Errorf("Expected entry with timestamp 1, got %d", receivedEntry.SequenceNumber)
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Errorf("Timeout waiting for entry")
|
||||
}
|
||||
|
||||
// Check that we received the batch
|
||||
select {
|
||||
case receivedBatch := <-batchReceived:
|
||||
if len(receivedBatch) != 2 {
|
||||
t.Errorf("Expected batch with 2 entries, got %d", len(receivedBatch))
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Errorf("Timeout waiting for batch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWALReplicatorCleanup(t *testing.T) {
|
||||
// Create a replicator with a small buffer
|
||||
replicator := NewWALReplicator(10)
|
||||
defer replicator.Close()
|
||||
|
||||
// Add more entries than the buffer can hold
|
||||
for i := 0; i < 20; i++ {
|
||||
entry := &wal.Entry{
|
||||
SequenceNumber: uint64(i),
|
||||
Type: wal.OpTypePut,
|
||||
Key: []byte("key"),
|
||||
Value: []byte("value"),
|
||||
}
|
||||
replicator.OnEntryWritten(entry)
|
||||
}
|
||||
|
||||
// Check that some entries were cleaned up
|
||||
count := replicator.GetEntryCount()
|
||||
if count > 20 {
|
||||
t.Errorf("Expected fewer than 20 entries after cleanup, got %d", count)
|
||||
}
|
||||
|
||||
// The most recent entries should still be there
|
||||
entries, err := replicator.GetEntriesAfter(ReplicationPosition{Timestamp: 15})
|
||||
if err != nil {
|
||||
t.Fatalf("Error getting entries: %v", err)
|
||||
}
|
||||
|
||||
if len(entries) == 0 {
|
||||
t.Errorf("Expected some entries after timestamp 15")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWALReplicatorClose(t *testing.T) {
|
||||
replicator := NewWALReplicator(1000)
|
||||
|
||||
// Add some entries
|
||||
entry := &wal.Entry{
|
||||
SequenceNumber: 1,
|
||||
Type: wal.OpTypePut,
|
||||
Key: []byte("key"),
|
||||
Value: []byte("value"),
|
||||
}
|
||||
replicator.OnEntryWritten(entry)
|
||||
|
||||
// Close the replicator
|
||||
if err := replicator.Close(); err != nil {
|
||||
t.Fatalf("Error closing replicator: %v", err)
|
||||
}
|
||||
|
||||
// Check that we can't add more entries
|
||||
replicator.OnEntryWritten(entry)
|
||||
|
||||
// Entry count should still be 0 after closure and cleanup
|
||||
if count := replicator.GetEntryCount(); count != 0 {
|
||||
t.Errorf("Expected 0 entries after close, got %d", count)
|
||||
}
|
||||
|
||||
// Try to get entries (should return an error)
|
||||
_, err := replicator.GetEntriesAfter(ReplicationPosition{Timestamp: 0})
|
||||
if err != ErrReplicatorClosed {
|
||||
t.Errorf("Expected ErrReplicatorClosed, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindOldestTimestamps(t *testing.T) {
|
||||
// Create a map with some timestamps
|
||||
m := map[uint64]string{
|
||||
1: "one",
|
||||
2: "two",
|
||||
3: "three",
|
||||
4: "four",
|
||||
5: "five",
|
||||
}
|
||||
|
||||
// Find the 2 oldest timestamps
|
||||
result := findOldestTimestamps(m, 2)
|
||||
|
||||
// Check the result length
|
||||
if len(result) != 2 {
|
||||
t.Fatalf("Expected 2 timestamps, got %d", len(result))
|
||||
}
|
||||
|
||||
// Check that the result contains the 2 smallest timestamps
|
||||
for _, ts := range result {
|
||||
if ts != 1 && ts != 2 {
|
||||
t.Errorf("Expected timestamp 1 or 2, got %d", ts)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSortEntriesByTimestamp(t *testing.T) {
|
||||
// Create some entries with unsorted timestamps
|
||||
entries := []*wal.Entry{
|
||||
{SequenceNumber: 3},
|
||||
{SequenceNumber: 1},
|
||||
{SequenceNumber: 2},
|
||||
}
|
||||
|
||||
// Sort the entries
|
||||
sortEntriesByTimestamp(entries)
|
||||
|
||||
// Check that the entries are sorted
|
||||
for i := 0; i < len(entries)-1; i++ {
|
||||
if entries[i].SequenceNumber > entries[i+1].SequenceNumber {
|
||||
t.Errorf("Entries not sorted at index %d: %d > %d",
|
||||
i, entries[i].SequenceNumber, entries[i+1].SequenceNumber)
|
||||
}
|
||||
}
|
||||
}
|
358
pkg/replication/serialization.go
Normal file
358
pkg/replication/serialization.go
Normal file
@ -0,0 +1,358 @@
|
||||
package replication
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"hash/crc32"
|
||||
|
||||
"github.com/KevoDB/kevo/pkg/wal"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrInvalidChecksum indicates a checksum validation failure during deserialization
|
||||
ErrInvalidChecksum = errors.New("invalid checksum")
|
||||
|
||||
// ErrInvalidFormat indicates an invalid format of serialized data
|
||||
ErrInvalidFormat = errors.New("invalid entry format")
|
||||
|
||||
// ErrBufferTooSmall indicates the provided buffer is too small for serialization
|
||||
ErrBufferTooSmall = errors.New("buffer too small")
|
||||
)
|
||||
|
||||
const (
|
||||
// Entry serialization constants
|
||||
entryHeaderSize = 17 // checksum(4) + timestamp(8) + type(1) + keylen(4)
|
||||
// Additional 4 bytes for value length when not a delete operation
|
||||
)
|
||||
|
||||
// EntrySerializer handles serialization and deserialization of WAL entries
|
||||
type EntrySerializer struct {
|
||||
// ChecksumEnabled controls whether checksums are calculated/verified
|
||||
ChecksumEnabled bool
|
||||
}
|
||||
|
||||
// NewEntrySerializer creates a new entry serializer
|
||||
func NewEntrySerializer() *EntrySerializer {
|
||||
return &EntrySerializer{
|
||||
ChecksumEnabled: true,
|
||||
}
|
||||
}
|
||||
|
||||
// SerializeEntry converts a WAL entry to a byte slice
|
||||
func (s *EntrySerializer) SerializeEntry(entry *wal.Entry) []byte {
|
||||
// Calculate total size needed
|
||||
totalSize := entryHeaderSize + len(entry.Key)
|
||||
if entry.Value != nil {
|
||||
totalSize += 4 + len(entry.Value) // vallen(4) + value
|
||||
}
|
||||
|
||||
// Allocate buffer
|
||||
data := make([]byte, totalSize)
|
||||
offset := 4 // Skip first 4 bytes for checksum
|
||||
|
||||
// Write timestamp
|
||||
binary.LittleEndian.PutUint64(data[offset:offset+8], entry.SequenceNumber)
|
||||
offset += 8
|
||||
|
||||
// Write entry type
|
||||
data[offset] = entry.Type
|
||||
offset++
|
||||
|
||||
// Write key length and key
|
||||
binary.LittleEndian.PutUint32(data[offset:offset+4], uint32(len(entry.Key)))
|
||||
offset += 4
|
||||
copy(data[offset:], entry.Key)
|
||||
offset += len(entry.Key)
|
||||
|
||||
// Write value length and value (if present)
|
||||
if entry.Value != nil {
|
||||
binary.LittleEndian.PutUint32(data[offset:offset+4], uint32(len(entry.Value)))
|
||||
offset += 4
|
||||
copy(data[offset:], entry.Value)
|
||||
}
|
||||
|
||||
// Calculate and store checksum if enabled
|
||||
if s.ChecksumEnabled {
|
||||
checksum := crc32.ChecksumIEEE(data[4:])
|
||||
binary.LittleEndian.PutUint32(data[0:4], checksum)
|
||||
}
|
||||
|
||||
return data
|
||||
}
|
||||
|
||||
// SerializeEntryToBuffer serializes a WAL entry to an existing buffer
|
||||
// Returns the number of bytes written or an error if the buffer is too small
|
||||
func (s *EntrySerializer) SerializeEntryToBuffer(entry *wal.Entry, buffer []byte) (int, error) {
|
||||
// Calculate total size needed
|
||||
totalSize := entryHeaderSize + len(entry.Key)
|
||||
if entry.Value != nil {
|
||||
totalSize += 4 + len(entry.Value) // vallen(4) + value
|
||||
}
|
||||
|
||||
// Check if buffer is large enough
|
||||
if len(buffer) < totalSize {
|
||||
return 0, ErrBufferTooSmall
|
||||
}
|
||||
|
||||
// Write to buffer
|
||||
offset := 4 // Skip first 4 bytes for checksum
|
||||
|
||||
// Write timestamp
|
||||
binary.LittleEndian.PutUint64(buffer[offset:offset+8], entry.SequenceNumber)
|
||||
offset += 8
|
||||
|
||||
// Write entry type
|
||||
buffer[offset] = entry.Type
|
||||
offset++
|
||||
|
||||
// Write key length and key
|
||||
binary.LittleEndian.PutUint32(buffer[offset:offset+4], uint32(len(entry.Key)))
|
||||
offset += 4
|
||||
copy(buffer[offset:], entry.Key)
|
||||
offset += len(entry.Key)
|
||||
|
||||
// Write value length and value (if present)
|
||||
if entry.Value != nil {
|
||||
binary.LittleEndian.PutUint32(buffer[offset:offset+4], uint32(len(entry.Value)))
|
||||
offset += 4
|
||||
copy(buffer[offset:], entry.Value)
|
||||
offset += len(entry.Value)
|
||||
}
|
||||
|
||||
// Calculate and store checksum if enabled
|
||||
if s.ChecksumEnabled {
|
||||
checksum := crc32.ChecksumIEEE(buffer[4:offset])
|
||||
binary.LittleEndian.PutUint32(buffer[0:4], checksum)
|
||||
}
|
||||
|
||||
return offset, nil
|
||||
}
|
||||
|
||||
// DeserializeEntry converts a byte slice back to a WAL entry
|
||||
func (s *EntrySerializer) DeserializeEntry(data []byte) (*wal.Entry, error) {
|
||||
// Validate minimum size
|
||||
if len(data) < entryHeaderSize {
|
||||
return nil, ErrInvalidFormat
|
||||
}
|
||||
|
||||
// Verify checksum if enabled
|
||||
if s.ChecksumEnabled {
|
||||
storedChecksum := binary.LittleEndian.Uint32(data[0:4])
|
||||
calculatedChecksum := crc32.ChecksumIEEE(data[4:])
|
||||
if storedChecksum != calculatedChecksum {
|
||||
return nil, ErrInvalidChecksum
|
||||
}
|
||||
}
|
||||
|
||||
offset := 4 // Skip checksum
|
||||
|
||||
// Read timestamp
|
||||
timestamp := binary.LittleEndian.Uint64(data[offset : offset+8])
|
||||
offset += 8
|
||||
|
||||
// Read entry type
|
||||
entryType := data[offset]
|
||||
offset++
|
||||
|
||||
// Read key length and key
|
||||
keyLen := binary.LittleEndian.Uint32(data[offset : offset+4])
|
||||
offset += 4
|
||||
|
||||
// Validate key length
|
||||
if offset+int(keyLen) > len(data) {
|
||||
return nil, ErrInvalidFormat
|
||||
}
|
||||
|
||||
key := make([]byte, keyLen)
|
||||
copy(key, data[offset:offset+int(keyLen)])
|
||||
offset += int(keyLen)
|
||||
|
||||
// Read value length and value if present
|
||||
var value []byte
|
||||
if offset < len(data) {
|
||||
// Only read value if there's more data
|
||||
if offset+4 > len(data) {
|
||||
return nil, ErrInvalidFormat
|
||||
}
|
||||
|
||||
valueLen := binary.LittleEndian.Uint32(data[offset : offset+4])
|
||||
offset += 4
|
||||
|
||||
// Validate value length
|
||||
if offset+int(valueLen) > len(data) {
|
||||
return nil, ErrInvalidFormat
|
||||
}
|
||||
|
||||
value = make([]byte, valueLen)
|
||||
copy(value, data[offset:offset+int(valueLen)])
|
||||
}
|
||||
|
||||
// Create and return the entry
|
||||
return &wal.Entry{
|
||||
SequenceNumber: timestamp,
|
||||
Type: entryType,
|
||||
Key: key,
|
||||
Value: value,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// BatchSerializer handles serialization of WAL entry batches
|
||||
type BatchSerializer struct {
|
||||
entrySerializer *EntrySerializer
|
||||
}
|
||||
|
||||
// NewBatchSerializer creates a new batch serializer
|
||||
func NewBatchSerializer() *BatchSerializer {
|
||||
return &BatchSerializer{
|
||||
entrySerializer: NewEntrySerializer(),
|
||||
}
|
||||
}
|
||||
|
||||
// SerializeBatch converts a batch of WAL entries to a byte slice
|
||||
func (s *BatchSerializer) SerializeBatch(entries []*wal.Entry) []byte {
|
||||
if len(entries) == 0 {
|
||||
// Empty batch - just return header with count 0
|
||||
result := make([]byte, 12) // checksum(4) + count(4) + timestamp(4)
|
||||
binary.LittleEndian.PutUint32(result[4:8], 0)
|
||||
|
||||
// Calculate and store checksum
|
||||
checksum := crc32.ChecksumIEEE(result[4:])
|
||||
binary.LittleEndian.PutUint32(result[0:4], checksum)
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// First pass: calculate total size needed
|
||||
var totalSize int = 12 // header: checksum(4) + count(4) + base timestamp(4)
|
||||
|
||||
for _, entry := range entries {
|
||||
// For each entry: size(4) + serialized entry data
|
||||
entrySize := entryHeaderSize + len(entry.Key)
|
||||
if entry.Value != nil {
|
||||
entrySize += 4 + len(entry.Value)
|
||||
}
|
||||
|
||||
totalSize += 4 + entrySize
|
||||
}
|
||||
|
||||
// Allocate buffer
|
||||
result := make([]byte, totalSize)
|
||||
offset := 4 // Skip checksum for now
|
||||
|
||||
// Write entry count
|
||||
binary.LittleEndian.PutUint32(result[offset:offset+4], uint32(len(entries)))
|
||||
offset += 4
|
||||
|
||||
// Write base timestamp (from first entry)
|
||||
binary.LittleEndian.PutUint32(result[offset:offset+4], uint32(entries[0].SequenceNumber))
|
||||
offset += 4
|
||||
|
||||
// Write each entry
|
||||
for _, entry := range entries {
|
||||
// Reserve space for entry size
|
||||
sizeOffset := offset
|
||||
offset += 4
|
||||
|
||||
// Serialize entry directly into the buffer
|
||||
entrySize, err := s.entrySerializer.SerializeEntryToBuffer(entry, result[offset:])
|
||||
if err != nil {
|
||||
// This shouldn't happen since we pre-calculated the size,
|
||||
// but handle it gracefully just in case
|
||||
panic("buffer too small for entry serialization")
|
||||
}
|
||||
|
||||
offset += entrySize
|
||||
|
||||
// Write the actual entry size
|
||||
binary.LittleEndian.PutUint32(result[sizeOffset:sizeOffset+4], uint32(entrySize))
|
||||
}
|
||||
|
||||
// Calculate and store checksum
|
||||
checksum := crc32.ChecksumIEEE(result[4:offset])
|
||||
binary.LittleEndian.PutUint32(result[0:4], checksum)
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// DeserializeBatch converts a byte slice back to a batch of WAL entries
|
||||
func (s *BatchSerializer) DeserializeBatch(data []byte) ([]*wal.Entry, error) {
|
||||
// Validate minimum size for batch header
|
||||
if len(data) < 12 {
|
||||
return nil, ErrInvalidFormat
|
||||
}
|
||||
|
||||
// Verify checksum
|
||||
storedChecksum := binary.LittleEndian.Uint32(data[0:4])
|
||||
calculatedChecksum := crc32.ChecksumIEEE(data[4:])
|
||||
if storedChecksum != calculatedChecksum {
|
||||
return nil, ErrInvalidChecksum
|
||||
}
|
||||
|
||||
offset := 4 // Skip checksum
|
||||
|
||||
// Read entry count
|
||||
count := binary.LittleEndian.Uint32(data[offset:offset+4])
|
||||
offset += 4
|
||||
|
||||
// Read base timestamp (we don't use this currently, but read past it)
|
||||
offset += 4 // Skip base timestamp
|
||||
|
||||
// Early return for empty batch
|
||||
if count == 0 {
|
||||
return []*wal.Entry{}, nil
|
||||
}
|
||||
|
||||
// Deserialize each entry
|
||||
entries := make([]*wal.Entry, count)
|
||||
for i := uint32(0); i < count; i++ {
|
||||
// Validate we have enough data for entry size
|
||||
if offset+4 > len(data) {
|
||||
return nil, ErrInvalidFormat
|
||||
}
|
||||
|
||||
// Read entry size
|
||||
entrySize := binary.LittleEndian.Uint32(data[offset:offset+4])
|
||||
offset += 4
|
||||
|
||||
// Validate entry size
|
||||
if offset+int(entrySize) > len(data) {
|
||||
return nil, ErrInvalidFormat
|
||||
}
|
||||
|
||||
// Deserialize entry
|
||||
entry, err := s.entrySerializer.DeserializeEntry(data[offset:offset+int(entrySize)])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
entries[i] = entry
|
||||
offset += int(entrySize)
|
||||
}
|
||||
|
||||
return entries, nil
|
||||
}
|
||||
|
||||
// EstimateEntrySize estimates the serialized size of a WAL entry without actually serializing it
|
||||
func EstimateEntrySize(entry *wal.Entry) int {
|
||||
size := entryHeaderSize + len(entry.Key)
|
||||
if entry.Value != nil {
|
||||
size += 4 + len(entry.Value)
|
||||
}
|
||||
return size
|
||||
}
|
||||
|
||||
// EstimateBatchSize estimates the serialized size of a batch of WAL entries
|
||||
func EstimateBatchSize(entries []*wal.Entry) int {
|
||||
if len(entries) == 0 {
|
||||
return 12 // Empty batch header
|
||||
}
|
||||
|
||||
size := 12 // Batch header: checksum(4) + count(4) + base timestamp(4)
|
||||
|
||||
for _, entry := range entries {
|
||||
entrySize := EstimateEntrySize(entry)
|
||||
size += 4 + entrySize // size field(4) + entry data
|
||||
}
|
||||
|
||||
return size
|
||||
}
|
420
pkg/replication/serialization_test.go
Normal file
420
pkg/replication/serialization_test.go
Normal file
@ -0,0 +1,420 @@
|
||||
package replication
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"hash/crc32"
|
||||
"testing"
|
||||
|
||||
"github.com/KevoDB/kevo/pkg/wal"
|
||||
)
|
||||
|
||||
func TestEntrySerializer(t *testing.T) {
|
||||
// Create a serializer
|
||||
serializer := NewEntrySerializer()
|
||||
|
||||
// Test different entry types
|
||||
testCases := []struct {
|
||||
name string
|
||||
entry *wal.Entry
|
||||
}{
|
||||
{
|
||||
name: "Put operation",
|
||||
entry: &wal.Entry{
|
||||
SequenceNumber: 123,
|
||||
Type: wal.OpTypePut,
|
||||
Key: []byte("test-key"),
|
||||
Value: []byte("test-value"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Delete operation",
|
||||
entry: &wal.Entry{
|
||||
SequenceNumber: 456,
|
||||
Type: wal.OpTypeDelete,
|
||||
Key: []byte("deleted-key"),
|
||||
Value: nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Large entry",
|
||||
entry: &wal.Entry{
|
||||
SequenceNumber: 789,
|
||||
Type: wal.OpTypePut,
|
||||
Key: bytes.Repeat([]byte("K"), 1000),
|
||||
Value: bytes.Repeat([]byte("V"), 1000),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Serialize the entry
|
||||
data := serializer.SerializeEntry(tc.entry)
|
||||
|
||||
// Deserialize back
|
||||
result, err := serializer.DeserializeEntry(data)
|
||||
if err != nil {
|
||||
t.Fatalf("Error deserializing entry: %v", err)
|
||||
}
|
||||
|
||||
// Compare entries
|
||||
if result.SequenceNumber != tc.entry.SequenceNumber {
|
||||
t.Errorf("Expected sequence number %d, got %d",
|
||||
tc.entry.SequenceNumber, result.SequenceNumber)
|
||||
}
|
||||
|
||||
if result.Type != tc.entry.Type {
|
||||
t.Errorf("Expected type %d, got %d", tc.entry.Type, result.Type)
|
||||
}
|
||||
|
||||
if !bytes.Equal(result.Key, tc.entry.Key) {
|
||||
t.Errorf("Expected key %q, got %q", tc.entry.Key, result.Key)
|
||||
}
|
||||
|
||||
if !bytes.Equal(result.Value, tc.entry.Value) {
|
||||
t.Errorf("Expected value %q, got %q", tc.entry.Value, result.Value)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEntrySerializerChecksum(t *testing.T) {
|
||||
// Create a serializer with checksums enabled
|
||||
serializer := NewEntrySerializer()
|
||||
serializer.ChecksumEnabled = true
|
||||
|
||||
// Create a test entry
|
||||
entry := &wal.Entry{
|
||||
SequenceNumber: 123,
|
||||
Type: wal.OpTypePut,
|
||||
Key: []byte("test-key"),
|
||||
Value: []byte("test-value"),
|
||||
}
|
||||
|
||||
// Serialize the entry
|
||||
data := serializer.SerializeEntry(entry)
|
||||
|
||||
// Corrupt the data
|
||||
data[10]++
|
||||
|
||||
// Try to deserialize - should fail with checksum error
|
||||
_, err := serializer.DeserializeEntry(data)
|
||||
if err != ErrInvalidChecksum {
|
||||
t.Errorf("Expected checksum error, got %v", err)
|
||||
}
|
||||
|
||||
// Now disable checksum verification and try again
|
||||
serializer.ChecksumEnabled = false
|
||||
result, err := serializer.DeserializeEntry(data)
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error with checksums disabled, got %v", err)
|
||||
}
|
||||
if result == nil {
|
||||
t.Fatal("Expected entry to be returned with checksums disabled")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEntrySerializerInvalidFormat(t *testing.T) {
|
||||
serializer := NewEntrySerializer()
|
||||
|
||||
// Test with empty data
|
||||
_, err := serializer.DeserializeEntry([]byte{})
|
||||
if err != ErrInvalidFormat {
|
||||
t.Errorf("Expected format error for empty data, got %v", err)
|
||||
}
|
||||
|
||||
// Test with insufficient data
|
||||
_, err = serializer.DeserializeEntry(make([]byte, 10))
|
||||
if err != ErrInvalidFormat {
|
||||
t.Errorf("Expected format error for insufficient data, got %v", err)
|
||||
}
|
||||
|
||||
// Test with invalid key length
|
||||
data := make([]byte, entryHeaderSize+4)
|
||||
offset := 4
|
||||
binary.LittleEndian.PutUint64(data[offset:offset+8], 123) // timestamp
|
||||
offset += 8
|
||||
data[offset] = wal.OpTypePut // type
|
||||
offset++
|
||||
binary.LittleEndian.PutUint32(data[offset:offset+4], 1000) // key length (too large)
|
||||
|
||||
// Calculate a valid checksum for this data
|
||||
checksum := crc32.ChecksumIEEE(data[4:])
|
||||
binary.LittleEndian.PutUint32(data[0:4], checksum)
|
||||
|
||||
_, err = serializer.DeserializeEntry(data)
|
||||
if err != ErrInvalidFormat {
|
||||
t.Errorf("Expected format error for invalid key length, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBatchSerializer(t *testing.T) {
|
||||
// Create batch serializer
|
||||
serializer := NewBatchSerializer()
|
||||
|
||||
// Test batch with multiple entries
|
||||
entries := []*wal.Entry{
|
||||
{
|
||||
SequenceNumber: 101,
|
||||
Type: wal.OpTypePut,
|
||||
Key: []byte("key1"),
|
||||
Value: []byte("value1"),
|
||||
},
|
||||
{
|
||||
SequenceNumber: 102,
|
||||
Type: wal.OpTypeDelete,
|
||||
Key: []byte("key2"),
|
||||
Value: nil,
|
||||
},
|
||||
{
|
||||
SequenceNumber: 103,
|
||||
Type: wal.OpTypePut,
|
||||
Key: []byte("key3"),
|
||||
Value: []byte("value3"),
|
||||
},
|
||||
}
|
||||
|
||||
// Serialize batch
|
||||
data := serializer.SerializeBatch(entries)
|
||||
|
||||
// Deserialize batch
|
||||
result, err := serializer.DeserializeBatch(data)
|
||||
if err != nil {
|
||||
t.Fatalf("Error deserializing batch: %v", err)
|
||||
}
|
||||
|
||||
// Verify batch
|
||||
if len(result) != len(entries) {
|
||||
t.Fatalf("Expected %d entries, got %d", len(entries), len(result))
|
||||
}
|
||||
|
||||
for i, entry := range entries {
|
||||
if result[i].SequenceNumber != entry.SequenceNumber {
|
||||
t.Errorf("Entry %d: Expected sequence number %d, got %d",
|
||||
i, entry.SequenceNumber, result[i].SequenceNumber)
|
||||
}
|
||||
if result[i].Type != entry.Type {
|
||||
t.Errorf("Entry %d: Expected type %d, got %d",
|
||||
i, entry.Type, result[i].Type)
|
||||
}
|
||||
if !bytes.Equal(result[i].Key, entry.Key) {
|
||||
t.Errorf("Entry %d: Expected key %q, got %q",
|
||||
i, entry.Key, result[i].Key)
|
||||
}
|
||||
if !bytes.Equal(result[i].Value, entry.Value) {
|
||||
t.Errorf("Entry %d: Expected value %q, got %q",
|
||||
i, entry.Value, result[i].Value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEmptyBatchSerialization(t *testing.T) {
|
||||
// Create batch serializer
|
||||
serializer := NewBatchSerializer()
|
||||
|
||||
// Test empty batch
|
||||
entries := []*wal.Entry{}
|
||||
|
||||
// Serialize batch
|
||||
data := serializer.SerializeBatch(entries)
|
||||
|
||||
// Deserialize batch
|
||||
result, err := serializer.DeserializeBatch(data)
|
||||
if err != nil {
|
||||
t.Fatalf("Error deserializing empty batch: %v", err)
|
||||
}
|
||||
|
||||
// Verify result is empty
|
||||
if len(result) != 0 {
|
||||
t.Errorf("Expected empty batch, got %d entries", len(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestBatchSerializerChecksum(t *testing.T) {
|
||||
// Create batch serializer
|
||||
serializer := NewBatchSerializer()
|
||||
|
||||
// Test batch with single entry
|
||||
entries := []*wal.Entry{
|
||||
{
|
||||
SequenceNumber: 101,
|
||||
Type: wal.OpTypePut,
|
||||
Key: []byte("key1"),
|
||||
Value: []byte("value1"),
|
||||
},
|
||||
}
|
||||
|
||||
// Serialize batch
|
||||
data := serializer.SerializeBatch(entries)
|
||||
|
||||
// Corrupt data
|
||||
data[8]++
|
||||
|
||||
// Attempt to deserialize - should fail
|
||||
_, err := serializer.DeserializeBatch(data)
|
||||
if err != ErrInvalidChecksum {
|
||||
t.Errorf("Expected checksum error for corrupted batch, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBatchSerializerInvalidFormat(t *testing.T) {
|
||||
serializer := NewBatchSerializer()
|
||||
|
||||
// Test with empty data
|
||||
_, err := serializer.DeserializeBatch([]byte{})
|
||||
if err != ErrInvalidFormat {
|
||||
t.Errorf("Expected format error for empty data, got %v", err)
|
||||
}
|
||||
|
||||
// Test with insufficient data
|
||||
_, err = serializer.DeserializeBatch(make([]byte, 10))
|
||||
if err != ErrInvalidFormat {
|
||||
t.Errorf("Expected format error for insufficient data, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEstimateEntrySize(t *testing.T) {
|
||||
// Test entries with different sizes
|
||||
testCases := []struct {
|
||||
name string
|
||||
entry *wal.Entry
|
||||
expected int
|
||||
}{
|
||||
{
|
||||
name: "Basic put entry",
|
||||
entry: &wal.Entry{
|
||||
SequenceNumber: 101,
|
||||
Type: wal.OpTypePut,
|
||||
Key: []byte("key"),
|
||||
Value: []byte("value"),
|
||||
},
|
||||
expected: entryHeaderSize + 3 + 4 + 5, // header + key_len + value_len + value
|
||||
},
|
||||
{
|
||||
name: "Delete entry (no value)",
|
||||
entry: &wal.Entry{
|
||||
SequenceNumber: 102,
|
||||
Type: wal.OpTypeDelete,
|
||||
Key: []byte("delete-key"),
|
||||
Value: nil,
|
||||
},
|
||||
expected: entryHeaderSize + 10, // header + key_len
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
size := EstimateEntrySize(tc.entry)
|
||||
if size != tc.expected {
|
||||
t.Errorf("Expected size %d, got %d", tc.expected, size)
|
||||
}
|
||||
|
||||
// Verify estimate matches actual size
|
||||
serializer := NewEntrySerializer()
|
||||
data := serializer.SerializeEntry(tc.entry)
|
||||
if len(data) != size {
|
||||
t.Errorf("Estimated size %d doesn't match actual size %d",
|
||||
size, len(data))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEstimateBatchSize(t *testing.T) {
|
||||
// Test batches with different contents
|
||||
testCases := []struct {
|
||||
name string
|
||||
entries []*wal.Entry
|
||||
expected int
|
||||
}{
|
||||
{
|
||||
name: "Empty batch",
|
||||
entries: []*wal.Entry{},
|
||||
expected: 12, // Just batch header
|
||||
},
|
||||
{
|
||||
name: "Batch with one entry",
|
||||
entries: []*wal.Entry{
|
||||
{
|
||||
SequenceNumber: 101,
|
||||
Type: wal.OpTypePut,
|
||||
Key: []byte("key1"),
|
||||
Value: []byte("value1"),
|
||||
},
|
||||
},
|
||||
expected: 12 + 4 + entryHeaderSize + 4 + 4 + 6, // batch header + entry size field + entry header + key + value size + value
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
size := EstimateBatchSize(tc.entries)
|
||||
if size != tc.expected {
|
||||
t.Errorf("Expected size %d, got %d", tc.expected, size)
|
||||
}
|
||||
|
||||
// Verify estimate matches actual size
|
||||
serializer := NewBatchSerializer()
|
||||
data := serializer.SerializeBatch(tc.entries)
|
||||
if len(data) != size {
|
||||
t.Errorf("Estimated size %d doesn't match actual size %d",
|
||||
size, len(data))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSerializeToBuffer(t *testing.T) {
|
||||
serializer := NewEntrySerializer()
|
||||
|
||||
// Create a test entry
|
||||
entry := &wal.Entry{
|
||||
SequenceNumber: 101,
|
||||
Type: wal.OpTypePut,
|
||||
Key: []byte("key1"),
|
||||
Value: []byte("value1"),
|
||||
}
|
||||
|
||||
// Estimate the size
|
||||
estimatedSize := EstimateEntrySize(entry)
|
||||
|
||||
// Create a buffer of the estimated size
|
||||
buffer := make([]byte, estimatedSize)
|
||||
|
||||
// Serialize to buffer
|
||||
n, err := serializer.SerializeEntryToBuffer(entry, buffer)
|
||||
if err != nil {
|
||||
t.Fatalf("Error serializing to buffer: %v", err)
|
||||
}
|
||||
|
||||
// Check bytes written
|
||||
if n != estimatedSize {
|
||||
t.Errorf("Expected %d bytes written, got %d", estimatedSize, n)
|
||||
}
|
||||
|
||||
// Verify by deserializing
|
||||
result, err := serializer.DeserializeEntry(buffer)
|
||||
if err != nil {
|
||||
t.Fatalf("Error deserializing from buffer: %v", err)
|
||||
}
|
||||
|
||||
// Check result
|
||||
if result.SequenceNumber != entry.SequenceNumber {
|
||||
t.Errorf("Expected sequence number %d, got %d",
|
||||
entry.SequenceNumber, result.SequenceNumber)
|
||||
}
|
||||
if !bytes.Equal(result.Key, entry.Key) {
|
||||
t.Errorf("Expected key %q, got %q", entry.Key, result.Key)
|
||||
}
|
||||
if !bytes.Equal(result.Value, entry.Value) {
|
||||
t.Errorf("Expected value %q, got %q", entry.Value, result.Value)
|
||||
}
|
||||
|
||||
// Test with too small buffer
|
||||
smallBuffer := make([]byte, estimatedSize - 1)
|
||||
_, err = serializer.SerializeEntryToBuffer(entry, smallBuffer)
|
||||
if err != ErrBufferTooSmall {
|
||||
t.Errorf("Expected buffer too small error, got %v", err)
|
||||
}
|
||||
}
|
87
pkg/replication/storage_snapshot.go
Normal file
87
pkg/replication/storage_snapshot.go
Normal file
@ -0,0 +1,87 @@
|
||||
package replication
|
||||
|
||||
import (
|
||||
"io"
|
||||
)
|
||||
|
||||
// StorageSnapshot provides an interface for taking a snapshot of the storage
|
||||
// for replication bootstrap purposes.
|
||||
type StorageSnapshot interface {
|
||||
// CreateSnapshotIterator creates an iterator for a storage snapshot
|
||||
CreateSnapshotIterator() (SnapshotIterator, error)
|
||||
|
||||
// KeyCount returns the approximate number of keys in storage
|
||||
KeyCount() int64
|
||||
}
|
||||
|
||||
// SnapshotIterator provides iteration over key-value pairs in storage
|
||||
type SnapshotIterator interface {
|
||||
// Next returns the next key-value pair
|
||||
// Returns io.EOF when there are no more items
|
||||
Next() (key []byte, value []byte, err error)
|
||||
|
||||
// Close closes the iterator
|
||||
Close() error
|
||||
}
|
||||
|
||||
// StorageSnapshotProvider is implemented by storage engines that support snapshots
|
||||
type StorageSnapshotProvider interface {
|
||||
// CreateSnapshot creates a snapshot of the current storage state
|
||||
CreateSnapshot() (StorageSnapshot, error)
|
||||
}
|
||||
|
||||
// MemoryStorageSnapshot is a simple in-memory implementation of StorageSnapshot
|
||||
// Useful for testing or small datasets
|
||||
type MemoryStorageSnapshot struct {
|
||||
Pairs []KeyValuePair
|
||||
position int
|
||||
}
|
||||
|
||||
// KeyValuePair represents a key-value pair in storage
|
||||
type KeyValuePair struct {
|
||||
Key []byte
|
||||
Value []byte
|
||||
}
|
||||
|
||||
// CreateSnapshotIterator creates an iterator for a memory storage snapshot
|
||||
func (m *MemoryStorageSnapshot) CreateSnapshotIterator() (SnapshotIterator, error) {
|
||||
return &MemorySnapshotIterator{
|
||||
snapshot: m,
|
||||
position: 0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// KeyCount returns the number of keys in the snapshot
|
||||
func (m *MemoryStorageSnapshot) KeyCount() int64 {
|
||||
return int64(len(m.Pairs))
|
||||
}
|
||||
|
||||
// MemorySnapshotIterator is an iterator for MemoryStorageSnapshot
|
||||
type MemorySnapshotIterator struct {
|
||||
snapshot *MemoryStorageSnapshot
|
||||
position int
|
||||
}
|
||||
|
||||
// Next returns the next key-value pair
|
||||
func (it *MemorySnapshotIterator) Next() ([]byte, []byte, error) {
|
||||
if it.position >= len(it.snapshot.Pairs) {
|
||||
return nil, nil, io.EOF
|
||||
}
|
||||
|
||||
pair := it.snapshot.Pairs[it.position]
|
||||
it.position++
|
||||
|
||||
return pair.Key, pair.Value, nil
|
||||
}
|
||||
|
||||
// Close closes the iterator
|
||||
func (it *MemorySnapshotIterator) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewMemoryStorageSnapshot creates a new in-memory storage snapshot
|
||||
func NewMemoryStorageSnapshot(pairs []KeyValuePair) *MemoryStorageSnapshot {
|
||||
return &MemoryStorageSnapshot{
|
||||
Pairs: pairs,
|
||||
}
|
||||
}
|
@ -138,12 +138,24 @@ type Registry interface {
|
||||
// RegisterServer adds a new server implementation to the registry
|
||||
RegisterServer(name string, factory ServerFactory)
|
||||
|
||||
// RegisterReplicationClient adds a new replication client implementation
|
||||
RegisterReplicationClient(name string, factory ReplicationClientFactory)
|
||||
|
||||
// RegisterReplicationServer adds a new replication server implementation
|
||||
RegisterReplicationServer(name string, factory ReplicationServerFactory)
|
||||
|
||||
// CreateClient instantiates a client by name
|
||||
CreateClient(name, endpoint string, options TransportOptions) (Client, error)
|
||||
|
||||
// CreateServer instantiates a server by name
|
||||
CreateServer(name, address string, options TransportOptions) (Server, error)
|
||||
|
||||
// CreateReplicationClient instantiates a replication client by name
|
||||
CreateReplicationClient(name, endpoint string, options TransportOptions) (ReplicationClient, error)
|
||||
|
||||
// CreateReplicationServer instantiates a replication server by name
|
||||
CreateReplicationServer(name, address string, options TransportOptions) (ReplicationServer, error)
|
||||
|
||||
// ListTransports returns all available transport names
|
||||
ListTransports() []string
|
||||
}
|
||||
|
@ -7,9 +7,11 @@ import (
|
||||
|
||||
// registry implements the Registry interface
|
||||
type registry struct {
|
||||
mu sync.RWMutex
|
||||
clientFactories map[string]ClientFactory
|
||||
serverFactories map[string]ServerFactory
|
||||
mu sync.RWMutex
|
||||
clientFactories map[string]ClientFactory
|
||||
serverFactories map[string]ServerFactory
|
||||
replicationClientFactories map[string]ReplicationClientFactory
|
||||
replicationServerFactories map[string]ReplicationServerFactory
|
||||
}
|
||||
|
||||
// NewRegistry creates a new transport registry
|
||||
@ -17,6 +19,8 @@ func NewRegistry() Registry {
|
||||
return ®istry{
|
||||
clientFactories: make(map[string]ClientFactory),
|
||||
serverFactories: make(map[string]ServerFactory),
|
||||
replicationClientFactories: make(map[string]ReplicationClientFactory),
|
||||
replicationServerFactories: make(map[string]ReplicationServerFactory),
|
||||
}
|
||||
}
|
||||
|
||||
@ -63,6 +67,46 @@ func (r *registry) CreateServer(name, address string, options TransportOptions)
|
||||
return factory(address, options)
|
||||
}
|
||||
|
||||
// RegisterReplicationClient adds a new replication client implementation to the registry
|
||||
func (r *registry) RegisterReplicationClient(name string, factory ReplicationClientFactory) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.replicationClientFactories[name] = factory
|
||||
}
|
||||
|
||||
// RegisterReplicationServer adds a new replication server implementation to the registry
|
||||
func (r *registry) RegisterReplicationServer(name string, factory ReplicationServerFactory) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.replicationServerFactories[name] = factory
|
||||
}
|
||||
|
||||
// CreateReplicationClient instantiates a replication client by name
|
||||
func (r *registry) CreateReplicationClient(name, endpoint string, options TransportOptions) (ReplicationClient, error) {
|
||||
r.mu.RLock()
|
||||
factory, exists := r.replicationClientFactories[name]
|
||||
r.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("replication client %q not registered", name)
|
||||
}
|
||||
|
||||
return factory(endpoint, options)
|
||||
}
|
||||
|
||||
// CreateReplicationServer instantiates a replication server by name
|
||||
func (r *registry) CreateReplicationServer(name, address string, options TransportOptions) (ReplicationServer, error) {
|
||||
r.mu.RLock()
|
||||
factory, exists := r.replicationServerFactories[name]
|
||||
r.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("replication server %q not registered", name)
|
||||
}
|
||||
|
||||
return factory(address, options)
|
||||
}
|
||||
|
||||
// ListTransports returns all available transport names
|
||||
func (r *registry) ListTransports() []string {
|
||||
r.mu.RLock()
|
||||
@ -76,6 +120,12 @@ func (r *registry) ListTransports() []string {
|
||||
for name := range r.serverFactories {
|
||||
names[name] = struct{}{}
|
||||
}
|
||||
for name := range r.replicationClientFactories {
|
||||
names[name] = struct{}{}
|
||||
}
|
||||
for name := range r.replicationServerFactories {
|
||||
names[name] = struct{}{}
|
||||
}
|
||||
|
||||
// Convert to slice
|
||||
result := make([]string, 0, len(names))
|
||||
@ -98,6 +148,16 @@ func RegisterServerTransport(name string, factory ServerFactory) {
|
||||
DefaultRegistry.RegisterServer(name, factory)
|
||||
}
|
||||
|
||||
// RegisterReplicationClientTransport registers a replication client with the default registry
|
||||
func RegisterReplicationClientTransport(name string, factory ReplicationClientFactory) {
|
||||
DefaultRegistry.RegisterReplicationClient(name, factory)
|
||||
}
|
||||
|
||||
// RegisterReplicationServerTransport registers a replication server with the default registry
|
||||
func RegisterReplicationServerTransport(name string, factory ReplicationServerFactory) {
|
||||
DefaultRegistry.RegisterReplicationServer(name, factory)
|
||||
}
|
||||
|
||||
// GetClient creates a client using the default registry
|
||||
func GetClient(name, endpoint string, options TransportOptions) (Client, error) {
|
||||
return DefaultRegistry.CreateClient(name, endpoint, options)
|
||||
@ -108,6 +168,16 @@ func GetServer(name, address string, options TransportOptions) (Server, error) {
|
||||
return DefaultRegistry.CreateServer(name, address, options)
|
||||
}
|
||||
|
||||
// GetReplicationClient creates a replication client using the default registry
|
||||
func GetReplicationClient(name, endpoint string, options TransportOptions) (ReplicationClient, error) {
|
||||
return DefaultRegistry.CreateReplicationClient(name, endpoint, options)
|
||||
}
|
||||
|
||||
// GetReplicationServer creates a replication server using the default registry
|
||||
func GetReplicationServer(name, address string, options TransportOptions) (ReplicationServer, error) {
|
||||
return DefaultRegistry.CreateReplicationServer(name, address, options)
|
||||
}
|
||||
|
||||
// AvailableTransports lists all available transports in the default registry
|
||||
func AvailableTransports() []string {
|
||||
return DefaultRegistry.ListTransports()
|
||||
|
183
pkg/transport/replication.go
Normal file
183
pkg/transport/replication.go
Normal file
@ -0,0 +1,183 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/KevoDB/kevo/pkg/wal"
|
||||
)
|
||||
|
||||
// Standard constants for replication message types
|
||||
const (
|
||||
// Request types
|
||||
TypeReplicaRegister = "replica_register"
|
||||
TypeReplicaHeartbeat = "replica_heartbeat"
|
||||
TypeReplicaWALSync = "replica_wal_sync"
|
||||
TypeReplicaBootstrap = "replica_bootstrap"
|
||||
TypeReplicaStatus = "replica_status"
|
||||
|
||||
// Response types
|
||||
TypeReplicaACK = "replica_ack"
|
||||
TypeReplicaWALEntries = "replica_wal_entries"
|
||||
TypeReplicaBootstrapData = "replica_bootstrap_data"
|
||||
TypeReplicaStatusData = "replica_status_data"
|
||||
)
|
||||
|
||||
// ReplicaRole defines the role of a node in replication
|
||||
type ReplicaRole string
|
||||
|
||||
// Replica roles
|
||||
const (
|
||||
RolePrimary ReplicaRole = "primary"
|
||||
RoleReplica ReplicaRole = "replica"
|
||||
RoleReadOnly ReplicaRole = "readonly"
|
||||
)
|
||||
|
||||
// ReplicaStatus defines the current status of a replica
|
||||
type ReplicaStatus string
|
||||
|
||||
// Replica statuses
|
||||
const (
|
||||
StatusConnecting ReplicaStatus = "connecting"
|
||||
StatusSyncing ReplicaStatus = "syncing"
|
||||
StatusBootstrapping ReplicaStatus = "bootstrapping"
|
||||
StatusReady ReplicaStatus = "ready"
|
||||
StatusDisconnected ReplicaStatus = "disconnected"
|
||||
StatusError ReplicaStatus = "error"
|
||||
)
|
||||
|
||||
// ReplicaInfo contains information about a replica
|
||||
type ReplicaInfo struct {
|
||||
ID string
|
||||
Address string
|
||||
Role ReplicaRole
|
||||
Status ReplicaStatus
|
||||
LastSeen time.Time
|
||||
CurrentLSN uint64 // Lamport Sequence Number
|
||||
ReplicationLag time.Duration
|
||||
Error error
|
||||
}
|
||||
|
||||
// ReplicationStreamDirection defines the direction of a replication stream
|
||||
type ReplicationStreamDirection int
|
||||
|
||||
const (
|
||||
DirectionPrimaryToReplica ReplicationStreamDirection = iota
|
||||
DirectionReplicaToPrimary
|
||||
DirectionBidirectional
|
||||
)
|
||||
|
||||
// ReplicationConnection provides methods specific to replication connections
|
||||
type ReplicationConnection interface {
|
||||
Connection
|
||||
|
||||
// GetReplicaInfo returns information about the remote replica
|
||||
GetReplicaInfo() (*ReplicaInfo, error)
|
||||
|
||||
// SendWALEntries sends WAL entries to the replica
|
||||
SendWALEntries(ctx context.Context, entries []*wal.Entry) error
|
||||
|
||||
// ReceiveWALEntries receives WAL entries from the replica
|
||||
ReceiveWALEntries(ctx context.Context) ([]*wal.Entry, error)
|
||||
|
||||
// StartReplicationStream starts a stream for WAL entries
|
||||
StartReplicationStream(ctx context.Context, direction ReplicationStreamDirection) (ReplicationStream, error)
|
||||
}
|
||||
|
||||
// ReplicationStream provides a bidirectional stream of WAL entries
|
||||
type ReplicationStream interface {
|
||||
// SendEntries sends WAL entries through the stream
|
||||
SendEntries(entries []*wal.Entry) error
|
||||
|
||||
// ReceiveEntries receives WAL entries from the stream
|
||||
ReceiveEntries() ([]*wal.Entry, error)
|
||||
|
||||
// Close closes the replication stream
|
||||
Close() error
|
||||
|
||||
// SetHighWatermark updates the highest applied Lamport sequence number
|
||||
SetHighWatermark(lsn uint64) error
|
||||
|
||||
// GetHighWatermark returns the highest applied Lamport sequence number
|
||||
GetHighWatermark() (uint64, error)
|
||||
}
|
||||
|
||||
// ReplicationClient extends the Client interface with replication-specific methods
|
||||
type ReplicationClient interface {
|
||||
Client
|
||||
|
||||
// RegisterAsReplica registers this client as a replica with the primary
|
||||
RegisterAsReplica(ctx context.Context, replicaID string) error
|
||||
|
||||
// SendHeartbeat sends a heartbeat to the primary
|
||||
SendHeartbeat(ctx context.Context, status *ReplicaInfo) error
|
||||
|
||||
// RequestWALEntries requests WAL entries from the primary starting from a specific LSN
|
||||
RequestWALEntries(ctx context.Context, fromLSN uint64) ([]*wal.Entry, error)
|
||||
|
||||
// RequestBootstrap requests a snapshot for bootstrap purposes
|
||||
RequestBootstrap(ctx context.Context) (BootstrapIterator, error)
|
||||
}
|
||||
|
||||
// ReplicationServer extends the Server interface with replication-specific methods
|
||||
type ReplicationServer interface {
|
||||
Server
|
||||
|
||||
// RegisterReplica registers a new replica
|
||||
RegisterReplica(replicaInfo *ReplicaInfo) error
|
||||
|
||||
// UpdateReplicaStatus updates the status of a replica
|
||||
UpdateReplicaStatus(replicaID string, status ReplicaStatus, lsn uint64) error
|
||||
|
||||
// GetReplicaInfo returns information about a specific replica
|
||||
GetReplicaInfo(replicaID string) (*ReplicaInfo, error)
|
||||
|
||||
// ListReplicas returns information about all connected replicas
|
||||
ListReplicas() ([]*ReplicaInfo, error)
|
||||
|
||||
// StreamWALEntriesToReplica streams WAL entries to a specific replica
|
||||
StreamWALEntriesToReplica(ctx context.Context, replicaID string, fromLSN uint64) error
|
||||
}
|
||||
|
||||
// BootstrapIterator provides an iterator over key-value pairs for bootstrapping a replica
|
||||
type BootstrapIterator interface {
|
||||
// Next returns the next key-value pair
|
||||
Next() (key []byte, value []byte, err error)
|
||||
|
||||
// Close closes the iterator
|
||||
Close() error
|
||||
|
||||
// Progress returns the progress of the bootstrap operation (0.0-1.0)
|
||||
Progress() float64
|
||||
}
|
||||
|
||||
// ReplicationRequestHandler processes replication-specific requests
|
||||
type ReplicationRequestHandler interface {
|
||||
// HandleReplicaRegister handles replica registration requests
|
||||
HandleReplicaRegister(ctx context.Context, replicaID string, address string) error
|
||||
|
||||
// HandleReplicaHeartbeat handles heartbeat requests
|
||||
HandleReplicaHeartbeat(ctx context.Context, status *ReplicaInfo) error
|
||||
|
||||
// HandleWALRequest handles requests for WAL entries
|
||||
HandleWALRequest(ctx context.Context, fromLSN uint64) ([]*wal.Entry, error)
|
||||
|
||||
// HandleBootstrapRequest handles bootstrap requests
|
||||
HandleBootstrapRequest(ctx context.Context) (BootstrapIterator, error)
|
||||
}
|
||||
|
||||
// ReplicationClientFactory creates a new replication client
|
||||
type ReplicationClientFactory func(endpoint string, options TransportOptions) (ReplicationClient, error)
|
||||
|
||||
// ReplicationServerFactory creates a new replication server
|
||||
type ReplicationServerFactory func(address string, options TransportOptions) (ReplicationServer, error)
|
||||
|
||||
// RegisterReplicationClient registers a replication client implementation
|
||||
func RegisterReplicationClient(name string, factory ReplicationClientFactory) {
|
||||
// This would be implemented to register with the transport registry
|
||||
}
|
||||
|
||||
// RegisterReplicationServer registers a replication server implementation
|
||||
func RegisterReplicationServer(name string, factory ReplicationServerFactory) {
|
||||
// This would be implemented to register with the transport registry
|
||||
}
|
401
pkg/transport/replication_test.go
Normal file
401
pkg/transport/replication_test.go
Normal file
@ -0,0 +1,401 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/KevoDB/kevo/pkg/wal"
|
||||
)
|
||||
|
||||
// MockReplicationClient implements ReplicationClient for testing
|
||||
type MockReplicationClient struct {
|
||||
connected bool
|
||||
registeredAsReplica bool
|
||||
heartbeatSent bool
|
||||
walEntriesRequested bool
|
||||
bootstrapRequested bool
|
||||
replicaID string
|
||||
walEntries []*wal.Entry
|
||||
bootstrapIterator BootstrapIterator
|
||||
status TransportStatus
|
||||
}
|
||||
|
||||
func NewMockReplicationClient() *MockReplicationClient {
|
||||
return &MockReplicationClient{
|
||||
connected: false,
|
||||
registeredAsReplica: false,
|
||||
heartbeatSent: false,
|
||||
walEntriesRequested: false,
|
||||
bootstrapRequested: false,
|
||||
status: TransportStatus{
|
||||
Connected: false,
|
||||
LastConnected: time.Time{},
|
||||
LastError: nil,
|
||||
BytesSent: 0,
|
||||
BytesReceived: 0,
|
||||
RTT: 0,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *MockReplicationClient) Connect(ctx context.Context) error {
|
||||
c.connected = true
|
||||
c.status.Connected = true
|
||||
c.status.LastConnected = time.Now()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *MockReplicationClient) Close() error {
|
||||
c.connected = false
|
||||
c.status.Connected = false
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *MockReplicationClient) IsConnected() bool {
|
||||
return c.connected
|
||||
}
|
||||
|
||||
func (c *MockReplicationClient) Status() TransportStatus {
|
||||
return c.status
|
||||
}
|
||||
|
||||
func (c *MockReplicationClient) Send(ctx context.Context, request Request) (Response, error) {
|
||||
return nil, ErrInvalidRequest
|
||||
}
|
||||
|
||||
func (c *MockReplicationClient) Stream(ctx context.Context) (Stream, error) {
|
||||
return nil, ErrInvalidRequest
|
||||
}
|
||||
|
||||
func (c *MockReplicationClient) RegisterAsReplica(ctx context.Context, replicaID string) error {
|
||||
c.registeredAsReplica = true
|
||||
c.replicaID = replicaID
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *MockReplicationClient) SendHeartbeat(ctx context.Context, status *ReplicaInfo) error {
|
||||
c.heartbeatSent = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *MockReplicationClient) RequestWALEntries(ctx context.Context, fromLSN uint64) ([]*wal.Entry, error) {
|
||||
c.walEntriesRequested = true
|
||||
return c.walEntries, nil
|
||||
}
|
||||
|
||||
func (c *MockReplicationClient) RequestBootstrap(ctx context.Context) (BootstrapIterator, error) {
|
||||
c.bootstrapRequested = true
|
||||
return c.bootstrapIterator, nil
|
||||
}
|
||||
|
||||
// MockBootstrapIterator implements BootstrapIterator for testing
|
||||
type MockBootstrapIterator struct {
|
||||
pairs []struct{ key, value []byte }
|
||||
position int
|
||||
progress float64
|
||||
closed bool
|
||||
}
|
||||
|
||||
func NewMockBootstrapIterator() *MockBootstrapIterator {
|
||||
return &MockBootstrapIterator{
|
||||
pairs: []struct{ key, value []byte }{
|
||||
{[]byte("key1"), []byte("value1")},
|
||||
{[]byte("key2"), []byte("value2")},
|
||||
{[]byte("key3"), []byte("value3")},
|
||||
},
|
||||
position: 0,
|
||||
progress: 0.0,
|
||||
closed: false,
|
||||
}
|
||||
}
|
||||
|
||||
func (it *MockBootstrapIterator) Next() ([]byte, []byte, error) {
|
||||
if it.position >= len(it.pairs) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
pair := it.pairs[it.position]
|
||||
it.position++
|
||||
it.progress = float64(it.position) / float64(len(it.pairs))
|
||||
|
||||
return pair.key, pair.value, nil
|
||||
}
|
||||
|
||||
func (it *MockBootstrapIterator) Close() error {
|
||||
it.closed = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (it *MockBootstrapIterator) Progress() float64 {
|
||||
return it.progress
|
||||
}
|
||||
|
||||
// Tests
|
||||
|
||||
func TestReplicationClientInterface(t *testing.T) {
|
||||
// Create a mock client
|
||||
client := NewMockReplicationClient()
|
||||
|
||||
// Test Connect
|
||||
ctx := context.Background()
|
||||
err := client.Connect(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("Connect failed: %v", err)
|
||||
}
|
||||
|
||||
// Test IsConnected
|
||||
if !client.IsConnected() {
|
||||
t.Errorf("Expected client to be connected")
|
||||
}
|
||||
|
||||
// Test Status
|
||||
status := client.Status()
|
||||
if !status.Connected {
|
||||
t.Errorf("Expected status.Connected to be true")
|
||||
}
|
||||
|
||||
// Test RegisterAsReplica
|
||||
err = client.RegisterAsReplica(ctx, "replica1")
|
||||
if err != nil {
|
||||
t.Errorf("RegisterAsReplica failed: %v", err)
|
||||
}
|
||||
if !client.registeredAsReplica {
|
||||
t.Errorf("Expected client to be registered as replica")
|
||||
}
|
||||
if client.replicaID != "replica1" {
|
||||
t.Errorf("Expected replicaID to be 'replica1', got '%s'", client.replicaID)
|
||||
}
|
||||
|
||||
// Test SendHeartbeat
|
||||
replicaInfo := &ReplicaInfo{
|
||||
ID: "replica1",
|
||||
Address: "localhost:50051",
|
||||
Role: RoleReplica,
|
||||
Status: StatusReady,
|
||||
LastSeen: time.Now(),
|
||||
CurrentLSN: 100,
|
||||
ReplicationLag: 0,
|
||||
}
|
||||
err = client.SendHeartbeat(ctx, replicaInfo)
|
||||
if err != nil {
|
||||
t.Errorf("SendHeartbeat failed: %v", err)
|
||||
}
|
||||
if !client.heartbeatSent {
|
||||
t.Errorf("Expected heartbeat to be sent")
|
||||
}
|
||||
|
||||
// Test RequestWALEntries
|
||||
client.walEntries = []*wal.Entry{
|
||||
{SequenceNumber: 101, Type: 1, Key: []byte("key1"), Value: []byte("value1")},
|
||||
{SequenceNumber: 102, Type: 1, Key: []byte("key2"), Value: []byte("value2")},
|
||||
}
|
||||
entries, err := client.RequestWALEntries(ctx, 100)
|
||||
if err != nil {
|
||||
t.Errorf("RequestWALEntries failed: %v", err)
|
||||
}
|
||||
if !client.walEntriesRequested {
|
||||
t.Errorf("Expected WAL entries to be requested")
|
||||
}
|
||||
if len(entries) != 2 {
|
||||
t.Errorf("Expected 2 entries, got %d", len(entries))
|
||||
}
|
||||
|
||||
// Test RequestBootstrap
|
||||
client.bootstrapIterator = NewMockBootstrapIterator()
|
||||
iterator, err := client.RequestBootstrap(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("RequestBootstrap failed: %v", err)
|
||||
}
|
||||
if !client.bootstrapRequested {
|
||||
t.Errorf("Expected bootstrap to be requested")
|
||||
}
|
||||
|
||||
// Test iterator
|
||||
key, value, err := iterator.Next()
|
||||
if err != nil {
|
||||
t.Errorf("Iterator.Next failed: %v", err)
|
||||
}
|
||||
if string(key) != "key1" || string(value) != "value1" {
|
||||
t.Errorf("Expected key1/value1, got %s/%s", string(key), string(value))
|
||||
}
|
||||
|
||||
progress := iterator.Progress()
|
||||
if progress != 1.0/3.0 {
|
||||
t.Errorf("Expected progress to be 1/3, got %f", progress)
|
||||
}
|
||||
|
||||
// Test Close
|
||||
err = client.Close()
|
||||
if err != nil {
|
||||
t.Errorf("Close failed: %v", err)
|
||||
}
|
||||
if client.IsConnected() {
|
||||
t.Errorf("Expected client to be disconnected")
|
||||
}
|
||||
|
||||
// Test iterator Close
|
||||
err = iterator.Close()
|
||||
if err != nil {
|
||||
t.Errorf("Iterator.Close failed: %v", err)
|
||||
}
|
||||
mockIter := iterator.(*MockBootstrapIterator)
|
||||
if !mockIter.closed {
|
||||
t.Errorf("Expected iterator to be closed")
|
||||
}
|
||||
}
|
||||
|
||||
// MockReplicationServer implements ReplicationServer for testing
|
||||
type MockReplicationServer struct {
|
||||
started bool
|
||||
stopped bool
|
||||
replicas map[string]*ReplicaInfo
|
||||
streamingReplicas map[string]bool
|
||||
}
|
||||
|
||||
func NewMockReplicationServer() *MockReplicationServer {
|
||||
return &MockReplicationServer{
|
||||
started: false,
|
||||
stopped: false,
|
||||
replicas: make(map[string]*ReplicaInfo),
|
||||
streamingReplicas: make(map[string]bool),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *MockReplicationServer) Start() error {
|
||||
s.started = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *MockReplicationServer) Serve() error {
|
||||
s.started = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *MockReplicationServer) Stop(ctx context.Context) error {
|
||||
s.stopped = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *MockReplicationServer) SetRequestHandler(handler RequestHandler) {
|
||||
// No-op for testing
|
||||
}
|
||||
|
||||
func (s *MockReplicationServer) RegisterReplica(replicaInfo *ReplicaInfo) error {
|
||||
s.replicas[replicaInfo.ID] = replicaInfo
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *MockReplicationServer) UpdateReplicaStatus(replicaID string, status ReplicaStatus, lsn uint64) error {
|
||||
replica, exists := s.replicas[replicaID]
|
||||
if !exists {
|
||||
return ErrInvalidRequest
|
||||
}
|
||||
|
||||
replica.Status = status
|
||||
replica.CurrentLSN = lsn
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *MockReplicationServer) GetReplicaInfo(replicaID string) (*ReplicaInfo, error) {
|
||||
replica, exists := s.replicas[replicaID]
|
||||
if !exists {
|
||||
return nil, ErrInvalidRequest
|
||||
}
|
||||
|
||||
return replica, nil
|
||||
}
|
||||
|
||||
func (s *MockReplicationServer) ListReplicas() ([]*ReplicaInfo, error) {
|
||||
result := make([]*ReplicaInfo, 0, len(s.replicas))
|
||||
for _, replica := range s.replicas {
|
||||
result = append(result, replica)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *MockReplicationServer) StreamWALEntriesToReplica(ctx context.Context, replicaID string, fromLSN uint64) error {
|
||||
_, exists := s.replicas[replicaID]
|
||||
if !exists {
|
||||
return ErrInvalidRequest
|
||||
}
|
||||
|
||||
s.streamingReplicas[replicaID] = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestReplicationServerInterface(t *testing.T) {
|
||||
// Create a mock server
|
||||
server := NewMockReplicationServer()
|
||||
|
||||
// Test Start
|
||||
err := server.Start()
|
||||
if err != nil {
|
||||
t.Errorf("Start failed: %v", err)
|
||||
}
|
||||
if !server.started {
|
||||
t.Errorf("Expected server to be started")
|
||||
}
|
||||
|
||||
// Test RegisterReplica
|
||||
replica1 := &ReplicaInfo{
|
||||
ID: "replica1",
|
||||
Address: "localhost:50051",
|
||||
Role: RoleReplica,
|
||||
Status: StatusConnecting,
|
||||
LastSeen: time.Now(),
|
||||
CurrentLSN: 0,
|
||||
ReplicationLag: 0,
|
||||
}
|
||||
err = server.RegisterReplica(replica1)
|
||||
if err != nil {
|
||||
t.Errorf("RegisterReplica failed: %v", err)
|
||||
}
|
||||
|
||||
// Test UpdateReplicaStatus
|
||||
err = server.UpdateReplicaStatus("replica1", StatusReady, 100)
|
||||
if err != nil {
|
||||
t.Errorf("UpdateReplicaStatus failed: %v", err)
|
||||
}
|
||||
|
||||
// Test GetReplicaInfo
|
||||
replica, err := server.GetReplicaInfo("replica1")
|
||||
if err != nil {
|
||||
t.Errorf("GetReplicaInfo failed: %v", err)
|
||||
}
|
||||
if replica.Status != StatusReady {
|
||||
t.Errorf("Expected status to be StatusReady, got %v", replica.Status)
|
||||
}
|
||||
if replica.CurrentLSN != 100 {
|
||||
t.Errorf("Expected LSN to be 100, got %d", replica.CurrentLSN)
|
||||
}
|
||||
|
||||
// Test ListReplicas
|
||||
replicas, err := server.ListReplicas()
|
||||
if err != nil {
|
||||
t.Errorf("ListReplicas failed: %v", err)
|
||||
}
|
||||
if len(replicas) != 1 {
|
||||
t.Errorf("Expected 1 replica, got %d", len(replicas))
|
||||
}
|
||||
|
||||
// Test StreamWALEntriesToReplica
|
||||
ctx := context.Background()
|
||||
err = server.StreamWALEntriesToReplica(ctx, "replica1", 0)
|
||||
if err != nil {
|
||||
t.Errorf("StreamWALEntriesToReplica failed: %v", err)
|
||||
}
|
||||
if !server.streamingReplicas["replica1"] {
|
||||
t.Errorf("Expected replica1 to be streaming")
|
||||
}
|
||||
|
||||
// Test Stop
|
||||
err = server.Stop(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("Stop failed: %v", err)
|
||||
}
|
||||
if !server.stopped {
|
||||
t.Errorf("Expected server to be stopped")
|
||||
}
|
||||
}
|
@ -138,13 +138,25 @@ func (b *Batch) Write(w *WAL) error {
|
||||
return ErrWALClosed
|
||||
}
|
||||
|
||||
// Set the sequence number
|
||||
b.Seq = w.nextSequence
|
||||
// Set the sequence number - use Lamport clock if available
|
||||
var seqNum uint64
|
||||
if w.clock != nil {
|
||||
// Generate Lamport timestamp for the batch
|
||||
seqNum = w.clock.Tick()
|
||||
// Keep the nextSequence in sync with the highest timestamp
|
||||
if seqNum >= w.nextSequence {
|
||||
w.nextSequence = seqNum + 1
|
||||
}
|
||||
} else {
|
||||
// Use traditional sequence number
|
||||
seqNum = w.nextSequence
|
||||
// Increment sequence for future operations
|
||||
w.nextSequence += uint64(len(b.Operations))
|
||||
}
|
||||
|
||||
b.Seq = seqNum
|
||||
binary.LittleEndian.PutUint64(data[4:12], b.Seq)
|
||||
|
||||
// Increment sequence for future operations
|
||||
w.nextSequence += uint64(len(b.Operations))
|
||||
|
||||
// Write as a batch entry
|
||||
if err := w.writeRecord(uint8(RecordTypeFull), OpTypeBatch, b.Seq, data, nil); err != nil {
|
||||
return err
|
||||
|
@ -47,7 +47,11 @@ func TestBatchEncoding(t *testing.T) {
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
cfg := createTestConfig()
|
||||
wal, err := NewWAL(cfg, dir)
|
||||
|
||||
// Create a mock Lamport clock for the test
|
||||
clock := &MockLamportClock{counter: 0}
|
||||
|
||||
wal, err := NewWALWithReplication(cfg, dir, clock, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create WAL: %v", err)
|
||||
}
|
||||
|
@ -5,12 +5,12 @@ package wal
|
||||
type LamportClock interface {
|
||||
// Tick increments the clock and returns the new timestamp value
|
||||
Tick() uint64
|
||||
|
||||
|
||||
// Update updates the clock based on a received timestamp,
|
||||
// ensuring the local clock is at least as large as the received timestamp,
|
||||
// then increments and returns the new value
|
||||
Update(received uint64) uint64
|
||||
|
||||
|
||||
// Current returns the current timestamp without incrementing the clock
|
||||
Current() uint64
|
||||
}
|
||||
@ -23,4 +23,4 @@ type ReplicationHook interface {
|
||||
|
||||
// OnBatchWritten is called when a batch of WAL entries is written
|
||||
OnBatchWritten(entries []*Entry)
|
||||
}
|
||||
}
|
||||
|
@ -53,7 +53,7 @@ type MockReplicationHook struct {
|
||||
func (m *MockReplicationHook) OnEntryWritten(entry *Entry) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
|
||||
// Make a deep copy of the entry to ensure tests are not affected by later modifications
|
||||
entryCopy := &Entry{
|
||||
SequenceNumber: entry.SequenceNumber,
|
||||
@ -63,7 +63,7 @@ func (m *MockReplicationHook) OnEntryWritten(entry *Entry) {
|
||||
if entry.Value != nil {
|
||||
entryCopy.Value = append([]byte{}, entry.Value...)
|
||||
}
|
||||
|
||||
|
||||
m.entries = append(m.entries, entryCopy)
|
||||
m.entriesReceived++
|
||||
}
|
||||
@ -71,7 +71,7 @@ func (m *MockReplicationHook) OnEntryWritten(entry *Entry) {
|
||||
func (m *MockReplicationHook) OnBatchWritten(entries []*Entry) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
|
||||
// Make a deep copy of all entries
|
||||
entriesCopy := make([]*Entry, len(entries))
|
||||
for i, entry := range entries {
|
||||
@ -84,7 +84,7 @@ func (m *MockReplicationHook) OnBatchWritten(entries []*Entry) {
|
||||
entriesCopy[i].Value = append([]byte{}, entry.Value...)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
m.batchEntries = append(m.batchEntries, entriesCopy)
|
||||
m.batchesReceived++
|
||||
}
|
||||
@ -117,10 +117,10 @@ func TestWALReplicationHook(t *testing.T) {
|
||||
|
||||
// Create a mock replication hook
|
||||
hook := &MockReplicationHook{}
|
||||
|
||||
|
||||
// Create a Lamport clock
|
||||
clock := NewMockLamportClock()
|
||||
|
||||
|
||||
// Create a WAL with the replication hook
|
||||
cfg := config.NewDefaultConfig(dir)
|
||||
wal, err := NewWALWithReplication(cfg, dir, clock, hook)
|
||||
@ -132,18 +132,18 @@ func TestWALReplicationHook(t *testing.T) {
|
||||
// Test single entry writes
|
||||
key1 := []byte("key1")
|
||||
value1 := []byte("value1")
|
||||
|
||||
|
||||
seq1, err := wal.Append(OpTypePut, key1, value1)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to append to WAL: %v", err)
|
||||
}
|
||||
|
||||
|
||||
// Test that the hook received the entry
|
||||
entries := hook.GetEntries()
|
||||
if len(entries) != 1 {
|
||||
t.Fatalf("Expected 1 entry, got %d", len(entries))
|
||||
}
|
||||
|
||||
|
||||
entry := entries[0]
|
||||
if entry.SequenceNumber != seq1 {
|
||||
t.Errorf("Expected sequence number %d, got %d", seq1, entry.SequenceNumber)
|
||||
@ -163,28 +163,28 @@ func TestWALReplicationHook(t *testing.T) {
|
||||
value2 := []byte("value2")
|
||||
key3 := []byte("key3")
|
||||
value3 := []byte("value3")
|
||||
|
||||
|
||||
batchEntries := []*Entry{
|
||||
{Type: OpTypePut, Key: key2, Value: value2},
|
||||
{Type: OpTypePut, Key: key3, Value: value3},
|
||||
}
|
||||
|
||||
|
||||
batchSeq, err := wal.AppendBatch(batchEntries)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to append batch to WAL: %v", err)
|
||||
}
|
||||
|
||||
|
||||
// Test that the hook received the batch
|
||||
batches := hook.GetBatchEntries()
|
||||
if len(batches) != 1 {
|
||||
t.Fatalf("Expected 1 batch, got %d", len(batches))
|
||||
}
|
||||
|
||||
|
||||
batch := batches[0]
|
||||
if len(batch) != 2 {
|
||||
t.Fatalf("Expected 2 entries in batch, got %d", len(batch))
|
||||
}
|
||||
|
||||
|
||||
// Check first entry in batch
|
||||
if batch[0].SequenceNumber != batchSeq {
|
||||
t.Errorf("Expected sequence number %d, got %d", batchSeq, batch[0].SequenceNumber)
|
||||
@ -198,7 +198,7 @@ func TestWALReplicationHook(t *testing.T) {
|
||||
if string(batch[0].Value) != string(value2) {
|
||||
t.Errorf("Expected value %q, got %q", value2, batch[0].Value)
|
||||
}
|
||||
|
||||
|
||||
// Check second entry in batch
|
||||
if batch[1].SequenceNumber != batchSeq+1 {
|
||||
t.Errorf("Expected sequence number %d, got %d", batchSeq+1, batch[1].SequenceNumber)
|
||||
@ -212,7 +212,7 @@ func TestWALReplicationHook(t *testing.T) {
|
||||
if string(batch[1].Value) != string(value3) {
|
||||
t.Errorf("Expected value %q, got %q", value3, batch[1].Value)
|
||||
}
|
||||
|
||||
|
||||
// Check call counts
|
||||
entriesReceived, batchesReceived := hook.GetStats()
|
||||
if entriesReceived != 1 {
|
||||
@ -233,7 +233,7 @@ func TestWALWithLamportClock(t *testing.T) {
|
||||
|
||||
// Create a Lamport clock
|
||||
clock := NewMockLamportClock()
|
||||
|
||||
|
||||
// Create a WAL with the Lamport clock but no hook
|
||||
cfg := config.NewDefaultConfig(dir)
|
||||
wal, err := NewWALWithReplication(cfg, dir, clock, nil)
|
||||
@ -246,7 +246,7 @@ func TestWALWithLamportClock(t *testing.T) {
|
||||
for i := 0; i < 5; i++ {
|
||||
clock.Tick()
|
||||
}
|
||||
|
||||
|
||||
// Current clock value should be 5
|
||||
if clock.Current() != 5 {
|
||||
t.Fatalf("Expected clock value 5, got %d", clock.Current())
|
||||
@ -255,38 +255,38 @@ func TestWALWithLamportClock(t *testing.T) {
|
||||
// Test that the WAL uses the Lamport clock for sequence numbers
|
||||
key1 := []byte("key1")
|
||||
value1 := []byte("value1")
|
||||
|
||||
|
||||
seq1, err := wal.Append(OpTypePut, key1, value1)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to append to WAL: %v", err)
|
||||
}
|
||||
|
||||
|
||||
// Sequence number should be 6 (previous 5 + 1 for this operation)
|
||||
if seq1 != 6 {
|
||||
t.Errorf("Expected sequence number 6, got %d", seq1)
|
||||
}
|
||||
|
||||
|
||||
// Clock should have incremented
|
||||
if clock.Current() != 6 {
|
||||
t.Errorf("Expected clock value 6, got %d", clock.Current())
|
||||
}
|
||||
|
||||
|
||||
// Test with a batch
|
||||
entries := []*Entry{
|
||||
{Type: OpTypePut, Key: []byte("key2"), Value: []byte("value2")},
|
||||
{Type: OpTypePut, Key: []byte("key3"), Value: []byte("value3")},
|
||||
}
|
||||
|
||||
|
||||
batchSeq, err := wal.AppendBatch(entries)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to append batch to WAL: %v", err)
|
||||
}
|
||||
|
||||
|
||||
// Batch sequence should be 7
|
||||
if batchSeq != 7 {
|
||||
t.Errorf("Expected batch sequence number 7, got %d", batchSeq)
|
||||
}
|
||||
|
||||
|
||||
// Clock should have incremented again
|
||||
if clock.Current() != 7 {
|
||||
t.Errorf("Expected clock value 7, got %d", clock.Current())
|
||||
@ -312,45 +312,45 @@ func TestWALHookAfterCreation(t *testing.T) {
|
||||
// Write an entry before adding a hook
|
||||
key1 := []byte("key1")
|
||||
value1 := []byte("value1")
|
||||
|
||||
|
||||
_, err = wal.Append(OpTypePut, key1, value1)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to append to WAL: %v", err)
|
||||
}
|
||||
|
||||
|
||||
// Create and add a hook after the fact
|
||||
hook := &MockReplicationHook{}
|
||||
wal.SetReplicationHook(hook)
|
||||
|
||||
|
||||
// Create and add a Lamport clock after the fact
|
||||
clock := NewMockLamportClock()
|
||||
wal.SetLamportClock(clock)
|
||||
|
||||
|
||||
// Write another entry, this should trigger the hook
|
||||
key2 := []byte("key2")
|
||||
value2 := []byte("value2")
|
||||
|
||||
|
||||
seq2, err := wal.Append(OpTypePut, key2, value2)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to append to WAL: %v", err)
|
||||
}
|
||||
|
||||
|
||||
// Verify hook received the entry
|
||||
entries := hook.GetEntries()
|
||||
if len(entries) != 1 {
|
||||
t.Fatalf("Expected 1 entry in hook, got %d", len(entries))
|
||||
}
|
||||
|
||||
|
||||
if entries[0].SequenceNumber != seq2 {
|
||||
t.Errorf("Expected sequence number %d, got %d", seq2, entries[0].SequenceNumber)
|
||||
}
|
||||
|
||||
|
||||
if string(entries[0].Key) != string(key2) {
|
||||
t.Errorf("Expected key %q, got %q", key2, entries[0].Key)
|
||||
}
|
||||
|
||||
|
||||
// Verify the clock was used
|
||||
if seq2 != 1 { // First tick of the clock
|
||||
t.Errorf("Expected sequence from clock to be 1, got %d", seq2)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -81,9 +81,9 @@ type WAL struct {
|
||||
status int32 // Using atomic int32 for status flags
|
||||
closed int32 // Atomic flag indicating if WAL is closed
|
||||
mu sync.Mutex
|
||||
|
||||
|
||||
// Replication support
|
||||
clock LamportClock // Lamport clock for logical timestamps
|
||||
clock LamportClock // Lamport clock for logical timestamps
|
||||
replicationHook ReplicationHook // Hook for replication events
|
||||
}
|
||||
|
||||
@ -226,11 +226,15 @@ func (w *WAL) Append(entryType uint8, key, value []byte) (uint64, error) {
|
||||
if w.clock != nil {
|
||||
// Generate Lamport timestamp (reusing SequenceNumber field)
|
||||
seqNum = w.clock.Tick()
|
||||
// Keep the nextSequence in sync with the highest used sequence number
|
||||
if seqNum >= w.nextSequence {
|
||||
w.nextSequence = seqNum + 1
|
||||
}
|
||||
} else {
|
||||
// Use traditional sequence number
|
||||
seqNum = w.nextSequence
|
||||
w.nextSequence++
|
||||
}
|
||||
w.nextSequence = seqNum + 1
|
||||
|
||||
// Encode the entry
|
||||
// Format: type(1) + seq(8) + keylen(4) + key + vallen(4) + val
|
||||
@ -257,11 +261,11 @@ func (w *WAL) Append(entryType uint8, key, value []byte) (uint64, error) {
|
||||
if err := w.maybeSync(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
|
||||
// Notify replication hook if available
|
||||
if w.replicationHook != nil {
|
||||
entry := &Entry{
|
||||
SequenceNumber: seqNum, // This now represents the Lamport timestamp
|
||||
SequenceNumber: seqNum, // This now represents the Lamport timestamp
|
||||
Type: entryType,
|
||||
Key: key,
|
||||
Value: value,
|
||||
@ -511,6 +515,10 @@ func (w *WAL) AppendBatch(entries []*Entry) (uint64, error) {
|
||||
if w.clock != nil {
|
||||
// Generate Lamport timestamp for the batch
|
||||
startSeqNum = w.clock.Tick()
|
||||
// Keep the nextSequence in sync with the highest timestamp
|
||||
if startSeqNum >= w.nextSequence {
|
||||
w.nextSequence = startSeqNum + 1
|
||||
}
|
||||
} else {
|
||||
// Use traditional sequence number
|
||||
startSeqNum = w.nextSequence
|
||||
@ -543,7 +551,7 @@ func (w *WAL) AppendBatch(entries []*Entry) (uint64, error) {
|
||||
for i, entry := range entries {
|
||||
// Assign sequential sequence numbers to each entry
|
||||
seqNum := startSeqNum + uint64(i)
|
||||
|
||||
|
||||
// Save sequence number in the entry for replication
|
||||
entry.SequenceNumber = seqNum
|
||||
entriesForReplication[i] = entry
|
||||
@ -569,7 +577,7 @@ func (w *WAL) AppendBatch(entries []*Entry) (uint64, error) {
|
||||
if err := w.maybeSync(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
|
||||
// Notify replication hook if available
|
||||
if w.replicationHook != nil {
|
||||
w.replicationHook.OnBatchWritten(entriesForReplication)
|
||||
@ -599,7 +607,6 @@ func (w *WAL) Close() error {
|
||||
if err := w.file.Close(); err != nil {
|
||||
return fmt.Errorf("failed to close WAL file: %w", err)
|
||||
}
|
||||
|
||||
atomic.StoreInt32(&w.status, WALStatusClosed)
|
||||
return nil
|
||||
}
|
||||
@ -629,7 +636,7 @@ func (w *WAL) UpdateNextSequence(nextSeq uint64) {
|
||||
func (w *WAL) SetReplicationHook(hook ReplicationHook) {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
|
||||
w.replicationHook = hook
|
||||
}
|
||||
|
||||
@ -637,7 +644,7 @@ func (w *WAL) SetReplicationHook(hook ReplicationHook) {
|
||||
func (w *WAL) SetLamportClock(clock LamportClock) {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
|
||||
w.clock = clock
|
||||
}
|
||||
|
||||
|
@ -12,7 +12,10 @@ import (
|
||||
)
|
||||
|
||||
func createTestConfig() *config.Config {
|
||||
return config.NewDefaultConfig("/tmp/gostorage_test")
|
||||
cfg := config.NewDefaultConfig("/tmp/gostorage_test")
|
||||
// Force immediate sync for tests
|
||||
cfg.WALSyncMode = config.SyncImmediate
|
||||
return cfg
|
||||
}
|
||||
|
||||
func createTempDir(t *testing.T) string {
|
||||
@ -370,12 +373,13 @@ func TestWALRecovery(t *testing.T) {
|
||||
|
||||
func TestWALSyncModes(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
syncMode config.SyncMode
|
||||
name string
|
||||
syncMode config.SyncMode
|
||||
expectedEntries int // Expected number of entries after crash (without explicit sync)
|
||||
}{
|
||||
{"SyncNone", config.SyncNone},
|
||||
{"SyncBatch", config.SyncBatch},
|
||||
{"SyncImmediate", config.SyncImmediate},
|
||||
{"SyncNone", config.SyncNone, 0}, // No entries should be recovered without explicit sync
|
||||
{"SyncBatch", config.SyncBatch, 0}, // No entries should be recovered if batch threshold not reached
|
||||
{"SyncImmediate", config.SyncImmediate, 10}, // All entries should be recovered
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
@ -386,6 +390,10 @@ func TestWALSyncModes(t *testing.T) {
|
||||
// Create config with specific sync mode
|
||||
cfg := createTestConfig()
|
||||
cfg.WALSyncMode = tc.syncMode
|
||||
// Set a high sync threshold for batch mode to ensure it doesn't auto-sync
|
||||
if tc.syncMode == config.SyncBatch {
|
||||
cfg.WALSyncBytes = 100 * 1024 * 1024 // 100MB, high enough to not trigger
|
||||
}
|
||||
|
||||
wal, err := NewWAL(cfg, dir)
|
||||
if err != nil {
|
||||
@ -403,6 +411,8 @@ func TestWALSyncModes(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// Skip explicit sync to simulate a crash
|
||||
|
||||
// Close the WAL
|
||||
if err := wal.Close(); err != nil {
|
||||
t.Fatalf("Failed to close WAL: %v", err)
|
||||
@ -421,8 +431,54 @@ func TestWALSyncModes(t *testing.T) {
|
||||
t.Fatalf("Failed to replay WAL: %v", err)
|
||||
}
|
||||
|
||||
if count != 10 {
|
||||
t.Errorf("Expected 10 entries, got %d", count)
|
||||
// Check that the number of recovered entries matches expectations for this sync mode
|
||||
if count != tc.expectedEntries {
|
||||
t.Errorf("Expected %d entries for %s mode, got %d", tc.expectedEntries, tc.name, count)
|
||||
}
|
||||
|
||||
// Now test with explicit sync - all entries should be recoverable
|
||||
wal, err = NewWAL(cfg, dir)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create WAL: %v", err)
|
||||
}
|
||||
|
||||
// Write some more entries
|
||||
for i := 0; i < 10; i++ {
|
||||
key := []byte(fmt.Sprintf("explicit_key%d", i))
|
||||
value := []byte(fmt.Sprintf("explicit_value%d", i))
|
||||
|
||||
_, err := wal.Append(OpTypePut, key, value)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to append entry: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Explicitly sync
|
||||
if err := wal.Sync(); err != nil {
|
||||
t.Fatalf("Failed to sync WAL: %v", err)
|
||||
}
|
||||
|
||||
// Close the WAL
|
||||
if err := wal.Close(); err != nil {
|
||||
t.Fatalf("Failed to close WAL: %v", err)
|
||||
}
|
||||
|
||||
// Verify entries by replaying
|
||||
explicitCount := 0
|
||||
_, err = ReplayWALDir(dir, func(entry *Entry) error {
|
||||
if entry.Type == OpTypePut && bytes.HasPrefix(entry.Key, []byte("explicit_")) {
|
||||
explicitCount++
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to replay WAL after explicit sync: %v", err)
|
||||
}
|
||||
|
||||
// After explicit sync, all 10 new entries should be recovered regardless of mode
|
||||
if explicitCount != 10 {
|
||||
t.Errorf("Expected 10 entries after explicit sync, got %d", explicitCount)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
1464
proto/kevo/replication.pb.go
Normal file
1464
proto/kevo/replication.pb.go
Normal file
File diff suppressed because it is too large
Load Diff
152
proto/kevo/replication.proto
Normal file
152
proto/kevo/replication.proto
Normal file
@ -0,0 +1,152 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package kevo;
|
||||
|
||||
option go_package = "github.com/KevoDB/kevo/proto/kevo";
|
||||
|
||||
service ReplicationService {
|
||||
// Replica Registration and Status
|
||||
rpc RegisterReplica(RegisterReplicaRequest) returns (RegisterReplicaResponse);
|
||||
rpc ReplicaHeartbeat(ReplicaHeartbeatRequest) returns (ReplicaHeartbeatResponse);
|
||||
rpc GetReplicaStatus(GetReplicaStatusRequest) returns (GetReplicaStatusResponse);
|
||||
rpc ListReplicas(ListReplicasRequest) returns (ListReplicasResponse);
|
||||
|
||||
// WAL Replication
|
||||
rpc GetWALEntries(GetWALEntriesRequest) returns (GetWALEntriesResponse);
|
||||
rpc StreamWALEntries(StreamWALEntriesRequest) returns (stream WALEntryBatch);
|
||||
rpc ReportAppliedEntries(ReportAppliedEntriesRequest) returns (ReportAppliedEntriesResponse);
|
||||
|
||||
// Bootstrap Operations
|
||||
rpc RequestBootstrap(BootstrapRequest) returns (stream BootstrapBatch);
|
||||
}
|
||||
|
||||
// Replication status enum
|
||||
enum ReplicaRole {
|
||||
PRIMARY = 0;
|
||||
REPLICA = 1;
|
||||
READ_ONLY = 2;
|
||||
}
|
||||
|
||||
enum ReplicaStatus {
|
||||
CONNECTING = 0;
|
||||
SYNCING = 1;
|
||||
BOOTSTRAPPING = 2;
|
||||
READY = 3;
|
||||
DISCONNECTED = 4;
|
||||
ERROR = 5;
|
||||
}
|
||||
|
||||
// Replica Registration messages
|
||||
message RegisterReplicaRequest {
|
||||
string replica_id = 1;
|
||||
string address = 2;
|
||||
ReplicaRole role = 3;
|
||||
}
|
||||
|
||||
message RegisterReplicaResponse {
|
||||
bool success = 1;
|
||||
string error_message = 2;
|
||||
uint64 current_lsn = 3; // Current Lamport Sequence Number on primary
|
||||
bool bootstrap_required = 4;
|
||||
}
|
||||
|
||||
// Heartbeat messages
|
||||
message ReplicaHeartbeatRequest {
|
||||
string replica_id = 1;
|
||||
ReplicaStatus status = 2;
|
||||
uint64 current_lsn = 3; // Current Lamport Sequence Number on replica
|
||||
string error_message = 4; // If status is ERROR
|
||||
}
|
||||
|
||||
message ReplicaHeartbeatResponse {
|
||||
bool success = 1;
|
||||
uint64 primary_lsn = 2; // Current Lamport Sequence Number on primary
|
||||
int64 replication_lag_ms = 3; // Estimated lag in milliseconds
|
||||
}
|
||||
|
||||
// Status messages
|
||||
message GetReplicaStatusRequest {
|
||||
string replica_id = 1;
|
||||
}
|
||||
|
||||
message ReplicaInfo {
|
||||
string replica_id = 1;
|
||||
string address = 2;
|
||||
ReplicaRole role = 3;
|
||||
ReplicaStatus status = 4;
|
||||
int64 last_seen_ms = 5; // Timestamp of last heartbeat in milliseconds since epoch
|
||||
uint64 current_lsn = 6; // Current Lamport Sequence Number
|
||||
int64 replication_lag_ms = 7; // Estimated lag in milliseconds
|
||||
string error_message = 8; // If status is ERROR
|
||||
}
|
||||
|
||||
message GetReplicaStatusResponse {
|
||||
ReplicaInfo replica = 1;
|
||||
}
|
||||
|
||||
message ListReplicasRequest {}
|
||||
|
||||
message ListReplicasResponse {
|
||||
repeated ReplicaInfo replicas = 1;
|
||||
}
|
||||
|
||||
// WAL Replication messages
|
||||
message WALEntry {
|
||||
uint64 sequence_number = 1; // Lamport Sequence Number
|
||||
uint32 type = 2; // Entry type (put, delete, etc.)
|
||||
bytes key = 3;
|
||||
bytes value = 4; // Optional, depending on type
|
||||
bytes checksum = 5; // Checksum for data integrity
|
||||
}
|
||||
|
||||
message WALEntryBatch {
|
||||
repeated WALEntry entries = 1;
|
||||
uint64 first_lsn = 2; // LSN of the first entry in the batch
|
||||
uint64 last_lsn = 3; // LSN of the last entry in the batch
|
||||
uint32 count = 4; // Number of entries in the batch
|
||||
bytes checksum = 5; // Checksum of the entire batch
|
||||
}
|
||||
|
||||
message GetWALEntriesRequest {
|
||||
string replica_id = 1;
|
||||
uint64 from_lsn = 2; // Request entries starting from this LSN
|
||||
uint32 max_entries = 3; // Maximum number of entries to return (0 for no limit)
|
||||
}
|
||||
|
||||
message GetWALEntriesResponse {
|
||||
WALEntryBatch batch = 1;
|
||||
bool has_more = 2; // True if there are more entries available
|
||||
}
|
||||
|
||||
message StreamWALEntriesRequest {
|
||||
string replica_id = 1;
|
||||
uint64 from_lsn = 2; // Request entries starting from this LSN
|
||||
bool continuous = 3; // If true, keep streaming as new entries arrive
|
||||
}
|
||||
|
||||
message ReportAppliedEntriesRequest {
|
||||
string replica_id = 1;
|
||||
uint64 applied_lsn = 2; // Highest LSN successfully applied on the replica
|
||||
}
|
||||
|
||||
message ReportAppliedEntriesResponse {
|
||||
bool success = 1;
|
||||
uint64 primary_lsn = 2; // Current LSN on primary
|
||||
}
|
||||
|
||||
// Bootstrap messages
|
||||
message BootstrapRequest {
|
||||
string replica_id = 1;
|
||||
}
|
||||
|
||||
message KeyValuePair {
|
||||
bytes key = 1;
|
||||
bytes value = 2;
|
||||
}
|
||||
|
||||
message BootstrapBatch {
|
||||
repeated KeyValuePair pairs = 1;
|
||||
float progress = 2; // Progress from 0.0 to 1.0
|
||||
bool is_last = 3; // True if this is the last batch
|
||||
uint64 snapshot_lsn = 4; // LSN at which this snapshot was taken
|
||||
}
|
400
proto/kevo/replication_grpc.pb.go
Normal file
400
proto/kevo/replication_grpc.pb.go
Normal file
@ -0,0 +1,400 @@
|
||||
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
|
||||
// versions:
|
||||
// - protoc-gen-go-grpc v1.5.1
|
||||
// - protoc v3.20.3
|
||||
// source: proto/kevo/replication.proto
|
||||
|
||||
package kevo
|
||||
|
||||
import (
|
||||
context "context"
|
||||
grpc "google.golang.org/grpc"
|
||||
codes "google.golang.org/grpc/codes"
|
||||
status "google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
// This is a compile-time assertion to ensure that this generated file
|
||||
// is compatible with the grpc package it is being compiled against.
|
||||
// Requires gRPC-Go v1.64.0 or later.
|
||||
const _ = grpc.SupportPackageIsVersion9
|
||||
|
||||
const (
|
||||
ReplicationService_RegisterReplica_FullMethodName = "/kevo.ReplicationService/RegisterReplica"
|
||||
ReplicationService_ReplicaHeartbeat_FullMethodName = "/kevo.ReplicationService/ReplicaHeartbeat"
|
||||
ReplicationService_GetReplicaStatus_FullMethodName = "/kevo.ReplicationService/GetReplicaStatus"
|
||||
ReplicationService_ListReplicas_FullMethodName = "/kevo.ReplicationService/ListReplicas"
|
||||
ReplicationService_GetWALEntries_FullMethodName = "/kevo.ReplicationService/GetWALEntries"
|
||||
ReplicationService_StreamWALEntries_FullMethodName = "/kevo.ReplicationService/StreamWALEntries"
|
||||
ReplicationService_ReportAppliedEntries_FullMethodName = "/kevo.ReplicationService/ReportAppliedEntries"
|
||||
ReplicationService_RequestBootstrap_FullMethodName = "/kevo.ReplicationService/RequestBootstrap"
|
||||
)
|
||||
|
||||
// ReplicationServiceClient is the client API for ReplicationService service.
|
||||
//
|
||||
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
|
||||
type ReplicationServiceClient interface {
|
||||
// Replica Registration and Status
|
||||
RegisterReplica(ctx context.Context, in *RegisterReplicaRequest, opts ...grpc.CallOption) (*RegisterReplicaResponse, error)
|
||||
ReplicaHeartbeat(ctx context.Context, in *ReplicaHeartbeatRequest, opts ...grpc.CallOption) (*ReplicaHeartbeatResponse, error)
|
||||
GetReplicaStatus(ctx context.Context, in *GetReplicaStatusRequest, opts ...grpc.CallOption) (*GetReplicaStatusResponse, error)
|
||||
ListReplicas(ctx context.Context, in *ListReplicasRequest, opts ...grpc.CallOption) (*ListReplicasResponse, error)
|
||||
// WAL Replication
|
||||
GetWALEntries(ctx context.Context, in *GetWALEntriesRequest, opts ...grpc.CallOption) (*GetWALEntriesResponse, error)
|
||||
StreamWALEntries(ctx context.Context, in *StreamWALEntriesRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[WALEntryBatch], error)
|
||||
ReportAppliedEntries(ctx context.Context, in *ReportAppliedEntriesRequest, opts ...grpc.CallOption) (*ReportAppliedEntriesResponse, error)
|
||||
// Bootstrap Operations
|
||||
RequestBootstrap(ctx context.Context, in *BootstrapRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[BootstrapBatch], error)
|
||||
}
|
||||
|
||||
type replicationServiceClient struct {
|
||||
cc grpc.ClientConnInterface
|
||||
}
|
||||
|
||||
func NewReplicationServiceClient(cc grpc.ClientConnInterface) ReplicationServiceClient {
|
||||
return &replicationServiceClient{cc}
|
||||
}
|
||||
|
||||
func (c *replicationServiceClient) RegisterReplica(ctx context.Context, in *RegisterReplicaRequest, opts ...grpc.CallOption) (*RegisterReplicaResponse, error) {
|
||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||
out := new(RegisterReplicaResponse)
|
||||
err := c.cc.Invoke(ctx, ReplicationService_RegisterReplica_FullMethodName, in, out, cOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *replicationServiceClient) ReplicaHeartbeat(ctx context.Context, in *ReplicaHeartbeatRequest, opts ...grpc.CallOption) (*ReplicaHeartbeatResponse, error) {
|
||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||
out := new(ReplicaHeartbeatResponse)
|
||||
err := c.cc.Invoke(ctx, ReplicationService_ReplicaHeartbeat_FullMethodName, in, out, cOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *replicationServiceClient) GetReplicaStatus(ctx context.Context, in *GetReplicaStatusRequest, opts ...grpc.CallOption) (*GetReplicaStatusResponse, error) {
|
||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||
out := new(GetReplicaStatusResponse)
|
||||
err := c.cc.Invoke(ctx, ReplicationService_GetReplicaStatus_FullMethodName, in, out, cOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *replicationServiceClient) ListReplicas(ctx context.Context, in *ListReplicasRequest, opts ...grpc.CallOption) (*ListReplicasResponse, error) {
|
||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||
out := new(ListReplicasResponse)
|
||||
err := c.cc.Invoke(ctx, ReplicationService_ListReplicas_FullMethodName, in, out, cOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *replicationServiceClient) GetWALEntries(ctx context.Context, in *GetWALEntriesRequest, opts ...grpc.CallOption) (*GetWALEntriesResponse, error) {
|
||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||
out := new(GetWALEntriesResponse)
|
||||
err := c.cc.Invoke(ctx, ReplicationService_GetWALEntries_FullMethodName, in, out, cOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *replicationServiceClient) StreamWALEntries(ctx context.Context, in *StreamWALEntriesRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[WALEntryBatch], error) {
|
||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||
stream, err := c.cc.NewStream(ctx, &ReplicationService_ServiceDesc.Streams[0], ReplicationService_StreamWALEntries_FullMethodName, cOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
x := &grpc.GenericClientStream[StreamWALEntriesRequest, WALEntryBatch]{ClientStream: stream}
|
||||
if err := x.ClientStream.SendMsg(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := x.ClientStream.CloseSend(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return x, nil
|
||||
}
|
||||
|
||||
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
|
||||
type ReplicationService_StreamWALEntriesClient = grpc.ServerStreamingClient[WALEntryBatch]
|
||||
|
||||
func (c *replicationServiceClient) ReportAppliedEntries(ctx context.Context, in *ReportAppliedEntriesRequest, opts ...grpc.CallOption) (*ReportAppliedEntriesResponse, error) {
|
||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||
out := new(ReportAppliedEntriesResponse)
|
||||
err := c.cc.Invoke(ctx, ReplicationService_ReportAppliedEntries_FullMethodName, in, out, cOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *replicationServiceClient) RequestBootstrap(ctx context.Context, in *BootstrapRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[BootstrapBatch], error) {
|
||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||
stream, err := c.cc.NewStream(ctx, &ReplicationService_ServiceDesc.Streams[1], ReplicationService_RequestBootstrap_FullMethodName, cOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
x := &grpc.GenericClientStream[BootstrapRequest, BootstrapBatch]{ClientStream: stream}
|
||||
if err := x.ClientStream.SendMsg(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := x.ClientStream.CloseSend(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return x, nil
|
||||
}
|
||||
|
||||
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
|
||||
type ReplicationService_RequestBootstrapClient = grpc.ServerStreamingClient[BootstrapBatch]
|
||||
|
||||
// ReplicationServiceServer is the server API for ReplicationService service.
|
||||
// All implementations must embed UnimplementedReplicationServiceServer
|
||||
// for forward compatibility.
|
||||
type ReplicationServiceServer interface {
|
||||
// Replica Registration and Status
|
||||
RegisterReplica(context.Context, *RegisterReplicaRequest) (*RegisterReplicaResponse, error)
|
||||
ReplicaHeartbeat(context.Context, *ReplicaHeartbeatRequest) (*ReplicaHeartbeatResponse, error)
|
||||
GetReplicaStatus(context.Context, *GetReplicaStatusRequest) (*GetReplicaStatusResponse, error)
|
||||
ListReplicas(context.Context, *ListReplicasRequest) (*ListReplicasResponse, error)
|
||||
// WAL Replication
|
||||
GetWALEntries(context.Context, *GetWALEntriesRequest) (*GetWALEntriesResponse, error)
|
||||
StreamWALEntries(*StreamWALEntriesRequest, grpc.ServerStreamingServer[WALEntryBatch]) error
|
||||
ReportAppliedEntries(context.Context, *ReportAppliedEntriesRequest) (*ReportAppliedEntriesResponse, error)
|
||||
// Bootstrap Operations
|
||||
RequestBootstrap(*BootstrapRequest, grpc.ServerStreamingServer[BootstrapBatch]) error
|
||||
mustEmbedUnimplementedReplicationServiceServer()
|
||||
}
|
||||
|
||||
// UnimplementedReplicationServiceServer must be embedded to have
|
||||
// forward compatible implementations.
|
||||
//
|
||||
// NOTE: this should be embedded by value instead of pointer to avoid a nil
|
||||
// pointer dereference when methods are called.
|
||||
type UnimplementedReplicationServiceServer struct{}
|
||||
|
||||
func (UnimplementedReplicationServiceServer) RegisterReplica(context.Context, *RegisterReplicaRequest) (*RegisterReplicaResponse, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method RegisterReplica not implemented")
|
||||
}
|
||||
func (UnimplementedReplicationServiceServer) ReplicaHeartbeat(context.Context, *ReplicaHeartbeatRequest) (*ReplicaHeartbeatResponse, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method ReplicaHeartbeat not implemented")
|
||||
}
|
||||
func (UnimplementedReplicationServiceServer) GetReplicaStatus(context.Context, *GetReplicaStatusRequest) (*GetReplicaStatusResponse, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetReplicaStatus not implemented")
|
||||
}
|
||||
func (UnimplementedReplicationServiceServer) ListReplicas(context.Context, *ListReplicasRequest) (*ListReplicasResponse, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method ListReplicas not implemented")
|
||||
}
|
||||
func (UnimplementedReplicationServiceServer) GetWALEntries(context.Context, *GetWALEntriesRequest) (*GetWALEntriesResponse, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetWALEntries not implemented")
|
||||
}
|
||||
func (UnimplementedReplicationServiceServer) StreamWALEntries(*StreamWALEntriesRequest, grpc.ServerStreamingServer[WALEntryBatch]) error {
|
||||
return status.Errorf(codes.Unimplemented, "method StreamWALEntries not implemented")
|
||||
}
|
||||
func (UnimplementedReplicationServiceServer) ReportAppliedEntries(context.Context, *ReportAppliedEntriesRequest) (*ReportAppliedEntriesResponse, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method ReportAppliedEntries not implemented")
|
||||
}
|
||||
func (UnimplementedReplicationServiceServer) RequestBootstrap(*BootstrapRequest, grpc.ServerStreamingServer[BootstrapBatch]) error {
|
||||
return status.Errorf(codes.Unimplemented, "method RequestBootstrap not implemented")
|
||||
}
|
||||
func (UnimplementedReplicationServiceServer) mustEmbedUnimplementedReplicationServiceServer() {}
|
||||
func (UnimplementedReplicationServiceServer) testEmbeddedByValue() {}
|
||||
|
||||
// UnsafeReplicationServiceServer may be embedded to opt out of forward compatibility for this service.
|
||||
// Use of this interface is not recommended, as added methods to ReplicationServiceServer will
|
||||
// result in compilation errors.
|
||||
type UnsafeReplicationServiceServer interface {
|
||||
mustEmbedUnimplementedReplicationServiceServer()
|
||||
}
|
||||
|
||||
func RegisterReplicationServiceServer(s grpc.ServiceRegistrar, srv ReplicationServiceServer) {
|
||||
// If the following call pancis, it indicates UnimplementedReplicationServiceServer was
|
||||
// embedded by pointer and is nil. This will cause panics if an
|
||||
// unimplemented method is ever invoked, so we test this at initialization
|
||||
// time to prevent it from happening at runtime later due to I/O.
|
||||
if t, ok := srv.(interface{ testEmbeddedByValue() }); ok {
|
||||
t.testEmbeddedByValue()
|
||||
}
|
||||
s.RegisterService(&ReplicationService_ServiceDesc, srv)
|
||||
}
|
||||
|
||||
func _ReplicationService_RegisterReplica_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(RegisterReplicaRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(ReplicationServiceServer).RegisterReplica(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: ReplicationService_RegisterReplica_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(ReplicationServiceServer).RegisterReplica(ctx, req.(*RegisterReplicaRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _ReplicationService_ReplicaHeartbeat_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(ReplicaHeartbeatRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(ReplicationServiceServer).ReplicaHeartbeat(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: ReplicationService_ReplicaHeartbeat_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(ReplicationServiceServer).ReplicaHeartbeat(ctx, req.(*ReplicaHeartbeatRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _ReplicationService_GetReplicaStatus_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(GetReplicaStatusRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(ReplicationServiceServer).GetReplicaStatus(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: ReplicationService_GetReplicaStatus_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(ReplicationServiceServer).GetReplicaStatus(ctx, req.(*GetReplicaStatusRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _ReplicationService_ListReplicas_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(ListReplicasRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(ReplicationServiceServer).ListReplicas(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: ReplicationService_ListReplicas_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(ReplicationServiceServer).ListReplicas(ctx, req.(*ListReplicasRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _ReplicationService_GetWALEntries_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(GetWALEntriesRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(ReplicationServiceServer).GetWALEntries(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: ReplicationService_GetWALEntries_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(ReplicationServiceServer).GetWALEntries(ctx, req.(*GetWALEntriesRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _ReplicationService_StreamWALEntries_Handler(srv interface{}, stream grpc.ServerStream) error {
|
||||
m := new(StreamWALEntriesRequest)
|
||||
if err := stream.RecvMsg(m); err != nil {
|
||||
return err
|
||||
}
|
||||
return srv.(ReplicationServiceServer).StreamWALEntries(m, &grpc.GenericServerStream[StreamWALEntriesRequest, WALEntryBatch]{ServerStream: stream})
|
||||
}
|
||||
|
||||
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
|
||||
type ReplicationService_StreamWALEntriesServer = grpc.ServerStreamingServer[WALEntryBatch]
|
||||
|
||||
func _ReplicationService_ReportAppliedEntries_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(ReportAppliedEntriesRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(ReplicationServiceServer).ReportAppliedEntries(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: ReplicationService_ReportAppliedEntries_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(ReplicationServiceServer).ReportAppliedEntries(ctx, req.(*ReportAppliedEntriesRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _ReplicationService_RequestBootstrap_Handler(srv interface{}, stream grpc.ServerStream) error {
|
||||
m := new(BootstrapRequest)
|
||||
if err := stream.RecvMsg(m); err != nil {
|
||||
return err
|
||||
}
|
||||
return srv.(ReplicationServiceServer).RequestBootstrap(m, &grpc.GenericServerStream[BootstrapRequest, BootstrapBatch]{ServerStream: stream})
|
||||
}
|
||||
|
||||
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
|
||||
type ReplicationService_RequestBootstrapServer = grpc.ServerStreamingServer[BootstrapBatch]
|
||||
|
||||
// ReplicationService_ServiceDesc is the grpc.ServiceDesc for ReplicationService service.
|
||||
// It's only intended for direct use with grpc.RegisterService,
|
||||
// and not to be introspected or modified (even as a copy)
|
||||
var ReplicationService_ServiceDesc = grpc.ServiceDesc{
|
||||
ServiceName: "kevo.ReplicationService",
|
||||
HandlerType: (*ReplicationServiceServer)(nil),
|
||||
Methods: []grpc.MethodDesc{
|
||||
{
|
||||
MethodName: "RegisterReplica",
|
||||
Handler: _ReplicationService_RegisterReplica_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "ReplicaHeartbeat",
|
||||
Handler: _ReplicationService_ReplicaHeartbeat_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "GetReplicaStatus",
|
||||
Handler: _ReplicationService_GetReplicaStatus_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "ListReplicas",
|
||||
Handler: _ReplicationService_ListReplicas_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "GetWALEntries",
|
||||
Handler: _ReplicationService_GetWALEntries_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "ReportAppliedEntries",
|
||||
Handler: _ReplicationService_ReportAppliedEntries_Handler,
|
||||
},
|
||||
},
|
||||
Streams: []grpc.StreamDesc{
|
||||
{
|
||||
StreamName: "StreamWALEntries",
|
||||
Handler: _ReplicationService_StreamWALEntries_Handler,
|
||||
ServerStreams: true,
|
||||
},
|
||||
{
|
||||
StreamName: "RequestBootstrap",
|
||||
Handler: _ReplicationService_RequestBootstrap_Handler,
|
||||
ServerStreams: true,
|
||||
},
|
||||
},
|
||||
Metadata: "proto/kevo/replication.proto",
|
||||
}
|
@ -4,7 +4,7 @@
|
||||
// protoc v3.20.3
|
||||
// source: proto/kevo/service.proto
|
||||
|
||||
package proto
|
||||
package kevo
|
||||
|
||||
import (
|
||||
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
|
||||
@ -1911,7 +1911,7 @@ const file_proto_kevo_service_proto_rawDesc = "" +
|
||||
"\bTxDelete\x12\x15.kevo.TxDeleteRequest\x1a\x16.kevo.TxDeleteResponse\x125\n" +
|
||||
"\x06TxScan\x12\x13.kevo.TxScanRequest\x1a\x14.kevo.TxScanResponse0\x01\x129\n" +
|
||||
"\bGetStats\x12\x15.kevo.GetStatsRequest\x1a\x16.kevo.GetStatsResponse\x126\n" +
|
||||
"\aCompact\x12\x14.kevo.CompactRequest\x1a\x15.kevo.CompactResponseB5Z3github.com/jeremytregunna/kevo/pkg/grpc/proto;protob\x06proto3"
|
||||
"\aCompact\x12\x14.kevo.CompactRequest\x1a\x15.kevo.CompactResponseB#Z!github.com/KevoDB/kevo/proto/kevob\x06proto3"
|
||||
|
||||
var (
|
||||
file_proto_kevo_service_proto_rawDescOnce sync.Once
|
||||
|
@ -2,7 +2,7 @@ syntax = "proto3";
|
||||
|
||||
package kevo;
|
||||
|
||||
option go_package = "github.com/jeremytregunna/kevo/pkg/grpc/proto;proto";
|
||||
option go_package = "github.com/KevoDB/kevo/proto/kevo";
|
||||
|
||||
service KevoService {
|
||||
// Key-Value Operations
|
||||
|
@ -4,7 +4,7 @@
|
||||
// - protoc v3.20.3
|
||||
// source: proto/kevo/service.proto
|
||||
|
||||
package proto
|
||||
package kevo
|
||||
|
||||
import (
|
||||
context "context"
|
||||
|
Loading…
Reference in New Issue
Block a user