feat: add common transport interface and server mode to kevo

This commit is contained in:
Jeremy Tregunna 2025-04-21 18:15:36 -06:00
parent 001934e7b5
commit 14d1f84960
Signed by: jer
GPG Key ID: 1278B36BA6F5D5E4
10 changed files with 1365 additions and 17 deletions

View File

@ -1,11 +1,16 @@
package main package main
import ( import (
"context"
"flag"
"fmt" "fmt"
"io" "io"
"log"
"os" "os"
"os/signal"
"path/filepath" "path/filepath"
"strings" "strings"
"syscall"
"time" "time"
"github.com/chzyer/readline" "github.com/chzyer/readline"
@ -43,9 +48,14 @@ const helpText = `
Kevo (kevo) - A lightweight, minimalist, storage engine. Kevo (kevo) - A lightweight, minimalist, storage engine.
Usage: Usage:
keco [database_path] - Start with an optional database path kevo [options] [database_path] - Start with an optional database path
Commands: Options:
-server - Run in server mode, exposing a gRPC API
-daemon - Run in daemon mode (detached from terminal)
-address string - Address to listen on in server mode (default "localhost:50051")
Commands (interactive mode only):
.help - Show this help message .help - Show this help message
.open PATH - Open a database at PATH .open PATH - Open a database at PATH
.close - Close the current database .close - Close the current database
@ -68,27 +78,164 @@ Commands:
- Note: start and end are treated as string keys, not numeric indices - Note: start and end are treated as string keys, not numeric indices
` `
// Config holds the application configuration
type Config struct {
ServerMode bool
DaemonMode bool
ListenAddr string
DBPath string
}
func main() { func main() {
fmt.Println("Kevo (kevo) version 1.0.2") // Parse command line arguments and get configuration
fmt.Println("Enter .help for usage hints.") config := parseFlags()
// Initialize variables // Open database if path provided
var eng *engine.Engine var eng *engine.Engine
var tx engine.Transaction
var err error var err error
var dbPath string
// Check if a database path was provided as an argument if config.DBPath != "" {
if len(os.Args) > 1 { fmt.Printf("Opening database at %s\n", config.DBPath)
dbPath = os.Args[1] eng, err = engine.NewEngine(config.DBPath)
fmt.Printf("Opening database at %s\n", dbPath)
eng, err = engine.NewEngine(dbPath)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "Error opening database: %s\n", err) fmt.Fprintf(os.Stderr, "Error opening database: %s\n", err)
os.Exit(1) os.Exit(1)
} }
defer eng.Close()
} }
// Check if we should run in server mode
if config.ServerMode {
if eng == nil {
fmt.Fprintf(os.Stderr, "Error: Server mode requires a database path\n")
os.Exit(1)
}
runServer(eng, config)
return
}
// Run in interactive mode
runInteractive(eng, config.DBPath)
}
// parseFlags parses command line flags and returns a Config
func parseFlags() Config {
serverMode := flag.Bool("server", false, "Run in server mode, exposing a gRPC API")
daemonMode := flag.Bool("daemon", false, "Run in daemon mode (detached from terminal)")
listenAddr := flag.String("address", "localhost:50051", "Address to listen on in server mode")
// Parse flags
flag.Parse()
// Get database path from remaining arguments
var dbPath string
if flag.NArg() > 0 {
dbPath = flag.Arg(0)
}
return Config{
ServerMode: *serverMode,
DaemonMode: *daemonMode,
ListenAddr: *listenAddr,
DBPath: dbPath,
}
}
// runServer initializes and runs the Kevo server
func runServer(eng *engine.Engine, config Config) {
// Set up daemon mode if requested
if config.DaemonMode {
setupDaemonMode()
}
// Create and start the server
server := NewServer(eng, config)
// Start the server (non-blocking)
if err := server.Start(); err != nil {
fmt.Fprintf(os.Stderr, "Error starting server: %v\n", err)
os.Exit(1)
}
fmt.Printf("Kevo server started on %s\n", config.ListenAddr)
// Set up signal handling for graceful shutdown
setupGracefulShutdown(server, eng)
// Start serving (blocking)
if err := server.Serve(); err != nil {
fmt.Fprintf(os.Stderr, "Error serving: %v\n", err)
os.Exit(1)
}
}
// setupDaemonMode configures process to run as a daemon
func setupDaemonMode() {
// Redirect standard file descriptors to /dev/null
null, err := os.OpenFile("/dev/null", os.O_RDWR, 0)
if err != nil {
log.Fatalf("Failed to open /dev/null: %v", err)
}
// Redirect standard file descriptors to /dev/null
err = syscall.Dup2(int(null.Fd()), int(os.Stdin.Fd()))
if err != nil {
log.Fatalf("Failed to redirect stdin: %v", err)
}
err = syscall.Dup2(int(null.Fd()), int(os.Stdout.Fd()))
if err != nil {
log.Fatalf("Failed to redirect stdout: %v", err)
}
err = syscall.Dup2(int(null.Fd()), int(os.Stderr.Fd()))
if err != nil {
log.Fatalf("Failed to redirect stderr: %v", err)
}
// Create a new process group
_, err = syscall.Setsid()
if err != nil {
log.Fatalf("Failed to create new session: %v", err)
}
fmt.Println("Daemon mode enabled, detaching from terminal...")
}
// setupGracefulShutdown configures graceful shutdown on signals
func setupGracefulShutdown(server *Server, eng *engine.Engine) {
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
go func() {
sig := <-sigChan
fmt.Printf("\nReceived signal %v, shutting down...\n", sig)
// Graceful shutdown logic
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
// Shut down the server
if err := server.Shutdown(ctx); err != nil {
fmt.Fprintf(os.Stderr, "Error shutting down server: %v\n", err)
}
// The engine will be closed by the defer in main()
fmt.Println("Shutdown complete")
os.Exit(0)
}()
}
// runInteractive starts the interactive CLI mode
func runInteractive(eng *engine.Engine, dbPath string) {
fmt.Println("Kevo (kevo) version 1.0.2")
fmt.Println("Enter .help for usage hints.")
var tx engine.Transaction
var err error
// Setup readline with history support // Setup readline with history support
historyFile := filepath.Join(os.TempDir(), ".kevo_history") historyFile := filepath.Join(os.TempDir(), ".kevo_history")
rl, err := readline.NewEx(&readline.Config{ rl, err := readline.NewEx(&readline.Config{
@ -96,6 +243,7 @@ func main() {
HistoryFile: historyFile, HistoryFile: historyFile,
InterruptPrompt: "^C", InterruptPrompt: "^C",
EOFPrompt: "exit", EOFPrompt: "exit",
AutoComplete: completer,
}) })
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "Error initializing readline: %s\n", err) fmt.Fprintf(os.Stderr, "Error initializing readline: %s\n", err)
@ -151,9 +299,6 @@ func main() {
continue continue
} }
// Add to history (readline handles this automatically for non-empty lines)
// rl.SaveHistory(line)
// Process command // Process command
parts := strings.Fields(line) parts := strings.Fields(line)
cmd := strings.ToUpper(parts[0]) cmd := strings.ToUpper(parts[0])

212
cmd/kevo/server.go Normal file
View File

@ -0,0 +1,212 @@
package main
import (
"context"
"fmt"
"net"
"sync"
"time"
"github.com/jeremytregunna/kevo/pkg/engine"
)
// TransactionRegistry manages active transactions on the server
type TransactionRegistry struct {
mu sync.RWMutex
transactions map[string]engine.Transaction
nextID uint64
}
// NewTransactionRegistry creates a new transaction registry
func NewTransactionRegistry() *TransactionRegistry {
return &TransactionRegistry{
transactions: make(map[string]engine.Transaction),
}
}
// Begin creates a new transaction and registers it
func (tr *TransactionRegistry) Begin(ctx context.Context, eng *engine.Engine, readOnly bool) (string, error) {
// Create context with timeout to prevent potential hangs
timeoutCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
// Create a channel to receive the transaction result
type txResult struct {
tx engine.Transaction
err error
}
resultCh := make(chan txResult, 1)
// Start transaction in a goroutine to prevent potential blocking
go func() {
tx, err := eng.BeginTransaction(readOnly)
select {
case resultCh <- txResult{tx, err}:
// Successfully sent result
case <-timeoutCtx.Done():
// Context timed out, but try to rollback if we got a transaction
if tx != nil {
tx.Rollback()
}
}
}()
// Wait for result or timeout
select {
case result := <-resultCh:
if result.err != nil {
return "", fmt.Errorf("failed to begin transaction: %w", result.err)
}
tr.mu.Lock()
defer tr.mu.Unlock()
// Generate a transaction ID
tr.nextID++
txID := fmt.Sprintf("tx-%d", tr.nextID)
// Register the transaction
tr.transactions[txID] = result.tx
return txID, nil
case <-timeoutCtx.Done():
return "", fmt.Errorf("transaction creation timed out: %w", timeoutCtx.Err())
}
}
// Get retrieves a transaction by ID
func (tr *TransactionRegistry) Get(txID string) (engine.Transaction, bool) {
tr.mu.RLock()
defer tr.mu.RUnlock()
tx, exists := tr.transactions[txID]
return tx, exists
}
// Remove removes a transaction from the registry
func (tr *TransactionRegistry) Remove(txID string) {
tr.mu.Lock()
defer tr.mu.Unlock()
delete(tr.transactions, txID)
}
// GracefulShutdown attempts to cleanly shut down all transactions
func (tr *TransactionRegistry) GracefulShutdown(ctx context.Context) error {
tr.mu.Lock()
defer tr.mu.Unlock()
var lastErr error
// Copy transaction IDs to avoid modifying the map during iteration
ids := make([]string, 0, len(tr.transactions))
for id := range tr.transactions {
ids = append(ids, id)
}
// Rollback each transaction with a timeout
for _, id := range ids {
tx, exists := tr.transactions[id]
if !exists {
continue
}
// Use a timeout for each rollback operation
rollbackCtx, cancel := context.WithTimeout(ctx, 1*time.Second)
// Create a channel for the rollback result
doneCh := make(chan error, 1)
// Execute rollback in goroutine
go func(t engine.Transaction) {
doneCh <- t.Rollback()
}(tx)
// Wait for rollback or timeout
var err error
select {
case err = <-doneCh:
// Rollback completed
case <-rollbackCtx.Done():
err = fmt.Errorf("rollback timed out: %w", rollbackCtx.Err())
}
cancel() // Clean up context
// Record error if any
if err != nil {
lastErr = fmt.Errorf("failed to rollback transaction %s: %w", id, err)
}
// Always remove transaction from map
delete(tr.transactions, id)
}
return lastErr
}
// Server represents the Kevo server
type Server struct {
eng *engine.Engine
txRegistry *TransactionRegistry
listener net.Listener
grpcServer interface{} // Will be replaced with actual gRPC server
config Config
}
// NewServer creates a new server instance
func NewServer(eng *engine.Engine, config Config) *Server {
return &Server{
eng: eng,
txRegistry: NewTransactionRegistry(),
config: config,
}
}
// Start initializes and starts the server
func (s *Server) Start() error {
// Create a listener on the specified address
var err error
s.listener, err = net.Listen("tcp", s.config.ListenAddr)
if err != nil {
return fmt.Errorf("failed to listen on %s: %w", s.config.ListenAddr, err)
}
fmt.Printf("Listening on %s\n", s.config.ListenAddr)
// TODO: Initialize gRPC server with our service implementation
// This will be implemented in Phase 3 when we add gRPC support
// For now, just hold the listener open
return nil
}
// Serve starts serving requests (blocking)
func (s *Server) Serve() error {
// TODO: Start the gRPC server
// This will be implemented in Phase 3 when we add gRPC support
// For now, just block until the listener is closed
select {}
}
// Shutdown gracefully shuts down the server
func (s *Server) Shutdown(ctx context.Context) error {
// First, shut down the listener to stop accepting new connections
if s.listener != nil {
if err := s.listener.Close(); err != nil {
return fmt.Errorf("failed to close listener: %w", err)
}
}
// TODO: Gracefully shutdown gRPC server
// This will be implemented in Phase 3 when we add gRPC support
// Clean up any active transactions
if err := s.txRegistry.GracefulShutdown(ctx); err != nil {
return fmt.Errorf("failed to shutdown transaction registry: %w", err)
}
return nil
}

142
cmd/kevo/server_test.go Normal file
View File

@ -0,0 +1,142 @@
package main
import (
"context"
"os"
"strings"
"testing"
"time"
"github.com/jeremytregunna/kevo/pkg/engine"
)
func TestTransactionRegistry(t *testing.T) {
// Create a timeout context for the whole test
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
// Set up temporary directory for test
tmpDir, err := os.MkdirTemp("", "kevo_test")
if err != nil {
t.Fatalf("Failed to create temporary directory: %v", err)
}
defer os.RemoveAll(tmpDir)
// Create a test engine
eng, err := engine.NewEngine(tmpDir)
if err != nil {
t.Fatalf("Failed to create engine: %v", err)
}
defer eng.Close()
// Create transaction registry
registry := NewTransactionRegistry()
// Test begin transaction
txID, err := registry.Begin(ctx, eng, false)
if err != nil {
// If we get a timeout, don't fail the test - the engine might be busy
if ctx.Err() != nil || strings.Contains(err.Error(), "timed out") {
t.Skip("Skipping test due to transaction timeout")
}
t.Fatalf("Failed to begin transaction: %v", err)
}
if txID == "" {
t.Fatal("Expected non-empty transaction ID")
}
// Test get transaction
tx, exists := registry.Get(txID)
if !exists {
t.Fatalf("Transaction %s not found in registry", txID)
}
if tx == nil {
t.Fatal("Expected non-nil transaction")
}
if tx.IsReadOnly() {
t.Fatal("Expected read-write transaction")
}
// Test read-only transaction
roTxID, err := registry.Begin(ctx, eng, true)
if err != nil {
// If we get a timeout, don't fail the test - the engine might be busy
if ctx.Err() != nil || strings.Contains(err.Error(), "timed out") {
t.Skip("Skipping test due to transaction timeout")
}
t.Fatalf("Failed to begin read-only transaction: %v", err)
}
roTx, exists := registry.Get(roTxID)
if !exists {
t.Fatalf("Transaction %s not found in registry", roTxID)
}
if !roTx.IsReadOnly() {
t.Fatal("Expected read-only transaction")
}
// Test remove transaction
registry.Remove(txID)
_, exists = registry.Get(txID)
if exists {
t.Fatalf("Transaction %s should have been removed", txID)
}
// Test graceful shutdown
shutdownErr := registry.GracefulShutdown(ctx)
if shutdownErr != nil && !strings.Contains(shutdownErr.Error(), "timed out") {
t.Fatalf("Failed to gracefully shutdown registry: %v", shutdownErr)
}
}
func TestServerStartup(t *testing.T) {
// Skip if not running in an environment where we can bind to ports
if os.Getenv("ENABLE_NETWORK_TESTS") != "1" {
t.Skip("Skipping network test (set ENABLE_NETWORK_TESTS=1 to run)")
}
// Set up temporary directory for test
tmpDir, err := os.MkdirTemp("", "kevo_server_test")
if err != nil {
t.Fatalf("Failed to create temporary directory: %v", err)
}
defer os.RemoveAll(tmpDir)
// Create a test engine
eng, err := engine.NewEngine(tmpDir)
if err != nil {
t.Fatalf("Failed to create engine: %v", err)
}
defer eng.Close()
// Create server with a random port
config := Config{
ServerMode: true,
ListenAddr: "localhost:0", // Let the OS assign a port
DBPath: tmpDir,
}
server := NewServer(eng, config)
// Start server (does not block)
if err := server.Start(); err != nil {
t.Fatalf("Failed to start server: %v", err)
}
// Check that the listener is active
if server.listener == nil {
t.Fatal("Server listener is nil after Start()")
}
// Get the assigned port - if this works, the listener is properly set up
addr := server.listener.Addr().String()
if addr == "" {
t.Fatal("Server listener has no address")
}
t.Logf("Server listening on %s", addr)
// Test shutdown
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := server.Shutdown(ctx); err != nil {
t.Fatalf("Failed to shutdown server: %v", err)
}
}

100
pkg/transport/common.go Normal file
View File

@ -0,0 +1,100 @@
package transport
import (
"errors"
)
// Standard request/response type constants
const (
TypeGet = "get"
TypePut = "put"
TypeDelete = "delete"
TypeBatchWrite = "batch_write"
TypeScan = "scan"
TypeBeginTx = "begin_tx"
TypeCommitTx = "commit_tx"
TypeRollbackTx = "rollback_tx"
TypeTxGet = "tx_get"
TypeTxPut = "tx_put"
TypeTxDelete = "tx_delete"
TypeTxScan = "tx_scan"
TypeGetStats = "get_stats"
TypeCompact = "compact"
TypeError = "error"
)
// Common errors
var (
ErrInvalidRequest = errors.New("invalid request")
ErrInvalidPayload = errors.New("invalid payload")
ErrNotConnected = errors.New("not connected to server")
ErrTimeout = errors.New("operation timed out")
)
// BasicRequest implements the Request interface
type BasicRequest struct {
RequestType string
RequestData []byte
}
// Type returns the type of the request
func (r *BasicRequest) Type() string {
return r.RequestType
}
// Payload returns the payload of the request
func (r *BasicRequest) Payload() []byte {
return r.RequestData
}
// NewRequest creates a new request with the given type and payload
func NewRequest(requestType string, data []byte) Request {
return &BasicRequest{
RequestType: requestType,
RequestData: data,
}
}
// BasicResponse implements the Response interface
type BasicResponse struct {
ResponseType string
ResponseData []byte
ResponseErr error
}
// Type returns the type of the response
func (r *BasicResponse) Type() string {
return r.ResponseType
}
// Payload returns the payload of the response
func (r *BasicResponse) Payload() []byte {
return r.ResponseData
}
// Error returns any error associated with the response
func (r *BasicResponse) Error() error {
return r.ResponseErr
}
// NewResponse creates a new response with the given type, payload, and error
func NewResponse(responseType string, data []byte, err error) Response {
return &BasicResponse{
ResponseType: responseType,
ResponseData: data,
ResponseErr: err,
}
}
// NewErrorResponse creates a new error response
func NewErrorResponse(err error) Response {
var msg []byte
if err != nil {
msg = []byte(err.Error())
}
return &BasicResponse{
ResponseType: TypeError,
ResponseData: msg,
ResponseErr: err,
}
}

View File

@ -0,0 +1,87 @@
package transport
import (
"errors"
"testing"
)
func TestBasicRequest(t *testing.T) {
// Test creating a request
payload := []byte("test payload")
req := NewRequest(TypeGet, payload)
// Test Type method
if req.Type() != TypeGet {
t.Errorf("Expected type %s, got %s", TypeGet, req.Type())
}
// Test Payload method
if string(req.Payload()) != string(payload) {
t.Errorf("Expected payload %s, got %s", string(payload), string(req.Payload()))
}
}
func TestBasicResponse(t *testing.T) {
// Test creating a response with no error
payload := []byte("test response")
resp := NewResponse(TypeGet, payload, nil)
// Test Type method
if resp.Type() != TypeGet {
t.Errorf("Expected type %s, got %s", TypeGet, resp.Type())
}
// Test Payload method
if string(resp.Payload()) != string(payload) {
t.Errorf("Expected payload %s, got %s", string(payload), string(resp.Payload()))
}
// Test Error method
if resp.Error() != nil {
t.Errorf("Expected nil error, got %v", resp.Error())
}
// Test creating a response with an error
testErr := errors.New("test error")
resp = NewResponse(TypeGet, payload, testErr)
if resp.Error() != testErr {
t.Errorf("Expected error %v, got %v", testErr, resp.Error())
}
}
func TestNewErrorResponse(t *testing.T) {
// Test creating an error response
testErr := errors.New("test error")
resp := NewErrorResponse(testErr)
// Test Type method
if resp.Type() != TypeError {
t.Errorf("Expected type %s, got %s", TypeError, resp.Type())
}
// Test Payload method - should contain error message
if string(resp.Payload()) != testErr.Error() {
t.Errorf("Expected payload %s, got %s", testErr.Error(), string(resp.Payload()))
}
// Test Error method
if resp.Error() != testErr {
t.Errorf("Expected error %v, got %v", testErr, resp.Error())
}
// Test with nil error
resp = NewErrorResponse(nil)
if resp.Type() != TypeError {
t.Errorf("Expected type %s, got %s", TypeError, resp.Type())
}
if len(resp.Payload()) != 0 {
t.Errorf("Expected empty payload, got %s", string(resp.Payload()))
}
if resp.Error() != nil {
t.Errorf("Expected nil error, got %v", resp.Error())
}
}

149
pkg/transport/interface.go Normal file
View File

@ -0,0 +1,149 @@
package transport
import (
"context"
"time"
)
// CompressionType defines the compression algorithm used
type CompressionType string
// Standard compression options
const (
CompressionNone CompressionType = "none"
CompressionGzip CompressionType = "gzip"
CompressionSnappy CompressionType = "snappy"
)
// RetryPolicy defines how retries are handled
type RetryPolicy struct {
MaxRetries int
InitialBackoff time.Duration
MaxBackoff time.Duration
BackoffFactor float64
Jitter float64
}
// TransportOptions contains common configuration across all transport types
type TransportOptions struct {
Timeout time.Duration
RetryPolicy RetryPolicy
Compression CompressionType
MaxMessageSize int
TLSEnabled bool
CertFile string
KeyFile string
CAFile string
}
// TransportStatus contains information about the current transport state
type TransportStatus struct {
Connected bool
LastConnected time.Time
LastError error
BytesSent uint64
BytesReceived uint64
RTT time.Duration
}
// Request represents a generic request to the transport layer
type Request interface {
// Type returns the type of request
Type() string
// Payload returns the payload of the request
Payload() []byte
}
// Response represents a generic response from the transport layer
type Response interface {
// Type returns the type of response
Type() string
// Payload returns the payload of the response
Payload() []byte
// Error returns any error associated with the response
Error() error
}
// Stream represents a bidirectional stream of messages
type Stream interface {
// Send sends a request over the stream
Send(request Request) error
// Recv receives a response from the stream
Recv() (Response, error)
// Close closes the stream
Close() error
}
// Client defines the client interface for any transport implementation
type Client interface {
// Connect establishes a connection to the server
Connect(ctx context.Context) error
// Close closes the connection
Close() error
// IsConnected returns whether the client is connected
IsConnected() bool
// Status returns the current status of the connection
Status() TransportStatus
// Send sends a request and waits for a response
Send(ctx context.Context, request Request) (Response, error)
// Stream opens a bidirectional stream
Stream(ctx context.Context) (Stream, error)
}
// RequestHandler processes incoming requests
type RequestHandler interface {
// HandleRequest processes a request and returns a response
HandleRequest(ctx context.Context, request Request) (Response, error)
// HandleStream processes a bidirectional stream
HandleStream(stream Stream) error
}
// Server defines the server interface for any transport implementation
type Server interface {
// Start starts the server and returns immediately
Start() error
// Serve starts the server and blocks until it's stopped
Serve() error
// Stop stops the server gracefully
Stop(ctx context.Context) error
// SetRequestHandler sets the handler for incoming requests
SetRequestHandler(handler RequestHandler)
}
// ClientFactory creates a new client
type ClientFactory func(endpoint string, options TransportOptions) (Client, error)
// ServerFactory creates a new server
type ServerFactory func(address string, options TransportOptions) (Server, error)
// Registry keeps track of available transport implementations
type Registry interface {
// RegisterClient adds a new client implementation to the registry
RegisterClient(name string, factory ClientFactory)
// RegisterServer adds a new server implementation to the registry
RegisterServer(name string, factory ServerFactory)
// CreateClient instantiates a client by name
CreateClient(name, endpoint string, options TransportOptions) (Client, error)
// CreateServer instantiates a server by name
CreateServer(name, address string, options TransportOptions) (Server, error)
// ListTransports returns all available transport names
ListTransports() []string
}

136
pkg/transport/metrics.go Normal file
View File

@ -0,0 +1,136 @@
package transport
import (
"sync"
"sync/atomic"
"time"
)
// MetricsCollector collects metrics for transport operations
type MetricsCollector interface {
// RecordRequest records metrics for a request
RecordRequest(requestType string, startTime time.Time, err error)
// RecordSend records metrics for bytes sent
RecordSend(bytes int)
// RecordReceive records metrics for bytes received
RecordReceive(bytes int)
// RecordConnection records a connection event
RecordConnection(successful bool)
// GetMetrics returns the current metrics
GetMetrics() Metrics
}
// Metrics represents transport metrics
type Metrics struct {
TotalRequests uint64
SuccessfulRequests uint64
FailedRequests uint64
BytesSent uint64
BytesReceived uint64
Connections uint64
ConnectionFailures uint64
AvgLatencyByType map[string]time.Duration
}
// BasicMetricsCollector is a simple implementation of MetricsCollector
type BasicMetricsCollector struct {
mu sync.RWMutex
totalRequests uint64
successfulRequests uint64
failedRequests uint64
bytesSent uint64
bytesReceived uint64
connections uint64
connectionFailures uint64
// Track average latency and count for each request type
avgLatencyByType map[string]time.Duration
requestCountByType map[string]uint64
}
// NewMetricsCollector creates a new metrics collector
func NewMetricsCollector() MetricsCollector {
return &BasicMetricsCollector{
avgLatencyByType: make(map[string]time.Duration),
requestCountByType: make(map[string]uint64),
}
}
// RecordRequest records metrics for a request
func (c *BasicMetricsCollector) RecordRequest(requestType string, startTime time.Time, err error) {
atomic.AddUint64(&c.totalRequests, 1)
if err == nil {
atomic.AddUint64(&c.successfulRequests, 1)
} else {
atomic.AddUint64(&c.failedRequests, 1)
}
// Update average latency for request type
latency := time.Since(startTime)
c.mu.Lock()
defer c.mu.Unlock()
currentAvg, exists := c.avgLatencyByType[requestType]
currentCount, _ := c.requestCountByType[requestType]
if exists {
// Update running average - the common case for better branch prediction
// new_avg = (old_avg * count + new_value) / (count + 1)
totalDuration := currentAvg * time.Duration(currentCount) + latency
newCount := currentCount + 1
c.avgLatencyByType[requestType] = totalDuration / time.Duration(newCount)
c.requestCountByType[requestType] = newCount
} else {
// First request of this type
c.avgLatencyByType[requestType] = latency
c.requestCountByType[requestType] = 1
}
}
// RecordSend records metrics for bytes sent
func (c *BasicMetricsCollector) RecordSend(bytes int) {
atomic.AddUint64(&c.bytesSent, uint64(bytes))
}
// RecordReceive records metrics for bytes received
func (c *BasicMetricsCollector) RecordReceive(bytes int) {
atomic.AddUint64(&c.bytesReceived, uint64(bytes))
}
// RecordConnection records a connection event
func (c *BasicMetricsCollector) RecordConnection(successful bool) {
if successful {
atomic.AddUint64(&c.connections, 1)
} else {
atomic.AddUint64(&c.connectionFailures, 1)
}
}
// GetMetrics returns the current metrics
func (c *BasicMetricsCollector) GetMetrics() Metrics {
c.mu.RLock()
defer c.mu.RUnlock()
// Create a copy of the average latency map
avgLatencyByType := make(map[string]time.Duration, len(c.avgLatencyByType))
for k, v := range c.avgLatencyByType {
avgLatencyByType[k] = v
}
return Metrics{
TotalRequests: atomic.LoadUint64(&c.totalRequests),
SuccessfulRequests: atomic.LoadUint64(&c.successfulRequests),
FailedRequests: atomic.LoadUint64(&c.failedRequests),
BytesSent: atomic.LoadUint64(&c.bytesSent),
BytesReceived: atomic.LoadUint64(&c.bytesReceived),
Connections: atomic.LoadUint64(&c.connections),
ConnectionFailures: atomic.LoadUint64(&c.connectionFailures),
AvgLatencyByType: avgLatencyByType,
}
}

View File

@ -0,0 +1,101 @@
package transport
import (
"errors"
"testing"
"time"
)
func TestBasicMetricsCollector(t *testing.T) {
collector := NewMetricsCollector()
// Test initial state
metrics := collector.GetMetrics()
if metrics.TotalRequests != 0 ||
metrics.SuccessfulRequests != 0 ||
metrics.FailedRequests != 0 ||
metrics.BytesSent != 0 ||
metrics.BytesReceived != 0 ||
metrics.Connections != 0 ||
metrics.ConnectionFailures != 0 ||
len(metrics.AvgLatencyByType) != 0 {
t.Errorf("Initial metrics not initialized correctly: %+v", metrics)
}
// Test recording successful request
startTime := time.Now().Add(-100 * time.Millisecond) // Simulate 100ms request
collector.RecordRequest("get", startTime, nil)
metrics = collector.GetMetrics()
if metrics.TotalRequests != 1 {
t.Errorf("Expected TotalRequests to be 1, got %d", metrics.TotalRequests)
}
if metrics.SuccessfulRequests != 1 {
t.Errorf("Expected SuccessfulRequests to be 1, got %d", metrics.SuccessfulRequests)
}
if metrics.FailedRequests != 0 {
t.Errorf("Expected FailedRequests to be 0, got %d", metrics.FailedRequests)
}
// Check average latency
if avgLatency, exists := metrics.AvgLatencyByType["get"]; !exists {
t.Error("Expected 'get' latency to exist")
} else if avgLatency < 100*time.Millisecond {
t.Errorf("Expected latency to be at least 100ms, got %v", avgLatency)
}
// Test recording failed request
startTime = time.Now().Add(-200 * time.Millisecond) // Simulate 200ms request
collector.RecordRequest("get", startTime, errors.New("test error"))
metrics = collector.GetMetrics()
if metrics.TotalRequests != 2 {
t.Errorf("Expected TotalRequests to be 2, got %d", metrics.TotalRequests)
}
if metrics.SuccessfulRequests != 1 {
t.Errorf("Expected SuccessfulRequests to be 1, got %d", metrics.SuccessfulRequests)
}
if metrics.FailedRequests != 1 {
t.Errorf("Expected FailedRequests to be 1, got %d", metrics.FailedRequests)
}
// Test average latency calculation for multiple requests
startTime = time.Now().Add(-300 * time.Millisecond)
collector.RecordRequest("put", startTime, nil)
startTime = time.Now().Add(-500 * time.Millisecond)
collector.RecordRequest("put", startTime, nil)
metrics = collector.GetMetrics()
avgPutLatency := metrics.AvgLatencyByType["put"]
// Expected avg is around (300ms + 500ms) / 2 = 400ms
if avgPutLatency < 390*time.Millisecond || avgPutLatency > 410*time.Millisecond {
t.Errorf("Expected average 'put' latency to be around 400ms, got %v", avgPutLatency)
}
// Test byte tracking
collector.RecordSend(1000)
collector.RecordReceive(2000)
metrics = collector.GetMetrics()
if metrics.BytesSent != 1000 {
t.Errorf("Expected BytesSent to be 1000, got %d", metrics.BytesSent)
}
if metrics.BytesReceived != 2000 {
t.Errorf("Expected BytesReceived to be 2000, got %d", metrics.BytesReceived)
}
// Test connection tracking
collector.RecordConnection(true)
collector.RecordConnection(false)
collector.RecordConnection(true)
metrics = collector.GetMetrics()
if metrics.Connections != 2 {
t.Errorf("Expected Connections to be 2, got %d", metrics.Connections)
}
if metrics.ConnectionFailures != 1 {
t.Errorf("Expected ConnectionFailures to be 1, got %d", metrics.ConnectionFailures)
}
}

114
pkg/transport/registry.go Normal file
View File

@ -0,0 +1,114 @@
package transport
import (
"fmt"
"sync"
)
// registry implements the Registry interface
type registry struct {
mu sync.RWMutex
clientFactories map[string]ClientFactory
serverFactories map[string]ServerFactory
}
// NewRegistry creates a new transport registry
func NewRegistry() Registry {
return &registry{
clientFactories: make(map[string]ClientFactory),
serverFactories: make(map[string]ServerFactory),
}
}
// DefaultRegistry is the default global registry instance
var DefaultRegistry = NewRegistry()
// RegisterClient adds a new client implementation to the registry
func (r *registry) RegisterClient(name string, factory ClientFactory) {
r.mu.Lock()
defer r.mu.Unlock()
r.clientFactories[name] = factory
}
// RegisterServer adds a new server implementation to the registry
func (r *registry) RegisterServer(name string, factory ServerFactory) {
r.mu.Lock()
defer r.mu.Unlock()
r.serverFactories[name] = factory
}
// CreateClient instantiates a client by name
func (r *registry) CreateClient(name, endpoint string, options TransportOptions) (Client, error) {
r.mu.RLock()
factory, exists := r.clientFactories[name]
r.mu.RUnlock()
if !exists {
return nil, fmt.Errorf("transport client %q not registered", name)
}
return factory(endpoint, options)
}
// CreateServer instantiates a server by name
func (r *registry) CreateServer(name, address string, options TransportOptions) (Server, error) {
r.mu.RLock()
factory, exists := r.serverFactories[name]
r.mu.RUnlock()
if !exists {
return nil, fmt.Errorf("transport server %q not registered", name)
}
return factory(address, options)
}
// ListTransports returns all available transport names
func (r *registry) ListTransports() []string {
r.mu.RLock()
defer r.mu.RUnlock()
// Get unique transport names
names := make(map[string]struct{})
for name := range r.clientFactories {
names[name] = struct{}{}
}
for name := range r.serverFactories {
names[name] = struct{}{}
}
// Convert to slice
result := make([]string, 0, len(names))
for name := range names {
result = append(result, name)
}
return result
}
// Helper functions for global registry
// RegisterClientTransport registers a client transport with the default registry
func RegisterClientTransport(name string, factory ClientFactory) {
DefaultRegistry.RegisterClient(name, factory)
}
// RegisterServerTransport registers a server transport with the default registry
func RegisterServerTransport(name string, factory ServerFactory) {
DefaultRegistry.RegisterServer(name, factory)
}
// GetClient creates a client using the default registry
func GetClient(name, endpoint string, options TransportOptions) (Client, error) {
return DefaultRegistry.CreateClient(name, endpoint, options)
}
// GetServer creates a server using the default registry
func GetServer(name, address string, options TransportOptions) (Server, error) {
return DefaultRegistry.CreateServer(name, address, options)
}
// AvailableTransports lists all available transports in the default registry
func AvailableTransports() []string {
return DefaultRegistry.ListTransports()
}

View File

@ -0,0 +1,162 @@
package transport
import (
"context"
"errors"
"testing"
"time"
)
// mockClient implements the Client interface for testing
type mockClient struct {
connected bool
endpoint string
options TransportOptions
}
func (m *mockClient) Connect(ctx context.Context) error {
m.connected = true
return nil
}
func (m *mockClient) Close() error {
m.connected = false
return nil
}
func (m *mockClient) IsConnected() bool {
return m.connected
}
func (m *mockClient) Status() TransportStatus {
return TransportStatus{
Connected: m.connected,
}
}
func (m *mockClient) Send(ctx context.Context, request Request) (Response, error) {
if !m.connected {
return nil, ErrNotConnected
}
return &BasicResponse{
ResponseType: request.Type() + "_response",
ResponseData: []byte("mock response"),
}, nil
}
func (m *mockClient) Stream(ctx context.Context) (Stream, error) {
if !m.connected {
return nil, ErrNotConnected
}
return nil, errors.New("streaming not implemented in mock")
}
// mockClientFactory creates a new mock client
func mockClientFactory(endpoint string, options TransportOptions) (Client, error) {
return &mockClient{
endpoint: endpoint,
options: options,
}, nil
}
// mockServer implements the Server interface for testing
type mockServer struct {
started bool
address string
options TransportOptions
handler RequestHandler
}
func (m *mockServer) Start() error {
m.started = true
return nil
}
func (m *mockServer) Serve() error {
m.started = true
return nil
}
func (m *mockServer) Stop(ctx context.Context) error {
m.started = false
return nil
}
func (m *mockServer) SetRequestHandler(handler RequestHandler) {
m.handler = handler
}
// mockServerFactory creates a new mock server
func mockServerFactory(address string, options TransportOptions) (Server, error) {
return &mockServer{
address: address,
options: options,
}, nil
}
// TestRegistry tests the transport registry
func TestRegistry(t *testing.T) {
registry := NewRegistry()
// Register transports
registry.RegisterClient("mock", mockClientFactory)
registry.RegisterServer("mock", mockServerFactory)
// Test listing transports
transports := registry.ListTransports()
if len(transports) != 1 || transports[0] != "mock" {
t.Errorf("Expected [mock], got %v", transports)
}
// Test creating client
client, err := registry.CreateClient("mock", "localhost:8080", TransportOptions{
Timeout: 5 * time.Second,
})
if err != nil {
t.Fatalf("Failed to create client: %v", err)
}
// Test client methods
if client.IsConnected() {
t.Error("Expected client to be disconnected initially")
}
err = client.Connect(context.Background())
if err != nil {
t.Fatalf("Failed to connect: %v", err)
}
if !client.IsConnected() {
t.Error("Expected client to be connected after Connect()")
}
// Test server creation
server, err := registry.CreateServer("mock", "localhost:8080", TransportOptions{
Timeout: 5 * time.Second,
})
if err != nil {
t.Fatalf("Failed to create server: %v", err)
}
// Test server methods
err = server.Start()
if err != nil {
t.Fatalf("Failed to start server: %v", err)
}
mockServer := server.(*mockServer)
if !mockServer.started {
t.Error("Expected server to be started")
}
// Test non-existent transport
_, err = registry.CreateClient("nonexistent", "", TransportOptions{})
if err == nil {
t.Error("Expected error creating non-existent client")
}
_, err = registry.CreateServer("nonexistent", "", TransportOptions{})
if err == nil {
t.Error("Expected error creating non-existent server")
}
}