chore: formatting
All checks were successful
Go Tests / Run Tests (1.24.2) (push) Successful in 9m50s
All checks were successful
Go Tests / Run Tests (1.24.2) (push) Successful in 9m50s
This commit is contained in:
parent
e7974e008d
commit
a0a1c0512f
100
cmd/kevo/main.go
100
cmd/kevo/main.go
@ -81,14 +81,14 @@ Commands (interactive mode only):
|
||||
|
||||
// Config holds the application configuration
|
||||
type Config struct {
|
||||
ServerMode bool
|
||||
DaemonMode bool
|
||||
ListenAddr string
|
||||
DBPath string
|
||||
TLSEnabled bool
|
||||
TLSCertFile string
|
||||
TLSKeyFile string
|
||||
TLSCAFile string
|
||||
ServerMode bool
|
||||
DaemonMode bool
|
||||
ListenAddr string
|
||||
DBPath string
|
||||
TLSEnabled bool
|
||||
TLSCertFile string
|
||||
TLSKeyFile string
|
||||
TLSCAFile string
|
||||
}
|
||||
|
||||
func main() {
|
||||
@ -98,7 +98,7 @@ func main() {
|
||||
// Open database if path provided
|
||||
var eng *engine.Engine
|
||||
var err error
|
||||
|
||||
|
||||
if config.DBPath != "" {
|
||||
fmt.Printf("Opening database at %s\n", config.DBPath)
|
||||
eng, err = engine.NewEngine(config.DBPath)
|
||||
@ -108,18 +108,18 @@ func main() {
|
||||
}
|
||||
defer eng.Close()
|
||||
}
|
||||
|
||||
|
||||
// Check if we should run in server mode
|
||||
if config.ServerMode {
|
||||
if eng == nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: Server mode requires a database path\n")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
|
||||
runServer(eng, config)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
// Run in interactive mode
|
||||
runInteractive(eng, config.DBPath)
|
||||
}
|
||||
@ -151,31 +151,31 @@ func parseFlags() Config {
|
||||
serverMode := flag.Bool("server", false, "Run in server mode, exposing a gRPC API")
|
||||
daemonMode := flag.Bool("daemon", false, "Run in daemon mode (detached from terminal)")
|
||||
listenAddr := flag.String("address", "localhost:50051", "Address to listen on in server mode")
|
||||
|
||||
|
||||
// TLS options
|
||||
tlsEnabled := flag.Bool("tls", false, "Enable TLS for secure connections")
|
||||
tlsCertFile := flag.String("cert", "", "TLS certificate file path")
|
||||
tlsKeyFile := flag.String("key", "", "TLS private key file path")
|
||||
tlsCAFile := flag.String("ca", "", "TLS CA certificate file for client verification")
|
||||
|
||||
|
||||
// Parse flags
|
||||
flag.Parse()
|
||||
|
||||
|
||||
// Get database path from remaining arguments
|
||||
var dbPath string
|
||||
if flag.NArg() > 0 {
|
||||
dbPath = flag.Arg(0)
|
||||
}
|
||||
|
||||
|
||||
return Config{
|
||||
ServerMode: *serverMode,
|
||||
DaemonMode: *daemonMode,
|
||||
ListenAddr: *listenAddr,
|
||||
DBPath: dbPath,
|
||||
TLSEnabled: *tlsEnabled,
|
||||
TLSCertFile: *tlsCertFile,
|
||||
TLSKeyFile: *tlsKeyFile,
|
||||
TLSCAFile: *tlsCAFile,
|
||||
ServerMode: *serverMode,
|
||||
DaemonMode: *daemonMode,
|
||||
ListenAddr: *listenAddr,
|
||||
DBPath: dbPath,
|
||||
TLSEnabled: *tlsEnabled,
|
||||
TLSCertFile: *tlsCertFile,
|
||||
TLSKeyFile: *tlsKeyFile,
|
||||
TLSCAFile: *tlsCAFile,
|
||||
}
|
||||
}
|
||||
|
||||
@ -185,10 +185,10 @@ func runServer(eng *engine.Engine, config Config) {
|
||||
if config.DaemonMode {
|
||||
setupDaemonMode()
|
||||
}
|
||||
|
||||
|
||||
// Create and start the server
|
||||
server := NewServer(eng, config)
|
||||
|
||||
|
||||
// Start the server (non-blocking)
|
||||
if err := server.Start(); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error starting server: %v\n", err)
|
||||
@ -196,10 +196,10 @@ func runServer(eng *engine.Engine, config Config) {
|
||||
}
|
||||
|
||||
fmt.Printf("Kevo server started on %s\n", config.ListenAddr)
|
||||
|
||||
|
||||
// Set up signal handling for graceful shutdown
|
||||
setupGracefulShutdown(server, eng)
|
||||
|
||||
|
||||
// Start serving (blocking)
|
||||
if err := server.Serve(); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error serving: %v\n", err)
|
||||
@ -214,29 +214,29 @@ func setupDaemonMode() {
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to open /dev/null: %v", err)
|
||||
}
|
||||
|
||||
|
||||
// Redirect standard file descriptors to /dev/null
|
||||
err = syscall.Dup2(int(null.Fd()), int(os.Stdin.Fd()))
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to redirect stdin: %v", err)
|
||||
}
|
||||
|
||||
|
||||
err = syscall.Dup2(int(null.Fd()), int(os.Stdout.Fd()))
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to redirect stdout: %v", err)
|
||||
}
|
||||
|
||||
|
||||
err = syscall.Dup2(int(null.Fd()), int(os.Stderr.Fd()))
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to redirect stderr: %v", err)
|
||||
}
|
||||
|
||||
|
||||
// Create a new process group
|
||||
_, err = syscall.Setsid()
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to create new session: %v", err)
|
||||
}
|
||||
|
||||
|
||||
fmt.Println("Daemon mode enabled, detaching from terminal...")
|
||||
}
|
||||
|
||||
@ -244,22 +244,22 @@ func setupDaemonMode() {
|
||||
func setupGracefulShutdown(server *Server, eng *engine.Engine) {
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
|
||||
go func() {
|
||||
sig := <-sigChan
|
||||
fmt.Printf("\nReceived signal %v, shutting down...\n", sig)
|
||||
|
||||
|
||||
// Graceful shutdown logic
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
|
||||
// Shut down the server
|
||||
if err := server.Shutdown(ctx); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error shutting down server: %v\n", err)
|
||||
}
|
||||
|
||||
|
||||
// The engine will be closed by the defer in main()
|
||||
|
||||
|
||||
fmt.Println("Shutdown complete")
|
||||
os.Exit(0)
|
||||
}()
|
||||
@ -269,7 +269,7 @@ func setupGracefulShutdown(server *Server, eng *engine.Engine) {
|
||||
func runInteractive(eng *engine.Engine, dbPath string) {
|
||||
fmt.Println("Kevo (kevo) version 1.0.2")
|
||||
fmt.Println("Enter .help for usage hints.")
|
||||
|
||||
|
||||
var tx engine.Transaction
|
||||
var err error
|
||||
|
||||
@ -412,7 +412,7 @@ func runInteractive(eng *engine.Engine, dbPath string) {
|
||||
|
||||
// Print statistics
|
||||
stats := eng.GetStats()
|
||||
|
||||
|
||||
// Format human-readable time for the last operation timestamps
|
||||
var lastPutTime, lastGetTime, lastDeleteTime time.Time
|
||||
if putTime, ok := stats["last_put_time"].(int64); ok && putTime > 0 {
|
||||
@ -424,13 +424,13 @@ func runInteractive(eng *engine.Engine, dbPath string) {
|
||||
if deleteTime, ok := stats["last_delete_time"].(int64); ok && deleteTime > 0 {
|
||||
lastDeleteTime = time.Unix(0, deleteTime)
|
||||
}
|
||||
|
||||
|
||||
// Operations section
|
||||
fmt.Println("📊 Operations:")
|
||||
fmt.Printf(" • Puts: %d\n", stats["put_ops"])
|
||||
fmt.Printf(" • Gets: %d (Hits: %d, Misses: %d)\n", stats["get_ops"], stats["get_hits"], stats["get_misses"])
|
||||
fmt.Printf(" • Deletes: %d\n", stats["delete_ops"])
|
||||
|
||||
|
||||
// Last Operation Times
|
||||
fmt.Println("\n⏱️ Last Operation Times:")
|
||||
if !lastPutTime.IsZero() {
|
||||
@ -448,25 +448,25 @@ func runInteractive(eng *engine.Engine, dbPath string) {
|
||||
} else {
|
||||
fmt.Printf(" • Last Delete: Never\n")
|
||||
}
|
||||
|
||||
|
||||
// Transactions
|
||||
fmt.Println("\n💼 Transactions:")
|
||||
fmt.Printf(" • Started: %d\n", stats["tx_started"])
|
||||
fmt.Printf(" • Completed: %d\n", stats["tx_completed"])
|
||||
fmt.Printf(" • Aborted: %d\n", stats["tx_aborted"])
|
||||
|
||||
|
||||
// Storage metrics
|
||||
fmt.Println("\n💾 Storage:")
|
||||
fmt.Printf(" • Total Bytes Read: %d\n", stats["total_bytes_read"])
|
||||
fmt.Printf(" • Total Bytes Written: %d\n", stats["total_bytes_written"])
|
||||
fmt.Printf(" • Flush Count: %d\n", stats["flush_count"])
|
||||
|
||||
|
||||
// Table stats
|
||||
fmt.Println("\n📋 Tables:")
|
||||
fmt.Printf(" • SSTable Count: %d\n", stats["sstable_count"])
|
||||
fmt.Printf(" • Immutable MemTable Count: %d\n", stats["immutable_memtable_count"])
|
||||
fmt.Printf(" • Current MemTable Size: %d bytes\n", stats["memtable_size"])
|
||||
|
||||
|
||||
// WAL recovery stats
|
||||
fmt.Println("\n🔄 WAL Recovery:")
|
||||
fmt.Printf(" • Files Recovered: %d\n", stats["wal_files_recovered"])
|
||||
@ -475,17 +475,17 @@ func runInteractive(eng *engine.Engine, dbPath string) {
|
||||
if recoveryDuration, ok := stats["wal_recovery_duration_ms"]; ok {
|
||||
fmt.Printf(" • Recovery Duration: %d ms\n", recoveryDuration)
|
||||
}
|
||||
|
||||
|
||||
// Error counts
|
||||
fmt.Println("\n⚠️ Errors:")
|
||||
fmt.Printf(" • Read Errors: %d\n", stats["read_errors"])
|
||||
fmt.Printf(" • Write Errors: %d\n", stats["write_errors"])
|
||||
|
||||
|
||||
// Compaction stats (if available)
|
||||
if compactionOutputCount, ok := stats["compaction_last_outputs_count"]; ok {
|
||||
fmt.Println("\n🧹 Compaction:")
|
||||
fmt.Printf(" • Last Output Files Count: %d\n", compactionOutputCount)
|
||||
|
||||
|
||||
// Display other compaction stats as available
|
||||
for key, value := range stats {
|
||||
if strings.HasPrefix(key, "compaction_") && key != "compaction_last_outputs_count" && key != "compaction_last_outputs" {
|
||||
@ -825,4 +825,4 @@ func toTitle(s string) string {
|
||||
return r
|
||||
},
|
||||
s)
|
||||
}
|
||||
}
|
||||
|
@ -35,14 +35,14 @@ func (tr *TransactionRegistry) Begin(ctx context.Context, eng *engine.Engine, re
|
||||
// Create context with timeout to prevent potential hangs
|
||||
timeoutCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
|
||||
// Create a channel to receive the transaction result
|
||||
type txResult struct {
|
||||
tx engine.Transaction
|
||||
err error
|
||||
}
|
||||
resultCh := make(chan txResult, 1)
|
||||
|
||||
|
||||
// Start transaction in a goroutine to prevent potential blocking
|
||||
go func() {
|
||||
tx, err := eng.BeginTransaction(readOnly)
|
||||
@ -56,26 +56,26 @@ func (tr *TransactionRegistry) Begin(ctx context.Context, eng *engine.Engine, re
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
|
||||
// Wait for result or timeout
|
||||
select {
|
||||
case result := <-resultCh:
|
||||
if result.err != nil {
|
||||
return "", fmt.Errorf("failed to begin transaction: %w", result.err)
|
||||
}
|
||||
|
||||
|
||||
tr.mu.Lock()
|
||||
defer tr.mu.Unlock()
|
||||
|
||||
|
||||
// Generate a transaction ID
|
||||
tr.nextID++
|
||||
txID := fmt.Sprintf("tx-%d", tr.nextID)
|
||||
|
||||
|
||||
// Register the transaction
|
||||
tr.transactions[txID] = result.tx
|
||||
|
||||
|
||||
return txID, nil
|
||||
|
||||
|
||||
case <-timeoutCtx.Done():
|
||||
return "", fmt.Errorf("transaction creation timed out: %w", timeoutCtx.Err())
|
||||
}
|
||||
@ -104,31 +104,31 @@ func (tr *TransactionRegistry) GracefulShutdown(ctx context.Context) error {
|
||||
defer tr.mu.Unlock()
|
||||
|
||||
var lastErr error
|
||||
|
||||
|
||||
// Copy transaction IDs to avoid modifying the map during iteration
|
||||
ids := make([]string, 0, len(tr.transactions))
|
||||
for id := range tr.transactions {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
|
||||
|
||||
// Rollback each transaction with a timeout
|
||||
for _, id := range ids {
|
||||
tx, exists := tr.transactions[id]
|
||||
if !exists {
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
// Use a timeout for each rollback operation
|
||||
rollbackCtx, cancel := context.WithTimeout(ctx, 1*time.Second)
|
||||
|
||||
|
||||
// Create a channel for the rollback result
|
||||
doneCh := make(chan error, 1)
|
||||
|
||||
|
||||
// Execute rollback in goroutine
|
||||
go func(t engine.Transaction) {
|
||||
doneCh <- t.Rollback()
|
||||
}(tx)
|
||||
|
||||
|
||||
// Wait for rollback or timeout
|
||||
var err error
|
||||
select {
|
||||
@ -137,14 +137,14 @@ func (tr *TransactionRegistry) GracefulShutdown(ctx context.Context) error {
|
||||
case <-rollbackCtx.Done():
|
||||
err = fmt.Errorf("rollback timed out: %w", rollbackCtx.Err())
|
||||
}
|
||||
|
||||
|
||||
cancel() // Clean up context
|
||||
|
||||
|
||||
// Record error if any
|
||||
if err != nil {
|
||||
lastErr = fmt.Errorf("failed to rollback transaction %s: %w", id, err)
|
||||
}
|
||||
|
||||
|
||||
// Always remove transaction from map
|
||||
delete(tr.transactions, id)
|
||||
}
|
||||
@ -154,12 +154,12 @@ func (tr *TransactionRegistry) GracefulShutdown(ctx context.Context) error {
|
||||
|
||||
// Server represents the Kevo server
|
||||
type Server struct {
|
||||
eng *engine.Engine
|
||||
txRegistry *TransactionRegistry
|
||||
listener net.Listener
|
||||
grpcServer *grpc.Server
|
||||
eng *engine.Engine
|
||||
txRegistry *TransactionRegistry
|
||||
listener net.Listener
|
||||
grpcServer *grpc.Server
|
||||
kevoService *grpcservice.KevoServiceServer
|
||||
config Config
|
||||
config Config
|
||||
}
|
||||
|
||||
// NewServer creates a new server instance
|
||||
@ -249,14 +249,14 @@ func (s *Server) Shutdown(ctx context.Context) error {
|
||||
// First, gracefully stop the gRPC server if it exists
|
||||
if s.grpcServer != nil {
|
||||
fmt.Println("Gracefully stopping gRPC server...")
|
||||
|
||||
|
||||
// Create a channel to signal when the server has stopped
|
||||
stopped := make(chan struct{})
|
||||
go func() {
|
||||
s.grpcServer.GracefulStop()
|
||||
close(stopped)
|
||||
}()
|
||||
|
||||
|
||||
// Wait for graceful stop or context deadline
|
||||
select {
|
||||
case <-stopped:
|
||||
@ -266,7 +266,7 @@ func (s *Server) Shutdown(ctx context.Context) error {
|
||||
s.grpcServer.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Shut down the listener if it's still open
|
||||
if s.listener != nil {
|
||||
if err := s.listener.Close(); err != nil {
|
||||
@ -280,4 +280,4 @@ func (s *Server) Shutdown(ctx context.Context) error {
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
@ -14,7 +14,7 @@ func TestTransactionRegistry(t *testing.T) {
|
||||
// Create a timeout context for the whole test
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
|
||||
// Set up temporary directory for test
|
||||
tmpDir, err := os.MkdirTemp("", "kevo_test")
|
||||
if err != nil {
|
||||
@ -153,47 +153,47 @@ func TestGRPCServer(t *testing.T) {
|
||||
t.Fatalf("Failed to create temporary directory: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDBPath)
|
||||
|
||||
|
||||
// Create engine
|
||||
eng, err := engine.NewEngine(tempDBPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create engine: %v", err)
|
||||
}
|
||||
defer eng.Close()
|
||||
|
||||
|
||||
// Create server configuration
|
||||
config := Config{
|
||||
ServerMode: true,
|
||||
ListenAddr: "localhost:50052", // Use a different port for tests
|
||||
DBPath: tempDBPath,
|
||||
}
|
||||
|
||||
|
||||
// Create and start the server
|
||||
server := NewServer(eng, config)
|
||||
if err := server.Start(); err != nil {
|
||||
t.Fatalf("Failed to start server: %v", err)
|
||||
}
|
||||
|
||||
|
||||
// Run server in a goroutine
|
||||
go func() {
|
||||
if err := server.Serve(); err != nil {
|
||||
t.Logf("Server stopped: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
|
||||
// Give the server a moment to start
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
|
||||
// Clean up at the end
|
||||
defer func() {
|
||||
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer shutdownCancel()
|
||||
|
||||
|
||||
if err := server.Shutdown(shutdownCtx); err != nil {
|
||||
t.Logf("Failed to shut down server: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
|
||||
// TODO: Add gRPC client tests here when client implementation is complete
|
||||
t.Log("gRPC server integration test scaffolding added")
|
||||
}
|
||||
}
|
||||
|
@ -23,11 +23,11 @@ const (
|
||||
// ClientOptions configures a Kevo client
|
||||
type ClientOptions struct {
|
||||
// Connection options
|
||||
Endpoint string // Server address
|
||||
ConnectTimeout time.Duration // Timeout for connection attempts
|
||||
RequestTimeout time.Duration // Default timeout for requests
|
||||
TransportType string // Transport type (e.g. "grpc")
|
||||
PoolSize int // Connection pool size
|
||||
Endpoint string // Server address
|
||||
ConnectTimeout time.Duration // Timeout for connection attempts
|
||||
RequestTimeout time.Duration // Default timeout for requests
|
||||
TransportType string // Transport type (e.g. "grpc")
|
||||
PoolSize int // Connection pool size
|
||||
|
||||
// Security options
|
||||
TLSEnabled bool // Enable TLS
|
||||
@ -50,19 +50,19 @@ type ClientOptions struct {
|
||||
// DefaultClientOptions returns sensible default client options
|
||||
func DefaultClientOptions() ClientOptions {
|
||||
return ClientOptions{
|
||||
Endpoint: "localhost:50051",
|
||||
ConnectTimeout: time.Second * 5,
|
||||
RequestTimeout: time.Second * 10,
|
||||
TransportType: "grpc",
|
||||
PoolSize: 5,
|
||||
TLSEnabled: false,
|
||||
MaxRetries: 3,
|
||||
InitialBackoff: time.Millisecond * 100,
|
||||
MaxBackoff: time.Second * 2,
|
||||
BackoffFactor: 1.5,
|
||||
RetryJitter: 0.2,
|
||||
Compression: CompressionNone,
|
||||
MaxMessageSize: 16 * 1024 * 1024, // 16MB
|
||||
Endpoint: "localhost:50051",
|
||||
ConnectTimeout: time.Second * 5,
|
||||
RequestTimeout: time.Second * 10,
|
||||
TransportType: "grpc",
|
||||
PoolSize: 5,
|
||||
TLSEnabled: false,
|
||||
MaxRetries: 3,
|
||||
InitialBackoff: time.Millisecond * 100,
|
||||
MaxBackoff: time.Second * 2,
|
||||
BackoffFactor: 1.5,
|
||||
RetryJitter: 0.2,
|
||||
Compression: CompressionNone,
|
||||
MaxMessageSize: 16 * 1024 * 1024, // 16MB
|
||||
}
|
||||
}
|
||||
|
||||
@ -378,4 +378,4 @@ type Stats struct {
|
||||
SstableCount int32
|
||||
WriteAmplification float64
|
||||
ReadAmplification float64
|
||||
}
|
||||
}
|
||||
|
@ -105,13 +105,13 @@ func TestClientConnect(t *testing.T) {
|
||||
// Modify default options to use mock transport
|
||||
options := DefaultClientOptions()
|
||||
options.TransportType = "mock"
|
||||
|
||||
|
||||
// Create a client with the mock transport
|
||||
client, err := NewClient(options)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create client: %v", err)
|
||||
}
|
||||
|
||||
|
||||
// Get the underlying mock client for test assertions
|
||||
mock := client.client.(*mockClient)
|
||||
|
||||
@ -139,12 +139,12 @@ func TestClientGet(t *testing.T) {
|
||||
// Create a client with the mock transport
|
||||
options := DefaultClientOptions()
|
||||
options.TransportType = "mock"
|
||||
|
||||
|
||||
client, err := NewClient(options)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create client: %v", err)
|
||||
}
|
||||
|
||||
|
||||
// Get the underlying mock client for test assertions
|
||||
mock := client.client.(*mockClient)
|
||||
mock.connected = true
|
||||
@ -186,12 +186,12 @@ func TestClientPut(t *testing.T) {
|
||||
// Create a client with the mock transport
|
||||
options := DefaultClientOptions()
|
||||
options.TransportType = "mock"
|
||||
|
||||
|
||||
client, err := NewClient(options)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create client: %v", err)
|
||||
}
|
||||
|
||||
|
||||
// Get the underlying mock client for test assertions
|
||||
mock := client.client.(*mockClient)
|
||||
mock.connected = true
|
||||
@ -220,12 +220,12 @@ func TestClientDelete(t *testing.T) {
|
||||
// Create a client with the mock transport
|
||||
options := DefaultClientOptions()
|
||||
options.TransportType = "mock"
|
||||
|
||||
|
||||
client, err := NewClient(options)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create client: %v", err)
|
||||
}
|
||||
|
||||
|
||||
// Get the underlying mock client for test assertions
|
||||
mock := client.client.(*mockClient)
|
||||
mock.connected = true
|
||||
@ -254,12 +254,12 @@ func TestClientBatchWrite(t *testing.T) {
|
||||
// Create a client with the mock transport
|
||||
options := DefaultClientOptions()
|
||||
options.TransportType = "mock"
|
||||
|
||||
|
||||
client, err := NewClient(options)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create client: %v", err)
|
||||
}
|
||||
|
||||
|
||||
// Get the underlying mock client for test assertions
|
||||
mock := client.client.(*mockClient)
|
||||
mock.connected = true
|
||||
@ -295,12 +295,12 @@ func TestClientGetStats(t *testing.T) {
|
||||
// Create a client with the mock transport
|
||||
options := DefaultClientOptions()
|
||||
options.TransportType = "mock"
|
||||
|
||||
|
||||
client, err := NewClient(options)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create client: %v", err)
|
||||
}
|
||||
|
||||
|
||||
// Get the underlying mock client for test assertions
|
||||
mock := client.client.(*mockClient)
|
||||
mock.connected = true
|
||||
@ -317,12 +317,12 @@ func TestClientGetStats(t *testing.T) {
|
||||
"read_amplification": 2.0
|
||||
}`
|
||||
mock.setResponse(transport.TypeGetStats, []byte(statsJSON))
|
||||
|
||||
|
||||
stats, err := client.GetStats(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("Expected successful get stats, got error: %v", err)
|
||||
}
|
||||
|
||||
|
||||
if stats.KeyCount != 1000 {
|
||||
t.Errorf("Expected KeyCount 1000, got %d", stats.KeyCount)
|
||||
}
|
||||
@ -354,12 +354,12 @@ func TestClientCompact(t *testing.T) {
|
||||
// Create a client with the mock transport
|
||||
options := DefaultClientOptions()
|
||||
options.TransportType = "mock"
|
||||
|
||||
|
||||
client, err := NewClient(options)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create client: %v", err)
|
||||
}
|
||||
|
||||
|
||||
// Get the underlying mock client for test assertions
|
||||
mock := client.client.(*mockClient)
|
||||
mock.connected = true
|
||||
@ -386,7 +386,7 @@ func TestClientCompact(t *testing.T) {
|
||||
|
||||
func TestRetryWithBackoff(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
|
||||
// Test successful retry
|
||||
attempts := 0
|
||||
err := RetryWithBackoff(
|
||||
@ -404,14 +404,14 @@ func TestRetryWithBackoff(t *testing.T) {
|
||||
2.0, // backoffFactor
|
||||
0.1, // jitter
|
||||
)
|
||||
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Expected successful retry, got error: %v", err)
|
||||
}
|
||||
if attempts != 3 {
|
||||
t.Errorf("Expected 3 attempts, got %d", attempts)
|
||||
}
|
||||
|
||||
|
||||
// Test max retries exceeded
|
||||
attempts = 0
|
||||
err = RetryWithBackoff(
|
||||
@ -426,14 +426,14 @@ func TestRetryWithBackoff(t *testing.T) {
|
||||
2.0, // backoffFactor
|
||||
0.1, // jitter
|
||||
)
|
||||
|
||||
|
||||
if err == nil {
|
||||
t.Error("Expected error after max retries, got nil")
|
||||
}
|
||||
if attempts != 4 { // Initial + 3 retries
|
||||
t.Errorf("Expected 4 attempts, got %d", attempts)
|
||||
}
|
||||
|
||||
|
||||
// Test non-retryable error
|
||||
attempts = 0
|
||||
err = RetryWithBackoff(
|
||||
@ -448,14 +448,14 @@ func TestRetryWithBackoff(t *testing.T) {
|
||||
2.0, // backoffFactor
|
||||
0.1, // jitter
|
||||
)
|
||||
|
||||
|
||||
if err == nil {
|
||||
t.Error("Expected non-retryable error to be returned, got nil")
|
||||
}
|
||||
if attempts != 1 {
|
||||
t.Errorf("Expected 1 attempt for non-retryable error, got %d", attempts)
|
||||
}
|
||||
|
||||
|
||||
// Test context cancellation
|
||||
attempts = 0
|
||||
cancelCtx, cancel := context.WithCancel(ctx)
|
||||
@ -463,7 +463,7 @@ func TestRetryWithBackoff(t *testing.T) {
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
cancel()
|
||||
}()
|
||||
|
||||
|
||||
err = RetryWithBackoff(
|
||||
cancelCtx,
|
||||
func() error {
|
||||
@ -476,8 +476,8 @@ func TestRetryWithBackoff(t *testing.T) {
|
||||
2.0, // backoffFactor
|
||||
0.1, // jitter
|
||||
)
|
||||
|
||||
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
t.Errorf("Expected context.Canceled error, got: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -304,4 +304,4 @@ func (s *transactionScanIterator) Close() error {
|
||||
s.closed = true
|
||||
s.cancelFunc()
|
||||
return s.stream.Close()
|
||||
}
|
||||
}
|
||||
|
@ -7,33 +7,33 @@ import (
|
||||
|
||||
func TestDefaultClientOptions(t *testing.T) {
|
||||
options := DefaultClientOptions()
|
||||
|
||||
|
||||
// Verify the default options have sensible values
|
||||
if options.Endpoint != "localhost:50051" {
|
||||
t.Errorf("Expected default endpoint to be localhost:50051, got %s", options.Endpoint)
|
||||
}
|
||||
|
||||
|
||||
if options.ConnectTimeout != 5*time.Second {
|
||||
t.Errorf("Expected default connect timeout to be 5s, got %s", options.ConnectTimeout)
|
||||
}
|
||||
|
||||
|
||||
if options.RequestTimeout != 10*time.Second {
|
||||
t.Errorf("Expected default request timeout to be 10s, got %s", options.RequestTimeout)
|
||||
}
|
||||
|
||||
|
||||
if options.TransportType != "grpc" {
|
||||
t.Errorf("Expected default transport type to be grpc, got %s", options.TransportType)
|
||||
}
|
||||
|
||||
|
||||
if options.PoolSize != 5 {
|
||||
t.Errorf("Expected default pool size to be 5, got %d", options.PoolSize)
|
||||
}
|
||||
|
||||
|
||||
if options.TLSEnabled != false {
|
||||
t.Errorf("Expected default TLS enabled to be false")
|
||||
}
|
||||
|
||||
|
||||
if options.MaxRetries != 3 {
|
||||
t.Errorf("Expected default max retries to be 3, got %d", options.MaxRetries)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -17,19 +17,19 @@ func mockClientFactory(endpoint string, options transport.TransportOptions) (tra
|
||||
func TestClientCreation(t *testing.T) {
|
||||
// First, register our mock transport
|
||||
transport.RegisterClientTransport("mock_test", mockClientFactory)
|
||||
|
||||
|
||||
// Create client options using our mock transport
|
||||
options := DefaultClientOptions()
|
||||
options.TransportType = "mock_test"
|
||||
|
||||
|
||||
// Create a client
|
||||
client, err := NewClient(options)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create client: %v", err)
|
||||
}
|
||||
|
||||
|
||||
// Verify the client was created
|
||||
if client == nil {
|
||||
t.Fatal("Client is nil")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -12,11 +12,11 @@ import (
|
||||
|
||||
// Transaction represents a database transaction
|
||||
type Transaction struct {
|
||||
client *Client
|
||||
id string
|
||||
readOnly bool
|
||||
closed bool
|
||||
mu sync.RWMutex
|
||||
client *Client
|
||||
id string
|
||||
readOnly bool
|
||||
closed bool
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// ErrTransactionClosed is returned when attempting to use a closed transaction
|
||||
@ -285,4 +285,4 @@ func (tx *Transaction) Delete(ctx context.Context, key []byte) (bool, error) {
|
||||
}
|
||||
|
||||
return deleteResp.Success, nil
|
||||
}
|
||||
}
|
||||
|
@ -117,4 +117,4 @@ func CalculateExponentialBackoff(
|
||||
}
|
||||
|
||||
return backoff
|
||||
}
|
||||
}
|
||||
|
@ -130,14 +130,14 @@ func (h *HierarchicalIterator) Seek(target []byte) bool {
|
||||
if !iter.Valid() {
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
// If a newer iterator has the same key, use its value
|
||||
if bytes.Equal(iter.Key(), bestKey) {
|
||||
bestValue = iter.Value()
|
||||
break // Since iterators are in newest-to-oldest order, we can stop at the first match
|
||||
break // Since iterators are in newest-to-oldest order, we can stop at the first match
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Set the found key/value
|
||||
h.key = bestKey
|
||||
h.value = bestValue
|
||||
@ -253,7 +253,7 @@ func (h *HierarchicalIterator) findNextUniqueKey(prevKey []byte) bool {
|
||||
|
||||
// Get the current key
|
||||
key := iter.Key()
|
||||
|
||||
|
||||
// If we haven't found a valid key yet, or this key is smaller than the current best key
|
||||
if bestIterIdx == -1 || bytes.Compare(key, bestKey) < 0 {
|
||||
// This becomes our best candidate so far
|
||||
@ -271,14 +271,14 @@ func (h *HierarchicalIterator) findNextUniqueKey(prevKey []byte) bool {
|
||||
if !iter.Valid() {
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
// If a newer iterator has the same key, use its value
|
||||
if bytes.Equal(iter.Key(), bestKey) {
|
||||
bestValue = iter.Value()
|
||||
break // Since iterators are in newest-to-oldest order, we can stop at the first match
|
||||
break // Since iterators are in newest-to-oldest order, we can stop at the first match
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Set the found key/value
|
||||
h.key = bestKey
|
||||
h.value = bestValue
|
||||
|
@ -62,7 +62,7 @@ type EngineStats struct {
|
||||
TxAborted atomic.Uint64
|
||||
|
||||
// Recovery stats
|
||||
WALFilesRecovered atomic.Uint64
|
||||
WALFilesRecovered atomic.Uint64
|
||||
WALEntriesRecovered atomic.Uint64
|
||||
WALCorruptedEntries atomic.Uint64
|
||||
WALRecoveryDuration atomic.Int64 // nanoseconds
|
||||
@ -524,15 +524,15 @@ func (e *Engine) flushMemTable(mem *memtable.MemTable) error {
|
||||
for iter.SeekToFirst(); iter.Valid(); iter.Next() {
|
||||
key := iter.Key()
|
||||
keyStr := string(key) // Use as map key
|
||||
|
||||
|
||||
// Skip keys we've already processed (including tombstones)
|
||||
if _, seen := processedKeys[keyStr]; seen {
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
// Mark this key as processed regardless of whether it's a value or tombstone
|
||||
processedKeys[keyStr] = struct{}{}
|
||||
|
||||
|
||||
// Only write non-tombstone entries to the SSTable
|
||||
if value := iter.Value(); value != nil {
|
||||
bytesWritten += uint64(len(key) + len(value))
|
||||
@ -673,7 +673,7 @@ func (e *Engine) loadSSTables() error {
|
||||
// recoverFromWAL recovers memtables from existing WAL files
|
||||
func (e *Engine) recoverFromWAL() error {
|
||||
startTime := time.Now()
|
||||
|
||||
|
||||
// Check if WAL directory exists
|
||||
if _, err := os.Stat(e.walDir); os.IsNotExist(err) {
|
||||
return nil // No WAL directory, nothing to recover
|
||||
@ -685,7 +685,7 @@ func (e *Engine) recoverFromWAL() error {
|
||||
e.stats.ReadErrors.Add(1)
|
||||
return fmt.Errorf("error listing WAL files: %w", err)
|
||||
}
|
||||
|
||||
|
||||
if len(walFiles) > 0 {
|
||||
e.stats.WALFilesRecovered.Add(uint64(len(walFiles)))
|
||||
}
|
||||
@ -698,7 +698,7 @@ func (e *Engine) recoverFromWAL() error {
|
||||
if err != nil {
|
||||
// If recovery fails, let's try cleaning up WAL files
|
||||
e.stats.ReadErrors.Add(1)
|
||||
|
||||
|
||||
// Create a backup directory
|
||||
backupDir := filepath.Join(e.walDir, "backup_"+time.Now().Format("20060102_150405"))
|
||||
if err := os.MkdirAll(backupDir, 0755); err != nil {
|
||||
@ -724,14 +724,14 @@ func (e *Engine) recoverFromWAL() error {
|
||||
e.stats.WALRecoveryDuration.Store(time.Since(startTime).Nanoseconds())
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
// Update recovery statistics based on actual entries recovered
|
||||
if len(walFiles) > 0 {
|
||||
// Use WALDir function directly to get stats
|
||||
recoveryStats, statErr := wal.ReplayWALDir(e.cfg.WALDir, func(entry *wal.Entry) error {
|
||||
return nil // Just counting, not processing
|
||||
})
|
||||
|
||||
|
||||
if statErr == nil && recoveryStats != nil {
|
||||
e.stats.WALEntriesRecovered.Add(recoveryStats.EntriesProcessed)
|
||||
e.stats.WALCorruptedEntries.Add(recoveryStats.EntriesSkipped)
|
||||
@ -767,7 +767,7 @@ func (e *Engine) recoverFromWAL() error {
|
||||
|
||||
// Record recovery stats
|
||||
e.stats.WALRecoveryDuration.Store(time.Since(startTime).Nanoseconds())
|
||||
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -933,7 +933,7 @@ func (e *Engine) GetStats() map[string]interface{} {
|
||||
// Add error statistics
|
||||
stats["read_errors"] = e.stats.ReadErrors.Load()
|
||||
stats["write_errors"] = e.stats.WriteErrors.Load()
|
||||
|
||||
|
||||
// Add WAL recovery statistics
|
||||
stats["wal_files_recovered"] = e.stats.WALFilesRecovered.Load()
|
||||
stats["wal_entries_recovered"] = e.stats.WALEntriesRecovered.Load()
|
||||
|
@ -81,8 +81,8 @@ func TestEngine_SameKeyMultipleOperationsFlush(t *testing.T) {
|
||||
// Simulate exactly the bug scenario from the CLI
|
||||
// Add the same key multiple times with different values
|
||||
key := []byte("foo")
|
||||
|
||||
// First add
|
||||
|
||||
// First add
|
||||
if err := engine.Put(key, []byte("23")); err != nil {
|
||||
t.Fatalf("Failed to put first value: %v", err)
|
||||
}
|
||||
@ -91,17 +91,17 @@ func TestEngine_SameKeyMultipleOperationsFlush(t *testing.T) {
|
||||
if err := engine.Delete(key); err != nil {
|
||||
t.Fatalf("Failed to delete key: %v", err)
|
||||
}
|
||||
|
||||
|
||||
// Add it again with different value
|
||||
if err := engine.Put(key, []byte("42")); err != nil {
|
||||
t.Fatalf("Failed to re-add key: %v", err)
|
||||
}
|
||||
|
||||
|
||||
// Add another key
|
||||
if err := engine.Put([]byte("bar"), []byte("23")); err != nil {
|
||||
t.Fatalf("Failed to add another key: %v", err)
|
||||
}
|
||||
|
||||
|
||||
// Add another key
|
||||
if err := engine.Put([]byte("user:1"), []byte(`{"name":"John"}`)); err != nil {
|
||||
t.Fatalf("Failed to add another key: %v", err)
|
||||
@ -130,7 +130,7 @@ func TestEngine_SameKeyMultipleOperationsFlush(t *testing.T) {
|
||||
if !bytes.Equal(value, []byte("42")) {
|
||||
t.Errorf("Got incorrect value after flush. Expected: %s, Got: %s", "42", string(value))
|
||||
}
|
||||
|
||||
|
||||
value, err = engine.Get([]byte("bar"))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get 'bar' after flush: %v", err)
|
||||
@ -138,7 +138,7 @@ func TestEngine_SameKeyMultipleOperationsFlush(t *testing.T) {
|
||||
if !bytes.Equal(value, []byte("23")) {
|
||||
t.Errorf("Got incorrect value for 'bar' after flush. Expected: %s, Got: %s", "23", string(value))
|
||||
}
|
||||
|
||||
|
||||
value, err = engine.Get([]byte("user:1"))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get 'user:1' after flush: %v", err)
|
||||
@ -154,7 +154,7 @@ func TestEngine_DuplicateKeysFlush(t *testing.T) {
|
||||
|
||||
// Test with a key that will be deleted and re-added multiple times
|
||||
key := []byte("foo")
|
||||
|
||||
|
||||
// Add the key
|
||||
if err := engine.Put(key, []byte("42")); err != nil {
|
||||
t.Fatalf("Failed to put initial value: %v", err)
|
||||
|
@ -442,13 +442,13 @@ func (c *chainedIterator) SeekToFirst() {
|
||||
|
||||
// Find the iterator with the smallest key from the newest source
|
||||
c.current = -1
|
||||
|
||||
|
||||
// Find the smallest valid key
|
||||
for i, iter := range c.iterators {
|
||||
if !iter.Valid() {
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
// If we haven't found a key yet, or this key is smaller than the current smallest
|
||||
if c.current == -1 || bytes.Compare(iter.Key(), c.iterators[c.current].Key()) < 0 {
|
||||
c.current = i
|
||||
@ -499,13 +499,13 @@ func (c *chainedIterator) Seek(target []byte) bool {
|
||||
|
||||
// Find the iterator with the smallest key from the newest source
|
||||
c.current = -1
|
||||
|
||||
|
||||
// Find the smallest valid key
|
||||
for i, iter := range c.iterators {
|
||||
if !iter.Valid() {
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
// If we haven't found a key yet, or this key is smaller than the current smallest
|
||||
if c.current == -1 || bytes.Compare(iter.Key(), c.iterators[c.current].Key()) < 0 {
|
||||
c.current = i
|
||||
@ -537,18 +537,18 @@ func (c *chainedIterator) Next() bool {
|
||||
|
||||
// Find the iterator with the smallest key from the newest source
|
||||
c.current = -1
|
||||
|
||||
|
||||
// Find the smallest valid key that is greater than the current key
|
||||
for i, iter := range c.iterators {
|
||||
if !iter.Valid() {
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
// Skip if the key is the same as the current key (we've already advanced past it)
|
||||
if bytes.Equal(iter.Key(), currentKey) {
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
// If we haven't found a key yet, or this key is smaller than the current smallest
|
||||
if c.current == -1 || bytes.Compare(iter.Key(), c.iterators[c.current].Key()) < 0 {
|
||||
c.current = i
|
||||
|
@ -109,7 +109,7 @@ func (s *KevoServiceServer) BatchWrite(ctx context.Context, req *pb.BatchWriteRe
|
||||
if err != nil {
|
||||
return &pb.BatchWriteResponse{Success: false}, fmt.Errorf("failed to start transaction: %w", err)
|
||||
}
|
||||
|
||||
|
||||
// Ensure we either commit or rollback
|
||||
defer func() {
|
||||
if err != nil {
|
||||
@ -182,7 +182,7 @@ func (s *KevoServiceServer) Scan(req *pb.ScanRequest, stream pb.KevoService_Scan
|
||||
count := int32(0)
|
||||
// Position iterator at the first entry
|
||||
iter.SeekToFirst()
|
||||
|
||||
|
||||
// Iterate through all valid entries
|
||||
for iter.Valid() {
|
||||
if limit > 0 && count >= limit {
|
||||
@ -199,7 +199,7 @@ func (s *KevoServiceServer) Scan(req *pb.ScanRequest, stream pb.KevoService_Scan
|
||||
}
|
||||
count++
|
||||
}
|
||||
|
||||
|
||||
// Move to the next entry
|
||||
iter.Next()
|
||||
}
|
||||
@ -225,8 +225,8 @@ func (pi *prefixIterator) Next() bool {
|
||||
for pi.iter.Next() {
|
||||
// Check if current key has the prefix
|
||||
key := pi.iter.Key()
|
||||
if len(key) >= len(pi.prefix) &&
|
||||
equalByteSlice(key[:len(pi.prefix)], pi.prefix) {
|
||||
if len(key) >= len(pi.prefix) &&
|
||||
equalByteSlice(key[:len(pi.prefix)], pi.prefix) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
@ -415,7 +415,7 @@ func (s *KevoServiceServer) TxScan(req *pb.TxScanRequest, stream pb.KevoService_
|
||||
count := int32(0)
|
||||
// Position iterator at the first entry
|
||||
iter.SeekToFirst()
|
||||
|
||||
|
||||
// Iterate through all valid entries
|
||||
for iter.Valid() {
|
||||
if limit > 0 && count >= limit {
|
||||
@ -432,7 +432,7 @@ func (s *KevoServiceServer) TxScan(req *pb.TxScanRequest, stream pb.KevoService_
|
||||
}
|
||||
count++
|
||||
}
|
||||
|
||||
|
||||
// Move to the next entry
|
||||
iter.Next()
|
||||
}
|
||||
@ -446,17 +446,17 @@ func (s *KevoServiceServer) GetStats(ctx context.Context, req *pb.GetStatsReques
|
||||
keyCount := int64(0)
|
||||
sstableCount := int32(0)
|
||||
memtableCount := int32(1) // At least 1 active memtable
|
||||
|
||||
|
||||
// Create a read-only transaction to count keys
|
||||
tx, err := s.engine.BeginTransaction(true)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to begin transaction for stats: %w", err)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
|
||||
// Use an iterator to count keys
|
||||
iter := tx.NewIterator()
|
||||
|
||||
|
||||
// Count keys and estimate size
|
||||
var totalSize int64
|
||||
for iter.Next() {
|
||||
@ -492,7 +492,7 @@ func (s *KevoServiceServer) Compact(ctx context.Context, req *pb.CompactRequest)
|
||||
if err != nil {
|
||||
return &pb.CompactResponse{Success: false}, err
|
||||
}
|
||||
|
||||
|
||||
// Do a dummy write to force a flush
|
||||
if req.Force {
|
||||
err = tx.Put([]byte("__compact_marker__"), []byte("force"))
|
||||
@ -501,11 +501,11 @@ func (s *KevoServiceServer) Compact(ctx context.Context, req *pb.CompactRequest)
|
||||
return &pb.CompactResponse{Success: false}, err
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
return &pb.CompactResponse{Success: false}, err
|
||||
}
|
||||
|
||||
return &pb.CompactResponse{Success: true}, nil
|
||||
}
|
||||
}
|
||||
|
@ -56,4 +56,4 @@ func sortDurations(durations []time.Duration) {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -9,8 +9,8 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
pb "github.com/KevoDB/kevo/proto/kevo"
|
||||
"github.com/KevoDB/kevo/pkg/transport"
|
||||
pb "github.com/KevoDB/kevo/proto/kevo"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
@ -19,13 +19,13 @@ import (
|
||||
|
||||
// GRPCClient implements the transport.Client interface for gRPC
|
||||
type GRPCClient struct {
|
||||
endpoint string
|
||||
options transport.TransportOptions
|
||||
conn *grpc.ClientConn
|
||||
client pb.KevoServiceClient
|
||||
status transport.TransportStatus
|
||||
statusMu sync.RWMutex
|
||||
metrics transport.MetricsCollector
|
||||
endpoint string
|
||||
options transport.TransportOptions
|
||||
conn *grpc.ClientConn
|
||||
client pb.KevoServiceClient
|
||||
status transport.TransportStatus
|
||||
statusMu sync.RWMutex
|
||||
metrics transport.MetricsCollector
|
||||
}
|
||||
|
||||
// NewGRPCClient creates a new gRPC client
|
||||
@ -123,10 +123,10 @@ func (c *GRPCClient) Status() transport.TransportStatus {
|
||||
func (c *GRPCClient) setStatus(connected bool, err error) {
|
||||
c.statusMu.Lock()
|
||||
defer c.statusMu.Unlock()
|
||||
|
||||
|
||||
c.status.Connected = connected
|
||||
c.status.LastError = err
|
||||
|
||||
|
||||
if connected {
|
||||
c.status.LastConnected = time.Now()
|
||||
}
|
||||
@ -141,11 +141,11 @@ func (c *GRPCClient) Send(ctx context.Context, request transport.Request) (trans
|
||||
// Record request metrics
|
||||
startTime := time.Now()
|
||||
requestType := request.Type()
|
||||
|
||||
|
||||
// Record bytes sent
|
||||
requestPayload := request.Payload()
|
||||
c.metrics.RecordSend(len(requestPayload))
|
||||
|
||||
|
||||
var resp transport.Response
|
||||
var err error
|
||||
|
||||
@ -182,12 +182,12 @@ func (c *GRPCClient) Send(ctx context.Context, request transport.Request) (trans
|
||||
|
||||
// Record metrics for the request
|
||||
c.metrics.RecordRequest(requestType, startTime, err)
|
||||
|
||||
|
||||
// If we got a response, record received bytes
|
||||
if resp != nil {
|
||||
c.metrics.RecordReceive(len(resp.Payload()))
|
||||
}
|
||||
|
||||
|
||||
return resp, err
|
||||
}
|
||||
|
||||
@ -206,20 +206,20 @@ func (c *GRPCClient) handleGet(ctx context.Context, payload []byte) (transport.R
|
||||
var req struct {
|
||||
Key []byte `json:"key"`
|
||||
}
|
||||
|
||||
|
||||
if err := json.Unmarshal(payload, &req); err != nil {
|
||||
return transport.NewErrorResponse(fmt.Errorf("invalid get request payload: %w", err)), err
|
||||
}
|
||||
|
||||
|
||||
grpcReq := &pb.GetRequest{
|
||||
Key: req.Key,
|
||||
}
|
||||
|
||||
|
||||
grpcResp, err := c.client.Get(ctx, grpcReq)
|
||||
if err != nil {
|
||||
return transport.NewErrorResponse(err), err
|
||||
}
|
||||
|
||||
|
||||
resp := struct {
|
||||
Value []byte `json:"value"`
|
||||
Found bool `json:"found"`
|
||||
@ -227,12 +227,12 @@ func (c *GRPCClient) handleGet(ctx context.Context, payload []byte) (transport.R
|
||||
Value: grpcResp.Value,
|
||||
Found: grpcResp.Found,
|
||||
}
|
||||
|
||||
|
||||
respData, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
return transport.NewErrorResponse(err), err
|
||||
}
|
||||
|
||||
|
||||
return transport.NewResponse(transport.TypeGet, respData, nil), nil
|
||||
}
|
||||
|
||||
@ -242,33 +242,33 @@ func (c *GRPCClient) handlePut(ctx context.Context, payload []byte) (transport.R
|
||||
Value []byte `json:"value"`
|
||||
Sync bool `json:"sync"`
|
||||
}
|
||||
|
||||
|
||||
if err := json.Unmarshal(payload, &req); err != nil {
|
||||
return transport.NewErrorResponse(fmt.Errorf("invalid put request payload: %w", err)), err
|
||||
}
|
||||
|
||||
|
||||
grpcReq := &pb.PutRequest{
|
||||
Key: req.Key,
|
||||
Value: req.Value,
|
||||
Sync: req.Sync,
|
||||
}
|
||||
|
||||
|
||||
grpcResp, err := c.client.Put(ctx, grpcReq)
|
||||
if err != nil {
|
||||
return transport.NewErrorResponse(err), err
|
||||
}
|
||||
|
||||
|
||||
resp := struct {
|
||||
Success bool `json:"success"`
|
||||
}{
|
||||
Success: grpcResp.Success,
|
||||
}
|
||||
|
||||
|
||||
respData, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
return transport.NewErrorResponse(err), err
|
||||
}
|
||||
|
||||
|
||||
return transport.NewResponse(transport.TypePut, respData, nil), nil
|
||||
}
|
||||
|
||||
@ -277,32 +277,32 @@ func (c *GRPCClient) handleDelete(ctx context.Context, payload []byte) (transpor
|
||||
Key []byte `json:"key"`
|
||||
Sync bool `json:"sync"`
|
||||
}
|
||||
|
||||
|
||||
if err := json.Unmarshal(payload, &req); err != nil {
|
||||
return transport.NewErrorResponse(fmt.Errorf("invalid delete request payload: %w", err)), err
|
||||
}
|
||||
|
||||
|
||||
grpcReq := &pb.DeleteRequest{
|
||||
Key: req.Key,
|
||||
Sync: req.Sync,
|
||||
}
|
||||
|
||||
|
||||
grpcResp, err := c.client.Delete(ctx, grpcReq)
|
||||
if err != nil {
|
||||
return transport.NewErrorResponse(err), err
|
||||
}
|
||||
|
||||
|
||||
resp := struct {
|
||||
Success bool `json:"success"`
|
||||
}{
|
||||
Success: grpcResp.Success,
|
||||
}
|
||||
|
||||
|
||||
respData, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
return transport.NewErrorResponse(err), err
|
||||
}
|
||||
|
||||
|
||||
return transport.NewResponse(transport.TypeDelete, respData, nil), nil
|
||||
}
|
||||
|
||||
@ -315,18 +315,18 @@ func (c *GRPCClient) handleBatchWrite(ctx context.Context, payload []byte) (tran
|
||||
} `json:"operations"`
|
||||
Sync bool `json:"sync"`
|
||||
}
|
||||
|
||||
|
||||
if err := json.Unmarshal(payload, &req); err != nil {
|
||||
return transport.NewErrorResponse(fmt.Errorf("invalid batch write request payload: %w", err)), err
|
||||
}
|
||||
|
||||
|
||||
operations := make([]*pb.Operation, len(req.Operations))
|
||||
for i, op := range req.Operations {
|
||||
pbOp := &pb.Operation{
|
||||
Key: op.Key,
|
||||
Value: op.Value,
|
||||
}
|
||||
|
||||
|
||||
switch op.Type {
|
||||
case "put":
|
||||
pbOp.Type = pb.Operation_PUT
|
||||
@ -335,31 +335,31 @@ func (c *GRPCClient) handleBatchWrite(ctx context.Context, payload []byte) (tran
|
||||
default:
|
||||
return transport.NewErrorResponse(fmt.Errorf("invalid operation type: %s", op.Type)), fmt.Errorf("invalid operation type: %s", op.Type)
|
||||
}
|
||||
|
||||
|
||||
operations[i] = pbOp
|
||||
}
|
||||
|
||||
|
||||
grpcReq := &pb.BatchWriteRequest{
|
||||
Operations: operations,
|
||||
Sync: req.Sync,
|
||||
}
|
||||
|
||||
|
||||
grpcResp, err := c.client.BatchWrite(ctx, grpcReq)
|
||||
if err != nil {
|
||||
return transport.NewErrorResponse(err), err
|
||||
}
|
||||
|
||||
|
||||
resp := struct {
|
||||
Success bool `json:"success"`
|
||||
}{
|
||||
Success: grpcResp.Success,
|
||||
}
|
||||
|
||||
|
||||
respData, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
return transport.NewErrorResponse(err), err
|
||||
}
|
||||
|
||||
|
||||
return transport.NewResponse(transport.TypeBatchWrite, respData, nil), nil
|
||||
}
|
||||
|
||||
@ -367,31 +367,31 @@ func (c *GRPCClient) handleBeginTransaction(ctx context.Context, payload []byte)
|
||||
var req struct {
|
||||
ReadOnly bool `json:"read_only"`
|
||||
}
|
||||
|
||||
|
||||
if err := json.Unmarshal(payload, &req); err != nil {
|
||||
return transport.NewErrorResponse(fmt.Errorf("invalid begin transaction request payload: %w", err)), err
|
||||
}
|
||||
|
||||
|
||||
grpcReq := &pb.BeginTransactionRequest{
|
||||
ReadOnly: req.ReadOnly,
|
||||
}
|
||||
|
||||
|
||||
grpcResp, err := c.client.BeginTransaction(ctx, grpcReq)
|
||||
if err != nil {
|
||||
return transport.NewErrorResponse(err), err
|
||||
}
|
||||
|
||||
|
||||
resp := struct {
|
||||
TransactionID string `json:"transaction_id"`
|
||||
}{
|
||||
TransactionID: grpcResp.TransactionId,
|
||||
}
|
||||
|
||||
|
||||
respData, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
return transport.NewErrorResponse(err), err
|
||||
}
|
||||
|
||||
|
||||
return transport.NewResponse(transport.TypeBeginTx, respData, nil), nil
|
||||
}
|
||||
|
||||
@ -399,31 +399,31 @@ func (c *GRPCClient) handleCommitTransaction(ctx context.Context, payload []byte
|
||||
var req struct {
|
||||
TransactionID string `json:"transaction_id"`
|
||||
}
|
||||
|
||||
|
||||
if err := json.Unmarshal(payload, &req); err != nil {
|
||||
return transport.NewErrorResponse(fmt.Errorf("invalid commit transaction request payload: %w", err)), err
|
||||
}
|
||||
|
||||
|
||||
grpcReq := &pb.CommitTransactionRequest{
|
||||
TransactionId: req.TransactionID,
|
||||
}
|
||||
|
||||
|
||||
grpcResp, err := c.client.CommitTransaction(ctx, grpcReq)
|
||||
if err != nil {
|
||||
return transport.NewErrorResponse(err), err
|
||||
}
|
||||
|
||||
|
||||
resp := struct {
|
||||
Success bool `json:"success"`
|
||||
}{
|
||||
Success: grpcResp.Success,
|
||||
}
|
||||
|
||||
|
||||
respData, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
return transport.NewErrorResponse(err), err
|
||||
}
|
||||
|
||||
|
||||
return transport.NewResponse(transport.TypeCommitTx, respData, nil), nil
|
||||
}
|
||||
|
||||
@ -431,31 +431,31 @@ func (c *GRPCClient) handleRollbackTransaction(ctx context.Context, payload []by
|
||||
var req struct {
|
||||
TransactionID string `json:"transaction_id"`
|
||||
}
|
||||
|
||||
|
||||
if err := json.Unmarshal(payload, &req); err != nil {
|
||||
return transport.NewErrorResponse(fmt.Errorf("invalid rollback transaction request payload: %w", err)), err
|
||||
}
|
||||
|
||||
|
||||
grpcReq := &pb.RollbackTransactionRequest{
|
||||
TransactionId: req.TransactionID,
|
||||
}
|
||||
|
||||
|
||||
grpcResp, err := c.client.RollbackTransaction(ctx, grpcReq)
|
||||
if err != nil {
|
||||
return transport.NewErrorResponse(err), err
|
||||
}
|
||||
|
||||
|
||||
resp := struct {
|
||||
Success bool `json:"success"`
|
||||
}{
|
||||
Success: grpcResp.Success,
|
||||
}
|
||||
|
||||
|
||||
respData, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
return transport.NewErrorResponse(err), err
|
||||
}
|
||||
|
||||
|
||||
return transport.NewResponse(transport.TypeRollbackTx, respData, nil), nil
|
||||
}
|
||||
|
||||
@ -464,21 +464,21 @@ func (c *GRPCClient) handleTxGet(ctx context.Context, payload []byte) (transport
|
||||
TransactionID string `json:"transaction_id"`
|
||||
Key []byte `json:"key"`
|
||||
}
|
||||
|
||||
|
||||
if err := json.Unmarshal(payload, &req); err != nil {
|
||||
return transport.NewErrorResponse(fmt.Errorf("invalid tx get request payload: %w", err)), err
|
||||
}
|
||||
|
||||
|
||||
grpcReq := &pb.TxGetRequest{
|
||||
TransactionId: req.TransactionID,
|
||||
Key: req.Key,
|
||||
}
|
||||
|
||||
|
||||
grpcResp, err := c.client.TxGet(ctx, grpcReq)
|
||||
if err != nil {
|
||||
return transport.NewErrorResponse(err), err
|
||||
}
|
||||
|
||||
|
||||
resp := struct {
|
||||
Value []byte `json:"value"`
|
||||
Found bool `json:"found"`
|
||||
@ -486,12 +486,12 @@ func (c *GRPCClient) handleTxGet(ctx context.Context, payload []byte) (transport
|
||||
Value: grpcResp.Value,
|
||||
Found: grpcResp.Found,
|
||||
}
|
||||
|
||||
|
||||
respData, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
return transport.NewErrorResponse(err), err
|
||||
}
|
||||
|
||||
|
||||
return transport.NewResponse(transport.TypeTxGet, respData, nil), nil
|
||||
}
|
||||
|
||||
@ -501,33 +501,33 @@ func (c *GRPCClient) handleTxPut(ctx context.Context, payload []byte) (transport
|
||||
Key []byte `json:"key"`
|
||||
Value []byte `json:"value"`
|
||||
}
|
||||
|
||||
|
||||
if err := json.Unmarshal(payload, &req); err != nil {
|
||||
return transport.NewErrorResponse(fmt.Errorf("invalid tx put request payload: %w", err)), err
|
||||
}
|
||||
|
||||
|
||||
grpcReq := &pb.TxPutRequest{
|
||||
TransactionId: req.TransactionID,
|
||||
Key: req.Key,
|
||||
Value: req.Value,
|
||||
}
|
||||
|
||||
|
||||
grpcResp, err := c.client.TxPut(ctx, grpcReq)
|
||||
if err != nil {
|
||||
return transport.NewErrorResponse(err), err
|
||||
}
|
||||
|
||||
|
||||
resp := struct {
|
||||
Success bool `json:"success"`
|
||||
}{
|
||||
Success: grpcResp.Success,
|
||||
}
|
||||
|
||||
|
||||
respData, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
return transport.NewErrorResponse(err), err
|
||||
}
|
||||
|
||||
|
||||
return transport.NewResponse(transport.TypeTxPut, respData, nil), nil
|
||||
}
|
||||
|
||||
@ -536,43 +536,43 @@ func (c *GRPCClient) handleTxDelete(ctx context.Context, payload []byte) (transp
|
||||
TransactionID string `json:"transaction_id"`
|
||||
Key []byte `json:"key"`
|
||||
}
|
||||
|
||||
|
||||
if err := json.Unmarshal(payload, &req); err != nil {
|
||||
return transport.NewErrorResponse(fmt.Errorf("invalid tx delete request payload: %w", err)), err
|
||||
}
|
||||
|
||||
|
||||
grpcReq := &pb.TxDeleteRequest{
|
||||
TransactionId: req.TransactionID,
|
||||
Key: req.Key,
|
||||
}
|
||||
|
||||
|
||||
grpcResp, err := c.client.TxDelete(ctx, grpcReq)
|
||||
if err != nil {
|
||||
return transport.NewErrorResponse(err), err
|
||||
}
|
||||
|
||||
|
||||
resp := struct {
|
||||
Success bool `json:"success"`
|
||||
}{
|
||||
Success: grpcResp.Success,
|
||||
}
|
||||
|
||||
|
||||
respData, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
return transport.NewErrorResponse(err), err
|
||||
}
|
||||
|
||||
|
||||
return transport.NewResponse(transport.TypeTxDelete, respData, nil), nil
|
||||
}
|
||||
|
||||
func (c *GRPCClient) handleGetStats(ctx context.Context, payload []byte) (transport.Response, error) {
|
||||
grpcReq := &pb.GetStatsRequest{}
|
||||
|
||||
|
||||
grpcResp, err := c.client.GetStats(ctx, grpcReq)
|
||||
if err != nil {
|
||||
return transport.NewErrorResponse(err), err
|
||||
}
|
||||
|
||||
|
||||
resp := struct {
|
||||
KeyCount int64 `json:"key_count"`
|
||||
StorageSize int64 `json:"storage_size"`
|
||||
@ -588,12 +588,12 @@ func (c *GRPCClient) handleGetStats(ctx context.Context, payload []byte) (transp
|
||||
WriteAmplification: grpcResp.WriteAmplification,
|
||||
ReadAmplification: grpcResp.ReadAmplification,
|
||||
}
|
||||
|
||||
|
||||
respData, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
return transport.NewErrorResponse(err), err
|
||||
}
|
||||
|
||||
|
||||
return transport.NewResponse(transport.TypeGetStats, respData, nil), nil
|
||||
}
|
||||
|
||||
@ -601,31 +601,31 @@ func (c *GRPCClient) handleCompact(ctx context.Context, payload []byte) (transpo
|
||||
var req struct {
|
||||
Force bool `json:"force"`
|
||||
}
|
||||
|
||||
|
||||
if err := json.Unmarshal(payload, &req); err != nil {
|
||||
return transport.NewErrorResponse(fmt.Errorf("invalid compact request payload: %w", err)), err
|
||||
}
|
||||
|
||||
|
||||
grpcReq := &pb.CompactRequest{
|
||||
Force: req.Force,
|
||||
}
|
||||
|
||||
|
||||
grpcResp, err := c.client.Compact(ctx, grpcReq)
|
||||
if err != nil {
|
||||
return transport.NewErrorResponse(err), err
|
||||
}
|
||||
|
||||
|
||||
resp := struct {
|
||||
Success bool `json:"success"`
|
||||
}{
|
||||
Success: grpcResp.Success,
|
||||
}
|
||||
|
||||
|
||||
respData, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
return transport.NewErrorResponse(err), err
|
||||
}
|
||||
|
||||
|
||||
return transport.NewResponse(transport.TypeCompact, respData, nil), nil
|
||||
}
|
||||
|
||||
@ -650,7 +650,7 @@ func (s *GRPCScanStream) Recv() (transport.Response, error) {
|
||||
}
|
||||
return transport.NewErrorResponse(err), err
|
||||
}
|
||||
|
||||
|
||||
// Build response based on scan type
|
||||
scanResp := struct {
|
||||
Key []byte `json:"key"`
|
||||
@ -659,12 +659,12 @@ func (s *GRPCScanStream) Recv() (transport.Response, error) {
|
||||
Key: resp.Key,
|
||||
Value: resp.Value,
|
||||
}
|
||||
|
||||
|
||||
respData, err := json.Marshal(scanResp)
|
||||
if err != nil {
|
||||
return transport.NewErrorResponse(err), err
|
||||
}
|
||||
|
||||
|
||||
s.client.metrics.RecordReceive(len(respData))
|
||||
return transport.NewResponse(s.streamType, respData, nil), nil
|
||||
}
|
||||
@ -672,4 +672,4 @@ func (s *GRPCScanStream) Recv() (transport.Response, error) {
|
||||
func (s *GRPCScanStream) Close() error {
|
||||
s.cancel()
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
@ -7,8 +7,8 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
pb "github.com/KevoDB/kevo/proto/kevo"
|
||||
"github.com/KevoDB/kevo/pkg/transport"
|
||||
pb "github.com/KevoDB/kevo/proto/kevo"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
@ -72,7 +72,7 @@ func (g *GRPCTransportManager) Serve() error {
|
||||
if err := g.Start(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
// Block until server is stopped
|
||||
<-ctx.Done()
|
||||
return nil
|
||||
@ -284,4 +284,4 @@ func init() {
|
||||
transport.RegisterClientTransport("grpc", func(endpoint string, options transport.TransportOptions) (transport.Client, error) {
|
||||
return NewGRPCClient(endpoint, options)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -7,15 +7,15 @@ import (
|
||||
// Simple smoke test for the gRPC transport
|
||||
func TestNewGRPCTransportManager(t *testing.T) {
|
||||
opts := DefaultGRPCTransportOptions()
|
||||
|
||||
|
||||
// Override the listen address to avoid port conflicts
|
||||
opts.ListenAddr = ":0" // use random available port
|
||||
|
||||
|
||||
manager, err := NewGRPCTransportManager(opts)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create transport manager: %v", err)
|
||||
}
|
||||
|
||||
|
||||
// Verify the manager was created
|
||||
if manager == nil {
|
||||
t.Fatal("Expected non-nil manager")
|
||||
@ -60,4 +60,4 @@ func TestLoadClientTLSConfigFromStruct(t *testing.T) {
|
||||
if !config.InsecureSkipVerify {
|
||||
t.Fatal("Expected InsecureSkipVerify to be true")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -207,4 +207,4 @@ func (m *ConnectionPoolManager) CloseAll() {
|
||||
m.pools.Delete(key)
|
||||
return true
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -7,8 +7,8 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
pb "github.com/KevoDB/kevo/proto/kevo"
|
||||
"github.com/KevoDB/kevo/pkg/transport"
|
||||
pb "github.com/KevoDB/kevo/proto/kevo"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
@ -16,30 +16,30 @@ import (
|
||||
|
||||
// GRPCServer implements the transport.Server interface for gRPC
|
||||
type GRPCServer struct {
|
||||
address string
|
||||
tlsConfig *tls.Config
|
||||
server *grpc.Server
|
||||
address string
|
||||
tlsConfig *tls.Config
|
||||
server *grpc.Server
|
||||
requestHandler transport.RequestHandler
|
||||
started bool
|
||||
mu sync.Mutex
|
||||
metrics *transport.ExtendedMetricsCollector
|
||||
started bool
|
||||
mu sync.Mutex
|
||||
metrics *transport.ExtendedMetricsCollector
|
||||
}
|
||||
|
||||
// NewGRPCServer creates a new gRPC server
|
||||
func NewGRPCServer(address string, options transport.TransportOptions) (transport.Server, error) {
|
||||
// Create server options
|
||||
var serverOpts []grpc.ServerOption
|
||||
|
||||
|
||||
// Configure TLS if enabled
|
||||
if options.TLSEnabled {
|
||||
tlsConfig, err := LoadServerTLSConfig(options.CertFile, options.KeyFile, options.CAFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load TLS config: %w", err)
|
||||
}
|
||||
|
||||
|
||||
serverOpts = append(serverOpts, grpc.Creds(credentials.NewTLS(tlsConfig)))
|
||||
}
|
||||
|
||||
|
||||
// Configure keepalive parameters
|
||||
kaProps := keepalive.ServerParameters{
|
||||
MaxConnectionIdle: 30 * time.Minute,
|
||||
@ -47,20 +47,20 @@ func NewGRPCServer(address string, options transport.TransportOptions) (transpor
|
||||
Time: 15 * time.Second,
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
|
||||
kaPolicy := keepalive.EnforcementPolicy{
|
||||
MinTime: 10 * time.Second,
|
||||
PermitWithoutStream: true,
|
||||
}
|
||||
|
||||
|
||||
serverOpts = append(serverOpts,
|
||||
grpc.KeepaliveParams(kaProps),
|
||||
grpc.KeepaliveEnforcementPolicy(kaPolicy),
|
||||
)
|
||||
|
||||
|
||||
// Create the server
|
||||
server := grpc.NewServer(serverOpts...)
|
||||
|
||||
|
||||
return &GRPCServer{
|
||||
address: address,
|
||||
server: server,
|
||||
@ -72,18 +72,18 @@ func NewGRPCServer(address string, options transport.TransportOptions) (transpor
|
||||
func (s *GRPCServer) Start() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
|
||||
if s.started {
|
||||
return fmt.Errorf("server already started")
|
||||
}
|
||||
|
||||
|
||||
// Start the server in a goroutine
|
||||
go func() {
|
||||
if err := s.Serve(); err != nil {
|
||||
fmt.Printf("gRPC server error: %v\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
|
||||
s.started = true
|
||||
return nil
|
||||
}
|
||||
@ -93,31 +93,31 @@ func (s *GRPCServer) Serve() error {
|
||||
if s.requestHandler == nil {
|
||||
return fmt.Errorf("no request handler set")
|
||||
}
|
||||
|
||||
|
||||
// Create the service implementation
|
||||
service := &kevoServiceServer{
|
||||
handler: s.requestHandler,
|
||||
}
|
||||
|
||||
|
||||
// Register the service
|
||||
pb.RegisterKevoServiceServer(s.server, service)
|
||||
|
||||
|
||||
// Start listening
|
||||
listener, err := transport.CreateListener("tcp", s.address, s.tlsConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to listen on %s: %w", s.address, err)
|
||||
}
|
||||
|
||||
|
||||
s.metrics.ServerStarted()
|
||||
|
||||
|
||||
// Serve requests
|
||||
err = s.server.Serve(listener)
|
||||
|
||||
|
||||
if err != nil {
|
||||
s.metrics.ServerErrored()
|
||||
return fmt.Errorf("failed to serve: %w", err)
|
||||
}
|
||||
|
||||
|
||||
s.metrics.ServerStopped()
|
||||
return nil
|
||||
}
|
||||
@ -126,14 +126,14 @@ func (s *GRPCServer) Serve() error {
|
||||
func (s *GRPCServer) Stop(ctx context.Context) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
|
||||
if !s.started {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
s.server.GracefulStop()
|
||||
s.started = false
|
||||
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -141,7 +141,7 @@ func (s *GRPCServer) Stop(ctx context.Context) error {
|
||||
func (s *GRPCServer) SetRequestHandler(handler transport.RequestHandler) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
|
||||
s.requestHandler = handler
|
||||
}
|
||||
|
||||
@ -151,4 +151,4 @@ type kevoServiceServer struct {
|
||||
handler transport.RequestHandler
|
||||
}
|
||||
|
||||
// TODO: Implement service methods
|
||||
// TODO: Implement service methods
|
||||
|
@ -92,4 +92,4 @@ func LoadClientTLSConfigFromStruct(config *TLSConfig) (*tls.Config, error) {
|
||||
return &tls.Config{MinVersion: tls.VersionTLS12}, nil
|
||||
}
|
||||
return LoadClientTLSConfig(config.CertFile, config.KeyFile, config.CAFile, config.SkipVerify)
|
||||
}
|
||||
}
|
||||
|
@ -5,8 +5,8 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
pb "github.com/KevoDB/kevo/proto/kevo"
|
||||
"github.com/KevoDB/kevo/pkg/transport"
|
||||
pb "github.com/KevoDB/kevo/proto/kevo"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
@ -30,17 +30,17 @@ func (c *GRPCConnection) Execute(fn func(interface{}) error) error {
|
||||
|
||||
// Create a new client from the connection
|
||||
client := pb.NewKevoServiceClient(c.conn)
|
||||
|
||||
|
||||
// Execute the provided function with the client
|
||||
err := fn(client)
|
||||
|
||||
|
||||
// Update metrics if there was an error
|
||||
if err != nil {
|
||||
c.mu.Lock()
|
||||
c.errCount++
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
@ -58,10 +58,10 @@ func (c *GRPCConnection) Address() string {
|
||||
func (c *GRPCConnection) Status() transport.ConnectionStatus {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
|
||||
// Check the connection state
|
||||
isConnected := c.conn != nil
|
||||
|
||||
|
||||
return transport.ConnectionStatus{
|
||||
Connected: isConnected,
|
||||
LastActivity: c.lastUsed,
|
||||
@ -81,4 +81,4 @@ type GRPCTransportOptions struct {
|
||||
MaxConnectionIdle time.Duration
|
||||
MaxConnectionAge time.Duration
|
||||
MaxPoolConnections int
|
||||
}
|
||||
}
|
||||
|
@ -6,21 +6,21 @@ import (
|
||||
|
||||
// Standard request/response type constants
|
||||
const (
|
||||
TypeGet = "get"
|
||||
TypePut = "put"
|
||||
TypeDelete = "delete"
|
||||
TypeBatchWrite = "batch_write"
|
||||
TypeScan = "scan"
|
||||
TypeBeginTx = "begin_tx"
|
||||
TypeCommitTx = "commit_tx"
|
||||
TypeRollbackTx = "rollback_tx"
|
||||
TypeTxGet = "tx_get"
|
||||
TypeTxPut = "tx_put"
|
||||
TypeTxDelete = "tx_delete"
|
||||
TypeTxScan = "tx_scan"
|
||||
TypeGetStats = "get_stats"
|
||||
TypeCompact = "compact"
|
||||
TypeError = "error"
|
||||
TypeGet = "get"
|
||||
TypePut = "put"
|
||||
TypeDelete = "delete"
|
||||
TypeBatchWrite = "batch_write"
|
||||
TypeScan = "scan"
|
||||
TypeBeginTx = "begin_tx"
|
||||
TypeCommitTx = "commit_tx"
|
||||
TypeRollbackTx = "rollback_tx"
|
||||
TypeTxGet = "tx_get"
|
||||
TypeTxPut = "tx_put"
|
||||
TypeTxDelete = "tx_delete"
|
||||
TypeTxScan = "tx_scan"
|
||||
TypeGetStats = "get_stats"
|
||||
TypeCompact = "compact"
|
||||
TypeError = "error"
|
||||
)
|
||||
|
||||
// Common errors
|
||||
@ -97,4 +97,4 @@ func NewErrorResponse(err error) Response {
|
||||
ResponseData: msg,
|
||||
ResponseErr: err,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -9,12 +9,12 @@ func TestBasicRequest(t *testing.T) {
|
||||
// Test creating a request
|
||||
payload := []byte("test payload")
|
||||
req := NewRequest(TypeGet, payload)
|
||||
|
||||
|
||||
// Test Type method
|
||||
if req.Type() != TypeGet {
|
||||
t.Errorf("Expected type %s, got %s", TypeGet, req.Type())
|
||||
}
|
||||
|
||||
|
||||
// Test Payload method
|
||||
if string(req.Payload()) != string(payload) {
|
||||
t.Errorf("Expected payload %s, got %s", string(payload), string(req.Payload()))
|
||||
@ -25,26 +25,26 @@ func TestBasicResponse(t *testing.T) {
|
||||
// Test creating a response with no error
|
||||
payload := []byte("test response")
|
||||
resp := NewResponse(TypeGet, payload, nil)
|
||||
|
||||
|
||||
// Test Type method
|
||||
if resp.Type() != TypeGet {
|
||||
t.Errorf("Expected type %s, got %s", TypeGet, resp.Type())
|
||||
}
|
||||
|
||||
|
||||
// Test Payload method
|
||||
if string(resp.Payload()) != string(payload) {
|
||||
t.Errorf("Expected payload %s, got %s", string(payload), string(resp.Payload()))
|
||||
}
|
||||
|
||||
|
||||
// Test Error method
|
||||
if resp.Error() != nil {
|
||||
t.Errorf("Expected nil error, got %v", resp.Error())
|
||||
}
|
||||
|
||||
|
||||
// Test creating a response with an error
|
||||
testErr := errors.New("test error")
|
||||
resp = NewResponse(TypeGet, payload, testErr)
|
||||
|
||||
|
||||
if resp.Error() != testErr {
|
||||
t.Errorf("Expected error %v, got %v", testErr, resp.Error())
|
||||
}
|
||||
@ -54,34 +54,34 @@ func TestNewErrorResponse(t *testing.T) {
|
||||
// Test creating an error response
|
||||
testErr := errors.New("test error")
|
||||
resp := NewErrorResponse(testErr)
|
||||
|
||||
|
||||
// Test Type method
|
||||
if resp.Type() != TypeError {
|
||||
t.Errorf("Expected type %s, got %s", TypeError, resp.Type())
|
||||
}
|
||||
|
||||
|
||||
// Test Payload method - should contain error message
|
||||
if string(resp.Payload()) != testErr.Error() {
|
||||
t.Errorf("Expected payload %s, got %s", testErr.Error(), string(resp.Payload()))
|
||||
}
|
||||
|
||||
|
||||
// Test Error method
|
||||
if resp.Error() != testErr {
|
||||
t.Errorf("Expected error %v, got %v", testErr, resp.Error())
|
||||
}
|
||||
|
||||
|
||||
// Test with nil error
|
||||
resp = NewErrorResponse(nil)
|
||||
|
||||
|
||||
if resp.Type() != TypeError {
|
||||
t.Errorf("Expected type %s, got %s", TypeError, resp.Type())
|
||||
}
|
||||
|
||||
|
||||
if len(resp.Payload()) != 0 {
|
||||
t.Errorf("Expected empty payload, got %s", string(resp.Payload()))
|
||||
}
|
||||
|
||||
|
||||
if resp.Error() != nil {
|
||||
t.Errorf("Expected nil error, got %v", resp.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -50,7 +50,7 @@ type TransportStatus struct {
|
||||
type Request interface {
|
||||
// Type returns the type of request
|
||||
Type() string
|
||||
|
||||
|
||||
// Payload returns the payload of the request
|
||||
Payload() []byte
|
||||
}
|
||||
@ -59,10 +59,10 @@ type Request interface {
|
||||
type Response interface {
|
||||
// Type returns the type of response
|
||||
Type() string
|
||||
|
||||
|
||||
// Payload returns the payload of the response
|
||||
Payload() []byte
|
||||
|
||||
|
||||
// Error returns any error associated with the response
|
||||
Error() error
|
||||
}
|
||||
@ -71,10 +71,10 @@ type Response interface {
|
||||
type Stream interface {
|
||||
// Send sends a request over the stream
|
||||
Send(request Request) error
|
||||
|
||||
|
||||
// Recv receives a response from the stream
|
||||
Recv() (Response, error)
|
||||
|
||||
|
||||
// Close closes the stream
|
||||
Close() error
|
||||
}
|
||||
@ -83,19 +83,19 @@ type Stream interface {
|
||||
type Client interface {
|
||||
// Connect establishes a connection to the server
|
||||
Connect(ctx context.Context) error
|
||||
|
||||
|
||||
// Close closes the connection
|
||||
Close() error
|
||||
|
||||
|
||||
// IsConnected returns whether the client is connected
|
||||
IsConnected() bool
|
||||
|
||||
|
||||
// Status returns the current status of the connection
|
||||
Status() TransportStatus
|
||||
|
||||
|
||||
// Send sends a request and waits for a response
|
||||
Send(ctx context.Context, request Request) (Response, error)
|
||||
|
||||
|
||||
// Stream opens a bidirectional stream
|
||||
Stream(ctx context.Context) (Stream, error)
|
||||
}
|
||||
@ -104,7 +104,7 @@ type Client interface {
|
||||
type RequestHandler interface {
|
||||
// HandleRequest processes a request and returns a response
|
||||
HandleRequest(ctx context.Context, request Request) (Response, error)
|
||||
|
||||
|
||||
// HandleStream processes a bidirectional stream
|
||||
HandleStream(stream Stream) error
|
||||
}
|
||||
@ -113,13 +113,13 @@ type RequestHandler interface {
|
||||
type Server interface {
|
||||
// Start starts the server and returns immediately
|
||||
Start() error
|
||||
|
||||
|
||||
// Serve starts the server and blocks until it's stopped
|
||||
Serve() error
|
||||
|
||||
|
||||
// Stop stops the server gracefully
|
||||
Stop(ctx context.Context) error
|
||||
|
||||
|
||||
// SetRequestHandler sets the handler for incoming requests
|
||||
SetRequestHandler(handler RequestHandler)
|
||||
}
|
||||
@ -134,16 +134,16 @@ type ServerFactory func(address string, options TransportOptions) (Server, error
|
||||
type Registry interface {
|
||||
// RegisterClient adds a new client implementation to the registry
|
||||
RegisterClient(name string, factory ClientFactory)
|
||||
|
||||
|
||||
// RegisterServer adds a new server implementation to the registry
|
||||
RegisterServer(name string, factory ServerFactory)
|
||||
|
||||
|
||||
// CreateClient instantiates a client by name
|
||||
CreateClient(name, endpoint string, options TransportOptions) (Client, error)
|
||||
|
||||
|
||||
// CreateServer instantiates a server by name
|
||||
CreateServer(name, address string, options TransportOptions) (Server, error)
|
||||
|
||||
|
||||
// ListTransports returns all available transport names
|
||||
ListTransports() []string
|
||||
}
|
||||
}
|
||||
|
@ -10,16 +10,16 @@ import (
|
||||
type MetricsCollector interface {
|
||||
// RecordRequest records metrics for a request
|
||||
RecordRequest(requestType string, startTime time.Time, err error)
|
||||
|
||||
|
||||
// RecordSend records metrics for bytes sent
|
||||
RecordSend(bytes int)
|
||||
|
||||
|
||||
// RecordReceive records metrics for bytes received
|
||||
RecordReceive(bytes int)
|
||||
|
||||
|
||||
// RecordConnection records a connection event
|
||||
RecordConnection(successful bool)
|
||||
|
||||
|
||||
// GetMetrics returns the current metrics
|
||||
GetMetrics() Metrics
|
||||
}
|
||||
@ -46,7 +46,7 @@ type BasicMetricsCollector struct {
|
||||
bytesReceived uint64
|
||||
connections uint64
|
||||
connectionFailures uint64
|
||||
|
||||
|
||||
// Track average latency and count for each request type
|
||||
avgLatencyByType map[string]time.Duration
|
||||
requestCountByType map[string]uint64
|
||||
@ -63,26 +63,26 @@ func NewMetricsCollector() MetricsCollector {
|
||||
// RecordRequest records metrics for a request
|
||||
func (c *BasicMetricsCollector) RecordRequest(requestType string, startTime time.Time, err error) {
|
||||
atomic.AddUint64(&c.totalRequests, 1)
|
||||
|
||||
|
||||
if err == nil {
|
||||
atomic.AddUint64(&c.successfulRequests, 1)
|
||||
} else {
|
||||
atomic.AddUint64(&c.failedRequests, 1)
|
||||
}
|
||||
|
||||
|
||||
// Update average latency for request type
|
||||
latency := time.Since(startTime)
|
||||
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
|
||||
currentAvg, exists := c.avgLatencyByType[requestType]
|
||||
currentCount, _ := c.requestCountByType[requestType]
|
||||
|
||||
|
||||
if exists {
|
||||
// Update running average - the common case for better branch prediction
|
||||
// new_avg = (old_avg * count + new_value) / (count + 1)
|
||||
totalDuration := currentAvg * time.Duration(currentCount) + latency
|
||||
totalDuration := currentAvg*time.Duration(currentCount) + latency
|
||||
newCount := currentCount + 1
|
||||
c.avgLatencyByType[requestType] = totalDuration / time.Duration(newCount)
|
||||
c.requestCountByType[requestType] = newCount
|
||||
@ -116,13 +116,13 @@ func (c *BasicMetricsCollector) RecordConnection(successful bool) {
|
||||
func (c *BasicMetricsCollector) GetMetrics() Metrics {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
|
||||
// Create a copy of the average latency map
|
||||
avgLatencyByType := make(map[string]time.Duration, len(c.avgLatencyByType))
|
||||
for k, v := range c.avgLatencyByType {
|
||||
avgLatencyByType[k] = v
|
||||
}
|
||||
|
||||
|
||||
return Metrics{
|
||||
TotalRequests: atomic.LoadUint64(&c.totalRequests),
|
||||
SuccessfulRequests: atomic.LoadUint64(&c.successfulRequests),
|
||||
@ -133,4 +133,4 @@ func (c *BasicMetricsCollector) GetMetrics() Metrics {
|
||||
ConnectionFailures: atomic.LoadUint64(&c.connectionFailures),
|
||||
AvgLatencyByType: avgLatencyByType,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -9,9 +9,9 @@ import (
|
||||
// Metrics struct extensions for server metrics
|
||||
type ServerMetrics struct {
|
||||
Metrics
|
||||
ServerStarted uint64
|
||||
ServerErrored uint64
|
||||
ServerStopped uint64
|
||||
ServerStarted uint64
|
||||
ServerErrored uint64
|
||||
ServerStopped uint64
|
||||
}
|
||||
|
||||
// Connection represents a connection to a remote endpoint
|
||||
@ -31,11 +31,11 @@ type Connection interface {
|
||||
|
||||
// ConnectionStatus represents the status of a connection
|
||||
type ConnectionStatus struct {
|
||||
Connected bool
|
||||
LastActivity time.Time
|
||||
ErrorCount int
|
||||
RequestCount int
|
||||
LatencyAvg time.Duration
|
||||
Connected bool
|
||||
LastActivity time.Time
|
||||
ErrorCount int
|
||||
RequestCount int
|
||||
LatencyAvg time.Duration
|
||||
}
|
||||
|
||||
// TransportManager is an interface for managing transport layer operations
|
||||
|
@ -8,7 +8,7 @@ import (
|
||||
|
||||
func TestBasicMetricsCollector(t *testing.T) {
|
||||
collector := NewMetricsCollector()
|
||||
|
||||
|
||||
// Test initial state
|
||||
metrics := collector.GetMetrics()
|
||||
if metrics.TotalRequests != 0 ||
|
||||
@ -21,11 +21,11 @@ func TestBasicMetricsCollector(t *testing.T) {
|
||||
len(metrics.AvgLatencyByType) != 0 {
|
||||
t.Errorf("Initial metrics not initialized correctly: %+v", metrics)
|
||||
}
|
||||
|
||||
|
||||
// Test recording successful request
|
||||
startTime := time.Now().Add(-100 * time.Millisecond) // Simulate 100ms request
|
||||
collector.RecordRequest("get", startTime, nil)
|
||||
|
||||
|
||||
metrics = collector.GetMetrics()
|
||||
if metrics.TotalRequests != 1 {
|
||||
t.Errorf("Expected TotalRequests to be 1, got %d", metrics.TotalRequests)
|
||||
@ -36,18 +36,18 @@ func TestBasicMetricsCollector(t *testing.T) {
|
||||
if metrics.FailedRequests != 0 {
|
||||
t.Errorf("Expected FailedRequests to be 0, got %d", metrics.FailedRequests)
|
||||
}
|
||||
|
||||
|
||||
// Check average latency
|
||||
if avgLatency, exists := metrics.AvgLatencyByType["get"]; !exists {
|
||||
t.Error("Expected 'get' latency to exist")
|
||||
} else if avgLatency < 100*time.Millisecond {
|
||||
t.Errorf("Expected latency to be at least 100ms, got %v", avgLatency)
|
||||
}
|
||||
|
||||
|
||||
// Test recording failed request
|
||||
startTime = time.Now().Add(-200 * time.Millisecond) // Simulate 200ms request
|
||||
collector.RecordRequest("get", startTime, errors.New("test error"))
|
||||
|
||||
|
||||
metrics = collector.GetMetrics()
|
||||
if metrics.TotalRequests != 2 {
|
||||
t.Errorf("Expected TotalRequests to be 2, got %d", metrics.TotalRequests)
|
||||
@ -58,26 +58,26 @@ func TestBasicMetricsCollector(t *testing.T) {
|
||||
if metrics.FailedRequests != 1 {
|
||||
t.Errorf("Expected FailedRequests to be 1, got %d", metrics.FailedRequests)
|
||||
}
|
||||
|
||||
|
||||
// Test average latency calculation for multiple requests
|
||||
startTime = time.Now().Add(-300 * time.Millisecond)
|
||||
collector.RecordRequest("put", startTime, nil)
|
||||
|
||||
|
||||
startTime = time.Now().Add(-500 * time.Millisecond)
|
||||
collector.RecordRequest("put", startTime, nil)
|
||||
|
||||
|
||||
metrics = collector.GetMetrics()
|
||||
avgPutLatency := metrics.AvgLatencyByType["put"]
|
||||
|
||||
|
||||
// Expected avg is around (300ms + 500ms) / 2 = 400ms
|
||||
if avgPutLatency < 390*time.Millisecond || avgPutLatency > 410*time.Millisecond {
|
||||
t.Errorf("Expected average 'put' latency to be around 400ms, got %v", avgPutLatency)
|
||||
}
|
||||
|
||||
|
||||
// Test byte tracking
|
||||
collector.RecordSend(1000)
|
||||
collector.RecordReceive(2000)
|
||||
|
||||
|
||||
metrics = collector.GetMetrics()
|
||||
if metrics.BytesSent != 1000 {
|
||||
t.Errorf("Expected BytesSent to be 1000, got %d", metrics.BytesSent)
|
||||
@ -85,12 +85,12 @@ func TestBasicMetricsCollector(t *testing.T) {
|
||||
if metrics.BytesReceived != 2000 {
|
||||
t.Errorf("Expected BytesReceived to be 2000, got %d", metrics.BytesReceived)
|
||||
}
|
||||
|
||||
|
||||
// Test connection tracking
|
||||
collector.RecordConnection(true)
|
||||
collector.RecordConnection(false)
|
||||
collector.RecordConnection(true)
|
||||
|
||||
|
||||
metrics = collector.GetMetrics()
|
||||
if metrics.Connections != 2 {
|
||||
t.Errorf("Expected Connections to be 2, got %d", metrics.Connections)
|
||||
@ -98,4 +98,4 @@ func TestBasicMetricsCollector(t *testing.T) {
|
||||
if metrics.ConnectionFailures != 1 {
|
||||
t.Errorf("Expected ConnectionFailures to be 1, got %d", metrics.ConnectionFailures)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -12,11 +12,11 @@ func CreateListener(network, address string, tlsConfig *tls.Config) (net.Listene
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
// If TLS is configured, wrap the listener
|
||||
if tlsConfig != nil {
|
||||
listener = tls.NewListener(listener, tlsConfig)
|
||||
}
|
||||
|
||||
|
||||
return listener, nil
|
||||
}
|
||||
}
|
||||
|
@ -7,7 +7,7 @@ import (
|
||||
|
||||
// registry implements the Registry interface
|
||||
type registry struct {
|
||||
mu sync.RWMutex
|
||||
mu sync.RWMutex
|
||||
clientFactories map[string]ClientFactory
|
||||
serverFactories map[string]ServerFactory
|
||||
}
|
||||
@ -111,4 +111,4 @@ func GetServer(name, address string, options TransportOptions) (Server, error) {
|
||||
// AvailableTransports lists all available transports in the default registry
|
||||
func AvailableTransports() []string {
|
||||
return DefaultRegistry.ListTransports()
|
||||
}
|
||||
}
|
||||
|
@ -97,17 +97,17 @@ func mockServerFactory(address string, options TransportOptions) (Server, error)
|
||||
// TestRegistry tests the transport registry
|
||||
func TestRegistry(t *testing.T) {
|
||||
registry := NewRegistry()
|
||||
|
||||
|
||||
// Register transports
|
||||
registry.RegisterClient("mock", mockClientFactory)
|
||||
registry.RegisterServer("mock", mockServerFactory)
|
||||
|
||||
|
||||
// Test listing transports
|
||||
transports := registry.ListTransports()
|
||||
if len(transports) != 1 || transports[0] != "mock" {
|
||||
t.Errorf("Expected [mock], got %v", transports)
|
||||
}
|
||||
|
||||
|
||||
// Test creating client
|
||||
client, err := registry.CreateClient("mock", "localhost:8080", TransportOptions{
|
||||
Timeout: 5 * time.Second,
|
||||
@ -115,21 +115,21 @@ func TestRegistry(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create client: %v", err)
|
||||
}
|
||||
|
||||
|
||||
// Test client methods
|
||||
if client.IsConnected() {
|
||||
t.Error("Expected client to be disconnected initially")
|
||||
}
|
||||
|
||||
|
||||
err = client.Connect(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to connect: %v", err)
|
||||
}
|
||||
|
||||
|
||||
if !client.IsConnected() {
|
||||
t.Error("Expected client to be connected after Connect()")
|
||||
}
|
||||
|
||||
|
||||
// Test server creation
|
||||
server, err := registry.CreateServer("mock", "localhost:8080", TransportOptions{
|
||||
Timeout: 5 * time.Second,
|
||||
@ -137,26 +137,26 @@ func TestRegistry(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create server: %v", err)
|
||||
}
|
||||
|
||||
|
||||
// Test server methods
|
||||
err = server.Start()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to start server: %v", err)
|
||||
}
|
||||
|
||||
|
||||
mockServer := server.(*mockServer)
|
||||
if !mockServer.started {
|
||||
t.Error("Expected server to be started")
|
||||
}
|
||||
|
||||
|
||||
// Test non-existent transport
|
||||
_, err = registry.CreateClient("nonexistent", "", TransportOptions{})
|
||||
if err == nil {
|
||||
t.Error("Expected error creating non-existent client")
|
||||
}
|
||||
|
||||
|
||||
_, err = registry.CreateServer("nonexistent", "", TransportOptions{})
|
||||
if err == nil {
|
||||
t.Error("Expected error creating non-existent server")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -366,7 +366,7 @@ func ReplayWALDir(dir string, handler EntryHandler) (*RecoveryStats, error) {
|
||||
|
||||
// Track overall recovery stats
|
||||
totalStats := NewRecoveryStats()
|
||||
|
||||
|
||||
// Track number of files processed successfully
|
||||
successfulFiles := 0
|
||||
var lastErr error
|
||||
@ -397,7 +397,7 @@ func ReplayWALDir(dir string, handler EntryHandler) (*RecoveryStats, error) {
|
||||
// Add stats from this file to our totals
|
||||
totalStats.EntriesProcessed += fileStats.EntriesProcessed
|
||||
totalStats.EntriesSkipped += fileStats.EntriesSkipped
|
||||
|
||||
|
||||
successfulFiles++
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user