Compare commits

...

3 Commits

Author SHA1 Message Date
a0a1c0512f
chore: formatting
All checks were successful
Go Tests / Run Tests (1.24.2) (push) Successful in 9m50s
2025-04-22 14:09:54 -06:00
e7974e008d
feat: enhance wal recover statistics 2025-04-22 14:09:45 -06:00
dependabot[bot]
3b3d1c27a4 chore(deps): bump golang.org/x/net from 0.35.0 to 0.38.0
Bumps [golang.org/x/net](https://github.com/golang/net) from 0.35.0 to 0.38.0.
- [Commits](https://github.com/golang/net/compare/v0.35.0...v0.38.0)

---
updated-dependencies:
- dependency-name: golang.org/x/net
  dependency-version: 0.38.0
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>
2025-04-22 11:10:02 -06:00
38 changed files with 631 additions and 513 deletions

View File

@ -12,6 +12,7 @@ import (
"strings"
"syscall"
"time"
"unicode"
"github.com/chzyer/readline"
@ -80,14 +81,14 @@ Commands (interactive mode only):
// Config holds the application configuration
type Config struct {
ServerMode bool
DaemonMode bool
ListenAddr string
DBPath string
TLSEnabled bool
TLSCertFile string
TLSKeyFile string
TLSCAFile string
ServerMode bool
DaemonMode bool
ListenAddr string
DBPath string
TLSEnabled bool
TLSCertFile string
TLSKeyFile string
TLSCAFile string
}
func main() {
@ -97,7 +98,7 @@ func main() {
// Open database if path provided
var eng *engine.Engine
var err error
if config.DBPath != "" {
fmt.Printf("Opening database at %s\n", config.DBPath)
eng, err = engine.NewEngine(config.DBPath)
@ -107,18 +108,18 @@ func main() {
}
defer eng.Close()
}
// Check if we should run in server mode
if config.ServerMode {
if eng == nil {
fmt.Fprintf(os.Stderr, "Error: Server mode requires a database path\n")
os.Exit(1)
}
runServer(eng, config)
return
}
// Run in interactive mode
runInteractive(eng, config.DBPath)
}
@ -150,31 +151,31 @@ func parseFlags() Config {
serverMode := flag.Bool("server", false, "Run in server mode, exposing a gRPC API")
daemonMode := flag.Bool("daemon", false, "Run in daemon mode (detached from terminal)")
listenAddr := flag.String("address", "localhost:50051", "Address to listen on in server mode")
// TLS options
tlsEnabled := flag.Bool("tls", false, "Enable TLS for secure connections")
tlsCertFile := flag.String("cert", "", "TLS certificate file path")
tlsKeyFile := flag.String("key", "", "TLS private key file path")
tlsCAFile := flag.String("ca", "", "TLS CA certificate file for client verification")
// Parse flags
flag.Parse()
// Get database path from remaining arguments
var dbPath string
if flag.NArg() > 0 {
dbPath = flag.Arg(0)
}
return Config{
ServerMode: *serverMode,
DaemonMode: *daemonMode,
ListenAddr: *listenAddr,
DBPath: dbPath,
TLSEnabled: *tlsEnabled,
TLSCertFile: *tlsCertFile,
TLSKeyFile: *tlsKeyFile,
TLSCAFile: *tlsCAFile,
ServerMode: *serverMode,
DaemonMode: *daemonMode,
ListenAddr: *listenAddr,
DBPath: dbPath,
TLSEnabled: *tlsEnabled,
TLSCertFile: *tlsCertFile,
TLSKeyFile: *tlsKeyFile,
TLSCAFile: *tlsCAFile,
}
}
@ -184,10 +185,10 @@ func runServer(eng *engine.Engine, config Config) {
if config.DaemonMode {
setupDaemonMode()
}
// Create and start the server
server := NewServer(eng, config)
// Start the server (non-blocking)
if err := server.Start(); err != nil {
fmt.Fprintf(os.Stderr, "Error starting server: %v\n", err)
@ -195,10 +196,10 @@ func runServer(eng *engine.Engine, config Config) {
}
fmt.Printf("Kevo server started on %s\n", config.ListenAddr)
// Set up signal handling for graceful shutdown
setupGracefulShutdown(server, eng)
// Start serving (blocking)
if err := server.Serve(); err != nil {
fmt.Fprintf(os.Stderr, "Error serving: %v\n", err)
@ -213,29 +214,29 @@ func setupDaemonMode() {
if err != nil {
log.Fatalf("Failed to open /dev/null: %v", err)
}
// Redirect standard file descriptors to /dev/null
err = syscall.Dup2(int(null.Fd()), int(os.Stdin.Fd()))
if err != nil {
log.Fatalf("Failed to redirect stdin: %v", err)
}
err = syscall.Dup2(int(null.Fd()), int(os.Stdout.Fd()))
if err != nil {
log.Fatalf("Failed to redirect stdout: %v", err)
}
err = syscall.Dup2(int(null.Fd()), int(os.Stderr.Fd()))
if err != nil {
log.Fatalf("Failed to redirect stderr: %v", err)
}
// Create a new process group
_, err = syscall.Setsid()
if err != nil {
log.Fatalf("Failed to create new session: %v", err)
}
fmt.Println("Daemon mode enabled, detaching from terminal...")
}
@ -243,22 +244,22 @@ func setupDaemonMode() {
func setupGracefulShutdown(server *Server, eng *engine.Engine) {
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
go func() {
sig := <-sigChan
fmt.Printf("\nReceived signal %v, shutting down...\n", sig)
// Graceful shutdown logic
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
// Shut down the server
if err := server.Shutdown(ctx); err != nil {
fmt.Fprintf(os.Stderr, "Error shutting down server: %v\n", err)
}
// The engine will be closed by the defer in main()
fmt.Println("Shutdown complete")
os.Exit(0)
}()
@ -268,7 +269,7 @@ func setupGracefulShutdown(server *Server, eng *engine.Engine) {
func runInteractive(eng *engine.Engine, dbPath string) {
fmt.Println("Kevo (kevo) version 1.0.2")
fmt.Println("Enter .help for usage hints.")
var tx engine.Transaction
var err error
@ -411,15 +412,89 @@ func runInteractive(eng *engine.Engine, dbPath string) {
// Print statistics
stats := eng.GetStats()
fmt.Println("Database Statistics:")
fmt.Printf(" Operations: %d puts, %d gets (%d hits, %d misses), %d deletes\n",
stats["put_ops"], stats["get_ops"], stats["get_hits"], stats["get_misses"], stats["delete_ops"])
fmt.Printf(" Transactions: %d started, %d committed, %d aborted\n",
stats["tx_started"], stats["tx_completed"], stats["tx_aborted"])
fmt.Printf(" Storage: %d bytes read, %d bytes written, %d flushes\n",
stats["total_bytes_read"], stats["total_bytes_written"], stats["flush_count"])
fmt.Printf(" Tables: %d sstables, %d immutable memtables\n",
stats["sstable_count"], stats["immutable_memtable_count"])
// Format human-readable time for the last operation timestamps
var lastPutTime, lastGetTime, lastDeleteTime time.Time
if putTime, ok := stats["last_put_time"].(int64); ok && putTime > 0 {
lastPutTime = time.Unix(0, putTime)
}
if getTime, ok := stats["last_get_time"].(int64); ok && getTime > 0 {
lastGetTime = time.Unix(0, getTime)
}
if deleteTime, ok := stats["last_delete_time"].(int64); ok && deleteTime > 0 {
lastDeleteTime = time.Unix(0, deleteTime)
}
// Operations section
fmt.Println("📊 Operations:")
fmt.Printf(" • Puts: %d\n", stats["put_ops"])
fmt.Printf(" • Gets: %d (Hits: %d, Misses: %d)\n", stats["get_ops"], stats["get_hits"], stats["get_misses"])
fmt.Printf(" • Deletes: %d\n", stats["delete_ops"])
// Last Operation Times
fmt.Println("\n⏱ Last Operation Times:")
if !lastPutTime.IsZero() {
fmt.Printf(" • Last Put: %s\n", lastPutTime.Format(time.RFC3339))
} else {
fmt.Printf(" • Last Put: Never\n")
}
if !lastGetTime.IsZero() {
fmt.Printf(" • Last Get: %s\n", lastGetTime.Format(time.RFC3339))
} else {
fmt.Printf(" • Last Get: Never\n")
}
if !lastDeleteTime.IsZero() {
fmt.Printf(" • Last Delete: %s\n", lastDeleteTime.Format(time.RFC3339))
} else {
fmt.Printf(" • Last Delete: Never\n")
}
// Transactions
fmt.Println("\n💼 Transactions:")
fmt.Printf(" • Started: %d\n", stats["tx_started"])
fmt.Printf(" • Completed: %d\n", stats["tx_completed"])
fmt.Printf(" • Aborted: %d\n", stats["tx_aborted"])
// Storage metrics
fmt.Println("\n💾 Storage:")
fmt.Printf(" • Total Bytes Read: %d\n", stats["total_bytes_read"])
fmt.Printf(" • Total Bytes Written: %d\n", stats["total_bytes_written"])
fmt.Printf(" • Flush Count: %d\n", stats["flush_count"])
// Table stats
fmt.Println("\n📋 Tables:")
fmt.Printf(" • SSTable Count: %d\n", stats["sstable_count"])
fmt.Printf(" • Immutable MemTable Count: %d\n", stats["immutable_memtable_count"])
fmt.Printf(" • Current MemTable Size: %d bytes\n", stats["memtable_size"])
// WAL recovery stats
fmt.Println("\n🔄 WAL Recovery:")
fmt.Printf(" • Files Recovered: %d\n", stats["wal_files_recovered"])
fmt.Printf(" • Entries Recovered: %d\n", stats["wal_entries_recovered"])
fmt.Printf(" • Corrupted Entries: %d\n", stats["wal_corrupted_entries"])
if recoveryDuration, ok := stats["wal_recovery_duration_ms"]; ok {
fmt.Printf(" • Recovery Duration: %d ms\n", recoveryDuration)
}
// Error counts
fmt.Println("\n⚠ Errors:")
fmt.Printf(" • Read Errors: %d\n", stats["read_errors"])
fmt.Printf(" • Write Errors: %d\n", stats["write_errors"])
// Compaction stats (if available)
if compactionOutputCount, ok := stats["compaction_last_outputs_count"]; ok {
fmt.Println("\n🧹 Compaction:")
fmt.Printf(" • Last Output Files Count: %d\n", compactionOutputCount)
// Display other compaction stats as available
for key, value := range stats {
if strings.HasPrefix(key, "compaction_") && key != "compaction_last_outputs_count" && key != "compaction_last_outputs" {
// Format the key for display (remove prefix, replace underscores with spaces)
displayKey := toTitle(strings.Replace(strings.TrimPrefix(key, "compaction_"), "_", " ", -1))
fmt.Printf(" • %s: %v\n", displayKey, value)
}
}
}
case ".flush":
if eng == nil {
@ -734,4 +809,20 @@ func makeKeySuccessor(prefix []byte) []byte {
copy(successor, prefix)
successor[len(prefix)] = 0xFF
return successor
}
}
// toTitle replaces strings.Title which is deprecated
// It converts the first character of each word to title case
func toTitle(s string) string {
prev := ' '
return strings.Map(
func(r rune) rune {
if unicode.IsSpace(prev) || unicode.IsPunct(prev) {
prev = r
return unicode.ToTitle(r)
}
prev = r
return r
},
s)
}

View File

@ -35,14 +35,14 @@ func (tr *TransactionRegistry) Begin(ctx context.Context, eng *engine.Engine, re
// Create context with timeout to prevent potential hangs
timeoutCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
// Create a channel to receive the transaction result
type txResult struct {
tx engine.Transaction
err error
}
resultCh := make(chan txResult, 1)
// Start transaction in a goroutine to prevent potential blocking
go func() {
tx, err := eng.BeginTransaction(readOnly)
@ -56,26 +56,26 @@ func (tr *TransactionRegistry) Begin(ctx context.Context, eng *engine.Engine, re
}
}
}()
// Wait for result or timeout
select {
case result := <-resultCh:
if result.err != nil {
return "", fmt.Errorf("failed to begin transaction: %w", result.err)
}
tr.mu.Lock()
defer tr.mu.Unlock()
// Generate a transaction ID
tr.nextID++
txID := fmt.Sprintf("tx-%d", tr.nextID)
// Register the transaction
tr.transactions[txID] = result.tx
return txID, nil
case <-timeoutCtx.Done():
return "", fmt.Errorf("transaction creation timed out: %w", timeoutCtx.Err())
}
@ -104,31 +104,31 @@ func (tr *TransactionRegistry) GracefulShutdown(ctx context.Context) error {
defer tr.mu.Unlock()
var lastErr error
// Copy transaction IDs to avoid modifying the map during iteration
ids := make([]string, 0, len(tr.transactions))
for id := range tr.transactions {
ids = append(ids, id)
}
// Rollback each transaction with a timeout
for _, id := range ids {
tx, exists := tr.transactions[id]
if !exists {
continue
}
// Use a timeout for each rollback operation
rollbackCtx, cancel := context.WithTimeout(ctx, 1*time.Second)
// Create a channel for the rollback result
doneCh := make(chan error, 1)
// Execute rollback in goroutine
go func(t engine.Transaction) {
doneCh <- t.Rollback()
}(tx)
// Wait for rollback or timeout
var err error
select {
@ -137,14 +137,14 @@ func (tr *TransactionRegistry) GracefulShutdown(ctx context.Context) error {
case <-rollbackCtx.Done():
err = fmt.Errorf("rollback timed out: %w", rollbackCtx.Err())
}
cancel() // Clean up context
// Record error if any
if err != nil {
lastErr = fmt.Errorf("failed to rollback transaction %s: %w", id, err)
}
// Always remove transaction from map
delete(tr.transactions, id)
}
@ -154,12 +154,12 @@ func (tr *TransactionRegistry) GracefulShutdown(ctx context.Context) error {
// Server represents the Kevo server
type Server struct {
eng *engine.Engine
txRegistry *TransactionRegistry
listener net.Listener
grpcServer *grpc.Server
eng *engine.Engine
txRegistry *TransactionRegistry
listener net.Listener
grpcServer *grpc.Server
kevoService *grpcservice.KevoServiceServer
config Config
config Config
}
// NewServer creates a new server instance
@ -249,14 +249,14 @@ func (s *Server) Shutdown(ctx context.Context) error {
// First, gracefully stop the gRPC server if it exists
if s.grpcServer != nil {
fmt.Println("Gracefully stopping gRPC server...")
// Create a channel to signal when the server has stopped
stopped := make(chan struct{})
go func() {
s.grpcServer.GracefulStop()
close(stopped)
}()
// Wait for graceful stop or context deadline
select {
case <-stopped:
@ -266,7 +266,7 @@ func (s *Server) Shutdown(ctx context.Context) error {
s.grpcServer.Stop()
}
}
// Shut down the listener if it's still open
if s.listener != nil {
if err := s.listener.Close(); err != nil {
@ -280,4 +280,4 @@ func (s *Server) Shutdown(ctx context.Context) error {
}
return nil
}
}

View File

@ -14,7 +14,7 @@ func TestTransactionRegistry(t *testing.T) {
// Create a timeout context for the whole test
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
// Set up temporary directory for test
tmpDir, err := os.MkdirTemp("", "kevo_test")
if err != nil {
@ -153,47 +153,47 @@ func TestGRPCServer(t *testing.T) {
t.Fatalf("Failed to create temporary directory: %v", err)
}
defer os.RemoveAll(tempDBPath)
// Create engine
eng, err := engine.NewEngine(tempDBPath)
if err != nil {
t.Fatalf("Failed to create engine: %v", err)
}
defer eng.Close()
// Create server configuration
config := Config{
ServerMode: true,
ListenAddr: "localhost:50052", // Use a different port for tests
DBPath: tempDBPath,
}
// Create and start the server
server := NewServer(eng, config)
if err := server.Start(); err != nil {
t.Fatalf("Failed to start server: %v", err)
}
// Run server in a goroutine
go func() {
if err := server.Serve(); err != nil {
t.Logf("Server stopped: %v", err)
}
}()
// Give the server a moment to start
time.Sleep(200 * time.Millisecond)
// Clean up at the end
defer func() {
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer shutdownCancel()
if err := server.Shutdown(shutdownCtx); err != nil {
t.Logf("Failed to shut down server: %v", err)
}
}()
// TODO: Add gRPC client tests here when client implementation is complete
t.Log("gRPC server integration test scaffolding added")
}
}

6
go.mod
View File

@ -10,8 +10,8 @@ require (
)
require (
golang.org/x/net v0.35.0 // indirect
golang.org/x/sys v0.30.0 // indirect
golang.org/x/text v0.22.0 // indirect
golang.org/x/net v0.38.0 // indirect
golang.org/x/sys v0.31.0 // indirect
golang.org/x/text v0.23.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20250218202821-56aae31c358a // indirect
)

12
go.sum
View File

@ -28,13 +28,13 @@ go.opentelemetry.io/otel/sdk/metric v1.34.0 h1:5CeK9ujjbFVL5c1PhLuStg1wxA7vQv7ce
go.opentelemetry.io/otel/sdk/metric v1.34.0/go.mod h1:jQ/r8Ze28zRKoNRdkjCZxfs6YvBTG1+YIqyFVFYec5w=
go.opentelemetry.io/otel/trace v1.34.0 h1:+ouXS2V8Rd4hp4580a8q23bg0azF2nI8cqLYnC8mh/k=
go.opentelemetry.io/otel/trace v1.34.0/go.mod h1:Svm7lSjQD7kG7KJ/MUHPVXSDGz2OX4h0M2jHBhmSfRE=
golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8=
golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk=
golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8=
golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM=
golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY=
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY=
golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4=
google.golang.org/genproto/googleapis/rpc v0.0.0-20250218202821-56aae31c358a h1:51aaUVRocpvUOSQKM6Q7VuoaktNIaMCLuhZB6DKksq4=
google.golang.org/genproto/googleapis/rpc v0.0.0-20250218202821-56aae31c358a/go.mod h1:uRxBH1mhmO8PGhU89cMcHaXKZqO+OfakD8QQO0oYwlQ=
google.golang.org/grpc v1.72.0 h1:S7UkcVa60b5AAQTaO6ZKamFp1zMZSU0fGDK2WZLbBnM=

View File

@ -23,11 +23,11 @@ const (
// ClientOptions configures a Kevo client
type ClientOptions struct {
// Connection options
Endpoint string // Server address
ConnectTimeout time.Duration // Timeout for connection attempts
RequestTimeout time.Duration // Default timeout for requests
TransportType string // Transport type (e.g. "grpc")
PoolSize int // Connection pool size
Endpoint string // Server address
ConnectTimeout time.Duration // Timeout for connection attempts
RequestTimeout time.Duration // Default timeout for requests
TransportType string // Transport type (e.g. "grpc")
PoolSize int // Connection pool size
// Security options
TLSEnabled bool // Enable TLS
@ -50,19 +50,19 @@ type ClientOptions struct {
// DefaultClientOptions returns sensible default client options
func DefaultClientOptions() ClientOptions {
return ClientOptions{
Endpoint: "localhost:50051",
ConnectTimeout: time.Second * 5,
RequestTimeout: time.Second * 10,
TransportType: "grpc",
PoolSize: 5,
TLSEnabled: false,
MaxRetries: 3,
InitialBackoff: time.Millisecond * 100,
MaxBackoff: time.Second * 2,
BackoffFactor: 1.5,
RetryJitter: 0.2,
Compression: CompressionNone,
MaxMessageSize: 16 * 1024 * 1024, // 16MB
Endpoint: "localhost:50051",
ConnectTimeout: time.Second * 5,
RequestTimeout: time.Second * 10,
TransportType: "grpc",
PoolSize: 5,
TLSEnabled: false,
MaxRetries: 3,
InitialBackoff: time.Millisecond * 100,
MaxBackoff: time.Second * 2,
BackoffFactor: 1.5,
RetryJitter: 0.2,
Compression: CompressionNone,
MaxMessageSize: 16 * 1024 * 1024, // 16MB
}
}
@ -378,4 +378,4 @@ type Stats struct {
SstableCount int32
WriteAmplification float64
ReadAmplification float64
}
}

View File

@ -105,13 +105,13 @@ func TestClientConnect(t *testing.T) {
// Modify default options to use mock transport
options := DefaultClientOptions()
options.TransportType = "mock"
// Create a client with the mock transport
client, err := NewClient(options)
if err != nil {
t.Fatalf("Failed to create client: %v", err)
}
// Get the underlying mock client for test assertions
mock := client.client.(*mockClient)
@ -139,12 +139,12 @@ func TestClientGet(t *testing.T) {
// Create a client with the mock transport
options := DefaultClientOptions()
options.TransportType = "mock"
client, err := NewClient(options)
if err != nil {
t.Fatalf("Failed to create client: %v", err)
}
// Get the underlying mock client for test assertions
mock := client.client.(*mockClient)
mock.connected = true
@ -186,12 +186,12 @@ func TestClientPut(t *testing.T) {
// Create a client with the mock transport
options := DefaultClientOptions()
options.TransportType = "mock"
client, err := NewClient(options)
if err != nil {
t.Fatalf("Failed to create client: %v", err)
}
// Get the underlying mock client for test assertions
mock := client.client.(*mockClient)
mock.connected = true
@ -220,12 +220,12 @@ func TestClientDelete(t *testing.T) {
// Create a client with the mock transport
options := DefaultClientOptions()
options.TransportType = "mock"
client, err := NewClient(options)
if err != nil {
t.Fatalf("Failed to create client: %v", err)
}
// Get the underlying mock client for test assertions
mock := client.client.(*mockClient)
mock.connected = true
@ -254,12 +254,12 @@ func TestClientBatchWrite(t *testing.T) {
// Create a client with the mock transport
options := DefaultClientOptions()
options.TransportType = "mock"
client, err := NewClient(options)
if err != nil {
t.Fatalf("Failed to create client: %v", err)
}
// Get the underlying mock client for test assertions
mock := client.client.(*mockClient)
mock.connected = true
@ -295,12 +295,12 @@ func TestClientGetStats(t *testing.T) {
// Create a client with the mock transport
options := DefaultClientOptions()
options.TransportType = "mock"
client, err := NewClient(options)
if err != nil {
t.Fatalf("Failed to create client: %v", err)
}
// Get the underlying mock client for test assertions
mock := client.client.(*mockClient)
mock.connected = true
@ -317,12 +317,12 @@ func TestClientGetStats(t *testing.T) {
"read_amplification": 2.0
}`
mock.setResponse(transport.TypeGetStats, []byte(statsJSON))
stats, err := client.GetStats(ctx)
if err != nil {
t.Errorf("Expected successful get stats, got error: %v", err)
}
if stats.KeyCount != 1000 {
t.Errorf("Expected KeyCount 1000, got %d", stats.KeyCount)
}
@ -354,12 +354,12 @@ func TestClientCompact(t *testing.T) {
// Create a client with the mock transport
options := DefaultClientOptions()
options.TransportType = "mock"
client, err := NewClient(options)
if err != nil {
t.Fatalf("Failed to create client: %v", err)
}
// Get the underlying mock client for test assertions
mock := client.client.(*mockClient)
mock.connected = true
@ -386,7 +386,7 @@ func TestClientCompact(t *testing.T) {
func TestRetryWithBackoff(t *testing.T) {
ctx := context.Background()
// Test successful retry
attempts := 0
err := RetryWithBackoff(
@ -404,14 +404,14 @@ func TestRetryWithBackoff(t *testing.T) {
2.0, // backoffFactor
0.1, // jitter
)
if err != nil {
t.Errorf("Expected successful retry, got error: %v", err)
}
if attempts != 3 {
t.Errorf("Expected 3 attempts, got %d", attempts)
}
// Test max retries exceeded
attempts = 0
err = RetryWithBackoff(
@ -426,14 +426,14 @@ func TestRetryWithBackoff(t *testing.T) {
2.0, // backoffFactor
0.1, // jitter
)
if err == nil {
t.Error("Expected error after max retries, got nil")
}
if attempts != 4 { // Initial + 3 retries
t.Errorf("Expected 4 attempts, got %d", attempts)
}
// Test non-retryable error
attempts = 0
err = RetryWithBackoff(
@ -448,14 +448,14 @@ func TestRetryWithBackoff(t *testing.T) {
2.0, // backoffFactor
0.1, // jitter
)
if err == nil {
t.Error("Expected non-retryable error to be returned, got nil")
}
if attempts != 1 {
t.Errorf("Expected 1 attempt for non-retryable error, got %d", attempts)
}
// Test context cancellation
attempts = 0
cancelCtx, cancel := context.WithCancel(ctx)
@ -463,7 +463,7 @@ func TestRetryWithBackoff(t *testing.T) {
time.Sleep(20 * time.Millisecond)
cancel()
}()
err = RetryWithBackoff(
cancelCtx,
func() error {
@ -476,8 +476,8 @@ func TestRetryWithBackoff(t *testing.T) {
2.0, // backoffFactor
0.1, // jitter
)
if !errors.Is(err, context.Canceled) {
t.Errorf("Expected context.Canceled error, got: %v", err)
}
}
}

View File

@ -304,4 +304,4 @@ func (s *transactionScanIterator) Close() error {
s.closed = true
s.cancelFunc()
return s.stream.Close()
}
}

View File

@ -7,33 +7,33 @@ import (
func TestDefaultClientOptions(t *testing.T) {
options := DefaultClientOptions()
// Verify the default options have sensible values
if options.Endpoint != "localhost:50051" {
t.Errorf("Expected default endpoint to be localhost:50051, got %s", options.Endpoint)
}
if options.ConnectTimeout != 5*time.Second {
t.Errorf("Expected default connect timeout to be 5s, got %s", options.ConnectTimeout)
}
if options.RequestTimeout != 10*time.Second {
t.Errorf("Expected default request timeout to be 10s, got %s", options.RequestTimeout)
}
if options.TransportType != "grpc" {
t.Errorf("Expected default transport type to be grpc, got %s", options.TransportType)
}
if options.PoolSize != 5 {
t.Errorf("Expected default pool size to be 5, got %d", options.PoolSize)
}
if options.TLSEnabled != false {
t.Errorf("Expected default TLS enabled to be false")
}
if options.MaxRetries != 3 {
t.Errorf("Expected default max retries to be 3, got %d", options.MaxRetries)
}
}
}

View File

@ -17,19 +17,19 @@ func mockClientFactory(endpoint string, options transport.TransportOptions) (tra
func TestClientCreation(t *testing.T) {
// First, register our mock transport
transport.RegisterClientTransport("mock_test", mockClientFactory)
// Create client options using our mock transport
options := DefaultClientOptions()
options.TransportType = "mock_test"
// Create a client
client, err := NewClient(options)
if err != nil {
t.Fatalf("Failed to create client: %v", err)
}
// Verify the client was created
if client == nil {
t.Fatal("Client is nil")
}
}
}

View File

@ -12,11 +12,11 @@ import (
// Transaction represents a database transaction
type Transaction struct {
client *Client
id string
readOnly bool
closed bool
mu sync.RWMutex
client *Client
id string
readOnly bool
closed bool
mu sync.RWMutex
}
// ErrTransactionClosed is returned when attempting to use a closed transaction
@ -285,4 +285,4 @@ func (tx *Transaction) Delete(ctx context.Context, key []byte) (bool, error) {
}
return deleteResp.Success, nil
}
}

View File

@ -117,4 +117,4 @@ func CalculateExponentialBackoff(
}
return backoff
}
}

View File

@ -130,14 +130,14 @@ func (h *HierarchicalIterator) Seek(target []byte) bool {
if !iter.Valid() {
continue
}
// If a newer iterator has the same key, use its value
if bytes.Equal(iter.Key(), bestKey) {
bestValue = iter.Value()
break // Since iterators are in newest-to-oldest order, we can stop at the first match
break // Since iterators are in newest-to-oldest order, we can stop at the first match
}
}
// Set the found key/value
h.key = bestKey
h.value = bestValue
@ -253,7 +253,7 @@ func (h *HierarchicalIterator) findNextUniqueKey(prevKey []byte) bool {
// Get the current key
key := iter.Key()
// If we haven't found a valid key yet, or this key is smaller than the current best key
if bestIterIdx == -1 || bytes.Compare(key, bestKey) < 0 {
// This becomes our best candidate so far
@ -271,14 +271,14 @@ func (h *HierarchicalIterator) findNextUniqueKey(prevKey []byte) bool {
if !iter.Valid() {
continue
}
// If a newer iterator has the same key, use its value
if bytes.Equal(iter.Key(), bestKey) {
bestValue = iter.Value()
break // Since iterators are in newest-to-oldest order, we can stop at the first match
break // Since iterators are in newest-to-oldest order, we can stop at the first match
}
}
// Set the found key/value
h.key = bestKey
h.value = bestValue

View File

@ -61,6 +61,12 @@ type EngineStats struct {
TxCompleted atomic.Uint64
TxAborted atomic.Uint64
// Recovery stats
WALFilesRecovered atomic.Uint64
WALEntriesRecovered atomic.Uint64
WALCorruptedEntries atomic.Uint64
WALRecoveryDuration atomic.Int64 // nanoseconds
// Mutex for accessing non-atomic fields
mu sync.RWMutex
}
@ -518,15 +524,15 @@ func (e *Engine) flushMemTable(mem *memtable.MemTable) error {
for iter.SeekToFirst(); iter.Valid(); iter.Next() {
key := iter.Key()
keyStr := string(key) // Use as map key
// Skip keys we've already processed (including tombstones)
if _, seen := processedKeys[keyStr]; seen {
continue
}
// Mark this key as processed regardless of whether it's a value or tombstone
processedKeys[keyStr] = struct{}{}
// Only write non-tombstone entries to the SSTable
if value := iter.Value(); value != nil {
bytesWritten += uint64(len(key) + len(value))
@ -666,21 +672,22 @@ func (e *Engine) loadSSTables() error {
// recoverFromWAL recovers memtables from existing WAL files
func (e *Engine) recoverFromWAL() error {
startTime := time.Now()
// Check if WAL directory exists
if _, err := os.Stat(e.walDir); os.IsNotExist(err) {
return nil // No WAL directory, nothing to recover
}
// List all WAL files for diagnostic purposes
// List all WAL files
walFiles, err := wal.FindWALFiles(e.walDir)
if err != nil {
if !wal.DisableRecoveryLogs {
fmt.Printf("Error listing WAL files: %v\n", err)
}
} else {
if !wal.DisableRecoveryLogs {
fmt.Printf("Found %d WAL files: %v\n", len(walFiles), walFiles)
}
e.stats.ReadErrors.Add(1)
return fmt.Errorf("error listing WAL files: %w", err)
}
if len(walFiles) > 0 {
e.stats.WALFilesRecovered.Add(uint64(len(walFiles)))
}
// Get recovery options
@ -690,17 +697,11 @@ func (e *Engine) recoverFromWAL() error {
memTables, maxSeqNum, err := memtable.RecoverFromWAL(e.cfg, recoveryOpts)
if err != nil {
// If recovery fails, let's try cleaning up WAL files
if !wal.DisableRecoveryLogs {
fmt.Printf("WAL recovery failed: %v\n", err)
fmt.Printf("Attempting to recover by cleaning up WAL files...\n")
}
e.stats.ReadErrors.Add(1)
// Create a backup directory
backupDir := filepath.Join(e.walDir, "backup_"+time.Now().Format("20060102_150405"))
if err := os.MkdirAll(backupDir, 0755); err != nil {
if !wal.DisableRecoveryLogs {
fmt.Printf("Failed to create backup directory: %v\n", err)
}
return fmt.Errorf("failed to recover from WAL: %w", err)
}
@ -708,11 +709,7 @@ func (e *Engine) recoverFromWAL() error {
for _, walFile := range walFiles {
destFile := filepath.Join(backupDir, filepath.Base(walFile))
if err := os.Rename(walFile, destFile); err != nil {
if !wal.DisableRecoveryLogs {
fmt.Printf("Failed to move WAL file %s: %v\n", walFile, err)
}
} else if !wal.DisableRecoveryLogs {
fmt.Printf("Moved problematic WAL file to %s\n", destFile)
e.stats.ReadErrors.Add(1)
}
}
@ -723,15 +720,28 @@ func (e *Engine) recoverFromWAL() error {
}
e.wal = newWal
// No memtables to recover, starting fresh
if !wal.DisableRecoveryLogs {
fmt.Printf("Starting with a fresh WAL after recovery failure\n")
}
// Record recovery duration
e.stats.WALRecoveryDuration.Store(time.Since(startTime).Nanoseconds())
return nil
}
// Update recovery statistics based on actual entries recovered
if len(walFiles) > 0 {
// Use WALDir function directly to get stats
recoveryStats, statErr := wal.ReplayWALDir(e.cfg.WALDir, func(entry *wal.Entry) error {
return nil // Just counting, not processing
})
if statErr == nil && recoveryStats != nil {
e.stats.WALEntriesRecovered.Add(recoveryStats.EntriesProcessed)
e.stats.WALCorruptedEntries.Add(recoveryStats.EntriesSkipped)
}
}
// No memtables recovered or empty WAL
if len(memTables) == 0 {
// Record recovery duration
e.stats.WALRecoveryDuration.Store(time.Since(startTime).Nanoseconds())
return nil
}
@ -755,10 +765,9 @@ func (e *Engine) recoverFromWAL() error {
}
}
if !wal.DisableRecoveryLogs {
fmt.Printf("Recovered %d memtables from WAL with max sequence number %d\n",
len(memTables), maxSeqNum)
}
// Record recovery stats
e.stats.WALRecoveryDuration.Store(time.Since(startTime).Nanoseconds())
return nil
}
@ -925,6 +934,15 @@ func (e *Engine) GetStats() map[string]interface{} {
stats["read_errors"] = e.stats.ReadErrors.Load()
stats["write_errors"] = e.stats.WriteErrors.Load()
// Add WAL recovery statistics
stats["wal_files_recovered"] = e.stats.WALFilesRecovered.Load()
stats["wal_entries_recovered"] = e.stats.WALEntriesRecovered.Load()
stats["wal_corrupted_entries"] = e.stats.WALCorruptedEntries.Load()
recoveryDuration := e.stats.WALRecoveryDuration.Load()
if recoveryDuration > 0 {
stats["wal_recovery_duration_ms"] = recoveryDuration / int64(time.Millisecond)
}
// Add timing information
e.stats.mu.RLock()
defer e.stats.mu.RUnlock()

View File

@ -81,8 +81,8 @@ func TestEngine_SameKeyMultipleOperationsFlush(t *testing.T) {
// Simulate exactly the bug scenario from the CLI
// Add the same key multiple times with different values
key := []byte("foo")
// First add
// First add
if err := engine.Put(key, []byte("23")); err != nil {
t.Fatalf("Failed to put first value: %v", err)
}
@ -91,17 +91,17 @@ func TestEngine_SameKeyMultipleOperationsFlush(t *testing.T) {
if err := engine.Delete(key); err != nil {
t.Fatalf("Failed to delete key: %v", err)
}
// Add it again with different value
if err := engine.Put(key, []byte("42")); err != nil {
t.Fatalf("Failed to re-add key: %v", err)
}
// Add another key
if err := engine.Put([]byte("bar"), []byte("23")); err != nil {
t.Fatalf("Failed to add another key: %v", err)
}
// Add another key
if err := engine.Put([]byte("user:1"), []byte(`{"name":"John"}`)); err != nil {
t.Fatalf("Failed to add another key: %v", err)
@ -130,7 +130,7 @@ func TestEngine_SameKeyMultipleOperationsFlush(t *testing.T) {
if !bytes.Equal(value, []byte("42")) {
t.Errorf("Got incorrect value after flush. Expected: %s, Got: %s", "42", string(value))
}
value, err = engine.Get([]byte("bar"))
if err != nil {
t.Fatalf("Failed to get 'bar' after flush: %v", err)
@ -138,7 +138,7 @@ func TestEngine_SameKeyMultipleOperationsFlush(t *testing.T) {
if !bytes.Equal(value, []byte("23")) {
t.Errorf("Got incorrect value for 'bar' after flush. Expected: %s, Got: %s", "23", string(value))
}
value, err = engine.Get([]byte("user:1"))
if err != nil {
t.Fatalf("Failed to get 'user:1' after flush: %v", err)
@ -154,7 +154,7 @@ func TestEngine_DuplicateKeysFlush(t *testing.T) {
// Test with a key that will be deleted and re-added multiple times
key := []byte("foo")
// Add the key
if err := engine.Put(key, []byte("42")); err != nil {
t.Fatalf("Failed to put initial value: %v", err)

View File

@ -442,13 +442,13 @@ func (c *chainedIterator) SeekToFirst() {
// Find the iterator with the smallest key from the newest source
c.current = -1
// Find the smallest valid key
for i, iter := range c.iterators {
if !iter.Valid() {
continue
}
// If we haven't found a key yet, or this key is smaller than the current smallest
if c.current == -1 || bytes.Compare(iter.Key(), c.iterators[c.current].Key()) < 0 {
c.current = i
@ -499,13 +499,13 @@ func (c *chainedIterator) Seek(target []byte) bool {
// Find the iterator with the smallest key from the newest source
c.current = -1
// Find the smallest valid key
for i, iter := range c.iterators {
if !iter.Valid() {
continue
}
// If we haven't found a key yet, or this key is smaller than the current smallest
if c.current == -1 || bytes.Compare(iter.Key(), c.iterators[c.current].Key()) < 0 {
c.current = i
@ -537,18 +537,18 @@ func (c *chainedIterator) Next() bool {
// Find the iterator with the smallest key from the newest source
c.current = -1
// Find the smallest valid key that is greater than the current key
for i, iter := range c.iterators {
if !iter.Valid() {
continue
}
// Skip if the key is the same as the current key (we've already advanced past it)
if bytes.Equal(iter.Key(), currentKey) {
continue
}
// If we haven't found a key yet, or this key is smaller than the current smallest
if c.current == -1 || bytes.Compare(iter.Key(), c.iterators[c.current].Key()) < 0 {
c.current = i

View File

@ -109,7 +109,7 @@ func (s *KevoServiceServer) BatchWrite(ctx context.Context, req *pb.BatchWriteRe
if err != nil {
return &pb.BatchWriteResponse{Success: false}, fmt.Errorf("failed to start transaction: %w", err)
}
// Ensure we either commit or rollback
defer func() {
if err != nil {
@ -182,7 +182,7 @@ func (s *KevoServiceServer) Scan(req *pb.ScanRequest, stream pb.KevoService_Scan
count := int32(0)
// Position iterator at the first entry
iter.SeekToFirst()
// Iterate through all valid entries
for iter.Valid() {
if limit > 0 && count >= limit {
@ -199,7 +199,7 @@ func (s *KevoServiceServer) Scan(req *pb.ScanRequest, stream pb.KevoService_Scan
}
count++
}
// Move to the next entry
iter.Next()
}
@ -225,8 +225,8 @@ func (pi *prefixIterator) Next() bool {
for pi.iter.Next() {
// Check if current key has the prefix
key := pi.iter.Key()
if len(key) >= len(pi.prefix) &&
equalByteSlice(key[:len(pi.prefix)], pi.prefix) {
if len(key) >= len(pi.prefix) &&
equalByteSlice(key[:len(pi.prefix)], pi.prefix) {
return true
}
}
@ -415,7 +415,7 @@ func (s *KevoServiceServer) TxScan(req *pb.TxScanRequest, stream pb.KevoService_
count := int32(0)
// Position iterator at the first entry
iter.SeekToFirst()
// Iterate through all valid entries
for iter.Valid() {
if limit > 0 && count >= limit {
@ -432,7 +432,7 @@ func (s *KevoServiceServer) TxScan(req *pb.TxScanRequest, stream pb.KevoService_
}
count++
}
// Move to the next entry
iter.Next()
}
@ -446,17 +446,17 @@ func (s *KevoServiceServer) GetStats(ctx context.Context, req *pb.GetStatsReques
keyCount := int64(0)
sstableCount := int32(0)
memtableCount := int32(1) // At least 1 active memtable
// Create a read-only transaction to count keys
tx, err := s.engine.BeginTransaction(true)
if err != nil {
return nil, fmt.Errorf("failed to begin transaction for stats: %w", err)
}
defer tx.Rollback()
// Use an iterator to count keys
iter := tx.NewIterator()
// Count keys and estimate size
var totalSize int64
for iter.Next() {
@ -492,7 +492,7 @@ func (s *KevoServiceServer) Compact(ctx context.Context, req *pb.CompactRequest)
if err != nil {
return &pb.CompactResponse{Success: false}, err
}
// Do a dummy write to force a flush
if req.Force {
err = tx.Put([]byte("__compact_marker__"), []byte("force"))
@ -501,11 +501,11 @@ func (s *KevoServiceServer) Compact(ctx context.Context, req *pb.CompactRequest)
return &pb.CompactResponse{Success: false}, err
}
}
err = tx.Commit()
if err != nil {
return &pb.CompactResponse{Success: false}, err
}
return &pb.CompactResponse{Success: true}, nil
}
}

View File

@ -56,4 +56,4 @@ func sortDurations(durations []time.Duration) {
}
}
}
}
}

View File

@ -9,8 +9,8 @@ import (
"sync"
"time"
pb "github.com/KevoDB/kevo/proto/kevo"
"github.com/KevoDB/kevo/pkg/transport"
pb "github.com/KevoDB/kevo/proto/kevo"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
@ -19,13 +19,13 @@ import (
// GRPCClient implements the transport.Client interface for gRPC
type GRPCClient struct {
endpoint string
options transport.TransportOptions
conn *grpc.ClientConn
client pb.KevoServiceClient
status transport.TransportStatus
statusMu sync.RWMutex
metrics transport.MetricsCollector
endpoint string
options transport.TransportOptions
conn *grpc.ClientConn
client pb.KevoServiceClient
status transport.TransportStatus
statusMu sync.RWMutex
metrics transport.MetricsCollector
}
// NewGRPCClient creates a new gRPC client
@ -123,10 +123,10 @@ func (c *GRPCClient) Status() transport.TransportStatus {
func (c *GRPCClient) setStatus(connected bool, err error) {
c.statusMu.Lock()
defer c.statusMu.Unlock()
c.status.Connected = connected
c.status.LastError = err
if connected {
c.status.LastConnected = time.Now()
}
@ -141,11 +141,11 @@ func (c *GRPCClient) Send(ctx context.Context, request transport.Request) (trans
// Record request metrics
startTime := time.Now()
requestType := request.Type()
// Record bytes sent
requestPayload := request.Payload()
c.metrics.RecordSend(len(requestPayload))
var resp transport.Response
var err error
@ -182,12 +182,12 @@ func (c *GRPCClient) Send(ctx context.Context, request transport.Request) (trans
// Record metrics for the request
c.metrics.RecordRequest(requestType, startTime, err)
// If we got a response, record received bytes
if resp != nil {
c.metrics.RecordReceive(len(resp.Payload()))
}
return resp, err
}
@ -206,20 +206,20 @@ func (c *GRPCClient) handleGet(ctx context.Context, payload []byte) (transport.R
var req struct {
Key []byte `json:"key"`
}
if err := json.Unmarshal(payload, &req); err != nil {
return transport.NewErrorResponse(fmt.Errorf("invalid get request payload: %w", err)), err
}
grpcReq := &pb.GetRequest{
Key: req.Key,
}
grpcResp, err := c.client.Get(ctx, grpcReq)
if err != nil {
return transport.NewErrorResponse(err), err
}
resp := struct {
Value []byte `json:"value"`
Found bool `json:"found"`
@ -227,12 +227,12 @@ func (c *GRPCClient) handleGet(ctx context.Context, payload []byte) (transport.R
Value: grpcResp.Value,
Found: grpcResp.Found,
}
respData, err := json.Marshal(resp)
if err != nil {
return transport.NewErrorResponse(err), err
}
return transport.NewResponse(transport.TypeGet, respData, nil), nil
}
@ -242,33 +242,33 @@ func (c *GRPCClient) handlePut(ctx context.Context, payload []byte) (transport.R
Value []byte `json:"value"`
Sync bool `json:"sync"`
}
if err := json.Unmarshal(payload, &req); err != nil {
return transport.NewErrorResponse(fmt.Errorf("invalid put request payload: %w", err)), err
}
grpcReq := &pb.PutRequest{
Key: req.Key,
Value: req.Value,
Sync: req.Sync,
}
grpcResp, err := c.client.Put(ctx, grpcReq)
if err != nil {
return transport.NewErrorResponse(err), err
}
resp := struct {
Success bool `json:"success"`
}{
Success: grpcResp.Success,
}
respData, err := json.Marshal(resp)
if err != nil {
return transport.NewErrorResponse(err), err
}
return transport.NewResponse(transport.TypePut, respData, nil), nil
}
@ -277,32 +277,32 @@ func (c *GRPCClient) handleDelete(ctx context.Context, payload []byte) (transpor
Key []byte `json:"key"`
Sync bool `json:"sync"`
}
if err := json.Unmarshal(payload, &req); err != nil {
return transport.NewErrorResponse(fmt.Errorf("invalid delete request payload: %w", err)), err
}
grpcReq := &pb.DeleteRequest{
Key: req.Key,
Sync: req.Sync,
}
grpcResp, err := c.client.Delete(ctx, grpcReq)
if err != nil {
return transport.NewErrorResponse(err), err
}
resp := struct {
Success bool `json:"success"`
}{
Success: grpcResp.Success,
}
respData, err := json.Marshal(resp)
if err != nil {
return transport.NewErrorResponse(err), err
}
return transport.NewResponse(transport.TypeDelete, respData, nil), nil
}
@ -315,18 +315,18 @@ func (c *GRPCClient) handleBatchWrite(ctx context.Context, payload []byte) (tran
} `json:"operations"`
Sync bool `json:"sync"`
}
if err := json.Unmarshal(payload, &req); err != nil {
return transport.NewErrorResponse(fmt.Errorf("invalid batch write request payload: %w", err)), err
}
operations := make([]*pb.Operation, len(req.Operations))
for i, op := range req.Operations {
pbOp := &pb.Operation{
Key: op.Key,
Value: op.Value,
}
switch op.Type {
case "put":
pbOp.Type = pb.Operation_PUT
@ -335,31 +335,31 @@ func (c *GRPCClient) handleBatchWrite(ctx context.Context, payload []byte) (tran
default:
return transport.NewErrorResponse(fmt.Errorf("invalid operation type: %s", op.Type)), fmt.Errorf("invalid operation type: %s", op.Type)
}
operations[i] = pbOp
}
grpcReq := &pb.BatchWriteRequest{
Operations: operations,
Sync: req.Sync,
}
grpcResp, err := c.client.BatchWrite(ctx, grpcReq)
if err != nil {
return transport.NewErrorResponse(err), err
}
resp := struct {
Success bool `json:"success"`
}{
Success: grpcResp.Success,
}
respData, err := json.Marshal(resp)
if err != nil {
return transport.NewErrorResponse(err), err
}
return transport.NewResponse(transport.TypeBatchWrite, respData, nil), nil
}
@ -367,31 +367,31 @@ func (c *GRPCClient) handleBeginTransaction(ctx context.Context, payload []byte)
var req struct {
ReadOnly bool `json:"read_only"`
}
if err := json.Unmarshal(payload, &req); err != nil {
return transport.NewErrorResponse(fmt.Errorf("invalid begin transaction request payload: %w", err)), err
}
grpcReq := &pb.BeginTransactionRequest{
ReadOnly: req.ReadOnly,
}
grpcResp, err := c.client.BeginTransaction(ctx, grpcReq)
if err != nil {
return transport.NewErrorResponse(err), err
}
resp := struct {
TransactionID string `json:"transaction_id"`
}{
TransactionID: grpcResp.TransactionId,
}
respData, err := json.Marshal(resp)
if err != nil {
return transport.NewErrorResponse(err), err
}
return transport.NewResponse(transport.TypeBeginTx, respData, nil), nil
}
@ -399,31 +399,31 @@ func (c *GRPCClient) handleCommitTransaction(ctx context.Context, payload []byte
var req struct {
TransactionID string `json:"transaction_id"`
}
if err := json.Unmarshal(payload, &req); err != nil {
return transport.NewErrorResponse(fmt.Errorf("invalid commit transaction request payload: %w", err)), err
}
grpcReq := &pb.CommitTransactionRequest{
TransactionId: req.TransactionID,
}
grpcResp, err := c.client.CommitTransaction(ctx, grpcReq)
if err != nil {
return transport.NewErrorResponse(err), err
}
resp := struct {
Success bool `json:"success"`
}{
Success: grpcResp.Success,
}
respData, err := json.Marshal(resp)
if err != nil {
return transport.NewErrorResponse(err), err
}
return transport.NewResponse(transport.TypeCommitTx, respData, nil), nil
}
@ -431,31 +431,31 @@ func (c *GRPCClient) handleRollbackTransaction(ctx context.Context, payload []by
var req struct {
TransactionID string `json:"transaction_id"`
}
if err := json.Unmarshal(payload, &req); err != nil {
return transport.NewErrorResponse(fmt.Errorf("invalid rollback transaction request payload: %w", err)), err
}
grpcReq := &pb.RollbackTransactionRequest{
TransactionId: req.TransactionID,
}
grpcResp, err := c.client.RollbackTransaction(ctx, grpcReq)
if err != nil {
return transport.NewErrorResponse(err), err
}
resp := struct {
Success bool `json:"success"`
}{
Success: grpcResp.Success,
}
respData, err := json.Marshal(resp)
if err != nil {
return transport.NewErrorResponse(err), err
}
return transport.NewResponse(transport.TypeRollbackTx, respData, nil), nil
}
@ -464,21 +464,21 @@ func (c *GRPCClient) handleTxGet(ctx context.Context, payload []byte) (transport
TransactionID string `json:"transaction_id"`
Key []byte `json:"key"`
}
if err := json.Unmarshal(payload, &req); err != nil {
return transport.NewErrorResponse(fmt.Errorf("invalid tx get request payload: %w", err)), err
}
grpcReq := &pb.TxGetRequest{
TransactionId: req.TransactionID,
Key: req.Key,
}
grpcResp, err := c.client.TxGet(ctx, grpcReq)
if err != nil {
return transport.NewErrorResponse(err), err
}
resp := struct {
Value []byte `json:"value"`
Found bool `json:"found"`
@ -486,12 +486,12 @@ func (c *GRPCClient) handleTxGet(ctx context.Context, payload []byte) (transport
Value: grpcResp.Value,
Found: grpcResp.Found,
}
respData, err := json.Marshal(resp)
if err != nil {
return transport.NewErrorResponse(err), err
}
return transport.NewResponse(transport.TypeTxGet, respData, nil), nil
}
@ -501,33 +501,33 @@ func (c *GRPCClient) handleTxPut(ctx context.Context, payload []byte) (transport
Key []byte `json:"key"`
Value []byte `json:"value"`
}
if err := json.Unmarshal(payload, &req); err != nil {
return transport.NewErrorResponse(fmt.Errorf("invalid tx put request payload: %w", err)), err
}
grpcReq := &pb.TxPutRequest{
TransactionId: req.TransactionID,
Key: req.Key,
Value: req.Value,
}
grpcResp, err := c.client.TxPut(ctx, grpcReq)
if err != nil {
return transport.NewErrorResponse(err), err
}
resp := struct {
Success bool `json:"success"`
}{
Success: grpcResp.Success,
}
respData, err := json.Marshal(resp)
if err != nil {
return transport.NewErrorResponse(err), err
}
return transport.NewResponse(transport.TypeTxPut, respData, nil), nil
}
@ -536,43 +536,43 @@ func (c *GRPCClient) handleTxDelete(ctx context.Context, payload []byte) (transp
TransactionID string `json:"transaction_id"`
Key []byte `json:"key"`
}
if err := json.Unmarshal(payload, &req); err != nil {
return transport.NewErrorResponse(fmt.Errorf("invalid tx delete request payload: %w", err)), err
}
grpcReq := &pb.TxDeleteRequest{
TransactionId: req.TransactionID,
Key: req.Key,
}
grpcResp, err := c.client.TxDelete(ctx, grpcReq)
if err != nil {
return transport.NewErrorResponse(err), err
}
resp := struct {
Success bool `json:"success"`
}{
Success: grpcResp.Success,
}
respData, err := json.Marshal(resp)
if err != nil {
return transport.NewErrorResponse(err), err
}
return transport.NewResponse(transport.TypeTxDelete, respData, nil), nil
}
func (c *GRPCClient) handleGetStats(ctx context.Context, payload []byte) (transport.Response, error) {
grpcReq := &pb.GetStatsRequest{}
grpcResp, err := c.client.GetStats(ctx, grpcReq)
if err != nil {
return transport.NewErrorResponse(err), err
}
resp := struct {
KeyCount int64 `json:"key_count"`
StorageSize int64 `json:"storage_size"`
@ -588,12 +588,12 @@ func (c *GRPCClient) handleGetStats(ctx context.Context, payload []byte) (transp
WriteAmplification: grpcResp.WriteAmplification,
ReadAmplification: grpcResp.ReadAmplification,
}
respData, err := json.Marshal(resp)
if err != nil {
return transport.NewErrorResponse(err), err
}
return transport.NewResponse(transport.TypeGetStats, respData, nil), nil
}
@ -601,31 +601,31 @@ func (c *GRPCClient) handleCompact(ctx context.Context, payload []byte) (transpo
var req struct {
Force bool `json:"force"`
}
if err := json.Unmarshal(payload, &req); err != nil {
return transport.NewErrorResponse(fmt.Errorf("invalid compact request payload: %w", err)), err
}
grpcReq := &pb.CompactRequest{
Force: req.Force,
}
grpcResp, err := c.client.Compact(ctx, grpcReq)
if err != nil {
return transport.NewErrorResponse(err), err
}
resp := struct {
Success bool `json:"success"`
}{
Success: grpcResp.Success,
}
respData, err := json.Marshal(resp)
if err != nil {
return transport.NewErrorResponse(err), err
}
return transport.NewResponse(transport.TypeCompact, respData, nil), nil
}
@ -650,7 +650,7 @@ func (s *GRPCScanStream) Recv() (transport.Response, error) {
}
return transport.NewErrorResponse(err), err
}
// Build response based on scan type
scanResp := struct {
Key []byte `json:"key"`
@ -659,12 +659,12 @@ func (s *GRPCScanStream) Recv() (transport.Response, error) {
Key: resp.Key,
Value: resp.Value,
}
respData, err := json.Marshal(scanResp)
if err != nil {
return transport.NewErrorResponse(err), err
}
s.client.metrics.RecordReceive(len(respData))
return transport.NewResponse(s.streamType, respData, nil), nil
}
@ -672,4 +672,4 @@ func (s *GRPCScanStream) Recv() (transport.Response, error) {
func (s *GRPCScanStream) Close() error {
s.cancel()
return nil
}
}

View File

@ -7,8 +7,8 @@ import (
"sync"
"time"
pb "github.com/KevoDB/kevo/proto/kevo"
"github.com/KevoDB/kevo/pkg/transport"
pb "github.com/KevoDB/kevo/proto/kevo"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
@ -72,7 +72,7 @@ func (g *GRPCTransportManager) Serve() error {
if err := g.Start(); err != nil {
return err
}
// Block until server is stopped
<-ctx.Done()
return nil
@ -284,4 +284,4 @@ func init() {
transport.RegisterClientTransport("grpc", func(endpoint string, options transport.TransportOptions) (transport.Client, error) {
return NewGRPCClient(endpoint, options)
})
}
}

View File

@ -7,15 +7,15 @@ import (
// Simple smoke test for the gRPC transport
func TestNewGRPCTransportManager(t *testing.T) {
opts := DefaultGRPCTransportOptions()
// Override the listen address to avoid port conflicts
opts.ListenAddr = ":0" // use random available port
manager, err := NewGRPCTransportManager(opts)
if err != nil {
t.Fatalf("Failed to create transport manager: %v", err)
}
// Verify the manager was created
if manager == nil {
t.Fatal("Expected non-nil manager")
@ -60,4 +60,4 @@ func TestLoadClientTLSConfigFromStruct(t *testing.T) {
if !config.InsecureSkipVerify {
t.Fatal("Expected InsecureSkipVerify to be true")
}
}
}

View File

@ -207,4 +207,4 @@ func (m *ConnectionPoolManager) CloseAll() {
m.pools.Delete(key)
return true
})
}
}

View File

@ -7,8 +7,8 @@ import (
"sync"
"time"
pb "github.com/KevoDB/kevo/proto/kevo"
"github.com/KevoDB/kevo/pkg/transport"
pb "github.com/KevoDB/kevo/proto/kevo"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/keepalive"
@ -16,30 +16,30 @@ import (
// GRPCServer implements the transport.Server interface for gRPC
type GRPCServer struct {
address string
tlsConfig *tls.Config
server *grpc.Server
address string
tlsConfig *tls.Config
server *grpc.Server
requestHandler transport.RequestHandler
started bool
mu sync.Mutex
metrics *transport.ExtendedMetricsCollector
started bool
mu sync.Mutex
metrics *transport.ExtendedMetricsCollector
}
// NewGRPCServer creates a new gRPC server
func NewGRPCServer(address string, options transport.TransportOptions) (transport.Server, error) {
// Create server options
var serverOpts []grpc.ServerOption
// Configure TLS if enabled
if options.TLSEnabled {
tlsConfig, err := LoadServerTLSConfig(options.CertFile, options.KeyFile, options.CAFile)
if err != nil {
return nil, fmt.Errorf("failed to load TLS config: %w", err)
}
serverOpts = append(serverOpts, grpc.Creds(credentials.NewTLS(tlsConfig)))
}
// Configure keepalive parameters
kaProps := keepalive.ServerParameters{
MaxConnectionIdle: 30 * time.Minute,
@ -47,20 +47,20 @@ func NewGRPCServer(address string, options transport.TransportOptions) (transpor
Time: 15 * time.Second,
Timeout: 5 * time.Second,
}
kaPolicy := keepalive.EnforcementPolicy{
MinTime: 10 * time.Second,
PermitWithoutStream: true,
}
serverOpts = append(serverOpts,
grpc.KeepaliveParams(kaProps),
grpc.KeepaliveEnforcementPolicy(kaPolicy),
)
// Create the server
server := grpc.NewServer(serverOpts...)
return &GRPCServer{
address: address,
server: server,
@ -72,18 +72,18 @@ func NewGRPCServer(address string, options transport.TransportOptions) (transpor
func (s *GRPCServer) Start() error {
s.mu.Lock()
defer s.mu.Unlock()
if s.started {
return fmt.Errorf("server already started")
}
// Start the server in a goroutine
go func() {
if err := s.Serve(); err != nil {
fmt.Printf("gRPC server error: %v\n", err)
}
}()
s.started = true
return nil
}
@ -93,31 +93,31 @@ func (s *GRPCServer) Serve() error {
if s.requestHandler == nil {
return fmt.Errorf("no request handler set")
}
// Create the service implementation
service := &kevoServiceServer{
handler: s.requestHandler,
}
// Register the service
pb.RegisterKevoServiceServer(s.server, service)
// Start listening
listener, err := transport.CreateListener("tcp", s.address, s.tlsConfig)
if err != nil {
return fmt.Errorf("failed to listen on %s: %w", s.address, err)
}
s.metrics.ServerStarted()
// Serve requests
err = s.server.Serve(listener)
if err != nil {
s.metrics.ServerErrored()
return fmt.Errorf("failed to serve: %w", err)
}
s.metrics.ServerStopped()
return nil
}
@ -126,14 +126,14 @@ func (s *GRPCServer) Serve() error {
func (s *GRPCServer) Stop(ctx context.Context) error {
s.mu.Lock()
defer s.mu.Unlock()
if !s.started {
return nil
}
s.server.GracefulStop()
s.started = false
return nil
}
@ -141,7 +141,7 @@ func (s *GRPCServer) Stop(ctx context.Context) error {
func (s *GRPCServer) SetRequestHandler(handler transport.RequestHandler) {
s.mu.Lock()
defer s.mu.Unlock()
s.requestHandler = handler
}
@ -151,4 +151,4 @@ type kevoServiceServer struct {
handler transport.RequestHandler
}
// TODO: Implement service methods
// TODO: Implement service methods

View File

@ -92,4 +92,4 @@ func LoadClientTLSConfigFromStruct(config *TLSConfig) (*tls.Config, error) {
return &tls.Config{MinVersion: tls.VersionTLS12}, nil
}
return LoadClientTLSConfig(config.CertFile, config.KeyFile, config.CAFile, config.SkipVerify)
}
}

View File

@ -5,8 +5,8 @@ import (
"sync"
"time"
pb "github.com/KevoDB/kevo/proto/kevo"
"github.com/KevoDB/kevo/pkg/transport"
pb "github.com/KevoDB/kevo/proto/kevo"
"google.golang.org/grpc"
)
@ -30,17 +30,17 @@ func (c *GRPCConnection) Execute(fn func(interface{}) error) error {
// Create a new client from the connection
client := pb.NewKevoServiceClient(c.conn)
// Execute the provided function with the client
err := fn(client)
// Update metrics if there was an error
if err != nil {
c.mu.Lock()
c.errCount++
c.mu.Unlock()
}
return err
}
@ -58,10 +58,10 @@ func (c *GRPCConnection) Address() string {
func (c *GRPCConnection) Status() transport.ConnectionStatus {
c.mu.RLock()
defer c.mu.RUnlock()
// Check the connection state
isConnected := c.conn != nil
return transport.ConnectionStatus{
Connected: isConnected,
LastActivity: c.lastUsed,
@ -81,4 +81,4 @@ type GRPCTransportOptions struct {
MaxConnectionIdle time.Duration
MaxConnectionAge time.Duration
MaxPoolConnections int
}
}

View File

@ -31,7 +31,7 @@ func DefaultRecoveryOptions(cfg *config.Config) *RecoveryOptions {
}
// RecoverFromWAL rebuilds MemTables from the write-ahead log
// Returns a list of recovered MemTables and the maximum sequence number seen
// Returns a list of recovered MemTables, the maximum sequence number seen, and stats
func RecoverFromWAL(cfg *config.Config, opts *RecoveryOptions) ([]*MemTable, uint64, error) {
if opts == nil {
opts = DefaultRecoveryOptions(cfg)
@ -76,10 +76,13 @@ func RecoverFromWAL(cfg *config.Config, opts *RecoveryOptions) ([]*MemTable, uin
}
// Replay the WAL directory
if err := wal.ReplayWALDir(cfg.WALDir, entryHandler); err != nil {
_, err := wal.ReplayWALDir(cfg.WALDir, entryHandler)
if err != nil {
return nil, 0, fmt.Errorf("failed to replay WAL: %w", err)
}
// Stats will be captured in the engine directly
// For batch operations, we need to adjust maxSeqNum
finalTable := memTables[len(memTables)-1]
nextSeq := finalTable.GetNextSequenceNumber()

View File

@ -6,21 +6,21 @@ import (
// Standard request/response type constants
const (
TypeGet = "get"
TypePut = "put"
TypeDelete = "delete"
TypeBatchWrite = "batch_write"
TypeScan = "scan"
TypeBeginTx = "begin_tx"
TypeCommitTx = "commit_tx"
TypeRollbackTx = "rollback_tx"
TypeTxGet = "tx_get"
TypeTxPut = "tx_put"
TypeTxDelete = "tx_delete"
TypeTxScan = "tx_scan"
TypeGetStats = "get_stats"
TypeCompact = "compact"
TypeError = "error"
TypeGet = "get"
TypePut = "put"
TypeDelete = "delete"
TypeBatchWrite = "batch_write"
TypeScan = "scan"
TypeBeginTx = "begin_tx"
TypeCommitTx = "commit_tx"
TypeRollbackTx = "rollback_tx"
TypeTxGet = "tx_get"
TypeTxPut = "tx_put"
TypeTxDelete = "tx_delete"
TypeTxScan = "tx_scan"
TypeGetStats = "get_stats"
TypeCompact = "compact"
TypeError = "error"
)
// Common errors
@ -97,4 +97,4 @@ func NewErrorResponse(err error) Response {
ResponseData: msg,
ResponseErr: err,
}
}
}

View File

@ -9,12 +9,12 @@ func TestBasicRequest(t *testing.T) {
// Test creating a request
payload := []byte("test payload")
req := NewRequest(TypeGet, payload)
// Test Type method
if req.Type() != TypeGet {
t.Errorf("Expected type %s, got %s", TypeGet, req.Type())
}
// Test Payload method
if string(req.Payload()) != string(payload) {
t.Errorf("Expected payload %s, got %s", string(payload), string(req.Payload()))
@ -25,26 +25,26 @@ func TestBasicResponse(t *testing.T) {
// Test creating a response with no error
payload := []byte("test response")
resp := NewResponse(TypeGet, payload, nil)
// Test Type method
if resp.Type() != TypeGet {
t.Errorf("Expected type %s, got %s", TypeGet, resp.Type())
}
// Test Payload method
if string(resp.Payload()) != string(payload) {
t.Errorf("Expected payload %s, got %s", string(payload), string(resp.Payload()))
}
// Test Error method
if resp.Error() != nil {
t.Errorf("Expected nil error, got %v", resp.Error())
}
// Test creating a response with an error
testErr := errors.New("test error")
resp = NewResponse(TypeGet, payload, testErr)
if resp.Error() != testErr {
t.Errorf("Expected error %v, got %v", testErr, resp.Error())
}
@ -54,34 +54,34 @@ func TestNewErrorResponse(t *testing.T) {
// Test creating an error response
testErr := errors.New("test error")
resp := NewErrorResponse(testErr)
// Test Type method
if resp.Type() != TypeError {
t.Errorf("Expected type %s, got %s", TypeError, resp.Type())
}
// Test Payload method - should contain error message
if string(resp.Payload()) != testErr.Error() {
t.Errorf("Expected payload %s, got %s", testErr.Error(), string(resp.Payload()))
}
// Test Error method
if resp.Error() != testErr {
t.Errorf("Expected error %v, got %v", testErr, resp.Error())
}
// Test with nil error
resp = NewErrorResponse(nil)
if resp.Type() != TypeError {
t.Errorf("Expected type %s, got %s", TypeError, resp.Type())
}
if len(resp.Payload()) != 0 {
t.Errorf("Expected empty payload, got %s", string(resp.Payload()))
}
if resp.Error() != nil {
t.Errorf("Expected nil error, got %v", resp.Error())
}
}
}

View File

@ -50,7 +50,7 @@ type TransportStatus struct {
type Request interface {
// Type returns the type of request
Type() string
// Payload returns the payload of the request
Payload() []byte
}
@ -59,10 +59,10 @@ type Request interface {
type Response interface {
// Type returns the type of response
Type() string
// Payload returns the payload of the response
Payload() []byte
// Error returns any error associated with the response
Error() error
}
@ -71,10 +71,10 @@ type Response interface {
type Stream interface {
// Send sends a request over the stream
Send(request Request) error
// Recv receives a response from the stream
Recv() (Response, error)
// Close closes the stream
Close() error
}
@ -83,19 +83,19 @@ type Stream interface {
type Client interface {
// Connect establishes a connection to the server
Connect(ctx context.Context) error
// Close closes the connection
Close() error
// IsConnected returns whether the client is connected
IsConnected() bool
// Status returns the current status of the connection
Status() TransportStatus
// Send sends a request and waits for a response
Send(ctx context.Context, request Request) (Response, error)
// Stream opens a bidirectional stream
Stream(ctx context.Context) (Stream, error)
}
@ -104,7 +104,7 @@ type Client interface {
type RequestHandler interface {
// HandleRequest processes a request and returns a response
HandleRequest(ctx context.Context, request Request) (Response, error)
// HandleStream processes a bidirectional stream
HandleStream(stream Stream) error
}
@ -113,13 +113,13 @@ type RequestHandler interface {
type Server interface {
// Start starts the server and returns immediately
Start() error
// Serve starts the server and blocks until it's stopped
Serve() error
// Stop stops the server gracefully
Stop(ctx context.Context) error
// SetRequestHandler sets the handler for incoming requests
SetRequestHandler(handler RequestHandler)
}
@ -134,16 +134,16 @@ type ServerFactory func(address string, options TransportOptions) (Server, error
type Registry interface {
// RegisterClient adds a new client implementation to the registry
RegisterClient(name string, factory ClientFactory)
// RegisterServer adds a new server implementation to the registry
RegisterServer(name string, factory ServerFactory)
// CreateClient instantiates a client by name
CreateClient(name, endpoint string, options TransportOptions) (Client, error)
// CreateServer instantiates a server by name
CreateServer(name, address string, options TransportOptions) (Server, error)
// ListTransports returns all available transport names
ListTransports() []string
}
}

View File

@ -10,16 +10,16 @@ import (
type MetricsCollector interface {
// RecordRequest records metrics for a request
RecordRequest(requestType string, startTime time.Time, err error)
// RecordSend records metrics for bytes sent
RecordSend(bytes int)
// RecordReceive records metrics for bytes received
RecordReceive(bytes int)
// RecordConnection records a connection event
RecordConnection(successful bool)
// GetMetrics returns the current metrics
GetMetrics() Metrics
}
@ -46,7 +46,7 @@ type BasicMetricsCollector struct {
bytesReceived uint64
connections uint64
connectionFailures uint64
// Track average latency and count for each request type
avgLatencyByType map[string]time.Duration
requestCountByType map[string]uint64
@ -63,26 +63,26 @@ func NewMetricsCollector() MetricsCollector {
// RecordRequest records metrics for a request
func (c *BasicMetricsCollector) RecordRequest(requestType string, startTime time.Time, err error) {
atomic.AddUint64(&c.totalRequests, 1)
if err == nil {
atomic.AddUint64(&c.successfulRequests, 1)
} else {
atomic.AddUint64(&c.failedRequests, 1)
}
// Update average latency for request type
latency := time.Since(startTime)
c.mu.Lock()
defer c.mu.Unlock()
currentAvg, exists := c.avgLatencyByType[requestType]
currentCount, _ := c.requestCountByType[requestType]
if exists {
// Update running average - the common case for better branch prediction
// new_avg = (old_avg * count + new_value) / (count + 1)
totalDuration := currentAvg * time.Duration(currentCount) + latency
totalDuration := currentAvg*time.Duration(currentCount) + latency
newCount := currentCount + 1
c.avgLatencyByType[requestType] = totalDuration / time.Duration(newCount)
c.requestCountByType[requestType] = newCount
@ -116,13 +116,13 @@ func (c *BasicMetricsCollector) RecordConnection(successful bool) {
func (c *BasicMetricsCollector) GetMetrics() Metrics {
c.mu.RLock()
defer c.mu.RUnlock()
// Create a copy of the average latency map
avgLatencyByType := make(map[string]time.Duration, len(c.avgLatencyByType))
for k, v := range c.avgLatencyByType {
avgLatencyByType[k] = v
}
return Metrics{
TotalRequests: atomic.LoadUint64(&c.totalRequests),
SuccessfulRequests: atomic.LoadUint64(&c.successfulRequests),
@ -133,4 +133,4 @@ func (c *BasicMetricsCollector) GetMetrics() Metrics {
ConnectionFailures: atomic.LoadUint64(&c.connectionFailures),
AvgLatencyByType: avgLatencyByType,
}
}
}

View File

@ -9,9 +9,9 @@ import (
// Metrics struct extensions for server metrics
type ServerMetrics struct {
Metrics
ServerStarted uint64
ServerErrored uint64
ServerStopped uint64
ServerStarted uint64
ServerErrored uint64
ServerStopped uint64
}
// Connection represents a connection to a remote endpoint
@ -31,11 +31,11 @@ type Connection interface {
// ConnectionStatus represents the status of a connection
type ConnectionStatus struct {
Connected bool
LastActivity time.Time
ErrorCount int
RequestCount int
LatencyAvg time.Duration
Connected bool
LastActivity time.Time
ErrorCount int
RequestCount int
LatencyAvg time.Duration
}
// TransportManager is an interface for managing transport layer operations

View File

@ -8,7 +8,7 @@ import (
func TestBasicMetricsCollector(t *testing.T) {
collector := NewMetricsCollector()
// Test initial state
metrics := collector.GetMetrics()
if metrics.TotalRequests != 0 ||
@ -21,11 +21,11 @@ func TestBasicMetricsCollector(t *testing.T) {
len(metrics.AvgLatencyByType) != 0 {
t.Errorf("Initial metrics not initialized correctly: %+v", metrics)
}
// Test recording successful request
startTime := time.Now().Add(-100 * time.Millisecond) // Simulate 100ms request
collector.RecordRequest("get", startTime, nil)
metrics = collector.GetMetrics()
if metrics.TotalRequests != 1 {
t.Errorf("Expected TotalRequests to be 1, got %d", metrics.TotalRequests)
@ -36,18 +36,18 @@ func TestBasicMetricsCollector(t *testing.T) {
if metrics.FailedRequests != 0 {
t.Errorf("Expected FailedRequests to be 0, got %d", metrics.FailedRequests)
}
// Check average latency
if avgLatency, exists := metrics.AvgLatencyByType["get"]; !exists {
t.Error("Expected 'get' latency to exist")
} else if avgLatency < 100*time.Millisecond {
t.Errorf("Expected latency to be at least 100ms, got %v", avgLatency)
}
// Test recording failed request
startTime = time.Now().Add(-200 * time.Millisecond) // Simulate 200ms request
collector.RecordRequest("get", startTime, errors.New("test error"))
metrics = collector.GetMetrics()
if metrics.TotalRequests != 2 {
t.Errorf("Expected TotalRequests to be 2, got %d", metrics.TotalRequests)
@ -58,26 +58,26 @@ func TestBasicMetricsCollector(t *testing.T) {
if metrics.FailedRequests != 1 {
t.Errorf("Expected FailedRequests to be 1, got %d", metrics.FailedRequests)
}
// Test average latency calculation for multiple requests
startTime = time.Now().Add(-300 * time.Millisecond)
collector.RecordRequest("put", startTime, nil)
startTime = time.Now().Add(-500 * time.Millisecond)
collector.RecordRequest("put", startTime, nil)
metrics = collector.GetMetrics()
avgPutLatency := metrics.AvgLatencyByType["put"]
// Expected avg is around (300ms + 500ms) / 2 = 400ms
if avgPutLatency < 390*time.Millisecond || avgPutLatency > 410*time.Millisecond {
t.Errorf("Expected average 'put' latency to be around 400ms, got %v", avgPutLatency)
}
// Test byte tracking
collector.RecordSend(1000)
collector.RecordReceive(2000)
metrics = collector.GetMetrics()
if metrics.BytesSent != 1000 {
t.Errorf("Expected BytesSent to be 1000, got %d", metrics.BytesSent)
@ -85,12 +85,12 @@ func TestBasicMetricsCollector(t *testing.T) {
if metrics.BytesReceived != 2000 {
t.Errorf("Expected BytesReceived to be 2000, got %d", metrics.BytesReceived)
}
// Test connection tracking
collector.RecordConnection(true)
collector.RecordConnection(false)
collector.RecordConnection(true)
metrics = collector.GetMetrics()
if metrics.Connections != 2 {
t.Errorf("Expected Connections to be 2, got %d", metrics.Connections)
@ -98,4 +98,4 @@ func TestBasicMetricsCollector(t *testing.T) {
if metrics.ConnectionFailures != 1 {
t.Errorf("Expected ConnectionFailures to be 1, got %d", metrics.ConnectionFailures)
}
}
}

View File

@ -12,11 +12,11 @@ func CreateListener(network, address string, tlsConfig *tls.Config) (net.Listene
if err != nil {
return nil, err
}
// If TLS is configured, wrap the listener
if tlsConfig != nil {
listener = tls.NewListener(listener, tlsConfig)
}
return listener, nil
}
}

View File

@ -7,7 +7,7 @@ import (
// registry implements the Registry interface
type registry struct {
mu sync.RWMutex
mu sync.RWMutex
clientFactories map[string]ClientFactory
serverFactories map[string]ServerFactory
}
@ -111,4 +111,4 @@ func GetServer(name, address string, options TransportOptions) (Server, error) {
// AvailableTransports lists all available transports in the default registry
func AvailableTransports() []string {
return DefaultRegistry.ListTransports()
}
}

View File

@ -97,17 +97,17 @@ func mockServerFactory(address string, options TransportOptions) (Server, error)
// TestRegistry tests the transport registry
func TestRegistry(t *testing.T) {
registry := NewRegistry()
// Register transports
registry.RegisterClient("mock", mockClientFactory)
registry.RegisterServer("mock", mockServerFactory)
// Test listing transports
transports := registry.ListTransports()
if len(transports) != 1 || transports[0] != "mock" {
t.Errorf("Expected [mock], got %v", transports)
}
// Test creating client
client, err := registry.CreateClient("mock", "localhost:8080", TransportOptions{
Timeout: 5 * time.Second,
@ -115,21 +115,21 @@ func TestRegistry(t *testing.T) {
if err != nil {
t.Fatalf("Failed to create client: %v", err)
}
// Test client methods
if client.IsConnected() {
t.Error("Expected client to be disconnected initially")
}
err = client.Connect(context.Background())
if err != nil {
t.Fatalf("Failed to connect: %v", err)
}
if !client.IsConnected() {
t.Error("Expected client to be connected after Connect()")
}
// Test server creation
server, err := registry.CreateServer("mock", "localhost:8080", TransportOptions{
Timeout: 5 * time.Second,
@ -137,26 +137,26 @@ func TestRegistry(t *testing.T) {
if err != nil {
t.Fatalf("Failed to create server: %v", err)
}
// Test server methods
err = server.Start()
if err != nil {
t.Fatalf("Failed to start server: %v", err)
}
mockServer := server.(*mockServer)
if !mockServer.started {
t.Error("Expected server to be started")
}
// Test non-existent transport
_, err = registry.CreateClient("nonexistent", "", TransportOptions{})
if err == nil {
t.Error("Expected error creating non-existent client")
}
_, err = registry.CreateServer("nonexistent", "", TransportOptions{})
if err == nil {
t.Error("Expected error creating non-existent server")
}
}
}

View File

@ -75,7 +75,7 @@ func TestBatchEncoding(t *testing.T) {
// Replay and decode
var decodedBatch *Batch
err = ReplayWALDir(dir, func(entry *Entry) error {
_, err = ReplayWALDir(dir, func(entry *Entry) error {
if entry.Type == OpTypeBatch {
var err error
decodedBatch, err = DecodeBatch(entry)

View File

@ -229,6 +229,17 @@ func (r *Reader) Close() error {
// EntryHandler is a function that processes WAL entries during replay
type EntryHandler func(*Entry) error
// RecoveryStats tracks statistics about WAL recovery
type RecoveryStats struct {
EntriesProcessed uint64
EntriesSkipped uint64
}
// NewRecoveryStats creates a new RecoveryStats instance
func NewRecoveryStats() *RecoveryStats {
return &RecoveryStats{}
}
// FindWALFiles returns a list of WAL files in the given directory
func FindWALFiles(dir string) ([]string, error) {
pattern := filepath.Join(dir, "*.wal")
@ -267,16 +278,15 @@ func getEntryCount(path string) int {
return count
}
func ReplayWALFile(path string, handler EntryHandler) error {
func ReplayWALFile(path string, handler EntryHandler) (*RecoveryStats, error) {
reader, err := OpenReader(path)
if err != nil {
return err
return nil, err
}
defer reader.Close()
// Track statistics for reporting
entriesProcessed := 0
entriesSkipped := 0
// Track statistics
stats := NewRecoveryStats()
for {
entry, err := reader.ReadEntry()
@ -290,14 +300,11 @@ func ReplayWALFile(path string, handler EntryHandler) error {
if strings.Contains(err.Error(), "corrupt") ||
strings.Contains(err.Error(), "invalid") {
// Skip this corrupted entry
if !DisableRecoveryLogs {
fmt.Printf("Skipping corrupted entry in %s: %v\n", path, err)
}
entriesSkipped++
stats.EntriesSkipped++
// If we've seen too many corrupted entries in a row, give up on this file
if entriesSkipped > 5 && entriesProcessed == 0 {
return fmt.Errorf("too many corrupted entries at start of file %s", path)
if stats.EntriesSkipped > 5 && stats.EntriesProcessed == 0 {
return stats, fmt.Errorf("too many corrupted entries at start of file %s", path)
}
// Try to recover by scanning ahead
@ -310,7 +317,7 @@ func ReplayWALFile(path string, handler EntryHandler) error {
break
}
// Couldn't recover
return fmt.Errorf("failed to recover from corruption in %s: %w", path, recoverErr)
return stats, fmt.Errorf("failed to recover from corruption in %s: %w", path, recoverErr)
}
// Successfully recovered, continue to the next entry
@ -318,23 +325,18 @@ func ReplayWALFile(path string, handler EntryHandler) error {
}
// For other errors, fail the replay
return fmt.Errorf("error reading entry from %s: %w", path, err)
return stats, fmt.Errorf("error reading entry from %s: %w", path, err)
}
// Process the entry
if err := handler(entry); err != nil {
return fmt.Errorf("error handling entry: %w", err)
return stats, fmt.Errorf("error handling entry: %w", err)
}
entriesProcessed++
stats.EntriesProcessed++
}
if !DisableRecoveryLogs {
fmt.Printf("Processed %d entries from %s (skipped %d corrupted entries)\n",
entriesProcessed, path, entriesSkipped)
}
return nil
return stats, nil
}
// recoverFromCorruption attempts to recover from a corrupted record by scanning ahead
@ -356,54 +358,58 @@ func recoverFromCorruption(reader *Reader) error {
}
// ReplayWALDir replays all WAL files in the given directory in order
func ReplayWALDir(dir string, handler EntryHandler) error {
func ReplayWALDir(dir string, handler EntryHandler) (*RecoveryStats, error) {
files, err := FindWALFiles(dir)
if err != nil {
return err
return nil, err
}
// Track overall recovery stats
totalStats := NewRecoveryStats()
// Track number of files processed successfully
successfulFiles := 0
var lastErr error
// Try to process each file, but continue on recoverable errors
for _, file := range files {
err := ReplayWALFile(file, handler)
fileStats, err := ReplayWALFile(file, handler)
if err != nil {
if !DisableRecoveryLogs {
fmt.Printf("Error processing WAL file %s: %v\n", file, err)
}
// Record the error, but continue
lastErr = err
// If we got some stats from the file before the error, add them to our totals
if fileStats != nil {
totalStats.EntriesProcessed += fileStats.EntriesProcessed
totalStats.EntriesSkipped += fileStats.EntriesSkipped
}
// Check if this is a file-level error or just a corrupt record
if !strings.Contains(err.Error(), "corrupt") &&
!strings.Contains(err.Error(), "invalid") {
return fmt.Errorf("fatal error replaying WAL file %s: %w", file, err)
return totalStats, fmt.Errorf("fatal error replaying WAL file %s: %w", file, err)
}
// Continue to the next file for corrupt/invalid errors
continue
}
if !DisableRecoveryLogs {
fmt.Printf("Processed %d entries from %s (skipped 0 corrupted entries)\n",
getEntryCount(file), file)
}
// Add stats from this file to our totals
totalStats.EntriesProcessed += fileStats.EntriesProcessed
totalStats.EntriesSkipped += fileStats.EntriesSkipped
successfulFiles++
}
// If we processed at least one file successfully, the WAL recovery is considered successful
if successfulFiles > 0 {
return nil
return totalStats, nil
}
// If no files were processed successfully and we had errors, return the last error
if lastErr != nil {
return fmt.Errorf("failed to process any WAL files: %w", lastErr)
return totalStats, fmt.Errorf("failed to process any WAL files: %w", lastErr)
}
return nil
return totalStats, nil
}

View File

@ -56,7 +56,7 @@ func TestWALWrite(t *testing.T) {
// Verify entries by replaying
entries := make(map[string]string)
err = ReplayWALDir(dir, func(entry *Entry) error {
_, err = ReplayWALDir(dir, func(entry *Entry) error {
if entry.Type == OpTypePut {
entries[string(entry.Key)] = string(entry.Value)
} else if entry.Type == OpTypeDelete {
@ -115,7 +115,7 @@ func TestWALDelete(t *testing.T) {
// Verify entries by replaying
var deleted bool
err = ReplayWALDir(dir, func(entry *Entry) error {
_, err = ReplayWALDir(dir, func(entry *Entry) error {
if entry.Type == OpTypePut && bytes.Equal(entry.Key, key) {
if deleted {
deleted = false // Key was re-added
@ -171,7 +171,7 @@ func TestWALLargeEntry(t *testing.T) {
// Verify by replaying
var foundLargeEntry bool
err = ReplayWALDir(dir, func(entry *Entry) error {
_, err = ReplayWALDir(dir, func(entry *Entry) error {
if entry.Type == OpTypePut && len(entry.Key) == len(key) && len(entry.Value) == len(value) {
// Verify key
for i := range key {
@ -240,7 +240,7 @@ func TestWALBatch(t *testing.T) {
entries := make(map[string]string)
batchCount := 0
err = ReplayWALDir(dir, func(entry *Entry) error {
_, err = ReplayWALDir(dir, func(entry *Entry) error {
if entry.Type == OpTypeBatch {
batchCount++
@ -336,7 +336,7 @@ func TestWALRecovery(t *testing.T) {
// Verify entries by replaying all WAL files in order
entries := make(map[string]string)
err = ReplayWALDir(dir, func(entry *Entry) error {
_, err = ReplayWALDir(dir, func(entry *Entry) error {
if entry.Type == OpTypePut {
entries[string(entry.Key)] = string(entry.Value)
} else if entry.Type == OpTypeDelete {
@ -410,7 +410,7 @@ func TestWALSyncModes(t *testing.T) {
// Verify entries by replaying
count := 0
err = ReplayWALDir(dir, func(entry *Entry) error {
_, err = ReplayWALDir(dir, func(entry *Entry) error {
if entry.Type == OpTypePut {
count++
}
@ -471,7 +471,7 @@ func TestWALFragmentation(t *testing.T) {
var reconstructedValue []byte
var foundPut bool
err = ReplayWALDir(dir, func(entry *Entry) error {
_, err = ReplayWALDir(dir, func(entry *Entry) error {
if entry.Type == OpTypePut {
foundPut = true
reconstructedKey = entry.Key
@ -580,7 +580,7 @@ func TestWALErrorHandling(t *testing.T) {
// Try to replay a non-existent file
nonExistentPath := filepath.Join(dir, "nonexistent.wal")
err = ReplayWALFile(nonExistentPath, func(entry *Entry) error {
_, err = ReplayWALFile(nonExistentPath, func(entry *Entry) error {
return nil
})