Compare commits

..

1 Commits

Author SHA1 Message Date
5a926633bb
feat: enhance wal recover statistics
All checks were successful
Go Tests / Run Tests (1.24.2) (push) Successful in 9m48s
2025-04-22 13:25:26 -06:00
35 changed files with 449 additions and 449 deletions

View File

@ -81,14 +81,14 @@ Commands (interactive mode only):
// Config holds the application configuration // Config holds the application configuration
type Config struct { type Config struct {
ServerMode bool ServerMode bool
DaemonMode bool DaemonMode bool
ListenAddr string ListenAddr string
DBPath string DBPath string
TLSEnabled bool TLSEnabled bool
TLSCertFile string TLSCertFile string
TLSKeyFile string TLSKeyFile string
TLSCAFile string TLSCAFile string
} }
func main() { func main() {
@ -98,7 +98,7 @@ func main() {
// Open database if path provided // Open database if path provided
var eng *engine.Engine var eng *engine.Engine
var err error var err error
if config.DBPath != "" { if config.DBPath != "" {
fmt.Printf("Opening database at %s\n", config.DBPath) fmt.Printf("Opening database at %s\n", config.DBPath)
eng, err = engine.NewEngine(config.DBPath) eng, err = engine.NewEngine(config.DBPath)
@ -108,18 +108,18 @@ func main() {
} }
defer eng.Close() defer eng.Close()
} }
// Check if we should run in server mode // Check if we should run in server mode
if config.ServerMode { if config.ServerMode {
if eng == nil { if eng == nil {
fmt.Fprintf(os.Stderr, "Error: Server mode requires a database path\n") fmt.Fprintf(os.Stderr, "Error: Server mode requires a database path\n")
os.Exit(1) os.Exit(1)
} }
runServer(eng, config) runServer(eng, config)
return return
} }
// Run in interactive mode // Run in interactive mode
runInteractive(eng, config.DBPath) runInteractive(eng, config.DBPath)
} }
@ -151,31 +151,31 @@ func parseFlags() Config {
serverMode := flag.Bool("server", false, "Run in server mode, exposing a gRPC API") serverMode := flag.Bool("server", false, "Run in server mode, exposing a gRPC API")
daemonMode := flag.Bool("daemon", false, "Run in daemon mode (detached from terminal)") daemonMode := flag.Bool("daemon", false, "Run in daemon mode (detached from terminal)")
listenAddr := flag.String("address", "localhost:50051", "Address to listen on in server mode") listenAddr := flag.String("address", "localhost:50051", "Address to listen on in server mode")
// TLS options // TLS options
tlsEnabled := flag.Bool("tls", false, "Enable TLS for secure connections") tlsEnabled := flag.Bool("tls", false, "Enable TLS for secure connections")
tlsCertFile := flag.String("cert", "", "TLS certificate file path") tlsCertFile := flag.String("cert", "", "TLS certificate file path")
tlsKeyFile := flag.String("key", "", "TLS private key file path") tlsKeyFile := flag.String("key", "", "TLS private key file path")
tlsCAFile := flag.String("ca", "", "TLS CA certificate file for client verification") tlsCAFile := flag.String("ca", "", "TLS CA certificate file for client verification")
// Parse flags // Parse flags
flag.Parse() flag.Parse()
// Get database path from remaining arguments // Get database path from remaining arguments
var dbPath string var dbPath string
if flag.NArg() > 0 { if flag.NArg() > 0 {
dbPath = flag.Arg(0) dbPath = flag.Arg(0)
} }
return Config{ return Config{
ServerMode: *serverMode, ServerMode: *serverMode,
DaemonMode: *daemonMode, DaemonMode: *daemonMode,
ListenAddr: *listenAddr, ListenAddr: *listenAddr,
DBPath: dbPath, DBPath: dbPath,
TLSEnabled: *tlsEnabled, TLSEnabled: *tlsEnabled,
TLSCertFile: *tlsCertFile, TLSCertFile: *tlsCertFile,
TLSKeyFile: *tlsKeyFile, TLSKeyFile: *tlsKeyFile,
TLSCAFile: *tlsCAFile, TLSCAFile: *tlsCAFile,
} }
} }
@ -185,10 +185,10 @@ func runServer(eng *engine.Engine, config Config) {
if config.DaemonMode { if config.DaemonMode {
setupDaemonMode() setupDaemonMode()
} }
// Create and start the server // Create and start the server
server := NewServer(eng, config) server := NewServer(eng, config)
// Start the server (non-blocking) // Start the server (non-blocking)
if err := server.Start(); err != nil { if err := server.Start(); err != nil {
fmt.Fprintf(os.Stderr, "Error starting server: %v\n", err) fmt.Fprintf(os.Stderr, "Error starting server: %v\n", err)
@ -196,10 +196,10 @@ func runServer(eng *engine.Engine, config Config) {
} }
fmt.Printf("Kevo server started on %s\n", config.ListenAddr) fmt.Printf("Kevo server started on %s\n", config.ListenAddr)
// Set up signal handling for graceful shutdown // Set up signal handling for graceful shutdown
setupGracefulShutdown(server, eng) setupGracefulShutdown(server, eng)
// Start serving (blocking) // Start serving (blocking)
if err := server.Serve(); err != nil { if err := server.Serve(); err != nil {
fmt.Fprintf(os.Stderr, "Error serving: %v\n", err) fmt.Fprintf(os.Stderr, "Error serving: %v\n", err)
@ -214,29 +214,29 @@ func setupDaemonMode() {
if err != nil { if err != nil {
log.Fatalf("Failed to open /dev/null: %v", err) log.Fatalf("Failed to open /dev/null: %v", err)
} }
// Redirect standard file descriptors to /dev/null // Redirect standard file descriptors to /dev/null
err = syscall.Dup2(int(null.Fd()), int(os.Stdin.Fd())) err = syscall.Dup2(int(null.Fd()), int(os.Stdin.Fd()))
if err != nil { if err != nil {
log.Fatalf("Failed to redirect stdin: %v", err) log.Fatalf("Failed to redirect stdin: %v", err)
} }
err = syscall.Dup2(int(null.Fd()), int(os.Stdout.Fd())) err = syscall.Dup2(int(null.Fd()), int(os.Stdout.Fd()))
if err != nil { if err != nil {
log.Fatalf("Failed to redirect stdout: %v", err) log.Fatalf("Failed to redirect stdout: %v", err)
} }
err = syscall.Dup2(int(null.Fd()), int(os.Stderr.Fd())) err = syscall.Dup2(int(null.Fd()), int(os.Stderr.Fd()))
if err != nil { if err != nil {
log.Fatalf("Failed to redirect stderr: %v", err) log.Fatalf("Failed to redirect stderr: %v", err)
} }
// Create a new process group // Create a new process group
_, err = syscall.Setsid() _, err = syscall.Setsid()
if err != nil { if err != nil {
log.Fatalf("Failed to create new session: %v", err) log.Fatalf("Failed to create new session: %v", err)
} }
fmt.Println("Daemon mode enabled, detaching from terminal...") fmt.Println("Daemon mode enabled, detaching from terminal...")
} }
@ -244,22 +244,22 @@ func setupDaemonMode() {
func setupGracefulShutdown(server *Server, eng *engine.Engine) { func setupGracefulShutdown(server *Server, eng *engine.Engine) {
sigChan := make(chan os.Signal, 1) sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
go func() { go func() {
sig := <-sigChan sig := <-sigChan
fmt.Printf("\nReceived signal %v, shutting down...\n", sig) fmt.Printf("\nReceived signal %v, shutting down...\n", sig)
// Graceful shutdown logic // Graceful shutdown logic
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
// Shut down the server // Shut down the server
if err := server.Shutdown(ctx); err != nil { if err := server.Shutdown(ctx); err != nil {
fmt.Fprintf(os.Stderr, "Error shutting down server: %v\n", err) fmt.Fprintf(os.Stderr, "Error shutting down server: %v\n", err)
} }
// The engine will be closed by the defer in main() // The engine will be closed by the defer in main()
fmt.Println("Shutdown complete") fmt.Println("Shutdown complete")
os.Exit(0) os.Exit(0)
}() }()
@ -269,7 +269,7 @@ func setupGracefulShutdown(server *Server, eng *engine.Engine) {
func runInteractive(eng *engine.Engine, dbPath string) { func runInteractive(eng *engine.Engine, dbPath string) {
fmt.Println("Kevo (kevo) version 1.0.2") fmt.Println("Kevo (kevo) version 1.0.2")
fmt.Println("Enter .help for usage hints.") fmt.Println("Enter .help for usage hints.")
var tx engine.Transaction var tx engine.Transaction
var err error var err error
@ -412,7 +412,7 @@ func runInteractive(eng *engine.Engine, dbPath string) {
// Print statistics // Print statistics
stats := eng.GetStats() stats := eng.GetStats()
// Format human-readable time for the last operation timestamps // Format human-readable time for the last operation timestamps
var lastPutTime, lastGetTime, lastDeleteTime time.Time var lastPutTime, lastGetTime, lastDeleteTime time.Time
if putTime, ok := stats["last_put_time"].(int64); ok && putTime > 0 { if putTime, ok := stats["last_put_time"].(int64); ok && putTime > 0 {
@ -424,13 +424,13 @@ func runInteractive(eng *engine.Engine, dbPath string) {
if deleteTime, ok := stats["last_delete_time"].(int64); ok && deleteTime > 0 { if deleteTime, ok := stats["last_delete_time"].(int64); ok && deleteTime > 0 {
lastDeleteTime = time.Unix(0, deleteTime) lastDeleteTime = time.Unix(0, deleteTime)
} }
// Operations section // Operations section
fmt.Println("📊 Operations:") fmt.Println("📊 Operations:")
fmt.Printf(" • Puts: %d\n", stats["put_ops"]) fmt.Printf(" • Puts: %d\n", stats["put_ops"])
fmt.Printf(" • Gets: %d (Hits: %d, Misses: %d)\n", stats["get_ops"], stats["get_hits"], stats["get_misses"]) fmt.Printf(" • Gets: %d (Hits: %d, Misses: %d)\n", stats["get_ops"], stats["get_hits"], stats["get_misses"])
fmt.Printf(" • Deletes: %d\n", stats["delete_ops"]) fmt.Printf(" • Deletes: %d\n", stats["delete_ops"])
// Last Operation Times // Last Operation Times
fmt.Println("\n⏱ Last Operation Times:") fmt.Println("\n⏱ Last Operation Times:")
if !lastPutTime.IsZero() { if !lastPutTime.IsZero() {
@ -448,25 +448,25 @@ func runInteractive(eng *engine.Engine, dbPath string) {
} else { } else {
fmt.Printf(" • Last Delete: Never\n") fmt.Printf(" • Last Delete: Never\n")
} }
// Transactions // Transactions
fmt.Println("\n💼 Transactions:") fmt.Println("\n💼 Transactions:")
fmt.Printf(" • Started: %d\n", stats["tx_started"]) fmt.Printf(" • Started: %d\n", stats["tx_started"])
fmt.Printf(" • Completed: %d\n", stats["tx_completed"]) fmt.Printf(" • Completed: %d\n", stats["tx_completed"])
fmt.Printf(" • Aborted: %d\n", stats["tx_aborted"]) fmt.Printf(" • Aborted: %d\n", stats["tx_aborted"])
// Storage metrics // Storage metrics
fmt.Println("\n💾 Storage:") fmt.Println("\n💾 Storage:")
fmt.Printf(" • Total Bytes Read: %d\n", stats["total_bytes_read"]) fmt.Printf(" • Total Bytes Read: %d\n", stats["total_bytes_read"])
fmt.Printf(" • Total Bytes Written: %d\n", stats["total_bytes_written"]) fmt.Printf(" • Total Bytes Written: %d\n", stats["total_bytes_written"])
fmt.Printf(" • Flush Count: %d\n", stats["flush_count"]) fmt.Printf(" • Flush Count: %d\n", stats["flush_count"])
// Table stats // Table stats
fmt.Println("\n📋 Tables:") fmt.Println("\n📋 Tables:")
fmt.Printf(" • SSTable Count: %d\n", stats["sstable_count"]) fmt.Printf(" • SSTable Count: %d\n", stats["sstable_count"])
fmt.Printf(" • Immutable MemTable Count: %d\n", stats["immutable_memtable_count"]) fmt.Printf(" • Immutable MemTable Count: %d\n", stats["immutable_memtable_count"])
fmt.Printf(" • Current MemTable Size: %d bytes\n", stats["memtable_size"]) fmt.Printf(" • Current MemTable Size: %d bytes\n", stats["memtable_size"])
// WAL recovery stats // WAL recovery stats
fmt.Println("\n🔄 WAL Recovery:") fmt.Println("\n🔄 WAL Recovery:")
fmt.Printf(" • Files Recovered: %d\n", stats["wal_files_recovered"]) fmt.Printf(" • Files Recovered: %d\n", stats["wal_files_recovered"])
@ -475,17 +475,17 @@ func runInteractive(eng *engine.Engine, dbPath string) {
if recoveryDuration, ok := stats["wal_recovery_duration_ms"]; ok { if recoveryDuration, ok := stats["wal_recovery_duration_ms"]; ok {
fmt.Printf(" • Recovery Duration: %d ms\n", recoveryDuration) fmt.Printf(" • Recovery Duration: %d ms\n", recoveryDuration)
} }
// Error counts // Error counts
fmt.Println("\n⚠ Errors:") fmt.Println("\n⚠ Errors:")
fmt.Printf(" • Read Errors: %d\n", stats["read_errors"]) fmt.Printf(" • Read Errors: %d\n", stats["read_errors"])
fmt.Printf(" • Write Errors: %d\n", stats["write_errors"]) fmt.Printf(" • Write Errors: %d\n", stats["write_errors"])
// Compaction stats (if available) // Compaction stats (if available)
if compactionOutputCount, ok := stats["compaction_last_outputs_count"]; ok { if compactionOutputCount, ok := stats["compaction_last_outputs_count"]; ok {
fmt.Println("\n🧹 Compaction:") fmt.Println("\n🧹 Compaction:")
fmt.Printf(" • Last Output Files Count: %d\n", compactionOutputCount) fmt.Printf(" • Last Output Files Count: %d\n", compactionOutputCount)
// Display other compaction stats as available // Display other compaction stats as available
for key, value := range stats { for key, value := range stats {
if strings.HasPrefix(key, "compaction_") && key != "compaction_last_outputs_count" && key != "compaction_last_outputs" { if strings.HasPrefix(key, "compaction_") && key != "compaction_last_outputs_count" && key != "compaction_last_outputs" {
@ -825,4 +825,4 @@ func toTitle(s string) string {
return r return r
}, },
s) s)
} }

View File

@ -35,14 +35,14 @@ func (tr *TransactionRegistry) Begin(ctx context.Context, eng *engine.Engine, re
// Create context with timeout to prevent potential hangs // Create context with timeout to prevent potential hangs
timeoutCtx, cancel := context.WithTimeout(ctx, 5*time.Second) timeoutCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel() defer cancel()
// Create a channel to receive the transaction result // Create a channel to receive the transaction result
type txResult struct { type txResult struct {
tx engine.Transaction tx engine.Transaction
err error err error
} }
resultCh := make(chan txResult, 1) resultCh := make(chan txResult, 1)
// Start transaction in a goroutine to prevent potential blocking // Start transaction in a goroutine to prevent potential blocking
go func() { go func() {
tx, err := eng.BeginTransaction(readOnly) tx, err := eng.BeginTransaction(readOnly)
@ -56,26 +56,26 @@ func (tr *TransactionRegistry) Begin(ctx context.Context, eng *engine.Engine, re
} }
} }
}() }()
// Wait for result or timeout // Wait for result or timeout
select { select {
case result := <-resultCh: case result := <-resultCh:
if result.err != nil { if result.err != nil {
return "", fmt.Errorf("failed to begin transaction: %w", result.err) return "", fmt.Errorf("failed to begin transaction: %w", result.err)
} }
tr.mu.Lock() tr.mu.Lock()
defer tr.mu.Unlock() defer tr.mu.Unlock()
// Generate a transaction ID // Generate a transaction ID
tr.nextID++ tr.nextID++
txID := fmt.Sprintf("tx-%d", tr.nextID) txID := fmt.Sprintf("tx-%d", tr.nextID)
// Register the transaction // Register the transaction
tr.transactions[txID] = result.tx tr.transactions[txID] = result.tx
return txID, nil return txID, nil
case <-timeoutCtx.Done(): case <-timeoutCtx.Done():
return "", fmt.Errorf("transaction creation timed out: %w", timeoutCtx.Err()) return "", fmt.Errorf("transaction creation timed out: %w", timeoutCtx.Err())
} }
@ -104,31 +104,31 @@ func (tr *TransactionRegistry) GracefulShutdown(ctx context.Context) error {
defer tr.mu.Unlock() defer tr.mu.Unlock()
var lastErr error var lastErr error
// Copy transaction IDs to avoid modifying the map during iteration // Copy transaction IDs to avoid modifying the map during iteration
ids := make([]string, 0, len(tr.transactions)) ids := make([]string, 0, len(tr.transactions))
for id := range tr.transactions { for id := range tr.transactions {
ids = append(ids, id) ids = append(ids, id)
} }
// Rollback each transaction with a timeout // Rollback each transaction with a timeout
for _, id := range ids { for _, id := range ids {
tx, exists := tr.transactions[id] tx, exists := tr.transactions[id]
if !exists { if !exists {
continue continue
} }
// Use a timeout for each rollback operation // Use a timeout for each rollback operation
rollbackCtx, cancel := context.WithTimeout(ctx, 1*time.Second) rollbackCtx, cancel := context.WithTimeout(ctx, 1*time.Second)
// Create a channel for the rollback result // Create a channel for the rollback result
doneCh := make(chan error, 1) doneCh := make(chan error, 1)
// Execute rollback in goroutine // Execute rollback in goroutine
go func(t engine.Transaction) { go func(t engine.Transaction) {
doneCh <- t.Rollback() doneCh <- t.Rollback()
}(tx) }(tx)
// Wait for rollback or timeout // Wait for rollback or timeout
var err error var err error
select { select {
@ -137,14 +137,14 @@ func (tr *TransactionRegistry) GracefulShutdown(ctx context.Context) error {
case <-rollbackCtx.Done(): case <-rollbackCtx.Done():
err = fmt.Errorf("rollback timed out: %w", rollbackCtx.Err()) err = fmt.Errorf("rollback timed out: %w", rollbackCtx.Err())
} }
cancel() // Clean up context cancel() // Clean up context
// Record error if any // Record error if any
if err != nil { if err != nil {
lastErr = fmt.Errorf("failed to rollback transaction %s: %w", id, err) lastErr = fmt.Errorf("failed to rollback transaction %s: %w", id, err)
} }
// Always remove transaction from map // Always remove transaction from map
delete(tr.transactions, id) delete(tr.transactions, id)
} }
@ -154,12 +154,12 @@ func (tr *TransactionRegistry) GracefulShutdown(ctx context.Context) error {
// Server represents the Kevo server // Server represents the Kevo server
type Server struct { type Server struct {
eng *engine.Engine eng *engine.Engine
txRegistry *TransactionRegistry txRegistry *TransactionRegistry
listener net.Listener listener net.Listener
grpcServer *grpc.Server grpcServer *grpc.Server
kevoService *grpcservice.KevoServiceServer kevoService *grpcservice.KevoServiceServer
config Config config Config
} }
// NewServer creates a new server instance // NewServer creates a new server instance
@ -249,14 +249,14 @@ func (s *Server) Shutdown(ctx context.Context) error {
// First, gracefully stop the gRPC server if it exists // First, gracefully stop the gRPC server if it exists
if s.grpcServer != nil { if s.grpcServer != nil {
fmt.Println("Gracefully stopping gRPC server...") fmt.Println("Gracefully stopping gRPC server...")
// Create a channel to signal when the server has stopped // Create a channel to signal when the server has stopped
stopped := make(chan struct{}) stopped := make(chan struct{})
go func() { go func() {
s.grpcServer.GracefulStop() s.grpcServer.GracefulStop()
close(stopped) close(stopped)
}() }()
// Wait for graceful stop or context deadline // Wait for graceful stop or context deadline
select { select {
case <-stopped: case <-stopped:
@ -266,7 +266,7 @@ func (s *Server) Shutdown(ctx context.Context) error {
s.grpcServer.Stop() s.grpcServer.Stop()
} }
} }
// Shut down the listener if it's still open // Shut down the listener if it's still open
if s.listener != nil { if s.listener != nil {
if err := s.listener.Close(); err != nil { if err := s.listener.Close(); err != nil {
@ -280,4 +280,4 @@ func (s *Server) Shutdown(ctx context.Context) error {
} }
return nil return nil
} }

View File

@ -14,7 +14,7 @@ func TestTransactionRegistry(t *testing.T) {
// Create a timeout context for the whole test // Create a timeout context for the whole test
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
// Set up temporary directory for test // Set up temporary directory for test
tmpDir, err := os.MkdirTemp("", "kevo_test") tmpDir, err := os.MkdirTemp("", "kevo_test")
if err != nil { if err != nil {
@ -153,47 +153,47 @@ func TestGRPCServer(t *testing.T) {
t.Fatalf("Failed to create temporary directory: %v", err) t.Fatalf("Failed to create temporary directory: %v", err)
} }
defer os.RemoveAll(tempDBPath) defer os.RemoveAll(tempDBPath)
// Create engine // Create engine
eng, err := engine.NewEngine(tempDBPath) eng, err := engine.NewEngine(tempDBPath)
if err != nil { if err != nil {
t.Fatalf("Failed to create engine: %v", err) t.Fatalf("Failed to create engine: %v", err)
} }
defer eng.Close() defer eng.Close()
// Create server configuration // Create server configuration
config := Config{ config := Config{
ServerMode: true, ServerMode: true,
ListenAddr: "localhost:50052", // Use a different port for tests ListenAddr: "localhost:50052", // Use a different port for tests
DBPath: tempDBPath, DBPath: tempDBPath,
} }
// Create and start the server // Create and start the server
server := NewServer(eng, config) server := NewServer(eng, config)
if err := server.Start(); err != nil { if err := server.Start(); err != nil {
t.Fatalf("Failed to start server: %v", err) t.Fatalf("Failed to start server: %v", err)
} }
// Run server in a goroutine // Run server in a goroutine
go func() { go func() {
if err := server.Serve(); err != nil { if err := server.Serve(); err != nil {
t.Logf("Server stopped: %v", err) t.Logf("Server stopped: %v", err)
} }
}() }()
// Give the server a moment to start // Give the server a moment to start
time.Sleep(200 * time.Millisecond) time.Sleep(200 * time.Millisecond)
// Clean up at the end // Clean up at the end
defer func() { defer func() {
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second) shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer shutdownCancel() defer shutdownCancel()
if err := server.Shutdown(shutdownCtx); err != nil { if err := server.Shutdown(shutdownCtx); err != nil {
t.Logf("Failed to shut down server: %v", err) t.Logf("Failed to shut down server: %v", err)
} }
}() }()
// TODO: Add gRPC client tests here when client implementation is complete // TODO: Add gRPC client tests here when client implementation is complete
t.Log("gRPC server integration test scaffolding added") t.Log("gRPC server integration test scaffolding added")
} }

6
go.mod
View File

@ -10,8 +10,8 @@ require (
) )
require ( require (
golang.org/x/net v0.38.0 // indirect golang.org/x/net v0.35.0 // indirect
golang.org/x/sys v0.31.0 // indirect golang.org/x/sys v0.30.0 // indirect
golang.org/x/text v0.23.0 // indirect golang.org/x/text v0.22.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20250218202821-56aae31c358a // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20250218202821-56aae31c358a // indirect
) )

12
go.sum
View File

@ -28,13 +28,13 @@ go.opentelemetry.io/otel/sdk/metric v1.34.0 h1:5CeK9ujjbFVL5c1PhLuStg1wxA7vQv7ce
go.opentelemetry.io/otel/sdk/metric v1.34.0/go.mod h1:jQ/r8Ze28zRKoNRdkjCZxfs6YvBTG1+YIqyFVFYec5w= go.opentelemetry.io/otel/sdk/metric v1.34.0/go.mod h1:jQ/r8Ze28zRKoNRdkjCZxfs6YvBTG1+YIqyFVFYec5w=
go.opentelemetry.io/otel/trace v1.34.0 h1:+ouXS2V8Rd4hp4580a8q23bg0azF2nI8cqLYnC8mh/k= go.opentelemetry.io/otel/trace v1.34.0 h1:+ouXS2V8Rd4hp4580a8q23bg0azF2nI8cqLYnC8mh/k=
go.opentelemetry.io/otel/trace v1.34.0/go.mod h1:Svm7lSjQD7kG7KJ/MUHPVXSDGz2OX4h0M2jHBhmSfRE= go.opentelemetry.io/otel/trace v1.34.0/go.mod h1:Svm7lSjQD7kG7KJ/MUHPVXSDGz2OX4h0M2jHBhmSfRE=
golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8= golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8=
golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk=
golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY= golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM=
golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4= golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY=
google.golang.org/genproto/googleapis/rpc v0.0.0-20250218202821-56aae31c358a h1:51aaUVRocpvUOSQKM6Q7VuoaktNIaMCLuhZB6DKksq4= google.golang.org/genproto/googleapis/rpc v0.0.0-20250218202821-56aae31c358a h1:51aaUVRocpvUOSQKM6Q7VuoaktNIaMCLuhZB6DKksq4=
google.golang.org/genproto/googleapis/rpc v0.0.0-20250218202821-56aae31c358a/go.mod h1:uRxBH1mhmO8PGhU89cMcHaXKZqO+OfakD8QQO0oYwlQ= google.golang.org/genproto/googleapis/rpc v0.0.0-20250218202821-56aae31c358a/go.mod h1:uRxBH1mhmO8PGhU89cMcHaXKZqO+OfakD8QQO0oYwlQ=
google.golang.org/grpc v1.72.0 h1:S7UkcVa60b5AAQTaO6ZKamFp1zMZSU0fGDK2WZLbBnM= google.golang.org/grpc v1.72.0 h1:S7UkcVa60b5AAQTaO6ZKamFp1zMZSU0fGDK2WZLbBnM=

View File

@ -23,11 +23,11 @@ const (
// ClientOptions configures a Kevo client // ClientOptions configures a Kevo client
type ClientOptions struct { type ClientOptions struct {
// Connection options // Connection options
Endpoint string // Server address Endpoint string // Server address
ConnectTimeout time.Duration // Timeout for connection attempts ConnectTimeout time.Duration // Timeout for connection attempts
RequestTimeout time.Duration // Default timeout for requests RequestTimeout time.Duration // Default timeout for requests
TransportType string // Transport type (e.g. "grpc") TransportType string // Transport type (e.g. "grpc")
PoolSize int // Connection pool size PoolSize int // Connection pool size
// Security options // Security options
TLSEnabled bool // Enable TLS TLSEnabled bool // Enable TLS
@ -50,19 +50,19 @@ type ClientOptions struct {
// DefaultClientOptions returns sensible default client options // DefaultClientOptions returns sensible default client options
func DefaultClientOptions() ClientOptions { func DefaultClientOptions() ClientOptions {
return ClientOptions{ return ClientOptions{
Endpoint: "localhost:50051", Endpoint: "localhost:50051",
ConnectTimeout: time.Second * 5, ConnectTimeout: time.Second * 5,
RequestTimeout: time.Second * 10, RequestTimeout: time.Second * 10,
TransportType: "grpc", TransportType: "grpc",
PoolSize: 5, PoolSize: 5,
TLSEnabled: false, TLSEnabled: false,
MaxRetries: 3, MaxRetries: 3,
InitialBackoff: time.Millisecond * 100, InitialBackoff: time.Millisecond * 100,
MaxBackoff: time.Second * 2, MaxBackoff: time.Second * 2,
BackoffFactor: 1.5, BackoffFactor: 1.5,
RetryJitter: 0.2, RetryJitter: 0.2,
Compression: CompressionNone, Compression: CompressionNone,
MaxMessageSize: 16 * 1024 * 1024, // 16MB MaxMessageSize: 16 * 1024 * 1024, // 16MB
} }
} }
@ -378,4 +378,4 @@ type Stats struct {
SstableCount int32 SstableCount int32
WriteAmplification float64 WriteAmplification float64
ReadAmplification float64 ReadAmplification float64
} }

View File

@ -105,13 +105,13 @@ func TestClientConnect(t *testing.T) {
// Modify default options to use mock transport // Modify default options to use mock transport
options := DefaultClientOptions() options := DefaultClientOptions()
options.TransportType = "mock" options.TransportType = "mock"
// Create a client with the mock transport // Create a client with the mock transport
client, err := NewClient(options) client, err := NewClient(options)
if err != nil { if err != nil {
t.Fatalf("Failed to create client: %v", err) t.Fatalf("Failed to create client: %v", err)
} }
// Get the underlying mock client for test assertions // Get the underlying mock client for test assertions
mock := client.client.(*mockClient) mock := client.client.(*mockClient)
@ -139,12 +139,12 @@ func TestClientGet(t *testing.T) {
// Create a client with the mock transport // Create a client with the mock transport
options := DefaultClientOptions() options := DefaultClientOptions()
options.TransportType = "mock" options.TransportType = "mock"
client, err := NewClient(options) client, err := NewClient(options)
if err != nil { if err != nil {
t.Fatalf("Failed to create client: %v", err) t.Fatalf("Failed to create client: %v", err)
} }
// Get the underlying mock client for test assertions // Get the underlying mock client for test assertions
mock := client.client.(*mockClient) mock := client.client.(*mockClient)
mock.connected = true mock.connected = true
@ -186,12 +186,12 @@ func TestClientPut(t *testing.T) {
// Create a client with the mock transport // Create a client with the mock transport
options := DefaultClientOptions() options := DefaultClientOptions()
options.TransportType = "mock" options.TransportType = "mock"
client, err := NewClient(options) client, err := NewClient(options)
if err != nil { if err != nil {
t.Fatalf("Failed to create client: %v", err) t.Fatalf("Failed to create client: %v", err)
} }
// Get the underlying mock client for test assertions // Get the underlying mock client for test assertions
mock := client.client.(*mockClient) mock := client.client.(*mockClient)
mock.connected = true mock.connected = true
@ -220,12 +220,12 @@ func TestClientDelete(t *testing.T) {
// Create a client with the mock transport // Create a client with the mock transport
options := DefaultClientOptions() options := DefaultClientOptions()
options.TransportType = "mock" options.TransportType = "mock"
client, err := NewClient(options) client, err := NewClient(options)
if err != nil { if err != nil {
t.Fatalf("Failed to create client: %v", err) t.Fatalf("Failed to create client: %v", err)
} }
// Get the underlying mock client for test assertions // Get the underlying mock client for test assertions
mock := client.client.(*mockClient) mock := client.client.(*mockClient)
mock.connected = true mock.connected = true
@ -254,12 +254,12 @@ func TestClientBatchWrite(t *testing.T) {
// Create a client with the mock transport // Create a client with the mock transport
options := DefaultClientOptions() options := DefaultClientOptions()
options.TransportType = "mock" options.TransportType = "mock"
client, err := NewClient(options) client, err := NewClient(options)
if err != nil { if err != nil {
t.Fatalf("Failed to create client: %v", err) t.Fatalf("Failed to create client: %v", err)
} }
// Get the underlying mock client for test assertions // Get the underlying mock client for test assertions
mock := client.client.(*mockClient) mock := client.client.(*mockClient)
mock.connected = true mock.connected = true
@ -295,12 +295,12 @@ func TestClientGetStats(t *testing.T) {
// Create a client with the mock transport // Create a client with the mock transport
options := DefaultClientOptions() options := DefaultClientOptions()
options.TransportType = "mock" options.TransportType = "mock"
client, err := NewClient(options) client, err := NewClient(options)
if err != nil { if err != nil {
t.Fatalf("Failed to create client: %v", err) t.Fatalf("Failed to create client: %v", err)
} }
// Get the underlying mock client for test assertions // Get the underlying mock client for test assertions
mock := client.client.(*mockClient) mock := client.client.(*mockClient)
mock.connected = true mock.connected = true
@ -317,12 +317,12 @@ func TestClientGetStats(t *testing.T) {
"read_amplification": 2.0 "read_amplification": 2.0
}` }`
mock.setResponse(transport.TypeGetStats, []byte(statsJSON)) mock.setResponse(transport.TypeGetStats, []byte(statsJSON))
stats, err := client.GetStats(ctx) stats, err := client.GetStats(ctx)
if err != nil { if err != nil {
t.Errorf("Expected successful get stats, got error: %v", err) t.Errorf("Expected successful get stats, got error: %v", err)
} }
if stats.KeyCount != 1000 { if stats.KeyCount != 1000 {
t.Errorf("Expected KeyCount 1000, got %d", stats.KeyCount) t.Errorf("Expected KeyCount 1000, got %d", stats.KeyCount)
} }
@ -354,12 +354,12 @@ func TestClientCompact(t *testing.T) {
// Create a client with the mock transport // Create a client with the mock transport
options := DefaultClientOptions() options := DefaultClientOptions()
options.TransportType = "mock" options.TransportType = "mock"
client, err := NewClient(options) client, err := NewClient(options)
if err != nil { if err != nil {
t.Fatalf("Failed to create client: %v", err) t.Fatalf("Failed to create client: %v", err)
} }
// Get the underlying mock client for test assertions // Get the underlying mock client for test assertions
mock := client.client.(*mockClient) mock := client.client.(*mockClient)
mock.connected = true mock.connected = true
@ -386,7 +386,7 @@ func TestClientCompact(t *testing.T) {
func TestRetryWithBackoff(t *testing.T) { func TestRetryWithBackoff(t *testing.T) {
ctx := context.Background() ctx := context.Background()
// Test successful retry // Test successful retry
attempts := 0 attempts := 0
err := RetryWithBackoff( err := RetryWithBackoff(
@ -404,14 +404,14 @@ func TestRetryWithBackoff(t *testing.T) {
2.0, // backoffFactor 2.0, // backoffFactor
0.1, // jitter 0.1, // jitter
) )
if err != nil { if err != nil {
t.Errorf("Expected successful retry, got error: %v", err) t.Errorf("Expected successful retry, got error: %v", err)
} }
if attempts != 3 { if attempts != 3 {
t.Errorf("Expected 3 attempts, got %d", attempts) t.Errorf("Expected 3 attempts, got %d", attempts)
} }
// Test max retries exceeded // Test max retries exceeded
attempts = 0 attempts = 0
err = RetryWithBackoff( err = RetryWithBackoff(
@ -426,14 +426,14 @@ func TestRetryWithBackoff(t *testing.T) {
2.0, // backoffFactor 2.0, // backoffFactor
0.1, // jitter 0.1, // jitter
) )
if err == nil { if err == nil {
t.Error("Expected error after max retries, got nil") t.Error("Expected error after max retries, got nil")
} }
if attempts != 4 { // Initial + 3 retries if attempts != 4 { // Initial + 3 retries
t.Errorf("Expected 4 attempts, got %d", attempts) t.Errorf("Expected 4 attempts, got %d", attempts)
} }
// Test non-retryable error // Test non-retryable error
attempts = 0 attempts = 0
err = RetryWithBackoff( err = RetryWithBackoff(
@ -448,14 +448,14 @@ func TestRetryWithBackoff(t *testing.T) {
2.0, // backoffFactor 2.0, // backoffFactor
0.1, // jitter 0.1, // jitter
) )
if err == nil { if err == nil {
t.Error("Expected non-retryable error to be returned, got nil") t.Error("Expected non-retryable error to be returned, got nil")
} }
if attempts != 1 { if attempts != 1 {
t.Errorf("Expected 1 attempt for non-retryable error, got %d", attempts) t.Errorf("Expected 1 attempt for non-retryable error, got %d", attempts)
} }
// Test context cancellation // Test context cancellation
attempts = 0 attempts = 0
cancelCtx, cancel := context.WithCancel(ctx) cancelCtx, cancel := context.WithCancel(ctx)
@ -463,7 +463,7 @@ func TestRetryWithBackoff(t *testing.T) {
time.Sleep(20 * time.Millisecond) time.Sleep(20 * time.Millisecond)
cancel() cancel()
}() }()
err = RetryWithBackoff( err = RetryWithBackoff(
cancelCtx, cancelCtx,
func() error { func() error {
@ -476,8 +476,8 @@ func TestRetryWithBackoff(t *testing.T) {
2.0, // backoffFactor 2.0, // backoffFactor
0.1, // jitter 0.1, // jitter
) )
if !errors.Is(err, context.Canceled) { if !errors.Is(err, context.Canceled) {
t.Errorf("Expected context.Canceled error, got: %v", err) t.Errorf("Expected context.Canceled error, got: %v", err)
} }
} }

View File

@ -304,4 +304,4 @@ func (s *transactionScanIterator) Close() error {
s.closed = true s.closed = true
s.cancelFunc() s.cancelFunc()
return s.stream.Close() return s.stream.Close()
} }

View File

@ -7,33 +7,33 @@ import (
func TestDefaultClientOptions(t *testing.T) { func TestDefaultClientOptions(t *testing.T) {
options := DefaultClientOptions() options := DefaultClientOptions()
// Verify the default options have sensible values // Verify the default options have sensible values
if options.Endpoint != "localhost:50051" { if options.Endpoint != "localhost:50051" {
t.Errorf("Expected default endpoint to be localhost:50051, got %s", options.Endpoint) t.Errorf("Expected default endpoint to be localhost:50051, got %s", options.Endpoint)
} }
if options.ConnectTimeout != 5*time.Second { if options.ConnectTimeout != 5*time.Second {
t.Errorf("Expected default connect timeout to be 5s, got %s", options.ConnectTimeout) t.Errorf("Expected default connect timeout to be 5s, got %s", options.ConnectTimeout)
} }
if options.RequestTimeout != 10*time.Second { if options.RequestTimeout != 10*time.Second {
t.Errorf("Expected default request timeout to be 10s, got %s", options.RequestTimeout) t.Errorf("Expected default request timeout to be 10s, got %s", options.RequestTimeout)
} }
if options.TransportType != "grpc" { if options.TransportType != "grpc" {
t.Errorf("Expected default transport type to be grpc, got %s", options.TransportType) t.Errorf("Expected default transport type to be grpc, got %s", options.TransportType)
} }
if options.PoolSize != 5 { if options.PoolSize != 5 {
t.Errorf("Expected default pool size to be 5, got %d", options.PoolSize) t.Errorf("Expected default pool size to be 5, got %d", options.PoolSize)
} }
if options.TLSEnabled != false { if options.TLSEnabled != false {
t.Errorf("Expected default TLS enabled to be false") t.Errorf("Expected default TLS enabled to be false")
} }
if options.MaxRetries != 3 { if options.MaxRetries != 3 {
t.Errorf("Expected default max retries to be 3, got %d", options.MaxRetries) t.Errorf("Expected default max retries to be 3, got %d", options.MaxRetries)
} }
} }

View File

@ -17,19 +17,19 @@ func mockClientFactory(endpoint string, options transport.TransportOptions) (tra
func TestClientCreation(t *testing.T) { func TestClientCreation(t *testing.T) {
// First, register our mock transport // First, register our mock transport
transport.RegisterClientTransport("mock_test", mockClientFactory) transport.RegisterClientTransport("mock_test", mockClientFactory)
// Create client options using our mock transport // Create client options using our mock transport
options := DefaultClientOptions() options := DefaultClientOptions()
options.TransportType = "mock_test" options.TransportType = "mock_test"
// Create a client // Create a client
client, err := NewClient(options) client, err := NewClient(options)
if err != nil { if err != nil {
t.Fatalf("Failed to create client: %v", err) t.Fatalf("Failed to create client: %v", err)
} }
// Verify the client was created // Verify the client was created
if client == nil { if client == nil {
t.Fatal("Client is nil") t.Fatal("Client is nil")
} }
} }

View File

@ -12,11 +12,11 @@ import (
// Transaction represents a database transaction // Transaction represents a database transaction
type Transaction struct { type Transaction struct {
client *Client client *Client
id string id string
readOnly bool readOnly bool
closed bool closed bool
mu sync.RWMutex mu sync.RWMutex
} }
// ErrTransactionClosed is returned when attempting to use a closed transaction // ErrTransactionClosed is returned when attempting to use a closed transaction
@ -285,4 +285,4 @@ func (tx *Transaction) Delete(ctx context.Context, key []byte) (bool, error) {
} }
return deleteResp.Success, nil return deleteResp.Success, nil
} }

View File

@ -117,4 +117,4 @@ func CalculateExponentialBackoff(
} }
return backoff return backoff
} }

View File

@ -130,14 +130,14 @@ func (h *HierarchicalIterator) Seek(target []byte) bool {
if !iter.Valid() { if !iter.Valid() {
continue continue
} }
// If a newer iterator has the same key, use its value // If a newer iterator has the same key, use its value
if bytes.Equal(iter.Key(), bestKey) { if bytes.Equal(iter.Key(), bestKey) {
bestValue = iter.Value() bestValue = iter.Value()
break // Since iterators are in newest-to-oldest order, we can stop at the first match break // Since iterators are in newest-to-oldest order, we can stop at the first match
} }
} }
// Set the found key/value // Set the found key/value
h.key = bestKey h.key = bestKey
h.value = bestValue h.value = bestValue
@ -253,7 +253,7 @@ func (h *HierarchicalIterator) findNextUniqueKey(prevKey []byte) bool {
// Get the current key // Get the current key
key := iter.Key() key := iter.Key()
// If we haven't found a valid key yet, or this key is smaller than the current best key // If we haven't found a valid key yet, or this key is smaller than the current best key
if bestIterIdx == -1 || bytes.Compare(key, bestKey) < 0 { if bestIterIdx == -1 || bytes.Compare(key, bestKey) < 0 {
// This becomes our best candidate so far // This becomes our best candidate so far
@ -271,14 +271,14 @@ func (h *HierarchicalIterator) findNextUniqueKey(prevKey []byte) bool {
if !iter.Valid() { if !iter.Valid() {
continue continue
} }
// If a newer iterator has the same key, use its value // If a newer iterator has the same key, use its value
if bytes.Equal(iter.Key(), bestKey) { if bytes.Equal(iter.Key(), bestKey) {
bestValue = iter.Value() bestValue = iter.Value()
break // Since iterators are in newest-to-oldest order, we can stop at the first match break // Since iterators are in newest-to-oldest order, we can stop at the first match
} }
} }
// Set the found key/value // Set the found key/value
h.key = bestKey h.key = bestKey
h.value = bestValue h.value = bestValue

View File

@ -62,7 +62,7 @@ type EngineStats struct {
TxAborted atomic.Uint64 TxAborted atomic.Uint64
// Recovery stats // Recovery stats
WALFilesRecovered atomic.Uint64 WALFilesRecovered atomic.Uint64
WALEntriesRecovered atomic.Uint64 WALEntriesRecovered atomic.Uint64
WALCorruptedEntries atomic.Uint64 WALCorruptedEntries atomic.Uint64
WALRecoveryDuration atomic.Int64 // nanoseconds WALRecoveryDuration atomic.Int64 // nanoseconds
@ -524,15 +524,15 @@ func (e *Engine) flushMemTable(mem *memtable.MemTable) error {
for iter.SeekToFirst(); iter.Valid(); iter.Next() { for iter.SeekToFirst(); iter.Valid(); iter.Next() {
key := iter.Key() key := iter.Key()
keyStr := string(key) // Use as map key keyStr := string(key) // Use as map key
// Skip keys we've already processed (including tombstones) // Skip keys we've already processed (including tombstones)
if _, seen := processedKeys[keyStr]; seen { if _, seen := processedKeys[keyStr]; seen {
continue continue
} }
// Mark this key as processed regardless of whether it's a value or tombstone // Mark this key as processed regardless of whether it's a value or tombstone
processedKeys[keyStr] = struct{}{} processedKeys[keyStr] = struct{}{}
// Only write non-tombstone entries to the SSTable // Only write non-tombstone entries to the SSTable
if value := iter.Value(); value != nil { if value := iter.Value(); value != nil {
bytesWritten += uint64(len(key) + len(value)) bytesWritten += uint64(len(key) + len(value))
@ -673,7 +673,7 @@ func (e *Engine) loadSSTables() error {
// recoverFromWAL recovers memtables from existing WAL files // recoverFromWAL recovers memtables from existing WAL files
func (e *Engine) recoverFromWAL() error { func (e *Engine) recoverFromWAL() error {
startTime := time.Now() startTime := time.Now()
// Check if WAL directory exists // Check if WAL directory exists
if _, err := os.Stat(e.walDir); os.IsNotExist(err) { if _, err := os.Stat(e.walDir); os.IsNotExist(err) {
return nil // No WAL directory, nothing to recover return nil // No WAL directory, nothing to recover
@ -685,7 +685,7 @@ func (e *Engine) recoverFromWAL() error {
e.stats.ReadErrors.Add(1) e.stats.ReadErrors.Add(1)
return fmt.Errorf("error listing WAL files: %w", err) return fmt.Errorf("error listing WAL files: %w", err)
} }
if len(walFiles) > 0 { if len(walFiles) > 0 {
e.stats.WALFilesRecovered.Add(uint64(len(walFiles))) e.stats.WALFilesRecovered.Add(uint64(len(walFiles)))
} }
@ -698,7 +698,7 @@ func (e *Engine) recoverFromWAL() error {
if err != nil { if err != nil {
// If recovery fails, let's try cleaning up WAL files // If recovery fails, let's try cleaning up WAL files
e.stats.ReadErrors.Add(1) e.stats.ReadErrors.Add(1)
// Create a backup directory // Create a backup directory
backupDir := filepath.Join(e.walDir, "backup_"+time.Now().Format("20060102_150405")) backupDir := filepath.Join(e.walDir, "backup_"+time.Now().Format("20060102_150405"))
if err := os.MkdirAll(backupDir, 0755); err != nil { if err := os.MkdirAll(backupDir, 0755); err != nil {
@ -724,14 +724,14 @@ func (e *Engine) recoverFromWAL() error {
e.stats.WALRecoveryDuration.Store(time.Since(startTime).Nanoseconds()) e.stats.WALRecoveryDuration.Store(time.Since(startTime).Nanoseconds())
return nil return nil
} }
// Update recovery statistics based on actual entries recovered // Update recovery statistics based on actual entries recovered
if len(walFiles) > 0 { if len(walFiles) > 0 {
// Use WALDir function directly to get stats // Use WALDir function directly to get stats
recoveryStats, statErr := wal.ReplayWALDir(e.cfg.WALDir, func(entry *wal.Entry) error { recoveryStats, statErr := wal.ReplayWALDir(e.cfg.WALDir, func(entry *wal.Entry) error {
return nil // Just counting, not processing return nil // Just counting, not processing
}) })
if statErr == nil && recoveryStats != nil { if statErr == nil && recoveryStats != nil {
e.stats.WALEntriesRecovered.Add(recoveryStats.EntriesProcessed) e.stats.WALEntriesRecovered.Add(recoveryStats.EntriesProcessed)
e.stats.WALCorruptedEntries.Add(recoveryStats.EntriesSkipped) e.stats.WALCorruptedEntries.Add(recoveryStats.EntriesSkipped)
@ -767,7 +767,7 @@ func (e *Engine) recoverFromWAL() error {
// Record recovery stats // Record recovery stats
e.stats.WALRecoveryDuration.Store(time.Since(startTime).Nanoseconds()) e.stats.WALRecoveryDuration.Store(time.Since(startTime).Nanoseconds())
return nil return nil
} }
@ -933,7 +933,7 @@ func (e *Engine) GetStats() map[string]interface{} {
// Add error statistics // Add error statistics
stats["read_errors"] = e.stats.ReadErrors.Load() stats["read_errors"] = e.stats.ReadErrors.Load()
stats["write_errors"] = e.stats.WriteErrors.Load() stats["write_errors"] = e.stats.WriteErrors.Load()
// Add WAL recovery statistics // Add WAL recovery statistics
stats["wal_files_recovered"] = e.stats.WALFilesRecovered.Load() stats["wal_files_recovered"] = e.stats.WALFilesRecovered.Load()
stats["wal_entries_recovered"] = e.stats.WALEntriesRecovered.Load() stats["wal_entries_recovered"] = e.stats.WALEntriesRecovered.Load()

View File

@ -81,8 +81,8 @@ func TestEngine_SameKeyMultipleOperationsFlush(t *testing.T) {
// Simulate exactly the bug scenario from the CLI // Simulate exactly the bug scenario from the CLI
// Add the same key multiple times with different values // Add the same key multiple times with different values
key := []byte("foo") key := []byte("foo")
// First add // First add
if err := engine.Put(key, []byte("23")); err != nil { if err := engine.Put(key, []byte("23")); err != nil {
t.Fatalf("Failed to put first value: %v", err) t.Fatalf("Failed to put first value: %v", err)
} }
@ -91,17 +91,17 @@ func TestEngine_SameKeyMultipleOperationsFlush(t *testing.T) {
if err := engine.Delete(key); err != nil { if err := engine.Delete(key); err != nil {
t.Fatalf("Failed to delete key: %v", err) t.Fatalf("Failed to delete key: %v", err)
} }
// Add it again with different value // Add it again with different value
if err := engine.Put(key, []byte("42")); err != nil { if err := engine.Put(key, []byte("42")); err != nil {
t.Fatalf("Failed to re-add key: %v", err) t.Fatalf("Failed to re-add key: %v", err)
} }
// Add another key // Add another key
if err := engine.Put([]byte("bar"), []byte("23")); err != nil { if err := engine.Put([]byte("bar"), []byte("23")); err != nil {
t.Fatalf("Failed to add another key: %v", err) t.Fatalf("Failed to add another key: %v", err)
} }
// Add another key // Add another key
if err := engine.Put([]byte("user:1"), []byte(`{"name":"John"}`)); err != nil { if err := engine.Put([]byte("user:1"), []byte(`{"name":"John"}`)); err != nil {
t.Fatalf("Failed to add another key: %v", err) t.Fatalf("Failed to add another key: %v", err)
@ -130,7 +130,7 @@ func TestEngine_SameKeyMultipleOperationsFlush(t *testing.T) {
if !bytes.Equal(value, []byte("42")) { if !bytes.Equal(value, []byte("42")) {
t.Errorf("Got incorrect value after flush. Expected: %s, Got: %s", "42", string(value)) t.Errorf("Got incorrect value after flush. Expected: %s, Got: %s", "42", string(value))
} }
value, err = engine.Get([]byte("bar")) value, err = engine.Get([]byte("bar"))
if err != nil { if err != nil {
t.Fatalf("Failed to get 'bar' after flush: %v", err) t.Fatalf("Failed to get 'bar' after flush: %v", err)
@ -138,7 +138,7 @@ func TestEngine_SameKeyMultipleOperationsFlush(t *testing.T) {
if !bytes.Equal(value, []byte("23")) { if !bytes.Equal(value, []byte("23")) {
t.Errorf("Got incorrect value for 'bar' after flush. Expected: %s, Got: %s", "23", string(value)) t.Errorf("Got incorrect value for 'bar' after flush. Expected: %s, Got: %s", "23", string(value))
} }
value, err = engine.Get([]byte("user:1")) value, err = engine.Get([]byte("user:1"))
if err != nil { if err != nil {
t.Fatalf("Failed to get 'user:1' after flush: %v", err) t.Fatalf("Failed to get 'user:1' after flush: %v", err)
@ -154,7 +154,7 @@ func TestEngine_DuplicateKeysFlush(t *testing.T) {
// Test with a key that will be deleted and re-added multiple times // Test with a key that will be deleted and re-added multiple times
key := []byte("foo") key := []byte("foo")
// Add the key // Add the key
if err := engine.Put(key, []byte("42")); err != nil { if err := engine.Put(key, []byte("42")); err != nil {
t.Fatalf("Failed to put initial value: %v", err) t.Fatalf("Failed to put initial value: %v", err)

View File

@ -442,13 +442,13 @@ func (c *chainedIterator) SeekToFirst() {
// Find the iterator with the smallest key from the newest source // Find the iterator with the smallest key from the newest source
c.current = -1 c.current = -1
// Find the smallest valid key // Find the smallest valid key
for i, iter := range c.iterators { for i, iter := range c.iterators {
if !iter.Valid() { if !iter.Valid() {
continue continue
} }
// If we haven't found a key yet, or this key is smaller than the current smallest // If we haven't found a key yet, or this key is smaller than the current smallest
if c.current == -1 || bytes.Compare(iter.Key(), c.iterators[c.current].Key()) < 0 { if c.current == -1 || bytes.Compare(iter.Key(), c.iterators[c.current].Key()) < 0 {
c.current = i c.current = i
@ -499,13 +499,13 @@ func (c *chainedIterator) Seek(target []byte) bool {
// Find the iterator with the smallest key from the newest source // Find the iterator with the smallest key from the newest source
c.current = -1 c.current = -1
// Find the smallest valid key // Find the smallest valid key
for i, iter := range c.iterators { for i, iter := range c.iterators {
if !iter.Valid() { if !iter.Valid() {
continue continue
} }
// If we haven't found a key yet, or this key is smaller than the current smallest // If we haven't found a key yet, or this key is smaller than the current smallest
if c.current == -1 || bytes.Compare(iter.Key(), c.iterators[c.current].Key()) < 0 { if c.current == -1 || bytes.Compare(iter.Key(), c.iterators[c.current].Key()) < 0 {
c.current = i c.current = i
@ -537,18 +537,18 @@ func (c *chainedIterator) Next() bool {
// Find the iterator with the smallest key from the newest source // Find the iterator with the smallest key from the newest source
c.current = -1 c.current = -1
// Find the smallest valid key that is greater than the current key // Find the smallest valid key that is greater than the current key
for i, iter := range c.iterators { for i, iter := range c.iterators {
if !iter.Valid() { if !iter.Valid() {
continue continue
} }
// Skip if the key is the same as the current key (we've already advanced past it) // Skip if the key is the same as the current key (we've already advanced past it)
if bytes.Equal(iter.Key(), currentKey) { if bytes.Equal(iter.Key(), currentKey) {
continue continue
} }
// If we haven't found a key yet, or this key is smaller than the current smallest // If we haven't found a key yet, or this key is smaller than the current smallest
if c.current == -1 || bytes.Compare(iter.Key(), c.iterators[c.current].Key()) < 0 { if c.current == -1 || bytes.Compare(iter.Key(), c.iterators[c.current].Key()) < 0 {
c.current = i c.current = i

View File

@ -109,7 +109,7 @@ func (s *KevoServiceServer) BatchWrite(ctx context.Context, req *pb.BatchWriteRe
if err != nil { if err != nil {
return &pb.BatchWriteResponse{Success: false}, fmt.Errorf("failed to start transaction: %w", err) return &pb.BatchWriteResponse{Success: false}, fmt.Errorf("failed to start transaction: %w", err)
} }
// Ensure we either commit or rollback // Ensure we either commit or rollback
defer func() { defer func() {
if err != nil { if err != nil {
@ -182,7 +182,7 @@ func (s *KevoServiceServer) Scan(req *pb.ScanRequest, stream pb.KevoService_Scan
count := int32(0) count := int32(0)
// Position iterator at the first entry // Position iterator at the first entry
iter.SeekToFirst() iter.SeekToFirst()
// Iterate through all valid entries // Iterate through all valid entries
for iter.Valid() { for iter.Valid() {
if limit > 0 && count >= limit { if limit > 0 && count >= limit {
@ -199,7 +199,7 @@ func (s *KevoServiceServer) Scan(req *pb.ScanRequest, stream pb.KevoService_Scan
} }
count++ count++
} }
// Move to the next entry // Move to the next entry
iter.Next() iter.Next()
} }
@ -225,8 +225,8 @@ func (pi *prefixIterator) Next() bool {
for pi.iter.Next() { for pi.iter.Next() {
// Check if current key has the prefix // Check if current key has the prefix
key := pi.iter.Key() key := pi.iter.Key()
if len(key) >= len(pi.prefix) && if len(key) >= len(pi.prefix) &&
equalByteSlice(key[:len(pi.prefix)], pi.prefix) { equalByteSlice(key[:len(pi.prefix)], pi.prefix) {
return true return true
} }
} }
@ -415,7 +415,7 @@ func (s *KevoServiceServer) TxScan(req *pb.TxScanRequest, stream pb.KevoService_
count := int32(0) count := int32(0)
// Position iterator at the first entry // Position iterator at the first entry
iter.SeekToFirst() iter.SeekToFirst()
// Iterate through all valid entries // Iterate through all valid entries
for iter.Valid() { for iter.Valid() {
if limit > 0 && count >= limit { if limit > 0 && count >= limit {
@ -432,7 +432,7 @@ func (s *KevoServiceServer) TxScan(req *pb.TxScanRequest, stream pb.KevoService_
} }
count++ count++
} }
// Move to the next entry // Move to the next entry
iter.Next() iter.Next()
} }
@ -446,17 +446,17 @@ func (s *KevoServiceServer) GetStats(ctx context.Context, req *pb.GetStatsReques
keyCount := int64(0) keyCount := int64(0)
sstableCount := int32(0) sstableCount := int32(0)
memtableCount := int32(1) // At least 1 active memtable memtableCount := int32(1) // At least 1 active memtable
// Create a read-only transaction to count keys // Create a read-only transaction to count keys
tx, err := s.engine.BeginTransaction(true) tx, err := s.engine.BeginTransaction(true)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to begin transaction for stats: %w", err) return nil, fmt.Errorf("failed to begin transaction for stats: %w", err)
} }
defer tx.Rollback() defer tx.Rollback()
// Use an iterator to count keys // Use an iterator to count keys
iter := tx.NewIterator() iter := tx.NewIterator()
// Count keys and estimate size // Count keys and estimate size
var totalSize int64 var totalSize int64
for iter.Next() { for iter.Next() {
@ -492,7 +492,7 @@ func (s *KevoServiceServer) Compact(ctx context.Context, req *pb.CompactRequest)
if err != nil { if err != nil {
return &pb.CompactResponse{Success: false}, err return &pb.CompactResponse{Success: false}, err
} }
// Do a dummy write to force a flush // Do a dummy write to force a flush
if req.Force { if req.Force {
err = tx.Put([]byte("__compact_marker__"), []byte("force")) err = tx.Put([]byte("__compact_marker__"), []byte("force"))
@ -501,11 +501,11 @@ func (s *KevoServiceServer) Compact(ctx context.Context, req *pb.CompactRequest)
return &pb.CompactResponse{Success: false}, err return &pb.CompactResponse{Success: false}, err
} }
} }
err = tx.Commit() err = tx.Commit()
if err != nil { if err != nil {
return &pb.CompactResponse{Success: false}, err return &pb.CompactResponse{Success: false}, err
} }
return &pb.CompactResponse{Success: true}, nil return &pb.CompactResponse{Success: true}, nil
} }

View File

@ -56,4 +56,4 @@ func sortDurations(durations []time.Duration) {
} }
} }
} }
} }

View File

@ -9,8 +9,8 @@ import (
"sync" "sync"
"time" "time"
"github.com/KevoDB/kevo/pkg/transport"
pb "github.com/KevoDB/kevo/proto/kevo" pb "github.com/KevoDB/kevo/proto/kevo"
"github.com/KevoDB/kevo/pkg/transport"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/credentials/insecure"
@ -19,13 +19,13 @@ import (
// GRPCClient implements the transport.Client interface for gRPC // GRPCClient implements the transport.Client interface for gRPC
type GRPCClient struct { type GRPCClient struct {
endpoint string endpoint string
options transport.TransportOptions options transport.TransportOptions
conn *grpc.ClientConn conn *grpc.ClientConn
client pb.KevoServiceClient client pb.KevoServiceClient
status transport.TransportStatus status transport.TransportStatus
statusMu sync.RWMutex statusMu sync.RWMutex
metrics transport.MetricsCollector metrics transport.MetricsCollector
} }
// NewGRPCClient creates a new gRPC client // NewGRPCClient creates a new gRPC client
@ -123,10 +123,10 @@ func (c *GRPCClient) Status() transport.TransportStatus {
func (c *GRPCClient) setStatus(connected bool, err error) { func (c *GRPCClient) setStatus(connected bool, err error) {
c.statusMu.Lock() c.statusMu.Lock()
defer c.statusMu.Unlock() defer c.statusMu.Unlock()
c.status.Connected = connected c.status.Connected = connected
c.status.LastError = err c.status.LastError = err
if connected { if connected {
c.status.LastConnected = time.Now() c.status.LastConnected = time.Now()
} }
@ -141,11 +141,11 @@ func (c *GRPCClient) Send(ctx context.Context, request transport.Request) (trans
// Record request metrics // Record request metrics
startTime := time.Now() startTime := time.Now()
requestType := request.Type() requestType := request.Type()
// Record bytes sent // Record bytes sent
requestPayload := request.Payload() requestPayload := request.Payload()
c.metrics.RecordSend(len(requestPayload)) c.metrics.RecordSend(len(requestPayload))
var resp transport.Response var resp transport.Response
var err error var err error
@ -182,12 +182,12 @@ func (c *GRPCClient) Send(ctx context.Context, request transport.Request) (trans
// Record metrics for the request // Record metrics for the request
c.metrics.RecordRequest(requestType, startTime, err) c.metrics.RecordRequest(requestType, startTime, err)
// If we got a response, record received bytes // If we got a response, record received bytes
if resp != nil { if resp != nil {
c.metrics.RecordReceive(len(resp.Payload())) c.metrics.RecordReceive(len(resp.Payload()))
} }
return resp, err return resp, err
} }
@ -206,20 +206,20 @@ func (c *GRPCClient) handleGet(ctx context.Context, payload []byte) (transport.R
var req struct { var req struct {
Key []byte `json:"key"` Key []byte `json:"key"`
} }
if err := json.Unmarshal(payload, &req); err != nil { if err := json.Unmarshal(payload, &req); err != nil {
return transport.NewErrorResponse(fmt.Errorf("invalid get request payload: %w", err)), err return transport.NewErrorResponse(fmt.Errorf("invalid get request payload: %w", err)), err
} }
grpcReq := &pb.GetRequest{ grpcReq := &pb.GetRequest{
Key: req.Key, Key: req.Key,
} }
grpcResp, err := c.client.Get(ctx, grpcReq) grpcResp, err := c.client.Get(ctx, grpcReq)
if err != nil { if err != nil {
return transport.NewErrorResponse(err), err return transport.NewErrorResponse(err), err
} }
resp := struct { resp := struct {
Value []byte `json:"value"` Value []byte `json:"value"`
Found bool `json:"found"` Found bool `json:"found"`
@ -227,12 +227,12 @@ func (c *GRPCClient) handleGet(ctx context.Context, payload []byte) (transport.R
Value: grpcResp.Value, Value: grpcResp.Value,
Found: grpcResp.Found, Found: grpcResp.Found,
} }
respData, err := json.Marshal(resp) respData, err := json.Marshal(resp)
if err != nil { if err != nil {
return transport.NewErrorResponse(err), err return transport.NewErrorResponse(err), err
} }
return transport.NewResponse(transport.TypeGet, respData, nil), nil return transport.NewResponse(transport.TypeGet, respData, nil), nil
} }
@ -242,33 +242,33 @@ func (c *GRPCClient) handlePut(ctx context.Context, payload []byte) (transport.R
Value []byte `json:"value"` Value []byte `json:"value"`
Sync bool `json:"sync"` Sync bool `json:"sync"`
} }
if err := json.Unmarshal(payload, &req); err != nil { if err := json.Unmarshal(payload, &req); err != nil {
return transport.NewErrorResponse(fmt.Errorf("invalid put request payload: %w", err)), err return transport.NewErrorResponse(fmt.Errorf("invalid put request payload: %w", err)), err
} }
grpcReq := &pb.PutRequest{ grpcReq := &pb.PutRequest{
Key: req.Key, Key: req.Key,
Value: req.Value, Value: req.Value,
Sync: req.Sync, Sync: req.Sync,
} }
grpcResp, err := c.client.Put(ctx, grpcReq) grpcResp, err := c.client.Put(ctx, grpcReq)
if err != nil { if err != nil {
return transport.NewErrorResponse(err), err return transport.NewErrorResponse(err), err
} }
resp := struct { resp := struct {
Success bool `json:"success"` Success bool `json:"success"`
}{ }{
Success: grpcResp.Success, Success: grpcResp.Success,
} }
respData, err := json.Marshal(resp) respData, err := json.Marshal(resp)
if err != nil { if err != nil {
return transport.NewErrorResponse(err), err return transport.NewErrorResponse(err), err
} }
return transport.NewResponse(transport.TypePut, respData, nil), nil return transport.NewResponse(transport.TypePut, respData, nil), nil
} }
@ -277,32 +277,32 @@ func (c *GRPCClient) handleDelete(ctx context.Context, payload []byte) (transpor
Key []byte `json:"key"` Key []byte `json:"key"`
Sync bool `json:"sync"` Sync bool `json:"sync"`
} }
if err := json.Unmarshal(payload, &req); err != nil { if err := json.Unmarshal(payload, &req); err != nil {
return transport.NewErrorResponse(fmt.Errorf("invalid delete request payload: %w", err)), err return transport.NewErrorResponse(fmt.Errorf("invalid delete request payload: %w", err)), err
} }
grpcReq := &pb.DeleteRequest{ grpcReq := &pb.DeleteRequest{
Key: req.Key, Key: req.Key,
Sync: req.Sync, Sync: req.Sync,
} }
grpcResp, err := c.client.Delete(ctx, grpcReq) grpcResp, err := c.client.Delete(ctx, grpcReq)
if err != nil { if err != nil {
return transport.NewErrorResponse(err), err return transport.NewErrorResponse(err), err
} }
resp := struct { resp := struct {
Success bool `json:"success"` Success bool `json:"success"`
}{ }{
Success: grpcResp.Success, Success: grpcResp.Success,
} }
respData, err := json.Marshal(resp) respData, err := json.Marshal(resp)
if err != nil { if err != nil {
return transport.NewErrorResponse(err), err return transport.NewErrorResponse(err), err
} }
return transport.NewResponse(transport.TypeDelete, respData, nil), nil return transport.NewResponse(transport.TypeDelete, respData, nil), nil
} }
@ -315,18 +315,18 @@ func (c *GRPCClient) handleBatchWrite(ctx context.Context, payload []byte) (tran
} `json:"operations"` } `json:"operations"`
Sync bool `json:"sync"` Sync bool `json:"sync"`
} }
if err := json.Unmarshal(payload, &req); err != nil { if err := json.Unmarshal(payload, &req); err != nil {
return transport.NewErrorResponse(fmt.Errorf("invalid batch write request payload: %w", err)), err return transport.NewErrorResponse(fmt.Errorf("invalid batch write request payload: %w", err)), err
} }
operations := make([]*pb.Operation, len(req.Operations)) operations := make([]*pb.Operation, len(req.Operations))
for i, op := range req.Operations { for i, op := range req.Operations {
pbOp := &pb.Operation{ pbOp := &pb.Operation{
Key: op.Key, Key: op.Key,
Value: op.Value, Value: op.Value,
} }
switch op.Type { switch op.Type {
case "put": case "put":
pbOp.Type = pb.Operation_PUT pbOp.Type = pb.Operation_PUT
@ -335,31 +335,31 @@ func (c *GRPCClient) handleBatchWrite(ctx context.Context, payload []byte) (tran
default: default:
return transport.NewErrorResponse(fmt.Errorf("invalid operation type: %s", op.Type)), fmt.Errorf("invalid operation type: %s", op.Type) return transport.NewErrorResponse(fmt.Errorf("invalid operation type: %s", op.Type)), fmt.Errorf("invalid operation type: %s", op.Type)
} }
operations[i] = pbOp operations[i] = pbOp
} }
grpcReq := &pb.BatchWriteRequest{ grpcReq := &pb.BatchWriteRequest{
Operations: operations, Operations: operations,
Sync: req.Sync, Sync: req.Sync,
} }
grpcResp, err := c.client.BatchWrite(ctx, grpcReq) grpcResp, err := c.client.BatchWrite(ctx, grpcReq)
if err != nil { if err != nil {
return transport.NewErrorResponse(err), err return transport.NewErrorResponse(err), err
} }
resp := struct { resp := struct {
Success bool `json:"success"` Success bool `json:"success"`
}{ }{
Success: grpcResp.Success, Success: grpcResp.Success,
} }
respData, err := json.Marshal(resp) respData, err := json.Marshal(resp)
if err != nil { if err != nil {
return transport.NewErrorResponse(err), err return transport.NewErrorResponse(err), err
} }
return transport.NewResponse(transport.TypeBatchWrite, respData, nil), nil return transport.NewResponse(transport.TypeBatchWrite, respData, nil), nil
} }
@ -367,31 +367,31 @@ func (c *GRPCClient) handleBeginTransaction(ctx context.Context, payload []byte)
var req struct { var req struct {
ReadOnly bool `json:"read_only"` ReadOnly bool `json:"read_only"`
} }
if err := json.Unmarshal(payload, &req); err != nil { if err := json.Unmarshal(payload, &req); err != nil {
return transport.NewErrorResponse(fmt.Errorf("invalid begin transaction request payload: %w", err)), err return transport.NewErrorResponse(fmt.Errorf("invalid begin transaction request payload: %w", err)), err
} }
grpcReq := &pb.BeginTransactionRequest{ grpcReq := &pb.BeginTransactionRequest{
ReadOnly: req.ReadOnly, ReadOnly: req.ReadOnly,
} }
grpcResp, err := c.client.BeginTransaction(ctx, grpcReq) grpcResp, err := c.client.BeginTransaction(ctx, grpcReq)
if err != nil { if err != nil {
return transport.NewErrorResponse(err), err return transport.NewErrorResponse(err), err
} }
resp := struct { resp := struct {
TransactionID string `json:"transaction_id"` TransactionID string `json:"transaction_id"`
}{ }{
TransactionID: grpcResp.TransactionId, TransactionID: grpcResp.TransactionId,
} }
respData, err := json.Marshal(resp) respData, err := json.Marshal(resp)
if err != nil { if err != nil {
return transport.NewErrorResponse(err), err return transport.NewErrorResponse(err), err
} }
return transport.NewResponse(transport.TypeBeginTx, respData, nil), nil return transport.NewResponse(transport.TypeBeginTx, respData, nil), nil
} }
@ -399,31 +399,31 @@ func (c *GRPCClient) handleCommitTransaction(ctx context.Context, payload []byte
var req struct { var req struct {
TransactionID string `json:"transaction_id"` TransactionID string `json:"transaction_id"`
} }
if err := json.Unmarshal(payload, &req); err != nil { if err := json.Unmarshal(payload, &req); err != nil {
return transport.NewErrorResponse(fmt.Errorf("invalid commit transaction request payload: %w", err)), err return transport.NewErrorResponse(fmt.Errorf("invalid commit transaction request payload: %w", err)), err
} }
grpcReq := &pb.CommitTransactionRequest{ grpcReq := &pb.CommitTransactionRequest{
TransactionId: req.TransactionID, TransactionId: req.TransactionID,
} }
grpcResp, err := c.client.CommitTransaction(ctx, grpcReq) grpcResp, err := c.client.CommitTransaction(ctx, grpcReq)
if err != nil { if err != nil {
return transport.NewErrorResponse(err), err return transport.NewErrorResponse(err), err
} }
resp := struct { resp := struct {
Success bool `json:"success"` Success bool `json:"success"`
}{ }{
Success: grpcResp.Success, Success: grpcResp.Success,
} }
respData, err := json.Marshal(resp) respData, err := json.Marshal(resp)
if err != nil { if err != nil {
return transport.NewErrorResponse(err), err return transport.NewErrorResponse(err), err
} }
return transport.NewResponse(transport.TypeCommitTx, respData, nil), nil return transport.NewResponse(transport.TypeCommitTx, respData, nil), nil
} }
@ -431,31 +431,31 @@ func (c *GRPCClient) handleRollbackTransaction(ctx context.Context, payload []by
var req struct { var req struct {
TransactionID string `json:"transaction_id"` TransactionID string `json:"transaction_id"`
} }
if err := json.Unmarshal(payload, &req); err != nil { if err := json.Unmarshal(payload, &req); err != nil {
return transport.NewErrorResponse(fmt.Errorf("invalid rollback transaction request payload: %w", err)), err return transport.NewErrorResponse(fmt.Errorf("invalid rollback transaction request payload: %w", err)), err
} }
grpcReq := &pb.RollbackTransactionRequest{ grpcReq := &pb.RollbackTransactionRequest{
TransactionId: req.TransactionID, TransactionId: req.TransactionID,
} }
grpcResp, err := c.client.RollbackTransaction(ctx, grpcReq) grpcResp, err := c.client.RollbackTransaction(ctx, grpcReq)
if err != nil { if err != nil {
return transport.NewErrorResponse(err), err return transport.NewErrorResponse(err), err
} }
resp := struct { resp := struct {
Success bool `json:"success"` Success bool `json:"success"`
}{ }{
Success: grpcResp.Success, Success: grpcResp.Success,
} }
respData, err := json.Marshal(resp) respData, err := json.Marshal(resp)
if err != nil { if err != nil {
return transport.NewErrorResponse(err), err return transport.NewErrorResponse(err), err
} }
return transport.NewResponse(transport.TypeRollbackTx, respData, nil), nil return transport.NewResponse(transport.TypeRollbackTx, respData, nil), nil
} }
@ -464,21 +464,21 @@ func (c *GRPCClient) handleTxGet(ctx context.Context, payload []byte) (transport
TransactionID string `json:"transaction_id"` TransactionID string `json:"transaction_id"`
Key []byte `json:"key"` Key []byte `json:"key"`
} }
if err := json.Unmarshal(payload, &req); err != nil { if err := json.Unmarshal(payload, &req); err != nil {
return transport.NewErrorResponse(fmt.Errorf("invalid tx get request payload: %w", err)), err return transport.NewErrorResponse(fmt.Errorf("invalid tx get request payload: %w", err)), err
} }
grpcReq := &pb.TxGetRequest{ grpcReq := &pb.TxGetRequest{
TransactionId: req.TransactionID, TransactionId: req.TransactionID,
Key: req.Key, Key: req.Key,
} }
grpcResp, err := c.client.TxGet(ctx, grpcReq) grpcResp, err := c.client.TxGet(ctx, grpcReq)
if err != nil { if err != nil {
return transport.NewErrorResponse(err), err return transport.NewErrorResponse(err), err
} }
resp := struct { resp := struct {
Value []byte `json:"value"` Value []byte `json:"value"`
Found bool `json:"found"` Found bool `json:"found"`
@ -486,12 +486,12 @@ func (c *GRPCClient) handleTxGet(ctx context.Context, payload []byte) (transport
Value: grpcResp.Value, Value: grpcResp.Value,
Found: grpcResp.Found, Found: grpcResp.Found,
} }
respData, err := json.Marshal(resp) respData, err := json.Marshal(resp)
if err != nil { if err != nil {
return transport.NewErrorResponse(err), err return transport.NewErrorResponse(err), err
} }
return transport.NewResponse(transport.TypeTxGet, respData, nil), nil return transport.NewResponse(transport.TypeTxGet, respData, nil), nil
} }
@ -501,33 +501,33 @@ func (c *GRPCClient) handleTxPut(ctx context.Context, payload []byte) (transport
Key []byte `json:"key"` Key []byte `json:"key"`
Value []byte `json:"value"` Value []byte `json:"value"`
} }
if err := json.Unmarshal(payload, &req); err != nil { if err := json.Unmarshal(payload, &req); err != nil {
return transport.NewErrorResponse(fmt.Errorf("invalid tx put request payload: %w", err)), err return transport.NewErrorResponse(fmt.Errorf("invalid tx put request payload: %w", err)), err
} }
grpcReq := &pb.TxPutRequest{ grpcReq := &pb.TxPutRequest{
TransactionId: req.TransactionID, TransactionId: req.TransactionID,
Key: req.Key, Key: req.Key,
Value: req.Value, Value: req.Value,
} }
grpcResp, err := c.client.TxPut(ctx, grpcReq) grpcResp, err := c.client.TxPut(ctx, grpcReq)
if err != nil { if err != nil {
return transport.NewErrorResponse(err), err return transport.NewErrorResponse(err), err
} }
resp := struct { resp := struct {
Success bool `json:"success"` Success bool `json:"success"`
}{ }{
Success: grpcResp.Success, Success: grpcResp.Success,
} }
respData, err := json.Marshal(resp) respData, err := json.Marshal(resp)
if err != nil { if err != nil {
return transport.NewErrorResponse(err), err return transport.NewErrorResponse(err), err
} }
return transport.NewResponse(transport.TypeTxPut, respData, nil), nil return transport.NewResponse(transport.TypeTxPut, respData, nil), nil
} }
@ -536,43 +536,43 @@ func (c *GRPCClient) handleTxDelete(ctx context.Context, payload []byte) (transp
TransactionID string `json:"transaction_id"` TransactionID string `json:"transaction_id"`
Key []byte `json:"key"` Key []byte `json:"key"`
} }
if err := json.Unmarshal(payload, &req); err != nil { if err := json.Unmarshal(payload, &req); err != nil {
return transport.NewErrorResponse(fmt.Errorf("invalid tx delete request payload: %w", err)), err return transport.NewErrorResponse(fmt.Errorf("invalid tx delete request payload: %w", err)), err
} }
grpcReq := &pb.TxDeleteRequest{ grpcReq := &pb.TxDeleteRequest{
TransactionId: req.TransactionID, TransactionId: req.TransactionID,
Key: req.Key, Key: req.Key,
} }
grpcResp, err := c.client.TxDelete(ctx, grpcReq) grpcResp, err := c.client.TxDelete(ctx, grpcReq)
if err != nil { if err != nil {
return transport.NewErrorResponse(err), err return transport.NewErrorResponse(err), err
} }
resp := struct { resp := struct {
Success bool `json:"success"` Success bool `json:"success"`
}{ }{
Success: grpcResp.Success, Success: grpcResp.Success,
} }
respData, err := json.Marshal(resp) respData, err := json.Marshal(resp)
if err != nil { if err != nil {
return transport.NewErrorResponse(err), err return transport.NewErrorResponse(err), err
} }
return transport.NewResponse(transport.TypeTxDelete, respData, nil), nil return transport.NewResponse(transport.TypeTxDelete, respData, nil), nil
} }
func (c *GRPCClient) handleGetStats(ctx context.Context, payload []byte) (transport.Response, error) { func (c *GRPCClient) handleGetStats(ctx context.Context, payload []byte) (transport.Response, error) {
grpcReq := &pb.GetStatsRequest{} grpcReq := &pb.GetStatsRequest{}
grpcResp, err := c.client.GetStats(ctx, grpcReq) grpcResp, err := c.client.GetStats(ctx, grpcReq)
if err != nil { if err != nil {
return transport.NewErrorResponse(err), err return transport.NewErrorResponse(err), err
} }
resp := struct { resp := struct {
KeyCount int64 `json:"key_count"` KeyCount int64 `json:"key_count"`
StorageSize int64 `json:"storage_size"` StorageSize int64 `json:"storage_size"`
@ -588,12 +588,12 @@ func (c *GRPCClient) handleGetStats(ctx context.Context, payload []byte) (transp
WriteAmplification: grpcResp.WriteAmplification, WriteAmplification: grpcResp.WriteAmplification,
ReadAmplification: grpcResp.ReadAmplification, ReadAmplification: grpcResp.ReadAmplification,
} }
respData, err := json.Marshal(resp) respData, err := json.Marshal(resp)
if err != nil { if err != nil {
return transport.NewErrorResponse(err), err return transport.NewErrorResponse(err), err
} }
return transport.NewResponse(transport.TypeGetStats, respData, nil), nil return transport.NewResponse(transport.TypeGetStats, respData, nil), nil
} }
@ -601,31 +601,31 @@ func (c *GRPCClient) handleCompact(ctx context.Context, payload []byte) (transpo
var req struct { var req struct {
Force bool `json:"force"` Force bool `json:"force"`
} }
if err := json.Unmarshal(payload, &req); err != nil { if err := json.Unmarshal(payload, &req); err != nil {
return transport.NewErrorResponse(fmt.Errorf("invalid compact request payload: %w", err)), err return transport.NewErrorResponse(fmt.Errorf("invalid compact request payload: %w", err)), err
} }
grpcReq := &pb.CompactRequest{ grpcReq := &pb.CompactRequest{
Force: req.Force, Force: req.Force,
} }
grpcResp, err := c.client.Compact(ctx, grpcReq) grpcResp, err := c.client.Compact(ctx, grpcReq)
if err != nil { if err != nil {
return transport.NewErrorResponse(err), err return transport.NewErrorResponse(err), err
} }
resp := struct { resp := struct {
Success bool `json:"success"` Success bool `json:"success"`
}{ }{
Success: grpcResp.Success, Success: grpcResp.Success,
} }
respData, err := json.Marshal(resp) respData, err := json.Marshal(resp)
if err != nil { if err != nil {
return transport.NewErrorResponse(err), err return transport.NewErrorResponse(err), err
} }
return transport.NewResponse(transport.TypeCompact, respData, nil), nil return transport.NewResponse(transport.TypeCompact, respData, nil), nil
} }
@ -650,7 +650,7 @@ func (s *GRPCScanStream) Recv() (transport.Response, error) {
} }
return transport.NewErrorResponse(err), err return transport.NewErrorResponse(err), err
} }
// Build response based on scan type // Build response based on scan type
scanResp := struct { scanResp := struct {
Key []byte `json:"key"` Key []byte `json:"key"`
@ -659,12 +659,12 @@ func (s *GRPCScanStream) Recv() (transport.Response, error) {
Key: resp.Key, Key: resp.Key,
Value: resp.Value, Value: resp.Value,
} }
respData, err := json.Marshal(scanResp) respData, err := json.Marshal(scanResp)
if err != nil { if err != nil {
return transport.NewErrorResponse(err), err return transport.NewErrorResponse(err), err
} }
s.client.metrics.RecordReceive(len(respData)) s.client.metrics.RecordReceive(len(respData))
return transport.NewResponse(s.streamType, respData, nil), nil return transport.NewResponse(s.streamType, respData, nil), nil
} }
@ -672,4 +672,4 @@ func (s *GRPCScanStream) Recv() (transport.Response, error) {
func (s *GRPCScanStream) Close() error { func (s *GRPCScanStream) Close() error {
s.cancel() s.cancel()
return nil return nil
} }

View File

@ -7,8 +7,8 @@ import (
"sync" "sync"
"time" "time"
"github.com/KevoDB/kevo/pkg/transport"
pb "github.com/KevoDB/kevo/proto/kevo" pb "github.com/KevoDB/kevo/proto/kevo"
"github.com/KevoDB/kevo/pkg/transport"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/credentials/insecure"
@ -72,7 +72,7 @@ func (g *GRPCTransportManager) Serve() error {
if err := g.Start(); err != nil { if err := g.Start(); err != nil {
return err return err
} }
// Block until server is stopped // Block until server is stopped
<-ctx.Done() <-ctx.Done()
return nil return nil
@ -284,4 +284,4 @@ func init() {
transport.RegisterClientTransport("grpc", func(endpoint string, options transport.TransportOptions) (transport.Client, error) { transport.RegisterClientTransport("grpc", func(endpoint string, options transport.TransportOptions) (transport.Client, error) {
return NewGRPCClient(endpoint, options) return NewGRPCClient(endpoint, options)
}) })
} }

View File

@ -7,15 +7,15 @@ import (
// Simple smoke test for the gRPC transport // Simple smoke test for the gRPC transport
func TestNewGRPCTransportManager(t *testing.T) { func TestNewGRPCTransportManager(t *testing.T) {
opts := DefaultGRPCTransportOptions() opts := DefaultGRPCTransportOptions()
// Override the listen address to avoid port conflicts // Override the listen address to avoid port conflicts
opts.ListenAddr = ":0" // use random available port opts.ListenAddr = ":0" // use random available port
manager, err := NewGRPCTransportManager(opts) manager, err := NewGRPCTransportManager(opts)
if err != nil { if err != nil {
t.Fatalf("Failed to create transport manager: %v", err) t.Fatalf("Failed to create transport manager: %v", err)
} }
// Verify the manager was created // Verify the manager was created
if manager == nil { if manager == nil {
t.Fatal("Expected non-nil manager") t.Fatal("Expected non-nil manager")
@ -60,4 +60,4 @@ func TestLoadClientTLSConfigFromStruct(t *testing.T) {
if !config.InsecureSkipVerify { if !config.InsecureSkipVerify {
t.Fatal("Expected InsecureSkipVerify to be true") t.Fatal("Expected InsecureSkipVerify to be true")
} }
} }

View File

@ -207,4 +207,4 @@ func (m *ConnectionPoolManager) CloseAll() {
m.pools.Delete(key) m.pools.Delete(key)
return true return true
}) })
} }

View File

@ -7,8 +7,8 @@ import (
"sync" "sync"
"time" "time"
"github.com/KevoDB/kevo/pkg/transport"
pb "github.com/KevoDB/kevo/proto/kevo" pb "github.com/KevoDB/kevo/proto/kevo"
"github.com/KevoDB/kevo/pkg/transport"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
"google.golang.org/grpc/keepalive" "google.golang.org/grpc/keepalive"
@ -16,30 +16,30 @@ import (
// GRPCServer implements the transport.Server interface for gRPC // GRPCServer implements the transport.Server interface for gRPC
type GRPCServer struct { type GRPCServer struct {
address string address string
tlsConfig *tls.Config tlsConfig *tls.Config
server *grpc.Server server *grpc.Server
requestHandler transport.RequestHandler requestHandler transport.RequestHandler
started bool started bool
mu sync.Mutex mu sync.Mutex
metrics *transport.ExtendedMetricsCollector metrics *transport.ExtendedMetricsCollector
} }
// NewGRPCServer creates a new gRPC server // NewGRPCServer creates a new gRPC server
func NewGRPCServer(address string, options transport.TransportOptions) (transport.Server, error) { func NewGRPCServer(address string, options transport.TransportOptions) (transport.Server, error) {
// Create server options // Create server options
var serverOpts []grpc.ServerOption var serverOpts []grpc.ServerOption
// Configure TLS if enabled // Configure TLS if enabled
if options.TLSEnabled { if options.TLSEnabled {
tlsConfig, err := LoadServerTLSConfig(options.CertFile, options.KeyFile, options.CAFile) tlsConfig, err := LoadServerTLSConfig(options.CertFile, options.KeyFile, options.CAFile)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to load TLS config: %w", err) return nil, fmt.Errorf("failed to load TLS config: %w", err)
} }
serverOpts = append(serverOpts, grpc.Creds(credentials.NewTLS(tlsConfig))) serverOpts = append(serverOpts, grpc.Creds(credentials.NewTLS(tlsConfig)))
} }
// Configure keepalive parameters // Configure keepalive parameters
kaProps := keepalive.ServerParameters{ kaProps := keepalive.ServerParameters{
MaxConnectionIdle: 30 * time.Minute, MaxConnectionIdle: 30 * time.Minute,
@ -47,20 +47,20 @@ func NewGRPCServer(address string, options transport.TransportOptions) (transpor
Time: 15 * time.Second, Time: 15 * time.Second,
Timeout: 5 * time.Second, Timeout: 5 * time.Second,
} }
kaPolicy := keepalive.EnforcementPolicy{ kaPolicy := keepalive.EnforcementPolicy{
MinTime: 10 * time.Second, MinTime: 10 * time.Second,
PermitWithoutStream: true, PermitWithoutStream: true,
} }
serverOpts = append(serverOpts, serverOpts = append(serverOpts,
grpc.KeepaliveParams(kaProps), grpc.KeepaliveParams(kaProps),
grpc.KeepaliveEnforcementPolicy(kaPolicy), grpc.KeepaliveEnforcementPolicy(kaPolicy),
) )
// Create the server // Create the server
server := grpc.NewServer(serverOpts...) server := grpc.NewServer(serverOpts...)
return &GRPCServer{ return &GRPCServer{
address: address, address: address,
server: server, server: server,
@ -72,18 +72,18 @@ func NewGRPCServer(address string, options transport.TransportOptions) (transpor
func (s *GRPCServer) Start() error { func (s *GRPCServer) Start() error {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
if s.started { if s.started {
return fmt.Errorf("server already started") return fmt.Errorf("server already started")
} }
// Start the server in a goroutine // Start the server in a goroutine
go func() { go func() {
if err := s.Serve(); err != nil { if err := s.Serve(); err != nil {
fmt.Printf("gRPC server error: %v\n", err) fmt.Printf("gRPC server error: %v\n", err)
} }
}() }()
s.started = true s.started = true
return nil return nil
} }
@ -93,31 +93,31 @@ func (s *GRPCServer) Serve() error {
if s.requestHandler == nil { if s.requestHandler == nil {
return fmt.Errorf("no request handler set") return fmt.Errorf("no request handler set")
} }
// Create the service implementation // Create the service implementation
service := &kevoServiceServer{ service := &kevoServiceServer{
handler: s.requestHandler, handler: s.requestHandler,
} }
// Register the service // Register the service
pb.RegisterKevoServiceServer(s.server, service) pb.RegisterKevoServiceServer(s.server, service)
// Start listening // Start listening
listener, err := transport.CreateListener("tcp", s.address, s.tlsConfig) listener, err := transport.CreateListener("tcp", s.address, s.tlsConfig)
if err != nil { if err != nil {
return fmt.Errorf("failed to listen on %s: %w", s.address, err) return fmt.Errorf("failed to listen on %s: %w", s.address, err)
} }
s.metrics.ServerStarted() s.metrics.ServerStarted()
// Serve requests // Serve requests
err = s.server.Serve(listener) err = s.server.Serve(listener)
if err != nil { if err != nil {
s.metrics.ServerErrored() s.metrics.ServerErrored()
return fmt.Errorf("failed to serve: %w", err) return fmt.Errorf("failed to serve: %w", err)
} }
s.metrics.ServerStopped() s.metrics.ServerStopped()
return nil return nil
} }
@ -126,14 +126,14 @@ func (s *GRPCServer) Serve() error {
func (s *GRPCServer) Stop(ctx context.Context) error { func (s *GRPCServer) Stop(ctx context.Context) error {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
if !s.started { if !s.started {
return nil return nil
} }
s.server.GracefulStop() s.server.GracefulStop()
s.started = false s.started = false
return nil return nil
} }
@ -141,7 +141,7 @@ func (s *GRPCServer) Stop(ctx context.Context) error {
func (s *GRPCServer) SetRequestHandler(handler transport.RequestHandler) { func (s *GRPCServer) SetRequestHandler(handler transport.RequestHandler) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
s.requestHandler = handler s.requestHandler = handler
} }
@ -151,4 +151,4 @@ type kevoServiceServer struct {
handler transport.RequestHandler handler transport.RequestHandler
} }
// TODO: Implement service methods // TODO: Implement service methods

View File

@ -92,4 +92,4 @@ func LoadClientTLSConfigFromStruct(config *TLSConfig) (*tls.Config, error) {
return &tls.Config{MinVersion: tls.VersionTLS12}, nil return &tls.Config{MinVersion: tls.VersionTLS12}, nil
} }
return LoadClientTLSConfig(config.CertFile, config.KeyFile, config.CAFile, config.SkipVerify) return LoadClientTLSConfig(config.CertFile, config.KeyFile, config.CAFile, config.SkipVerify)
} }

View File

@ -5,8 +5,8 @@ import (
"sync" "sync"
"time" "time"
"github.com/KevoDB/kevo/pkg/transport"
pb "github.com/KevoDB/kevo/proto/kevo" pb "github.com/KevoDB/kevo/proto/kevo"
"github.com/KevoDB/kevo/pkg/transport"
"google.golang.org/grpc" "google.golang.org/grpc"
) )
@ -30,17 +30,17 @@ func (c *GRPCConnection) Execute(fn func(interface{}) error) error {
// Create a new client from the connection // Create a new client from the connection
client := pb.NewKevoServiceClient(c.conn) client := pb.NewKevoServiceClient(c.conn)
// Execute the provided function with the client // Execute the provided function with the client
err := fn(client) err := fn(client)
// Update metrics if there was an error // Update metrics if there was an error
if err != nil { if err != nil {
c.mu.Lock() c.mu.Lock()
c.errCount++ c.errCount++
c.mu.Unlock() c.mu.Unlock()
} }
return err return err
} }
@ -58,10 +58,10 @@ func (c *GRPCConnection) Address() string {
func (c *GRPCConnection) Status() transport.ConnectionStatus { func (c *GRPCConnection) Status() transport.ConnectionStatus {
c.mu.RLock() c.mu.RLock()
defer c.mu.RUnlock() defer c.mu.RUnlock()
// Check the connection state // Check the connection state
isConnected := c.conn != nil isConnected := c.conn != nil
return transport.ConnectionStatus{ return transport.ConnectionStatus{
Connected: isConnected, Connected: isConnected,
LastActivity: c.lastUsed, LastActivity: c.lastUsed,
@ -81,4 +81,4 @@ type GRPCTransportOptions struct {
MaxConnectionIdle time.Duration MaxConnectionIdle time.Duration
MaxConnectionAge time.Duration MaxConnectionAge time.Duration
MaxPoolConnections int MaxPoolConnections int
} }

View File

@ -6,21 +6,21 @@ import (
// Standard request/response type constants // Standard request/response type constants
const ( const (
TypeGet = "get" TypeGet = "get"
TypePut = "put" TypePut = "put"
TypeDelete = "delete" TypeDelete = "delete"
TypeBatchWrite = "batch_write" TypeBatchWrite = "batch_write"
TypeScan = "scan" TypeScan = "scan"
TypeBeginTx = "begin_tx" TypeBeginTx = "begin_tx"
TypeCommitTx = "commit_tx" TypeCommitTx = "commit_tx"
TypeRollbackTx = "rollback_tx" TypeRollbackTx = "rollback_tx"
TypeTxGet = "tx_get" TypeTxGet = "tx_get"
TypeTxPut = "tx_put" TypeTxPut = "tx_put"
TypeTxDelete = "tx_delete" TypeTxDelete = "tx_delete"
TypeTxScan = "tx_scan" TypeTxScan = "tx_scan"
TypeGetStats = "get_stats" TypeGetStats = "get_stats"
TypeCompact = "compact" TypeCompact = "compact"
TypeError = "error" TypeError = "error"
) )
// Common errors // Common errors
@ -97,4 +97,4 @@ func NewErrorResponse(err error) Response {
ResponseData: msg, ResponseData: msg,
ResponseErr: err, ResponseErr: err,
} }
} }

View File

@ -9,12 +9,12 @@ func TestBasicRequest(t *testing.T) {
// Test creating a request // Test creating a request
payload := []byte("test payload") payload := []byte("test payload")
req := NewRequest(TypeGet, payload) req := NewRequest(TypeGet, payload)
// Test Type method // Test Type method
if req.Type() != TypeGet { if req.Type() != TypeGet {
t.Errorf("Expected type %s, got %s", TypeGet, req.Type()) t.Errorf("Expected type %s, got %s", TypeGet, req.Type())
} }
// Test Payload method // Test Payload method
if string(req.Payload()) != string(payload) { if string(req.Payload()) != string(payload) {
t.Errorf("Expected payload %s, got %s", string(payload), string(req.Payload())) t.Errorf("Expected payload %s, got %s", string(payload), string(req.Payload()))
@ -25,26 +25,26 @@ func TestBasicResponse(t *testing.T) {
// Test creating a response with no error // Test creating a response with no error
payload := []byte("test response") payload := []byte("test response")
resp := NewResponse(TypeGet, payload, nil) resp := NewResponse(TypeGet, payload, nil)
// Test Type method // Test Type method
if resp.Type() != TypeGet { if resp.Type() != TypeGet {
t.Errorf("Expected type %s, got %s", TypeGet, resp.Type()) t.Errorf("Expected type %s, got %s", TypeGet, resp.Type())
} }
// Test Payload method // Test Payload method
if string(resp.Payload()) != string(payload) { if string(resp.Payload()) != string(payload) {
t.Errorf("Expected payload %s, got %s", string(payload), string(resp.Payload())) t.Errorf("Expected payload %s, got %s", string(payload), string(resp.Payload()))
} }
// Test Error method // Test Error method
if resp.Error() != nil { if resp.Error() != nil {
t.Errorf("Expected nil error, got %v", resp.Error()) t.Errorf("Expected nil error, got %v", resp.Error())
} }
// Test creating a response with an error // Test creating a response with an error
testErr := errors.New("test error") testErr := errors.New("test error")
resp = NewResponse(TypeGet, payload, testErr) resp = NewResponse(TypeGet, payload, testErr)
if resp.Error() != testErr { if resp.Error() != testErr {
t.Errorf("Expected error %v, got %v", testErr, resp.Error()) t.Errorf("Expected error %v, got %v", testErr, resp.Error())
} }
@ -54,34 +54,34 @@ func TestNewErrorResponse(t *testing.T) {
// Test creating an error response // Test creating an error response
testErr := errors.New("test error") testErr := errors.New("test error")
resp := NewErrorResponse(testErr) resp := NewErrorResponse(testErr)
// Test Type method // Test Type method
if resp.Type() != TypeError { if resp.Type() != TypeError {
t.Errorf("Expected type %s, got %s", TypeError, resp.Type()) t.Errorf("Expected type %s, got %s", TypeError, resp.Type())
} }
// Test Payload method - should contain error message // Test Payload method - should contain error message
if string(resp.Payload()) != testErr.Error() { if string(resp.Payload()) != testErr.Error() {
t.Errorf("Expected payload %s, got %s", testErr.Error(), string(resp.Payload())) t.Errorf("Expected payload %s, got %s", testErr.Error(), string(resp.Payload()))
} }
// Test Error method // Test Error method
if resp.Error() != testErr { if resp.Error() != testErr {
t.Errorf("Expected error %v, got %v", testErr, resp.Error()) t.Errorf("Expected error %v, got %v", testErr, resp.Error())
} }
// Test with nil error // Test with nil error
resp = NewErrorResponse(nil) resp = NewErrorResponse(nil)
if resp.Type() != TypeError { if resp.Type() != TypeError {
t.Errorf("Expected type %s, got %s", TypeError, resp.Type()) t.Errorf("Expected type %s, got %s", TypeError, resp.Type())
} }
if len(resp.Payload()) != 0 { if len(resp.Payload()) != 0 {
t.Errorf("Expected empty payload, got %s", string(resp.Payload())) t.Errorf("Expected empty payload, got %s", string(resp.Payload()))
} }
if resp.Error() != nil { if resp.Error() != nil {
t.Errorf("Expected nil error, got %v", resp.Error()) t.Errorf("Expected nil error, got %v", resp.Error())
} }
} }

View File

@ -50,7 +50,7 @@ type TransportStatus struct {
type Request interface { type Request interface {
// Type returns the type of request // Type returns the type of request
Type() string Type() string
// Payload returns the payload of the request // Payload returns the payload of the request
Payload() []byte Payload() []byte
} }
@ -59,10 +59,10 @@ type Request interface {
type Response interface { type Response interface {
// Type returns the type of response // Type returns the type of response
Type() string Type() string
// Payload returns the payload of the response // Payload returns the payload of the response
Payload() []byte Payload() []byte
// Error returns any error associated with the response // Error returns any error associated with the response
Error() error Error() error
} }
@ -71,10 +71,10 @@ type Response interface {
type Stream interface { type Stream interface {
// Send sends a request over the stream // Send sends a request over the stream
Send(request Request) error Send(request Request) error
// Recv receives a response from the stream // Recv receives a response from the stream
Recv() (Response, error) Recv() (Response, error)
// Close closes the stream // Close closes the stream
Close() error Close() error
} }
@ -83,19 +83,19 @@ type Stream interface {
type Client interface { type Client interface {
// Connect establishes a connection to the server // Connect establishes a connection to the server
Connect(ctx context.Context) error Connect(ctx context.Context) error
// Close closes the connection // Close closes the connection
Close() error Close() error
// IsConnected returns whether the client is connected // IsConnected returns whether the client is connected
IsConnected() bool IsConnected() bool
// Status returns the current status of the connection // Status returns the current status of the connection
Status() TransportStatus Status() TransportStatus
// Send sends a request and waits for a response // Send sends a request and waits for a response
Send(ctx context.Context, request Request) (Response, error) Send(ctx context.Context, request Request) (Response, error)
// Stream opens a bidirectional stream // Stream opens a bidirectional stream
Stream(ctx context.Context) (Stream, error) Stream(ctx context.Context) (Stream, error)
} }
@ -104,7 +104,7 @@ type Client interface {
type RequestHandler interface { type RequestHandler interface {
// HandleRequest processes a request and returns a response // HandleRequest processes a request and returns a response
HandleRequest(ctx context.Context, request Request) (Response, error) HandleRequest(ctx context.Context, request Request) (Response, error)
// HandleStream processes a bidirectional stream // HandleStream processes a bidirectional stream
HandleStream(stream Stream) error HandleStream(stream Stream) error
} }
@ -113,13 +113,13 @@ type RequestHandler interface {
type Server interface { type Server interface {
// Start starts the server and returns immediately // Start starts the server and returns immediately
Start() error Start() error
// Serve starts the server and blocks until it's stopped // Serve starts the server and blocks until it's stopped
Serve() error Serve() error
// Stop stops the server gracefully // Stop stops the server gracefully
Stop(ctx context.Context) error Stop(ctx context.Context) error
// SetRequestHandler sets the handler for incoming requests // SetRequestHandler sets the handler for incoming requests
SetRequestHandler(handler RequestHandler) SetRequestHandler(handler RequestHandler)
} }
@ -134,16 +134,16 @@ type ServerFactory func(address string, options TransportOptions) (Server, error
type Registry interface { type Registry interface {
// RegisterClient adds a new client implementation to the registry // RegisterClient adds a new client implementation to the registry
RegisterClient(name string, factory ClientFactory) RegisterClient(name string, factory ClientFactory)
// RegisterServer adds a new server implementation to the registry // RegisterServer adds a new server implementation to the registry
RegisterServer(name string, factory ServerFactory) RegisterServer(name string, factory ServerFactory)
// CreateClient instantiates a client by name // CreateClient instantiates a client by name
CreateClient(name, endpoint string, options TransportOptions) (Client, error) CreateClient(name, endpoint string, options TransportOptions) (Client, error)
// CreateServer instantiates a server by name // CreateServer instantiates a server by name
CreateServer(name, address string, options TransportOptions) (Server, error) CreateServer(name, address string, options TransportOptions) (Server, error)
// ListTransports returns all available transport names // ListTransports returns all available transport names
ListTransports() []string ListTransports() []string
} }

View File

@ -10,16 +10,16 @@ import (
type MetricsCollector interface { type MetricsCollector interface {
// RecordRequest records metrics for a request // RecordRequest records metrics for a request
RecordRequest(requestType string, startTime time.Time, err error) RecordRequest(requestType string, startTime time.Time, err error)
// RecordSend records metrics for bytes sent // RecordSend records metrics for bytes sent
RecordSend(bytes int) RecordSend(bytes int)
// RecordReceive records metrics for bytes received // RecordReceive records metrics for bytes received
RecordReceive(bytes int) RecordReceive(bytes int)
// RecordConnection records a connection event // RecordConnection records a connection event
RecordConnection(successful bool) RecordConnection(successful bool)
// GetMetrics returns the current metrics // GetMetrics returns the current metrics
GetMetrics() Metrics GetMetrics() Metrics
} }
@ -46,7 +46,7 @@ type BasicMetricsCollector struct {
bytesReceived uint64 bytesReceived uint64
connections uint64 connections uint64
connectionFailures uint64 connectionFailures uint64
// Track average latency and count for each request type // Track average latency and count for each request type
avgLatencyByType map[string]time.Duration avgLatencyByType map[string]time.Duration
requestCountByType map[string]uint64 requestCountByType map[string]uint64
@ -63,26 +63,26 @@ func NewMetricsCollector() MetricsCollector {
// RecordRequest records metrics for a request // RecordRequest records metrics for a request
func (c *BasicMetricsCollector) RecordRequest(requestType string, startTime time.Time, err error) { func (c *BasicMetricsCollector) RecordRequest(requestType string, startTime time.Time, err error) {
atomic.AddUint64(&c.totalRequests, 1) atomic.AddUint64(&c.totalRequests, 1)
if err == nil { if err == nil {
atomic.AddUint64(&c.successfulRequests, 1) atomic.AddUint64(&c.successfulRequests, 1)
} else { } else {
atomic.AddUint64(&c.failedRequests, 1) atomic.AddUint64(&c.failedRequests, 1)
} }
// Update average latency for request type // Update average latency for request type
latency := time.Since(startTime) latency := time.Since(startTime)
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
currentAvg, exists := c.avgLatencyByType[requestType] currentAvg, exists := c.avgLatencyByType[requestType]
currentCount, _ := c.requestCountByType[requestType] currentCount, _ := c.requestCountByType[requestType]
if exists { if exists {
// Update running average - the common case for better branch prediction // Update running average - the common case for better branch prediction
// new_avg = (old_avg * count + new_value) / (count + 1) // new_avg = (old_avg * count + new_value) / (count + 1)
totalDuration := currentAvg*time.Duration(currentCount) + latency totalDuration := currentAvg * time.Duration(currentCount) + latency
newCount := currentCount + 1 newCount := currentCount + 1
c.avgLatencyByType[requestType] = totalDuration / time.Duration(newCount) c.avgLatencyByType[requestType] = totalDuration / time.Duration(newCount)
c.requestCountByType[requestType] = newCount c.requestCountByType[requestType] = newCount
@ -116,13 +116,13 @@ func (c *BasicMetricsCollector) RecordConnection(successful bool) {
func (c *BasicMetricsCollector) GetMetrics() Metrics { func (c *BasicMetricsCollector) GetMetrics() Metrics {
c.mu.RLock() c.mu.RLock()
defer c.mu.RUnlock() defer c.mu.RUnlock()
// Create a copy of the average latency map // Create a copy of the average latency map
avgLatencyByType := make(map[string]time.Duration, len(c.avgLatencyByType)) avgLatencyByType := make(map[string]time.Duration, len(c.avgLatencyByType))
for k, v := range c.avgLatencyByType { for k, v := range c.avgLatencyByType {
avgLatencyByType[k] = v avgLatencyByType[k] = v
} }
return Metrics{ return Metrics{
TotalRequests: atomic.LoadUint64(&c.totalRequests), TotalRequests: atomic.LoadUint64(&c.totalRequests),
SuccessfulRequests: atomic.LoadUint64(&c.successfulRequests), SuccessfulRequests: atomic.LoadUint64(&c.successfulRequests),
@ -133,4 +133,4 @@ func (c *BasicMetricsCollector) GetMetrics() Metrics {
ConnectionFailures: atomic.LoadUint64(&c.connectionFailures), ConnectionFailures: atomic.LoadUint64(&c.connectionFailures),
AvgLatencyByType: avgLatencyByType, AvgLatencyByType: avgLatencyByType,
} }
} }

View File

@ -9,9 +9,9 @@ import (
// Metrics struct extensions for server metrics // Metrics struct extensions for server metrics
type ServerMetrics struct { type ServerMetrics struct {
Metrics Metrics
ServerStarted uint64 ServerStarted uint64
ServerErrored uint64 ServerErrored uint64
ServerStopped uint64 ServerStopped uint64
} }
// Connection represents a connection to a remote endpoint // Connection represents a connection to a remote endpoint
@ -31,11 +31,11 @@ type Connection interface {
// ConnectionStatus represents the status of a connection // ConnectionStatus represents the status of a connection
type ConnectionStatus struct { type ConnectionStatus struct {
Connected bool Connected bool
LastActivity time.Time LastActivity time.Time
ErrorCount int ErrorCount int
RequestCount int RequestCount int
LatencyAvg time.Duration LatencyAvg time.Duration
} }
// TransportManager is an interface for managing transport layer operations // TransportManager is an interface for managing transport layer operations

View File

@ -8,7 +8,7 @@ import (
func TestBasicMetricsCollector(t *testing.T) { func TestBasicMetricsCollector(t *testing.T) {
collector := NewMetricsCollector() collector := NewMetricsCollector()
// Test initial state // Test initial state
metrics := collector.GetMetrics() metrics := collector.GetMetrics()
if metrics.TotalRequests != 0 || if metrics.TotalRequests != 0 ||
@ -21,11 +21,11 @@ func TestBasicMetricsCollector(t *testing.T) {
len(metrics.AvgLatencyByType) != 0 { len(metrics.AvgLatencyByType) != 0 {
t.Errorf("Initial metrics not initialized correctly: %+v", metrics) t.Errorf("Initial metrics not initialized correctly: %+v", metrics)
} }
// Test recording successful request // Test recording successful request
startTime := time.Now().Add(-100 * time.Millisecond) // Simulate 100ms request startTime := time.Now().Add(-100 * time.Millisecond) // Simulate 100ms request
collector.RecordRequest("get", startTime, nil) collector.RecordRequest("get", startTime, nil)
metrics = collector.GetMetrics() metrics = collector.GetMetrics()
if metrics.TotalRequests != 1 { if metrics.TotalRequests != 1 {
t.Errorf("Expected TotalRequests to be 1, got %d", metrics.TotalRequests) t.Errorf("Expected TotalRequests to be 1, got %d", metrics.TotalRequests)
@ -36,18 +36,18 @@ func TestBasicMetricsCollector(t *testing.T) {
if metrics.FailedRequests != 0 { if metrics.FailedRequests != 0 {
t.Errorf("Expected FailedRequests to be 0, got %d", metrics.FailedRequests) t.Errorf("Expected FailedRequests to be 0, got %d", metrics.FailedRequests)
} }
// Check average latency // Check average latency
if avgLatency, exists := metrics.AvgLatencyByType["get"]; !exists { if avgLatency, exists := metrics.AvgLatencyByType["get"]; !exists {
t.Error("Expected 'get' latency to exist") t.Error("Expected 'get' latency to exist")
} else if avgLatency < 100*time.Millisecond { } else if avgLatency < 100*time.Millisecond {
t.Errorf("Expected latency to be at least 100ms, got %v", avgLatency) t.Errorf("Expected latency to be at least 100ms, got %v", avgLatency)
} }
// Test recording failed request // Test recording failed request
startTime = time.Now().Add(-200 * time.Millisecond) // Simulate 200ms request startTime = time.Now().Add(-200 * time.Millisecond) // Simulate 200ms request
collector.RecordRequest("get", startTime, errors.New("test error")) collector.RecordRequest("get", startTime, errors.New("test error"))
metrics = collector.GetMetrics() metrics = collector.GetMetrics()
if metrics.TotalRequests != 2 { if metrics.TotalRequests != 2 {
t.Errorf("Expected TotalRequests to be 2, got %d", metrics.TotalRequests) t.Errorf("Expected TotalRequests to be 2, got %d", metrics.TotalRequests)
@ -58,26 +58,26 @@ func TestBasicMetricsCollector(t *testing.T) {
if metrics.FailedRequests != 1 { if metrics.FailedRequests != 1 {
t.Errorf("Expected FailedRequests to be 1, got %d", metrics.FailedRequests) t.Errorf("Expected FailedRequests to be 1, got %d", metrics.FailedRequests)
} }
// Test average latency calculation for multiple requests // Test average latency calculation for multiple requests
startTime = time.Now().Add(-300 * time.Millisecond) startTime = time.Now().Add(-300 * time.Millisecond)
collector.RecordRequest("put", startTime, nil) collector.RecordRequest("put", startTime, nil)
startTime = time.Now().Add(-500 * time.Millisecond) startTime = time.Now().Add(-500 * time.Millisecond)
collector.RecordRequest("put", startTime, nil) collector.RecordRequest("put", startTime, nil)
metrics = collector.GetMetrics() metrics = collector.GetMetrics()
avgPutLatency := metrics.AvgLatencyByType["put"] avgPutLatency := metrics.AvgLatencyByType["put"]
// Expected avg is around (300ms + 500ms) / 2 = 400ms // Expected avg is around (300ms + 500ms) / 2 = 400ms
if avgPutLatency < 390*time.Millisecond || avgPutLatency > 410*time.Millisecond { if avgPutLatency < 390*time.Millisecond || avgPutLatency > 410*time.Millisecond {
t.Errorf("Expected average 'put' latency to be around 400ms, got %v", avgPutLatency) t.Errorf("Expected average 'put' latency to be around 400ms, got %v", avgPutLatency)
} }
// Test byte tracking // Test byte tracking
collector.RecordSend(1000) collector.RecordSend(1000)
collector.RecordReceive(2000) collector.RecordReceive(2000)
metrics = collector.GetMetrics() metrics = collector.GetMetrics()
if metrics.BytesSent != 1000 { if metrics.BytesSent != 1000 {
t.Errorf("Expected BytesSent to be 1000, got %d", metrics.BytesSent) t.Errorf("Expected BytesSent to be 1000, got %d", metrics.BytesSent)
@ -85,12 +85,12 @@ func TestBasicMetricsCollector(t *testing.T) {
if metrics.BytesReceived != 2000 { if metrics.BytesReceived != 2000 {
t.Errorf("Expected BytesReceived to be 2000, got %d", metrics.BytesReceived) t.Errorf("Expected BytesReceived to be 2000, got %d", metrics.BytesReceived)
} }
// Test connection tracking // Test connection tracking
collector.RecordConnection(true) collector.RecordConnection(true)
collector.RecordConnection(false) collector.RecordConnection(false)
collector.RecordConnection(true) collector.RecordConnection(true)
metrics = collector.GetMetrics() metrics = collector.GetMetrics()
if metrics.Connections != 2 { if metrics.Connections != 2 {
t.Errorf("Expected Connections to be 2, got %d", metrics.Connections) t.Errorf("Expected Connections to be 2, got %d", metrics.Connections)
@ -98,4 +98,4 @@ func TestBasicMetricsCollector(t *testing.T) {
if metrics.ConnectionFailures != 1 { if metrics.ConnectionFailures != 1 {
t.Errorf("Expected ConnectionFailures to be 1, got %d", metrics.ConnectionFailures) t.Errorf("Expected ConnectionFailures to be 1, got %d", metrics.ConnectionFailures)
} }
} }

View File

@ -12,11 +12,11 @@ func CreateListener(network, address string, tlsConfig *tls.Config) (net.Listene
if err != nil { if err != nil {
return nil, err return nil, err
} }
// If TLS is configured, wrap the listener // If TLS is configured, wrap the listener
if tlsConfig != nil { if tlsConfig != nil {
listener = tls.NewListener(listener, tlsConfig) listener = tls.NewListener(listener, tlsConfig)
} }
return listener, nil return listener, nil
} }

View File

@ -7,7 +7,7 @@ import (
// registry implements the Registry interface // registry implements the Registry interface
type registry struct { type registry struct {
mu sync.RWMutex mu sync.RWMutex
clientFactories map[string]ClientFactory clientFactories map[string]ClientFactory
serverFactories map[string]ServerFactory serverFactories map[string]ServerFactory
} }
@ -111,4 +111,4 @@ func GetServer(name, address string, options TransportOptions) (Server, error) {
// AvailableTransports lists all available transports in the default registry // AvailableTransports lists all available transports in the default registry
func AvailableTransports() []string { func AvailableTransports() []string {
return DefaultRegistry.ListTransports() return DefaultRegistry.ListTransports()
} }

View File

@ -97,17 +97,17 @@ func mockServerFactory(address string, options TransportOptions) (Server, error)
// TestRegistry tests the transport registry // TestRegistry tests the transport registry
func TestRegistry(t *testing.T) { func TestRegistry(t *testing.T) {
registry := NewRegistry() registry := NewRegistry()
// Register transports // Register transports
registry.RegisterClient("mock", mockClientFactory) registry.RegisterClient("mock", mockClientFactory)
registry.RegisterServer("mock", mockServerFactory) registry.RegisterServer("mock", mockServerFactory)
// Test listing transports // Test listing transports
transports := registry.ListTransports() transports := registry.ListTransports()
if len(transports) != 1 || transports[0] != "mock" { if len(transports) != 1 || transports[0] != "mock" {
t.Errorf("Expected [mock], got %v", transports) t.Errorf("Expected [mock], got %v", transports)
} }
// Test creating client // Test creating client
client, err := registry.CreateClient("mock", "localhost:8080", TransportOptions{ client, err := registry.CreateClient("mock", "localhost:8080", TransportOptions{
Timeout: 5 * time.Second, Timeout: 5 * time.Second,
@ -115,21 +115,21 @@ func TestRegistry(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("Failed to create client: %v", err) t.Fatalf("Failed to create client: %v", err)
} }
// Test client methods // Test client methods
if client.IsConnected() { if client.IsConnected() {
t.Error("Expected client to be disconnected initially") t.Error("Expected client to be disconnected initially")
} }
err = client.Connect(context.Background()) err = client.Connect(context.Background())
if err != nil { if err != nil {
t.Fatalf("Failed to connect: %v", err) t.Fatalf("Failed to connect: %v", err)
} }
if !client.IsConnected() { if !client.IsConnected() {
t.Error("Expected client to be connected after Connect()") t.Error("Expected client to be connected after Connect()")
} }
// Test server creation // Test server creation
server, err := registry.CreateServer("mock", "localhost:8080", TransportOptions{ server, err := registry.CreateServer("mock", "localhost:8080", TransportOptions{
Timeout: 5 * time.Second, Timeout: 5 * time.Second,
@ -137,26 +137,26 @@ func TestRegistry(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("Failed to create server: %v", err) t.Fatalf("Failed to create server: %v", err)
} }
// Test server methods // Test server methods
err = server.Start() err = server.Start()
if err != nil { if err != nil {
t.Fatalf("Failed to start server: %v", err) t.Fatalf("Failed to start server: %v", err)
} }
mockServer := server.(*mockServer) mockServer := server.(*mockServer)
if !mockServer.started { if !mockServer.started {
t.Error("Expected server to be started") t.Error("Expected server to be started")
} }
// Test non-existent transport // Test non-existent transport
_, err = registry.CreateClient("nonexistent", "", TransportOptions{}) _, err = registry.CreateClient("nonexistent", "", TransportOptions{})
if err == nil { if err == nil {
t.Error("Expected error creating non-existent client") t.Error("Expected error creating non-existent client")
} }
_, err = registry.CreateServer("nonexistent", "", TransportOptions{}) _, err = registry.CreateServer("nonexistent", "", TransportOptions{})
if err == nil { if err == nil {
t.Error("Expected error creating non-existent server") t.Error("Expected error creating non-existent server")
} }
} }

View File

@ -366,7 +366,7 @@ func ReplayWALDir(dir string, handler EntryHandler) (*RecoveryStats, error) {
// Track overall recovery stats // Track overall recovery stats
totalStats := NewRecoveryStats() totalStats := NewRecoveryStats()
// Track number of files processed successfully // Track number of files processed successfully
successfulFiles := 0 successfulFiles := 0
var lastErr error var lastErr error
@ -397,7 +397,7 @@ func ReplayWALDir(dir string, handler EntryHandler) (*RecoveryStats, error) {
// Add stats from this file to our totals // Add stats from this file to our totals
totalStats.EntriesProcessed += fileStats.EntriesProcessed totalStats.EntriesProcessed += fileStats.EntriesProcessed
totalStats.EntriesSkipped += fileStats.EntriesSkipped totalStats.EntriesSkipped += fileStats.EntriesSkipped
successfulFiles++ successfulFiles++
} }