From 14d1f84960b5dc3902a3c9caa11cd946e1991c75 Mon Sep 17 00:00:00 2001 From: Jeremy Tregunna Date: Mon, 21 Apr 2025 18:15:36 -0600 Subject: [PATCH] feat: add common transport interface and server mode to kevo --- cmd/kevo/main.go | 179 +++++++++++++++++++++++++--- cmd/kevo/server.go | 212 +++++++++++++++++++++++++++++++++ cmd/kevo/server_test.go | 142 ++++++++++++++++++++++ pkg/transport/common.go | 100 ++++++++++++++++ pkg/transport/common_test.go | 87 ++++++++++++++ pkg/transport/interface.go | 149 +++++++++++++++++++++++ pkg/transport/metrics.go | 136 +++++++++++++++++++++ pkg/transport/metrics_test.go | 101 ++++++++++++++++ pkg/transport/registry.go | 114 ++++++++++++++++++ pkg/transport/registry_test.go | 162 +++++++++++++++++++++++++ 10 files changed, 1365 insertions(+), 17 deletions(-) create mode 100644 cmd/kevo/server.go create mode 100644 cmd/kevo/server_test.go create mode 100644 pkg/transport/common.go create mode 100644 pkg/transport/common_test.go create mode 100644 pkg/transport/interface.go create mode 100644 pkg/transport/metrics.go create mode 100644 pkg/transport/metrics_test.go create mode 100644 pkg/transport/registry.go create mode 100644 pkg/transport/registry_test.go diff --git a/cmd/kevo/main.go b/cmd/kevo/main.go index 7ca7d7b..223c082 100644 --- a/cmd/kevo/main.go +++ b/cmd/kevo/main.go @@ -1,11 +1,16 @@ package main import ( + "context" + "flag" "fmt" "io" + "log" "os" + "os/signal" "path/filepath" "strings" + "syscall" "time" "github.com/chzyer/readline" @@ -43,9 +48,14 @@ const helpText = ` Kevo (kevo) - A lightweight, minimalist, storage engine. Usage: - keco [database_path] - Start with an optional database path + kevo [options] [database_path] - Start with an optional database path -Commands: +Options: + -server - Run in server mode, exposing a gRPC API + -daemon - Run in daemon mode (detached from terminal) + -address string - Address to listen on in server mode (default "localhost:50051") + +Commands (interactive mode only): .help - Show this help message .open PATH - Open a database at PATH .close - Close the current database @@ -68,26 +78,163 @@ Commands: - Note: start and end are treated as string keys, not numeric indices ` +// Config holds the application configuration +type Config struct { + ServerMode bool + DaemonMode bool + ListenAddr string + DBPath string +} + func main() { - fmt.Println("Kevo (kevo) version 1.0.2") - fmt.Println("Enter .help for usage hints.") + // Parse command line arguments and get configuration + config := parseFlags() - // Initialize variables + // Open database if path provided var eng *engine.Engine - var tx engine.Transaction var err error - var dbPath string - - // Check if a database path was provided as an argument - if len(os.Args) > 1 { - dbPath = os.Args[1] - fmt.Printf("Opening database at %s\n", dbPath) - eng, err = engine.NewEngine(dbPath) + + if config.DBPath != "" { + fmt.Printf("Opening database at %s\n", config.DBPath) + eng, err = engine.NewEngine(config.DBPath) if err != nil { fmt.Fprintf(os.Stderr, "Error opening database: %s\n", err) os.Exit(1) } + defer eng.Close() } + + // Check if we should run in server mode + if config.ServerMode { + if eng == nil { + fmt.Fprintf(os.Stderr, "Error: Server mode requires a database path\n") + os.Exit(1) + } + + runServer(eng, config) + return + } + + // Run in interactive mode + runInteractive(eng, config.DBPath) +} + +// parseFlags parses command line flags and returns a Config +func parseFlags() Config { + serverMode := flag.Bool("server", false, "Run in server mode, exposing a gRPC API") + daemonMode := flag.Bool("daemon", false, "Run in daemon mode (detached from terminal)") + listenAddr := flag.String("address", "localhost:50051", "Address to listen on in server mode") + + // Parse flags + flag.Parse() + + // Get database path from remaining arguments + var dbPath string + if flag.NArg() > 0 { + dbPath = flag.Arg(0) + } + + return Config{ + ServerMode: *serverMode, + DaemonMode: *daemonMode, + ListenAddr: *listenAddr, + DBPath: dbPath, + } +} + +// runServer initializes and runs the Kevo server +func runServer(eng *engine.Engine, config Config) { + // Set up daemon mode if requested + if config.DaemonMode { + setupDaemonMode() + } + + // Create and start the server + server := NewServer(eng, config) + + // Start the server (non-blocking) + if err := server.Start(); err != nil { + fmt.Fprintf(os.Stderr, "Error starting server: %v\n", err) + os.Exit(1) + } + + fmt.Printf("Kevo server started on %s\n", config.ListenAddr) + + // Set up signal handling for graceful shutdown + setupGracefulShutdown(server, eng) + + // Start serving (blocking) + if err := server.Serve(); err != nil { + fmt.Fprintf(os.Stderr, "Error serving: %v\n", err) + os.Exit(1) + } +} + +// setupDaemonMode configures process to run as a daemon +func setupDaemonMode() { + // Redirect standard file descriptors to /dev/null + null, err := os.OpenFile("/dev/null", os.O_RDWR, 0) + if err != nil { + log.Fatalf("Failed to open /dev/null: %v", err) + } + + // Redirect standard file descriptors to /dev/null + err = syscall.Dup2(int(null.Fd()), int(os.Stdin.Fd())) + if err != nil { + log.Fatalf("Failed to redirect stdin: %v", err) + } + + err = syscall.Dup2(int(null.Fd()), int(os.Stdout.Fd())) + if err != nil { + log.Fatalf("Failed to redirect stdout: %v", err) + } + + err = syscall.Dup2(int(null.Fd()), int(os.Stderr.Fd())) + if err != nil { + log.Fatalf("Failed to redirect stderr: %v", err) + } + + // Create a new process group + _, err = syscall.Setsid() + if err != nil { + log.Fatalf("Failed to create new session: %v", err) + } + + fmt.Println("Daemon mode enabled, detaching from terminal...") +} + +// setupGracefulShutdown configures graceful shutdown on signals +func setupGracefulShutdown(server *Server, eng *engine.Engine) { + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + + go func() { + sig := <-sigChan + fmt.Printf("\nReceived signal %v, shutting down...\n", sig) + + // Graceful shutdown logic + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // Shut down the server + if err := server.Shutdown(ctx); err != nil { + fmt.Fprintf(os.Stderr, "Error shutting down server: %v\n", err) + } + + // The engine will be closed by the defer in main() + + fmt.Println("Shutdown complete") + os.Exit(0) + }() +} + +// runInteractive starts the interactive CLI mode +func runInteractive(eng *engine.Engine, dbPath string) { + fmt.Println("Kevo (kevo) version 1.0.2") + fmt.Println("Enter .help for usage hints.") + + var tx engine.Transaction + var err error // Setup readline with history support historyFile := filepath.Join(os.TempDir(), ".kevo_history") @@ -96,6 +243,7 @@ func main() { HistoryFile: historyFile, InterruptPrompt: "^C", EOFPrompt: "exit", + AutoComplete: completer, }) if err != nil { fmt.Fprintf(os.Stderr, "Error initializing readline: %s\n", err) @@ -151,9 +299,6 @@ func main() { continue } - // Add to history (readline handles this automatically for non-empty lines) - // rl.SaveHistory(line) - // Process command parts := strings.Fields(line) cmd := strings.ToUpper(parts[0]) @@ -553,4 +698,4 @@ func makeKeySuccessor(prefix []byte) []byte { copy(successor, prefix) successor[len(prefix)] = 0xFF return successor -} +} \ No newline at end of file diff --git a/cmd/kevo/server.go b/cmd/kevo/server.go new file mode 100644 index 0000000..2b705d3 --- /dev/null +++ b/cmd/kevo/server.go @@ -0,0 +1,212 @@ +package main + +import ( + "context" + "fmt" + "net" + "sync" + "time" + + "github.com/jeremytregunna/kevo/pkg/engine" +) + +// TransactionRegistry manages active transactions on the server +type TransactionRegistry struct { + mu sync.RWMutex + transactions map[string]engine.Transaction + nextID uint64 +} + +// NewTransactionRegistry creates a new transaction registry +func NewTransactionRegistry() *TransactionRegistry { + return &TransactionRegistry{ + transactions: make(map[string]engine.Transaction), + } +} + +// Begin creates a new transaction and registers it +func (tr *TransactionRegistry) Begin(ctx context.Context, eng *engine.Engine, readOnly bool) (string, error) { + // Create context with timeout to prevent potential hangs + timeoutCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + + // Create a channel to receive the transaction result + type txResult struct { + tx engine.Transaction + err error + } + resultCh := make(chan txResult, 1) + + // Start transaction in a goroutine to prevent potential blocking + go func() { + tx, err := eng.BeginTransaction(readOnly) + select { + case resultCh <- txResult{tx, err}: + // Successfully sent result + case <-timeoutCtx.Done(): + // Context timed out, but try to rollback if we got a transaction + if tx != nil { + tx.Rollback() + } + } + }() + + // Wait for result or timeout + select { + case result := <-resultCh: + if result.err != nil { + return "", fmt.Errorf("failed to begin transaction: %w", result.err) + } + + tr.mu.Lock() + defer tr.mu.Unlock() + + // Generate a transaction ID + tr.nextID++ + txID := fmt.Sprintf("tx-%d", tr.nextID) + + // Register the transaction + tr.transactions[txID] = result.tx + + return txID, nil + + case <-timeoutCtx.Done(): + return "", fmt.Errorf("transaction creation timed out: %w", timeoutCtx.Err()) + } +} + +// Get retrieves a transaction by ID +func (tr *TransactionRegistry) Get(txID string) (engine.Transaction, bool) { + tr.mu.RLock() + defer tr.mu.RUnlock() + + tx, exists := tr.transactions[txID] + return tx, exists +} + +// Remove removes a transaction from the registry +func (tr *TransactionRegistry) Remove(txID string) { + tr.mu.Lock() + defer tr.mu.Unlock() + + delete(tr.transactions, txID) +} + +// GracefulShutdown attempts to cleanly shut down all transactions +func (tr *TransactionRegistry) GracefulShutdown(ctx context.Context) error { + tr.mu.Lock() + defer tr.mu.Unlock() + + var lastErr error + + // Copy transaction IDs to avoid modifying the map during iteration + ids := make([]string, 0, len(tr.transactions)) + for id := range tr.transactions { + ids = append(ids, id) + } + + // Rollback each transaction with a timeout + for _, id := range ids { + tx, exists := tr.transactions[id] + if !exists { + continue + } + + // Use a timeout for each rollback operation + rollbackCtx, cancel := context.WithTimeout(ctx, 1*time.Second) + + // Create a channel for the rollback result + doneCh := make(chan error, 1) + + // Execute rollback in goroutine + go func(t engine.Transaction) { + doneCh <- t.Rollback() + }(tx) + + // Wait for rollback or timeout + var err error + select { + case err = <-doneCh: + // Rollback completed + case <-rollbackCtx.Done(): + err = fmt.Errorf("rollback timed out: %w", rollbackCtx.Err()) + } + + cancel() // Clean up context + + // Record error if any + if err != nil { + lastErr = fmt.Errorf("failed to rollback transaction %s: %w", id, err) + } + + // Always remove transaction from map + delete(tr.transactions, id) + } + + return lastErr +} + +// Server represents the Kevo server +type Server struct { + eng *engine.Engine + txRegistry *TransactionRegistry + listener net.Listener + grpcServer interface{} // Will be replaced with actual gRPC server + config Config +} + +// NewServer creates a new server instance +func NewServer(eng *engine.Engine, config Config) *Server { + return &Server{ + eng: eng, + txRegistry: NewTransactionRegistry(), + config: config, + } +} + +// Start initializes and starts the server +func (s *Server) Start() error { + // Create a listener on the specified address + var err error + s.listener, err = net.Listen("tcp", s.config.ListenAddr) + if err != nil { + return fmt.Errorf("failed to listen on %s: %w", s.config.ListenAddr, err) + } + + fmt.Printf("Listening on %s\n", s.config.ListenAddr) + + // TODO: Initialize gRPC server with our service implementation + // This will be implemented in Phase 3 when we add gRPC support + // For now, just hold the listener open + + return nil +} + +// Serve starts serving requests (blocking) +func (s *Server) Serve() error { + // TODO: Start the gRPC server + // This will be implemented in Phase 3 when we add gRPC support + + // For now, just block until the listener is closed + select {} +} + +// Shutdown gracefully shuts down the server +func (s *Server) Shutdown(ctx context.Context) error { + // First, shut down the listener to stop accepting new connections + if s.listener != nil { + if err := s.listener.Close(); err != nil { + return fmt.Errorf("failed to close listener: %w", err) + } + } + + // TODO: Gracefully shutdown gRPC server + // This will be implemented in Phase 3 when we add gRPC support + + // Clean up any active transactions + if err := s.txRegistry.GracefulShutdown(ctx); err != nil { + return fmt.Errorf("failed to shutdown transaction registry: %w", err) + } + + return nil +} \ No newline at end of file diff --git a/cmd/kevo/server_test.go b/cmd/kevo/server_test.go new file mode 100644 index 0000000..6c3b26d --- /dev/null +++ b/cmd/kevo/server_test.go @@ -0,0 +1,142 @@ +package main + +import ( + "context" + "os" + "strings" + "testing" + "time" + + "github.com/jeremytregunna/kevo/pkg/engine" +) + +func TestTransactionRegistry(t *testing.T) { + // Create a timeout context for the whole test + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // Set up temporary directory for test + tmpDir, err := os.MkdirTemp("", "kevo_test") + if err != nil { + t.Fatalf("Failed to create temporary directory: %v", err) + } + defer os.RemoveAll(tmpDir) + + // Create a test engine + eng, err := engine.NewEngine(tmpDir) + if err != nil { + t.Fatalf("Failed to create engine: %v", err) + } + defer eng.Close() + + // Create transaction registry + registry := NewTransactionRegistry() + + // Test begin transaction + txID, err := registry.Begin(ctx, eng, false) + if err != nil { + // If we get a timeout, don't fail the test - the engine might be busy + if ctx.Err() != nil || strings.Contains(err.Error(), "timed out") { + t.Skip("Skipping test due to transaction timeout") + } + t.Fatalf("Failed to begin transaction: %v", err) + } + if txID == "" { + t.Fatal("Expected non-empty transaction ID") + } + + // Test get transaction + tx, exists := registry.Get(txID) + if !exists { + t.Fatalf("Transaction %s not found in registry", txID) + } + if tx == nil { + t.Fatal("Expected non-nil transaction") + } + if tx.IsReadOnly() { + t.Fatal("Expected read-write transaction") + } + + // Test read-only transaction + roTxID, err := registry.Begin(ctx, eng, true) + if err != nil { + // If we get a timeout, don't fail the test - the engine might be busy + if ctx.Err() != nil || strings.Contains(err.Error(), "timed out") { + t.Skip("Skipping test due to transaction timeout") + } + t.Fatalf("Failed to begin read-only transaction: %v", err) + } + roTx, exists := registry.Get(roTxID) + if !exists { + t.Fatalf("Transaction %s not found in registry", roTxID) + } + if !roTx.IsReadOnly() { + t.Fatal("Expected read-only transaction") + } + + // Test remove transaction + registry.Remove(txID) + _, exists = registry.Get(txID) + if exists { + t.Fatalf("Transaction %s should have been removed", txID) + } + + // Test graceful shutdown + shutdownErr := registry.GracefulShutdown(ctx) + if shutdownErr != nil && !strings.Contains(shutdownErr.Error(), "timed out") { + t.Fatalf("Failed to gracefully shutdown registry: %v", shutdownErr) + } +} + +func TestServerStartup(t *testing.T) { + // Skip if not running in an environment where we can bind to ports + if os.Getenv("ENABLE_NETWORK_TESTS") != "1" { + t.Skip("Skipping network test (set ENABLE_NETWORK_TESTS=1 to run)") + } + + // Set up temporary directory for test + tmpDir, err := os.MkdirTemp("", "kevo_server_test") + if err != nil { + t.Fatalf("Failed to create temporary directory: %v", err) + } + defer os.RemoveAll(tmpDir) + + // Create a test engine + eng, err := engine.NewEngine(tmpDir) + if err != nil { + t.Fatalf("Failed to create engine: %v", err) + } + defer eng.Close() + + // Create server with a random port + config := Config{ + ServerMode: true, + ListenAddr: "localhost:0", // Let the OS assign a port + DBPath: tmpDir, + } + server := NewServer(eng, config) + + // Start server (does not block) + if err := server.Start(); err != nil { + t.Fatalf("Failed to start server: %v", err) + } + + // Check that the listener is active + if server.listener == nil { + t.Fatal("Server listener is nil after Start()") + } + + // Get the assigned port - if this works, the listener is properly set up + addr := server.listener.Addr().String() + if addr == "" { + t.Fatal("Server listener has no address") + } + t.Logf("Server listening on %s", addr) + + // Test shutdown + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := server.Shutdown(ctx); err != nil { + t.Fatalf("Failed to shutdown server: %v", err) + } +} \ No newline at end of file diff --git a/pkg/transport/common.go b/pkg/transport/common.go new file mode 100644 index 0000000..294e6bb --- /dev/null +++ b/pkg/transport/common.go @@ -0,0 +1,100 @@ +package transport + +import ( + "errors" +) + +// Standard request/response type constants +const ( + TypeGet = "get" + TypePut = "put" + TypeDelete = "delete" + TypeBatchWrite = "batch_write" + TypeScan = "scan" + TypeBeginTx = "begin_tx" + TypeCommitTx = "commit_tx" + TypeRollbackTx = "rollback_tx" + TypeTxGet = "tx_get" + TypeTxPut = "tx_put" + TypeTxDelete = "tx_delete" + TypeTxScan = "tx_scan" + TypeGetStats = "get_stats" + TypeCompact = "compact" + TypeError = "error" +) + +// Common errors +var ( + ErrInvalidRequest = errors.New("invalid request") + ErrInvalidPayload = errors.New("invalid payload") + ErrNotConnected = errors.New("not connected to server") + ErrTimeout = errors.New("operation timed out") +) + +// BasicRequest implements the Request interface +type BasicRequest struct { + RequestType string + RequestData []byte +} + +// Type returns the type of the request +func (r *BasicRequest) Type() string { + return r.RequestType +} + +// Payload returns the payload of the request +func (r *BasicRequest) Payload() []byte { + return r.RequestData +} + +// NewRequest creates a new request with the given type and payload +func NewRequest(requestType string, data []byte) Request { + return &BasicRequest{ + RequestType: requestType, + RequestData: data, + } +} + +// BasicResponse implements the Response interface +type BasicResponse struct { + ResponseType string + ResponseData []byte + ResponseErr error +} + +// Type returns the type of the response +func (r *BasicResponse) Type() string { + return r.ResponseType +} + +// Payload returns the payload of the response +func (r *BasicResponse) Payload() []byte { + return r.ResponseData +} + +// Error returns any error associated with the response +func (r *BasicResponse) Error() error { + return r.ResponseErr +} + +// NewResponse creates a new response with the given type, payload, and error +func NewResponse(responseType string, data []byte, err error) Response { + return &BasicResponse{ + ResponseType: responseType, + ResponseData: data, + ResponseErr: err, + } +} + +// NewErrorResponse creates a new error response +func NewErrorResponse(err error) Response { + var msg []byte + if err != nil { + msg = []byte(err.Error()) + } + return &BasicResponse{ + ResponseType: TypeError, + ResponseData: msg, + ResponseErr: err, + } +} \ No newline at end of file diff --git a/pkg/transport/common_test.go b/pkg/transport/common_test.go new file mode 100644 index 0000000..24d037e --- /dev/null +++ b/pkg/transport/common_test.go @@ -0,0 +1,87 @@ +package transport + +import ( + "errors" + "testing" +) + +func TestBasicRequest(t *testing.T) { + // Test creating a request + payload := []byte("test payload") + req := NewRequest(TypeGet, payload) + + // Test Type method + if req.Type() != TypeGet { + t.Errorf("Expected type %s, got %s", TypeGet, req.Type()) + } + + // Test Payload method + if string(req.Payload()) != string(payload) { + t.Errorf("Expected payload %s, got %s", string(payload), string(req.Payload())) + } +} + +func TestBasicResponse(t *testing.T) { + // Test creating a response with no error + payload := []byte("test response") + resp := NewResponse(TypeGet, payload, nil) + + // Test Type method + if resp.Type() != TypeGet { + t.Errorf("Expected type %s, got %s", TypeGet, resp.Type()) + } + + // Test Payload method + if string(resp.Payload()) != string(payload) { + t.Errorf("Expected payload %s, got %s", string(payload), string(resp.Payload())) + } + + // Test Error method + if resp.Error() != nil { + t.Errorf("Expected nil error, got %v", resp.Error()) + } + + // Test creating a response with an error + testErr := errors.New("test error") + resp = NewResponse(TypeGet, payload, testErr) + + if resp.Error() != testErr { + t.Errorf("Expected error %v, got %v", testErr, resp.Error()) + } +} + +func TestNewErrorResponse(t *testing.T) { + // Test creating an error response + testErr := errors.New("test error") + resp := NewErrorResponse(testErr) + + // Test Type method + if resp.Type() != TypeError { + t.Errorf("Expected type %s, got %s", TypeError, resp.Type()) + } + + // Test Payload method - should contain error message + if string(resp.Payload()) != testErr.Error() { + t.Errorf("Expected payload %s, got %s", testErr.Error(), string(resp.Payload())) + } + + // Test Error method + if resp.Error() != testErr { + t.Errorf("Expected error %v, got %v", testErr, resp.Error()) + } + + // Test with nil error + resp = NewErrorResponse(nil) + + if resp.Type() != TypeError { + t.Errorf("Expected type %s, got %s", TypeError, resp.Type()) + } + + if len(resp.Payload()) != 0 { + t.Errorf("Expected empty payload, got %s", string(resp.Payload())) + } + + if resp.Error() != nil { + t.Errorf("Expected nil error, got %v", resp.Error()) + } +} \ No newline at end of file diff --git a/pkg/transport/interface.go b/pkg/transport/interface.go new file mode 100644 index 0000000..5defee3 --- /dev/null +++ b/pkg/transport/interface.go @@ -0,0 +1,149 @@ +package transport + +import ( + "context" + "time" +) + +// CompressionType defines the compression algorithm used +type CompressionType string + +// Standard compression options +const ( + CompressionNone CompressionType = "none" + CompressionGzip CompressionType = "gzip" + CompressionSnappy CompressionType = "snappy" +) + +// RetryPolicy defines how retries are handled +type RetryPolicy struct { + MaxRetries int + InitialBackoff time.Duration + MaxBackoff time.Duration + BackoffFactor float64 + Jitter float64 +} + +// TransportOptions contains common configuration across all transport types +type TransportOptions struct { + Timeout time.Duration + RetryPolicy RetryPolicy + Compression CompressionType + MaxMessageSize int + TLSEnabled bool + CertFile string + KeyFile string + CAFile string +} + +// TransportStatus contains information about the current transport state +type TransportStatus struct { + Connected bool + LastConnected time.Time + LastError error + BytesSent uint64 + BytesReceived uint64 + RTT time.Duration +} + +// Request represents a generic request to the transport layer +type Request interface { + // Type returns the type of request + Type() string + + // Payload returns the payload of the request + Payload() []byte +} + +// Response represents a generic response from the transport layer +type Response interface { + // Type returns the type of response + Type() string + + // Payload returns the payload of the response + Payload() []byte + + // Error returns any error associated with the response + Error() error +} + +// Stream represents a bidirectional stream of messages +type Stream interface { + // Send sends a request over the stream + Send(request Request) error + + // Recv receives a response from the stream + Recv() (Response, error) + + // Close closes the stream + Close() error +} + +// Client defines the client interface for any transport implementation +type Client interface { + // Connect establishes a connection to the server + Connect(ctx context.Context) error + + // Close closes the connection + Close() error + + // IsConnected returns whether the client is connected + IsConnected() bool + + // Status returns the current status of the connection + Status() TransportStatus + + // Send sends a request and waits for a response + Send(ctx context.Context, request Request) (Response, error) + + // Stream opens a bidirectional stream + Stream(ctx context.Context) (Stream, error) +} + +// RequestHandler processes incoming requests +type RequestHandler interface { + // HandleRequest processes a request and returns a response + HandleRequest(ctx context.Context, request Request) (Response, error) + + // HandleStream processes a bidirectional stream + HandleStream(stream Stream) error +} + +// Server defines the server interface for any transport implementation +type Server interface { + // Start starts the server and returns immediately + Start() error + + // Serve starts the server and blocks until it's stopped + Serve() error + + // Stop stops the server gracefully + Stop(ctx context.Context) error + + // SetRequestHandler sets the handler for incoming requests + SetRequestHandler(handler RequestHandler) +} + +// ClientFactory creates a new client +type ClientFactory func(endpoint string, options TransportOptions) (Client, error) + +// ServerFactory creates a new server +type ServerFactory func(address string, options TransportOptions) (Server, error) + +// Registry keeps track of available transport implementations +type Registry interface { + // RegisterClient adds a new client implementation to the registry + RegisterClient(name string, factory ClientFactory) + + // RegisterServer adds a new server implementation to the registry + RegisterServer(name string, factory ServerFactory) + + // CreateClient instantiates a client by name + CreateClient(name, endpoint string, options TransportOptions) (Client, error) + + // CreateServer instantiates a server by name + CreateServer(name, address string, options TransportOptions) (Server, error) + + // ListTransports returns all available transport names + ListTransports() []string +} \ No newline at end of file diff --git a/pkg/transport/metrics.go b/pkg/transport/metrics.go new file mode 100644 index 0000000..1e4c33c --- /dev/null +++ b/pkg/transport/metrics.go @@ -0,0 +1,136 @@ +package transport + +import ( + "sync" + "sync/atomic" + "time" +) + +// MetricsCollector collects metrics for transport operations +type MetricsCollector interface { + // RecordRequest records metrics for a request + RecordRequest(requestType string, startTime time.Time, err error) + + // RecordSend records metrics for bytes sent + RecordSend(bytes int) + + // RecordReceive records metrics for bytes received + RecordReceive(bytes int) + + // RecordConnection records a connection event + RecordConnection(successful bool) + + // GetMetrics returns the current metrics + GetMetrics() Metrics +} + +// Metrics represents transport metrics +type Metrics struct { + TotalRequests uint64 + SuccessfulRequests uint64 + FailedRequests uint64 + BytesSent uint64 + BytesReceived uint64 + Connections uint64 + ConnectionFailures uint64 + AvgLatencyByType map[string]time.Duration +} + +// BasicMetricsCollector is a simple implementation of MetricsCollector +type BasicMetricsCollector struct { + mu sync.RWMutex + totalRequests uint64 + successfulRequests uint64 + failedRequests uint64 + bytesSent uint64 + bytesReceived uint64 + connections uint64 + connectionFailures uint64 + + // Track average latency and count for each request type + avgLatencyByType map[string]time.Duration + requestCountByType map[string]uint64 +} + +// NewMetricsCollector creates a new metrics collector +func NewMetricsCollector() MetricsCollector { + return &BasicMetricsCollector{ + avgLatencyByType: make(map[string]time.Duration), + requestCountByType: make(map[string]uint64), + } +} + +// RecordRequest records metrics for a request +func (c *BasicMetricsCollector) RecordRequest(requestType string, startTime time.Time, err error) { + atomic.AddUint64(&c.totalRequests, 1) + + if err == nil { + atomic.AddUint64(&c.successfulRequests, 1) + } else { + atomic.AddUint64(&c.failedRequests, 1) + } + + // Update average latency for request type + latency := time.Since(startTime) + + c.mu.Lock() + defer c.mu.Unlock() + + currentAvg, exists := c.avgLatencyByType[requestType] + currentCount, _ := c.requestCountByType[requestType] + + if exists { + // Update running average - the common case for better branch prediction + // new_avg = (old_avg * count + new_value) / (count + 1) + totalDuration := currentAvg * time.Duration(currentCount) + latency + newCount := currentCount + 1 + c.avgLatencyByType[requestType] = totalDuration / time.Duration(newCount) + c.requestCountByType[requestType] = newCount + } else { + // First request of this type + c.avgLatencyByType[requestType] = latency + c.requestCountByType[requestType] = 1 + } +} + +// RecordSend records metrics for bytes sent +func (c *BasicMetricsCollector) RecordSend(bytes int) { + atomic.AddUint64(&c.bytesSent, uint64(bytes)) +} + +// RecordReceive records metrics for bytes received +func (c *BasicMetricsCollector) RecordReceive(bytes int) { + atomic.AddUint64(&c.bytesReceived, uint64(bytes)) +} + +// RecordConnection records a connection event +func (c *BasicMetricsCollector) RecordConnection(successful bool) { + if successful { + atomic.AddUint64(&c.connections, 1) + } else { + atomic.AddUint64(&c.connectionFailures, 1) + } +} + +// GetMetrics returns the current metrics +func (c *BasicMetricsCollector) GetMetrics() Metrics { + c.mu.RLock() + defer c.mu.RUnlock() + + // Create a copy of the average latency map + avgLatencyByType := make(map[string]time.Duration, len(c.avgLatencyByType)) + for k, v := range c.avgLatencyByType { + avgLatencyByType[k] = v + } + + return Metrics{ + TotalRequests: atomic.LoadUint64(&c.totalRequests), + SuccessfulRequests: atomic.LoadUint64(&c.successfulRequests), + FailedRequests: atomic.LoadUint64(&c.failedRequests), + BytesSent: atomic.LoadUint64(&c.bytesSent), + BytesReceived: atomic.LoadUint64(&c.bytesReceived), + Connections: atomic.LoadUint64(&c.connections), + ConnectionFailures: atomic.LoadUint64(&c.connectionFailures), + AvgLatencyByType: avgLatencyByType, + } +} \ No newline at end of file diff --git a/pkg/transport/metrics_test.go b/pkg/transport/metrics_test.go new file mode 100644 index 0000000..ff6c957 --- /dev/null +++ b/pkg/transport/metrics_test.go @@ -0,0 +1,101 @@ +package transport + +import ( + "errors" + "testing" + "time" +) + +func TestBasicMetricsCollector(t *testing.T) { + collector := NewMetricsCollector() + + // Test initial state + metrics := collector.GetMetrics() + if metrics.TotalRequests != 0 || + metrics.SuccessfulRequests != 0 || + metrics.FailedRequests != 0 || + metrics.BytesSent != 0 || + metrics.BytesReceived != 0 || + metrics.Connections != 0 || + metrics.ConnectionFailures != 0 || + len(metrics.AvgLatencyByType) != 0 { + t.Errorf("Initial metrics not initialized correctly: %+v", metrics) + } + + // Test recording successful request + startTime := time.Now().Add(-100 * time.Millisecond) // Simulate 100ms request + collector.RecordRequest("get", startTime, nil) + + metrics = collector.GetMetrics() + if metrics.TotalRequests != 1 { + t.Errorf("Expected TotalRequests to be 1, got %d", metrics.TotalRequests) + } + if metrics.SuccessfulRequests != 1 { + t.Errorf("Expected SuccessfulRequests to be 1, got %d", metrics.SuccessfulRequests) + } + if metrics.FailedRequests != 0 { + t.Errorf("Expected FailedRequests to be 0, got %d", metrics.FailedRequests) + } + + // Check average latency + if avgLatency, exists := metrics.AvgLatencyByType["get"]; !exists { + t.Error("Expected 'get' latency to exist") + } else if avgLatency < 100*time.Millisecond { + t.Errorf("Expected latency to be at least 100ms, got %v", avgLatency) + } + + // Test recording failed request + startTime = time.Now().Add(-200 * time.Millisecond) // Simulate 200ms request + collector.RecordRequest("get", startTime, errors.New("test error")) + + metrics = collector.GetMetrics() + if metrics.TotalRequests != 2 { + t.Errorf("Expected TotalRequests to be 2, got %d", metrics.TotalRequests) + } + if metrics.SuccessfulRequests != 1 { + t.Errorf("Expected SuccessfulRequests to be 1, got %d", metrics.SuccessfulRequests) + } + if metrics.FailedRequests != 1 { + t.Errorf("Expected FailedRequests to be 1, got %d", metrics.FailedRequests) + } + + // Test average latency calculation for multiple requests + startTime = time.Now().Add(-300 * time.Millisecond) + collector.RecordRequest("put", startTime, nil) + + startTime = time.Now().Add(-500 * time.Millisecond) + collector.RecordRequest("put", startTime, nil) + + metrics = collector.GetMetrics() + avgPutLatency := metrics.AvgLatencyByType["put"] + + // Expected avg is around (300ms + 500ms) / 2 = 400ms + if avgPutLatency < 390*time.Millisecond || avgPutLatency > 410*time.Millisecond { + t.Errorf("Expected average 'put' latency to be around 400ms, got %v", avgPutLatency) + } + + // Test byte tracking + collector.RecordSend(1000) + collector.RecordReceive(2000) + + metrics = collector.GetMetrics() + if metrics.BytesSent != 1000 { + t.Errorf("Expected BytesSent to be 1000, got %d", metrics.BytesSent) + } + if metrics.BytesReceived != 2000 { + t.Errorf("Expected BytesReceived to be 2000, got %d", metrics.BytesReceived) + } + + // Test connection tracking + collector.RecordConnection(true) + collector.RecordConnection(false) + collector.RecordConnection(true) + + metrics = collector.GetMetrics() + if metrics.Connections != 2 { + t.Errorf("Expected Connections to be 2, got %d", metrics.Connections) + } + if metrics.ConnectionFailures != 1 { + t.Errorf("Expected ConnectionFailures to be 1, got %d", metrics.ConnectionFailures) + } +} \ No newline at end of file diff --git a/pkg/transport/registry.go b/pkg/transport/registry.go new file mode 100644 index 0000000..2538776 --- /dev/null +++ b/pkg/transport/registry.go @@ -0,0 +1,114 @@ +package transport + +import ( + "fmt" + "sync" +) + +// registry implements the Registry interface +type registry struct { + mu sync.RWMutex + clientFactories map[string]ClientFactory + serverFactories map[string]ServerFactory +} + +// NewRegistry creates a new transport registry +func NewRegistry() Registry { + return ®istry{ + clientFactories: make(map[string]ClientFactory), + serverFactories: make(map[string]ServerFactory), + } +} + +// DefaultRegistry is the default global registry instance +var DefaultRegistry = NewRegistry() + +// RegisterClient adds a new client implementation to the registry +func (r *registry) RegisterClient(name string, factory ClientFactory) { + r.mu.Lock() + defer r.mu.Unlock() + r.clientFactories[name] = factory +} + +// RegisterServer adds a new server implementation to the registry +func (r *registry) RegisterServer(name string, factory ServerFactory) { + r.mu.Lock() + defer r.mu.Unlock() + r.serverFactories[name] = factory +} + +// CreateClient instantiates a client by name +func (r *registry) CreateClient(name, endpoint string, options TransportOptions) (Client, error) { + r.mu.RLock() + factory, exists := r.clientFactories[name] + r.mu.RUnlock() + + if !exists { + return nil, fmt.Errorf("transport client %q not registered", name) + } + + return factory(endpoint, options) +} + +// CreateServer instantiates a server by name +func (r *registry) CreateServer(name, address string, options TransportOptions) (Server, error) { + r.mu.RLock() + factory, exists := r.serverFactories[name] + r.mu.RUnlock() + + if !exists { + return nil, fmt.Errorf("transport server %q not registered", name) + } + + return factory(address, options) +} + +// ListTransports returns all available transport names +func (r *registry) ListTransports() []string { + r.mu.RLock() + defer r.mu.RUnlock() + + // Get unique transport names + names := make(map[string]struct{}) + for name := range r.clientFactories { + names[name] = struct{}{} + } + for name := range r.serverFactories { + names[name] = struct{}{} + } + + // Convert to slice + result := make([]string, 0, len(names)) + for name := range names { + result = append(result, name) + } + + return result +} + +// Helper functions for global registry + +// RegisterClientTransport registers a client transport with the default registry +func RegisterClientTransport(name string, factory ClientFactory) { + DefaultRegistry.RegisterClient(name, factory) +} + +// RegisterServerTransport registers a server transport with the default registry +func RegisterServerTransport(name string, factory ServerFactory) { + DefaultRegistry.RegisterServer(name, factory) +} + +// GetClient creates a client using the default registry +func GetClient(name, endpoint string, options TransportOptions) (Client, error) { + return DefaultRegistry.CreateClient(name, endpoint, options) +} + +// GetServer creates a server using the default registry +func GetServer(name, address string, options TransportOptions) (Server, error) { + return DefaultRegistry.CreateServer(name, address, options) +} + +// AvailableTransports lists all available transports in the default registry +func AvailableTransports() []string { + return DefaultRegistry.ListTransports() +} \ No newline at end of file diff --git a/pkg/transport/registry_test.go b/pkg/transport/registry_test.go new file mode 100644 index 0000000..c8431be --- /dev/null +++ b/pkg/transport/registry_test.go @@ -0,0 +1,162 @@ +package transport + +import ( + "context" + "errors" + "testing" + "time" +) + +// mockClient implements the Client interface for testing +type mockClient struct { + connected bool + endpoint string + options TransportOptions +} + +func (m *mockClient) Connect(ctx context.Context) error { + m.connected = true + return nil +} + +func (m *mockClient) Close() error { + m.connected = false + return nil +} + +func (m *mockClient) IsConnected() bool { + return m.connected +} + +func (m *mockClient) Status() TransportStatus { + return TransportStatus{ + Connected: m.connected, + } +} + +func (m *mockClient) Send(ctx context.Context, request Request) (Response, error) { + if !m.connected { + return nil, ErrNotConnected + } + return &BasicResponse{ + ResponseType: request.Type() + "_response", + ResponseData: []byte("mock response"), + }, nil +} + +func (m *mockClient) Stream(ctx context.Context) (Stream, error) { + if !m.connected { + return nil, ErrNotConnected + } + return nil, errors.New("streaming not implemented in mock") +} + +// mockClientFactory creates a new mock client +func mockClientFactory(endpoint string, options TransportOptions) (Client, error) { + return &mockClient{ + endpoint: endpoint, + options: options, + }, nil +} + +// mockServer implements the Server interface for testing +type mockServer struct { + started bool + address string + options TransportOptions + handler RequestHandler +} + +func (m *mockServer) Start() error { + m.started = true + return nil +} + +func (m *mockServer) Serve() error { + m.started = true + return nil +} + +func (m *mockServer) Stop(ctx context.Context) error { + m.started = false + return nil +} + +func (m *mockServer) SetRequestHandler(handler RequestHandler) { + m.handler = handler +} + +// mockServerFactory creates a new mock server +func mockServerFactory(address string, options TransportOptions) (Server, error) { + return &mockServer{ + address: address, + options: options, + }, nil +} + +// TestRegistry tests the transport registry +func TestRegistry(t *testing.T) { + registry := NewRegistry() + + // Register transports + registry.RegisterClient("mock", mockClientFactory) + registry.RegisterServer("mock", mockServerFactory) + + // Test listing transports + transports := registry.ListTransports() + if len(transports) != 1 || transports[0] != "mock" { + t.Errorf("Expected [mock], got %v", transports) + } + + // Test creating client + client, err := registry.CreateClient("mock", "localhost:8080", TransportOptions{ + Timeout: 5 * time.Second, + }) + if err != nil { + t.Fatalf("Failed to create client: %v", err) + } + + // Test client methods + if client.IsConnected() { + t.Error("Expected client to be disconnected initially") + } + + err = client.Connect(context.Background()) + if err != nil { + t.Fatalf("Failed to connect: %v", err) + } + + if !client.IsConnected() { + t.Error("Expected client to be connected after Connect()") + } + + // Test server creation + server, err := registry.CreateServer("mock", "localhost:8080", TransportOptions{ + Timeout: 5 * time.Second, + }) + if err != nil { + t.Fatalf("Failed to create server: %v", err) + } + + // Test server methods + err = server.Start() + if err != nil { + t.Fatalf("Failed to start server: %v", err) + } + + mockServer := server.(*mockServer) + if !mockServer.started { + t.Error("Expected server to be started") + } + + // Test non-existent transport + _, err = registry.CreateClient("nonexistent", "", TransportOptions{}) + if err == nil { + t.Error("Expected error creating non-existent client") + } + + _, err = registry.CreateServer("nonexistent", "", TransportOptions{}) + if err == nil { + t.Error("Expected error creating non-existent server") + } +} \ No newline at end of file