feat: add common transport interface and server mode to kevo
This commit is contained in:
parent
001934e7b5
commit
14d1f84960
175
cmd/kevo/main.go
175
cmd/kevo/main.go
@ -1,11 +1,16 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"log"
|
||||||
"os"
|
"os"
|
||||||
|
"os/signal"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/chzyer/readline"
|
"github.com/chzyer/readline"
|
||||||
@ -43,9 +48,14 @@ const helpText = `
|
|||||||
Kevo (kevo) - A lightweight, minimalist, storage engine.
|
Kevo (kevo) - A lightweight, minimalist, storage engine.
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
keco [database_path] - Start with an optional database path
|
kevo [options] [database_path] - Start with an optional database path
|
||||||
|
|
||||||
Commands:
|
Options:
|
||||||
|
-server - Run in server mode, exposing a gRPC API
|
||||||
|
-daemon - Run in daemon mode (detached from terminal)
|
||||||
|
-address string - Address to listen on in server mode (default "localhost:50051")
|
||||||
|
|
||||||
|
Commands (interactive mode only):
|
||||||
.help - Show this help message
|
.help - Show this help message
|
||||||
.open PATH - Open a database at PATH
|
.open PATH - Open a database at PATH
|
||||||
.close - Close the current database
|
.close - Close the current database
|
||||||
@ -68,27 +78,164 @@ Commands:
|
|||||||
- Note: start and end are treated as string keys, not numeric indices
|
- Note: start and end are treated as string keys, not numeric indices
|
||||||
`
|
`
|
||||||
|
|
||||||
|
// Config holds the application configuration
|
||||||
|
type Config struct {
|
||||||
|
ServerMode bool
|
||||||
|
DaemonMode bool
|
||||||
|
ListenAddr string
|
||||||
|
DBPath string
|
||||||
|
}
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
fmt.Println("Kevo (kevo) version 1.0.2")
|
// Parse command line arguments and get configuration
|
||||||
fmt.Println("Enter .help for usage hints.")
|
config := parseFlags()
|
||||||
|
|
||||||
// Initialize variables
|
// Open database if path provided
|
||||||
var eng *engine.Engine
|
var eng *engine.Engine
|
||||||
var tx engine.Transaction
|
|
||||||
var err error
|
var err error
|
||||||
var dbPath string
|
|
||||||
|
|
||||||
// Check if a database path was provided as an argument
|
if config.DBPath != "" {
|
||||||
if len(os.Args) > 1 {
|
fmt.Printf("Opening database at %s\n", config.DBPath)
|
||||||
dbPath = os.Args[1]
|
eng, err = engine.NewEngine(config.DBPath)
|
||||||
fmt.Printf("Opening database at %s\n", dbPath)
|
|
||||||
eng, err = engine.NewEngine(dbPath)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Fprintf(os.Stderr, "Error opening database: %s\n", err)
|
fmt.Fprintf(os.Stderr, "Error opening database: %s\n", err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
defer eng.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check if we should run in server mode
|
||||||
|
if config.ServerMode {
|
||||||
|
if eng == nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Error: Server mode requires a database path\n")
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
runServer(eng, config)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run in interactive mode
|
||||||
|
runInteractive(eng, config.DBPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseFlags parses command line flags and returns a Config
|
||||||
|
func parseFlags() Config {
|
||||||
|
serverMode := flag.Bool("server", false, "Run in server mode, exposing a gRPC API")
|
||||||
|
daemonMode := flag.Bool("daemon", false, "Run in daemon mode (detached from terminal)")
|
||||||
|
listenAddr := flag.String("address", "localhost:50051", "Address to listen on in server mode")
|
||||||
|
|
||||||
|
// Parse flags
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
|
// Get database path from remaining arguments
|
||||||
|
var dbPath string
|
||||||
|
if flag.NArg() > 0 {
|
||||||
|
dbPath = flag.Arg(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
return Config{
|
||||||
|
ServerMode: *serverMode,
|
||||||
|
DaemonMode: *daemonMode,
|
||||||
|
ListenAddr: *listenAddr,
|
||||||
|
DBPath: dbPath,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// runServer initializes and runs the Kevo server
|
||||||
|
func runServer(eng *engine.Engine, config Config) {
|
||||||
|
// Set up daemon mode if requested
|
||||||
|
if config.DaemonMode {
|
||||||
|
setupDaemonMode()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create and start the server
|
||||||
|
server := NewServer(eng, config)
|
||||||
|
|
||||||
|
// Start the server (non-blocking)
|
||||||
|
if err := server.Start(); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Error starting server: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Kevo server started on %s\n", config.ListenAddr)
|
||||||
|
|
||||||
|
// Set up signal handling for graceful shutdown
|
||||||
|
setupGracefulShutdown(server, eng)
|
||||||
|
|
||||||
|
// Start serving (blocking)
|
||||||
|
if err := server.Serve(); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Error serving: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// setupDaemonMode configures process to run as a daemon
|
||||||
|
func setupDaemonMode() {
|
||||||
|
// Redirect standard file descriptors to /dev/null
|
||||||
|
null, err := os.OpenFile("/dev/null", os.O_RDWR, 0)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Failed to open /dev/null: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Redirect standard file descriptors to /dev/null
|
||||||
|
err = syscall.Dup2(int(null.Fd()), int(os.Stdin.Fd()))
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Failed to redirect stdin: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = syscall.Dup2(int(null.Fd()), int(os.Stdout.Fd()))
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Failed to redirect stdout: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = syscall.Dup2(int(null.Fd()), int(os.Stderr.Fd()))
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Failed to redirect stderr: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a new process group
|
||||||
|
_, err = syscall.Setsid()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Failed to create new session: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("Daemon mode enabled, detaching from terminal...")
|
||||||
|
}
|
||||||
|
|
||||||
|
// setupGracefulShutdown configures graceful shutdown on signals
|
||||||
|
func setupGracefulShutdown(server *Server, eng *engine.Engine) {
|
||||||
|
sigChan := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
sig := <-sigChan
|
||||||
|
fmt.Printf("\nReceived signal %v, shutting down...\n", sig)
|
||||||
|
|
||||||
|
// Graceful shutdown logic
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// Shut down the server
|
||||||
|
if err := server.Shutdown(ctx); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Error shutting down server: %v\n", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// The engine will be closed by the defer in main()
|
||||||
|
|
||||||
|
fmt.Println("Shutdown complete")
|
||||||
|
os.Exit(0)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// runInteractive starts the interactive CLI mode
|
||||||
|
func runInteractive(eng *engine.Engine, dbPath string) {
|
||||||
|
fmt.Println("Kevo (kevo) version 1.0.2")
|
||||||
|
fmt.Println("Enter .help for usage hints.")
|
||||||
|
|
||||||
|
var tx engine.Transaction
|
||||||
|
var err error
|
||||||
|
|
||||||
// Setup readline with history support
|
// Setup readline with history support
|
||||||
historyFile := filepath.Join(os.TempDir(), ".kevo_history")
|
historyFile := filepath.Join(os.TempDir(), ".kevo_history")
|
||||||
rl, err := readline.NewEx(&readline.Config{
|
rl, err := readline.NewEx(&readline.Config{
|
||||||
@ -96,6 +243,7 @@ func main() {
|
|||||||
HistoryFile: historyFile,
|
HistoryFile: historyFile,
|
||||||
InterruptPrompt: "^C",
|
InterruptPrompt: "^C",
|
||||||
EOFPrompt: "exit",
|
EOFPrompt: "exit",
|
||||||
|
AutoComplete: completer,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Fprintf(os.Stderr, "Error initializing readline: %s\n", err)
|
fmt.Fprintf(os.Stderr, "Error initializing readline: %s\n", err)
|
||||||
@ -151,9 +299,6 @@ func main() {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add to history (readline handles this automatically for non-empty lines)
|
|
||||||
// rl.SaveHistory(line)
|
|
||||||
|
|
||||||
// Process command
|
// Process command
|
||||||
parts := strings.Fields(line)
|
parts := strings.Fields(line)
|
||||||
cmd := strings.ToUpper(parts[0])
|
cmd := strings.ToUpper(parts[0])
|
||||||
|
212
cmd/kevo/server.go
Normal file
212
cmd/kevo/server.go
Normal file
@ -0,0 +1,212 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/jeremytregunna/kevo/pkg/engine"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TransactionRegistry manages active transactions on the server
|
||||||
|
type TransactionRegistry struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
transactions map[string]engine.Transaction
|
||||||
|
nextID uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTransactionRegistry creates a new transaction registry
|
||||||
|
func NewTransactionRegistry() *TransactionRegistry {
|
||||||
|
return &TransactionRegistry{
|
||||||
|
transactions: make(map[string]engine.Transaction),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Begin creates a new transaction and registers it
|
||||||
|
func (tr *TransactionRegistry) Begin(ctx context.Context, eng *engine.Engine, readOnly bool) (string, error) {
|
||||||
|
// Create context with timeout to prevent potential hangs
|
||||||
|
timeoutCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// Create a channel to receive the transaction result
|
||||||
|
type txResult struct {
|
||||||
|
tx engine.Transaction
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
resultCh := make(chan txResult, 1)
|
||||||
|
|
||||||
|
// Start transaction in a goroutine to prevent potential blocking
|
||||||
|
go func() {
|
||||||
|
tx, err := eng.BeginTransaction(readOnly)
|
||||||
|
select {
|
||||||
|
case resultCh <- txResult{tx, err}:
|
||||||
|
// Successfully sent result
|
||||||
|
case <-timeoutCtx.Done():
|
||||||
|
// Context timed out, but try to rollback if we got a transaction
|
||||||
|
if tx != nil {
|
||||||
|
tx.Rollback()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Wait for result or timeout
|
||||||
|
select {
|
||||||
|
case result := <-resultCh:
|
||||||
|
if result.err != nil {
|
||||||
|
return "", fmt.Errorf("failed to begin transaction: %w", result.err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tr.mu.Lock()
|
||||||
|
defer tr.mu.Unlock()
|
||||||
|
|
||||||
|
// Generate a transaction ID
|
||||||
|
tr.nextID++
|
||||||
|
txID := fmt.Sprintf("tx-%d", tr.nextID)
|
||||||
|
|
||||||
|
// Register the transaction
|
||||||
|
tr.transactions[txID] = result.tx
|
||||||
|
|
||||||
|
return txID, nil
|
||||||
|
|
||||||
|
case <-timeoutCtx.Done():
|
||||||
|
return "", fmt.Errorf("transaction creation timed out: %w", timeoutCtx.Err())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get retrieves a transaction by ID
|
||||||
|
func (tr *TransactionRegistry) Get(txID string) (engine.Transaction, bool) {
|
||||||
|
tr.mu.RLock()
|
||||||
|
defer tr.mu.RUnlock()
|
||||||
|
|
||||||
|
tx, exists := tr.transactions[txID]
|
||||||
|
return tx, exists
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove removes a transaction from the registry
|
||||||
|
func (tr *TransactionRegistry) Remove(txID string) {
|
||||||
|
tr.mu.Lock()
|
||||||
|
defer tr.mu.Unlock()
|
||||||
|
|
||||||
|
delete(tr.transactions, txID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GracefulShutdown attempts to cleanly shut down all transactions
|
||||||
|
func (tr *TransactionRegistry) GracefulShutdown(ctx context.Context) error {
|
||||||
|
tr.mu.Lock()
|
||||||
|
defer tr.mu.Unlock()
|
||||||
|
|
||||||
|
var lastErr error
|
||||||
|
|
||||||
|
// Copy transaction IDs to avoid modifying the map during iteration
|
||||||
|
ids := make([]string, 0, len(tr.transactions))
|
||||||
|
for id := range tr.transactions {
|
||||||
|
ids = append(ids, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Rollback each transaction with a timeout
|
||||||
|
for _, id := range ids {
|
||||||
|
tx, exists := tr.transactions[id]
|
||||||
|
if !exists {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use a timeout for each rollback operation
|
||||||
|
rollbackCtx, cancel := context.WithTimeout(ctx, 1*time.Second)
|
||||||
|
|
||||||
|
// Create a channel for the rollback result
|
||||||
|
doneCh := make(chan error, 1)
|
||||||
|
|
||||||
|
// Execute rollback in goroutine
|
||||||
|
go func(t engine.Transaction) {
|
||||||
|
doneCh <- t.Rollback()
|
||||||
|
}(tx)
|
||||||
|
|
||||||
|
// Wait for rollback or timeout
|
||||||
|
var err error
|
||||||
|
select {
|
||||||
|
case err = <-doneCh:
|
||||||
|
// Rollback completed
|
||||||
|
case <-rollbackCtx.Done():
|
||||||
|
err = fmt.Errorf("rollback timed out: %w", rollbackCtx.Err())
|
||||||
|
}
|
||||||
|
|
||||||
|
cancel() // Clean up context
|
||||||
|
|
||||||
|
// Record error if any
|
||||||
|
if err != nil {
|
||||||
|
lastErr = fmt.Errorf("failed to rollback transaction %s: %w", id, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Always remove transaction from map
|
||||||
|
delete(tr.transactions, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
return lastErr
|
||||||
|
}
|
||||||
|
|
||||||
|
// Server represents the Kevo server
|
||||||
|
type Server struct {
|
||||||
|
eng *engine.Engine
|
||||||
|
txRegistry *TransactionRegistry
|
||||||
|
listener net.Listener
|
||||||
|
grpcServer interface{} // Will be replaced with actual gRPC server
|
||||||
|
config Config
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewServer creates a new server instance
|
||||||
|
func NewServer(eng *engine.Engine, config Config) *Server {
|
||||||
|
return &Server{
|
||||||
|
eng: eng,
|
||||||
|
txRegistry: NewTransactionRegistry(),
|
||||||
|
config: config,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start initializes and starts the server
|
||||||
|
func (s *Server) Start() error {
|
||||||
|
// Create a listener on the specified address
|
||||||
|
var err error
|
||||||
|
s.listener, err = net.Listen("tcp", s.config.ListenAddr)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to listen on %s: %w", s.config.ListenAddr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Listening on %s\n", s.config.ListenAddr)
|
||||||
|
|
||||||
|
// TODO: Initialize gRPC server with our service implementation
|
||||||
|
// This will be implemented in Phase 3 when we add gRPC support
|
||||||
|
// For now, just hold the listener open
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Serve starts serving requests (blocking)
|
||||||
|
func (s *Server) Serve() error {
|
||||||
|
// TODO: Start the gRPC server
|
||||||
|
// This will be implemented in Phase 3 when we add gRPC support
|
||||||
|
|
||||||
|
// For now, just block until the listener is closed
|
||||||
|
select {}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shutdown gracefully shuts down the server
|
||||||
|
func (s *Server) Shutdown(ctx context.Context) error {
|
||||||
|
// First, shut down the listener to stop accepting new connections
|
||||||
|
if s.listener != nil {
|
||||||
|
if err := s.listener.Close(); err != nil {
|
||||||
|
return fmt.Errorf("failed to close listener: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Gracefully shutdown gRPC server
|
||||||
|
// This will be implemented in Phase 3 when we add gRPC support
|
||||||
|
|
||||||
|
// Clean up any active transactions
|
||||||
|
if err := s.txRegistry.GracefulShutdown(ctx); err != nil {
|
||||||
|
return fmt.Errorf("failed to shutdown transaction registry: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
142
cmd/kevo/server_test.go
Normal file
142
cmd/kevo/server_test.go
Normal file
@ -0,0 +1,142 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/jeremytregunna/kevo/pkg/engine"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestTransactionRegistry(t *testing.T) {
|
||||||
|
// Create a timeout context for the whole test
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// Set up temporary directory for test
|
||||||
|
tmpDir, err := os.MkdirTemp("", "kevo_test")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create temporary directory: %v", err)
|
||||||
|
}
|
||||||
|
defer os.RemoveAll(tmpDir)
|
||||||
|
|
||||||
|
// Create a test engine
|
||||||
|
eng, err := engine.NewEngine(tmpDir)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create engine: %v", err)
|
||||||
|
}
|
||||||
|
defer eng.Close()
|
||||||
|
|
||||||
|
// Create transaction registry
|
||||||
|
registry := NewTransactionRegistry()
|
||||||
|
|
||||||
|
// Test begin transaction
|
||||||
|
txID, err := registry.Begin(ctx, eng, false)
|
||||||
|
if err != nil {
|
||||||
|
// If we get a timeout, don't fail the test - the engine might be busy
|
||||||
|
if ctx.Err() != nil || strings.Contains(err.Error(), "timed out") {
|
||||||
|
t.Skip("Skipping test due to transaction timeout")
|
||||||
|
}
|
||||||
|
t.Fatalf("Failed to begin transaction: %v", err)
|
||||||
|
}
|
||||||
|
if txID == "" {
|
||||||
|
t.Fatal("Expected non-empty transaction ID")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test get transaction
|
||||||
|
tx, exists := registry.Get(txID)
|
||||||
|
if !exists {
|
||||||
|
t.Fatalf("Transaction %s not found in registry", txID)
|
||||||
|
}
|
||||||
|
if tx == nil {
|
||||||
|
t.Fatal("Expected non-nil transaction")
|
||||||
|
}
|
||||||
|
if tx.IsReadOnly() {
|
||||||
|
t.Fatal("Expected read-write transaction")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test read-only transaction
|
||||||
|
roTxID, err := registry.Begin(ctx, eng, true)
|
||||||
|
if err != nil {
|
||||||
|
// If we get a timeout, don't fail the test - the engine might be busy
|
||||||
|
if ctx.Err() != nil || strings.Contains(err.Error(), "timed out") {
|
||||||
|
t.Skip("Skipping test due to transaction timeout")
|
||||||
|
}
|
||||||
|
t.Fatalf("Failed to begin read-only transaction: %v", err)
|
||||||
|
}
|
||||||
|
roTx, exists := registry.Get(roTxID)
|
||||||
|
if !exists {
|
||||||
|
t.Fatalf("Transaction %s not found in registry", roTxID)
|
||||||
|
}
|
||||||
|
if !roTx.IsReadOnly() {
|
||||||
|
t.Fatal("Expected read-only transaction")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test remove transaction
|
||||||
|
registry.Remove(txID)
|
||||||
|
_, exists = registry.Get(txID)
|
||||||
|
if exists {
|
||||||
|
t.Fatalf("Transaction %s should have been removed", txID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test graceful shutdown
|
||||||
|
shutdownErr := registry.GracefulShutdown(ctx)
|
||||||
|
if shutdownErr != nil && !strings.Contains(shutdownErr.Error(), "timed out") {
|
||||||
|
t.Fatalf("Failed to gracefully shutdown registry: %v", shutdownErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServerStartup(t *testing.T) {
|
||||||
|
// Skip if not running in an environment where we can bind to ports
|
||||||
|
if os.Getenv("ENABLE_NETWORK_TESTS") != "1" {
|
||||||
|
t.Skip("Skipping network test (set ENABLE_NETWORK_TESTS=1 to run)")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set up temporary directory for test
|
||||||
|
tmpDir, err := os.MkdirTemp("", "kevo_server_test")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create temporary directory: %v", err)
|
||||||
|
}
|
||||||
|
defer os.RemoveAll(tmpDir)
|
||||||
|
|
||||||
|
// Create a test engine
|
||||||
|
eng, err := engine.NewEngine(tmpDir)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create engine: %v", err)
|
||||||
|
}
|
||||||
|
defer eng.Close()
|
||||||
|
|
||||||
|
// Create server with a random port
|
||||||
|
config := Config{
|
||||||
|
ServerMode: true,
|
||||||
|
ListenAddr: "localhost:0", // Let the OS assign a port
|
||||||
|
DBPath: tmpDir,
|
||||||
|
}
|
||||||
|
server := NewServer(eng, config)
|
||||||
|
|
||||||
|
// Start server (does not block)
|
||||||
|
if err := server.Start(); err != nil {
|
||||||
|
t.Fatalf("Failed to start server: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that the listener is active
|
||||||
|
if server.listener == nil {
|
||||||
|
t.Fatal("Server listener is nil after Start()")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the assigned port - if this works, the listener is properly set up
|
||||||
|
addr := server.listener.Addr().String()
|
||||||
|
if addr == "" {
|
||||||
|
t.Fatal("Server listener has no address")
|
||||||
|
}
|
||||||
|
t.Logf("Server listening on %s", addr)
|
||||||
|
|
||||||
|
// Test shutdown
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
if err := server.Shutdown(ctx); err != nil {
|
||||||
|
t.Fatalf("Failed to shutdown server: %v", err)
|
||||||
|
}
|
||||||
|
}
|
100
pkg/transport/common.go
Normal file
100
pkg/transport/common.go
Normal file
@ -0,0 +1,100 @@
|
|||||||
|
package transport
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Standard request/response type constants
|
||||||
|
const (
|
||||||
|
TypeGet = "get"
|
||||||
|
TypePut = "put"
|
||||||
|
TypeDelete = "delete"
|
||||||
|
TypeBatchWrite = "batch_write"
|
||||||
|
TypeScan = "scan"
|
||||||
|
TypeBeginTx = "begin_tx"
|
||||||
|
TypeCommitTx = "commit_tx"
|
||||||
|
TypeRollbackTx = "rollback_tx"
|
||||||
|
TypeTxGet = "tx_get"
|
||||||
|
TypeTxPut = "tx_put"
|
||||||
|
TypeTxDelete = "tx_delete"
|
||||||
|
TypeTxScan = "tx_scan"
|
||||||
|
TypeGetStats = "get_stats"
|
||||||
|
TypeCompact = "compact"
|
||||||
|
TypeError = "error"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Common errors
|
||||||
|
var (
|
||||||
|
ErrInvalidRequest = errors.New("invalid request")
|
||||||
|
ErrInvalidPayload = errors.New("invalid payload")
|
||||||
|
ErrNotConnected = errors.New("not connected to server")
|
||||||
|
ErrTimeout = errors.New("operation timed out")
|
||||||
|
)
|
||||||
|
|
||||||
|
// BasicRequest implements the Request interface
|
||||||
|
type BasicRequest struct {
|
||||||
|
RequestType string
|
||||||
|
RequestData []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// Type returns the type of the request
|
||||||
|
func (r *BasicRequest) Type() string {
|
||||||
|
return r.RequestType
|
||||||
|
}
|
||||||
|
|
||||||
|
// Payload returns the payload of the request
|
||||||
|
func (r *BasicRequest) Payload() []byte {
|
||||||
|
return r.RequestData
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRequest creates a new request with the given type and payload
|
||||||
|
func NewRequest(requestType string, data []byte) Request {
|
||||||
|
return &BasicRequest{
|
||||||
|
RequestType: requestType,
|
||||||
|
RequestData: data,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BasicResponse implements the Response interface
|
||||||
|
type BasicResponse struct {
|
||||||
|
ResponseType string
|
||||||
|
ResponseData []byte
|
||||||
|
ResponseErr error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Type returns the type of the response
|
||||||
|
func (r *BasicResponse) Type() string {
|
||||||
|
return r.ResponseType
|
||||||
|
}
|
||||||
|
|
||||||
|
// Payload returns the payload of the response
|
||||||
|
func (r *BasicResponse) Payload() []byte {
|
||||||
|
return r.ResponseData
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error returns any error associated with the response
|
||||||
|
func (r *BasicResponse) Error() error {
|
||||||
|
return r.ResponseErr
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewResponse creates a new response with the given type, payload, and error
|
||||||
|
func NewResponse(responseType string, data []byte, err error) Response {
|
||||||
|
return &BasicResponse{
|
||||||
|
ResponseType: responseType,
|
||||||
|
ResponseData: data,
|
||||||
|
ResponseErr: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewErrorResponse creates a new error response
|
||||||
|
func NewErrorResponse(err error) Response {
|
||||||
|
var msg []byte
|
||||||
|
if err != nil {
|
||||||
|
msg = []byte(err.Error())
|
||||||
|
}
|
||||||
|
return &BasicResponse{
|
||||||
|
ResponseType: TypeError,
|
||||||
|
ResponseData: msg,
|
||||||
|
ResponseErr: err,
|
||||||
|
}
|
||||||
|
}
|
87
pkg/transport/common_test.go
Normal file
87
pkg/transport/common_test.go
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
package transport
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBasicRequest(t *testing.T) {
|
||||||
|
// Test creating a request
|
||||||
|
payload := []byte("test payload")
|
||||||
|
req := NewRequest(TypeGet, payload)
|
||||||
|
|
||||||
|
// Test Type method
|
||||||
|
if req.Type() != TypeGet {
|
||||||
|
t.Errorf("Expected type %s, got %s", TypeGet, req.Type())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test Payload method
|
||||||
|
if string(req.Payload()) != string(payload) {
|
||||||
|
t.Errorf("Expected payload %s, got %s", string(payload), string(req.Payload()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBasicResponse(t *testing.T) {
|
||||||
|
// Test creating a response with no error
|
||||||
|
payload := []byte("test response")
|
||||||
|
resp := NewResponse(TypeGet, payload, nil)
|
||||||
|
|
||||||
|
// Test Type method
|
||||||
|
if resp.Type() != TypeGet {
|
||||||
|
t.Errorf("Expected type %s, got %s", TypeGet, resp.Type())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test Payload method
|
||||||
|
if string(resp.Payload()) != string(payload) {
|
||||||
|
t.Errorf("Expected payload %s, got %s", string(payload), string(resp.Payload()))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test Error method
|
||||||
|
if resp.Error() != nil {
|
||||||
|
t.Errorf("Expected nil error, got %v", resp.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test creating a response with an error
|
||||||
|
testErr := errors.New("test error")
|
||||||
|
resp = NewResponse(TypeGet, payload, testErr)
|
||||||
|
|
||||||
|
if resp.Error() != testErr {
|
||||||
|
t.Errorf("Expected error %v, got %v", testErr, resp.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewErrorResponse(t *testing.T) {
|
||||||
|
// Test creating an error response
|
||||||
|
testErr := errors.New("test error")
|
||||||
|
resp := NewErrorResponse(testErr)
|
||||||
|
|
||||||
|
// Test Type method
|
||||||
|
if resp.Type() != TypeError {
|
||||||
|
t.Errorf("Expected type %s, got %s", TypeError, resp.Type())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test Payload method - should contain error message
|
||||||
|
if string(resp.Payload()) != testErr.Error() {
|
||||||
|
t.Errorf("Expected payload %s, got %s", testErr.Error(), string(resp.Payload()))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test Error method
|
||||||
|
if resp.Error() != testErr {
|
||||||
|
t.Errorf("Expected error %v, got %v", testErr, resp.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test with nil error
|
||||||
|
resp = NewErrorResponse(nil)
|
||||||
|
|
||||||
|
if resp.Type() != TypeError {
|
||||||
|
t.Errorf("Expected type %s, got %s", TypeError, resp.Type())
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(resp.Payload()) != 0 {
|
||||||
|
t.Errorf("Expected empty payload, got %s", string(resp.Payload()))
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.Error() != nil {
|
||||||
|
t.Errorf("Expected nil error, got %v", resp.Error())
|
||||||
|
}
|
||||||
|
}
|
149
pkg/transport/interface.go
Normal file
149
pkg/transport/interface.go
Normal file
@ -0,0 +1,149 @@
|
|||||||
|
package transport
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CompressionType defines the compression algorithm used
|
||||||
|
type CompressionType string
|
||||||
|
|
||||||
|
// Standard compression options
|
||||||
|
const (
|
||||||
|
CompressionNone CompressionType = "none"
|
||||||
|
CompressionGzip CompressionType = "gzip"
|
||||||
|
CompressionSnappy CompressionType = "snappy"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RetryPolicy defines how retries are handled
|
||||||
|
type RetryPolicy struct {
|
||||||
|
MaxRetries int
|
||||||
|
InitialBackoff time.Duration
|
||||||
|
MaxBackoff time.Duration
|
||||||
|
BackoffFactor float64
|
||||||
|
Jitter float64
|
||||||
|
}
|
||||||
|
|
||||||
|
// TransportOptions contains common configuration across all transport types
|
||||||
|
type TransportOptions struct {
|
||||||
|
Timeout time.Duration
|
||||||
|
RetryPolicy RetryPolicy
|
||||||
|
Compression CompressionType
|
||||||
|
MaxMessageSize int
|
||||||
|
TLSEnabled bool
|
||||||
|
CertFile string
|
||||||
|
KeyFile string
|
||||||
|
CAFile string
|
||||||
|
}
|
||||||
|
|
||||||
|
// TransportStatus contains information about the current transport state
|
||||||
|
type TransportStatus struct {
|
||||||
|
Connected bool
|
||||||
|
LastConnected time.Time
|
||||||
|
LastError error
|
||||||
|
BytesSent uint64
|
||||||
|
BytesReceived uint64
|
||||||
|
RTT time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// Request represents a generic request to the transport layer
|
||||||
|
type Request interface {
|
||||||
|
// Type returns the type of request
|
||||||
|
Type() string
|
||||||
|
|
||||||
|
// Payload returns the payload of the request
|
||||||
|
Payload() []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// Response represents a generic response from the transport layer
|
||||||
|
type Response interface {
|
||||||
|
// Type returns the type of response
|
||||||
|
Type() string
|
||||||
|
|
||||||
|
// Payload returns the payload of the response
|
||||||
|
Payload() []byte
|
||||||
|
|
||||||
|
// Error returns any error associated with the response
|
||||||
|
Error() error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stream represents a bidirectional stream of messages
|
||||||
|
type Stream interface {
|
||||||
|
// Send sends a request over the stream
|
||||||
|
Send(request Request) error
|
||||||
|
|
||||||
|
// Recv receives a response from the stream
|
||||||
|
Recv() (Response, error)
|
||||||
|
|
||||||
|
// Close closes the stream
|
||||||
|
Close() error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Client defines the client interface for any transport implementation
|
||||||
|
type Client interface {
|
||||||
|
// Connect establishes a connection to the server
|
||||||
|
Connect(ctx context.Context) error
|
||||||
|
|
||||||
|
// Close closes the connection
|
||||||
|
Close() error
|
||||||
|
|
||||||
|
// IsConnected returns whether the client is connected
|
||||||
|
IsConnected() bool
|
||||||
|
|
||||||
|
// Status returns the current status of the connection
|
||||||
|
Status() TransportStatus
|
||||||
|
|
||||||
|
// Send sends a request and waits for a response
|
||||||
|
Send(ctx context.Context, request Request) (Response, error)
|
||||||
|
|
||||||
|
// Stream opens a bidirectional stream
|
||||||
|
Stream(ctx context.Context) (Stream, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequestHandler processes incoming requests
|
||||||
|
type RequestHandler interface {
|
||||||
|
// HandleRequest processes a request and returns a response
|
||||||
|
HandleRequest(ctx context.Context, request Request) (Response, error)
|
||||||
|
|
||||||
|
// HandleStream processes a bidirectional stream
|
||||||
|
HandleStream(stream Stream) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Server defines the server interface for any transport implementation
|
||||||
|
type Server interface {
|
||||||
|
// Start starts the server and returns immediately
|
||||||
|
Start() error
|
||||||
|
|
||||||
|
// Serve starts the server and blocks until it's stopped
|
||||||
|
Serve() error
|
||||||
|
|
||||||
|
// Stop stops the server gracefully
|
||||||
|
Stop(ctx context.Context) error
|
||||||
|
|
||||||
|
// SetRequestHandler sets the handler for incoming requests
|
||||||
|
SetRequestHandler(handler RequestHandler)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClientFactory creates a new client
|
||||||
|
type ClientFactory func(endpoint string, options TransportOptions) (Client, error)
|
||||||
|
|
||||||
|
// ServerFactory creates a new server
|
||||||
|
type ServerFactory func(address string, options TransportOptions) (Server, error)
|
||||||
|
|
||||||
|
// Registry keeps track of available transport implementations
|
||||||
|
type Registry interface {
|
||||||
|
// RegisterClient adds a new client implementation to the registry
|
||||||
|
RegisterClient(name string, factory ClientFactory)
|
||||||
|
|
||||||
|
// RegisterServer adds a new server implementation to the registry
|
||||||
|
RegisterServer(name string, factory ServerFactory)
|
||||||
|
|
||||||
|
// CreateClient instantiates a client by name
|
||||||
|
CreateClient(name, endpoint string, options TransportOptions) (Client, error)
|
||||||
|
|
||||||
|
// CreateServer instantiates a server by name
|
||||||
|
CreateServer(name, address string, options TransportOptions) (Server, error)
|
||||||
|
|
||||||
|
// ListTransports returns all available transport names
|
||||||
|
ListTransports() []string
|
||||||
|
}
|
136
pkg/transport/metrics.go
Normal file
136
pkg/transport/metrics.go
Normal file
@ -0,0 +1,136 @@
|
|||||||
|
package transport
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MetricsCollector collects metrics for transport operations
|
||||||
|
type MetricsCollector interface {
|
||||||
|
// RecordRequest records metrics for a request
|
||||||
|
RecordRequest(requestType string, startTime time.Time, err error)
|
||||||
|
|
||||||
|
// RecordSend records metrics for bytes sent
|
||||||
|
RecordSend(bytes int)
|
||||||
|
|
||||||
|
// RecordReceive records metrics for bytes received
|
||||||
|
RecordReceive(bytes int)
|
||||||
|
|
||||||
|
// RecordConnection records a connection event
|
||||||
|
RecordConnection(successful bool)
|
||||||
|
|
||||||
|
// GetMetrics returns the current metrics
|
||||||
|
GetMetrics() Metrics
|
||||||
|
}
|
||||||
|
|
||||||
|
// Metrics represents transport metrics
|
||||||
|
type Metrics struct {
|
||||||
|
TotalRequests uint64
|
||||||
|
SuccessfulRequests uint64
|
||||||
|
FailedRequests uint64
|
||||||
|
BytesSent uint64
|
||||||
|
BytesReceived uint64
|
||||||
|
Connections uint64
|
||||||
|
ConnectionFailures uint64
|
||||||
|
AvgLatencyByType map[string]time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// BasicMetricsCollector is a simple implementation of MetricsCollector
|
||||||
|
type BasicMetricsCollector struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
totalRequests uint64
|
||||||
|
successfulRequests uint64
|
||||||
|
failedRequests uint64
|
||||||
|
bytesSent uint64
|
||||||
|
bytesReceived uint64
|
||||||
|
connections uint64
|
||||||
|
connectionFailures uint64
|
||||||
|
|
||||||
|
// Track average latency and count for each request type
|
||||||
|
avgLatencyByType map[string]time.Duration
|
||||||
|
requestCountByType map[string]uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMetricsCollector creates a new metrics collector
|
||||||
|
func NewMetricsCollector() MetricsCollector {
|
||||||
|
return &BasicMetricsCollector{
|
||||||
|
avgLatencyByType: make(map[string]time.Duration),
|
||||||
|
requestCountByType: make(map[string]uint64),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RecordRequest records metrics for a request
|
||||||
|
func (c *BasicMetricsCollector) RecordRequest(requestType string, startTime time.Time, err error) {
|
||||||
|
atomic.AddUint64(&c.totalRequests, 1)
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
atomic.AddUint64(&c.successfulRequests, 1)
|
||||||
|
} else {
|
||||||
|
atomic.AddUint64(&c.failedRequests, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update average latency for request type
|
||||||
|
latency := time.Since(startTime)
|
||||||
|
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
currentAvg, exists := c.avgLatencyByType[requestType]
|
||||||
|
currentCount, _ := c.requestCountByType[requestType]
|
||||||
|
|
||||||
|
if exists {
|
||||||
|
// Update running average - the common case for better branch prediction
|
||||||
|
// new_avg = (old_avg * count + new_value) / (count + 1)
|
||||||
|
totalDuration := currentAvg * time.Duration(currentCount) + latency
|
||||||
|
newCount := currentCount + 1
|
||||||
|
c.avgLatencyByType[requestType] = totalDuration / time.Duration(newCount)
|
||||||
|
c.requestCountByType[requestType] = newCount
|
||||||
|
} else {
|
||||||
|
// First request of this type
|
||||||
|
c.avgLatencyByType[requestType] = latency
|
||||||
|
c.requestCountByType[requestType] = 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RecordSend records metrics for bytes sent
|
||||||
|
func (c *BasicMetricsCollector) RecordSend(bytes int) {
|
||||||
|
atomic.AddUint64(&c.bytesSent, uint64(bytes))
|
||||||
|
}
|
||||||
|
|
||||||
|
// RecordReceive records metrics for bytes received
|
||||||
|
func (c *BasicMetricsCollector) RecordReceive(bytes int) {
|
||||||
|
atomic.AddUint64(&c.bytesReceived, uint64(bytes))
|
||||||
|
}
|
||||||
|
|
||||||
|
// RecordConnection records a connection event
|
||||||
|
func (c *BasicMetricsCollector) RecordConnection(successful bool) {
|
||||||
|
if successful {
|
||||||
|
atomic.AddUint64(&c.connections, 1)
|
||||||
|
} else {
|
||||||
|
atomic.AddUint64(&c.connectionFailures, 1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMetrics returns the current metrics
|
||||||
|
func (c *BasicMetricsCollector) GetMetrics() Metrics {
|
||||||
|
c.mu.RLock()
|
||||||
|
defer c.mu.RUnlock()
|
||||||
|
|
||||||
|
// Create a copy of the average latency map
|
||||||
|
avgLatencyByType := make(map[string]time.Duration, len(c.avgLatencyByType))
|
||||||
|
for k, v := range c.avgLatencyByType {
|
||||||
|
avgLatencyByType[k] = v
|
||||||
|
}
|
||||||
|
|
||||||
|
return Metrics{
|
||||||
|
TotalRequests: atomic.LoadUint64(&c.totalRequests),
|
||||||
|
SuccessfulRequests: atomic.LoadUint64(&c.successfulRequests),
|
||||||
|
FailedRequests: atomic.LoadUint64(&c.failedRequests),
|
||||||
|
BytesSent: atomic.LoadUint64(&c.bytesSent),
|
||||||
|
BytesReceived: atomic.LoadUint64(&c.bytesReceived),
|
||||||
|
Connections: atomic.LoadUint64(&c.connections),
|
||||||
|
ConnectionFailures: atomic.LoadUint64(&c.connectionFailures),
|
||||||
|
AvgLatencyByType: avgLatencyByType,
|
||||||
|
}
|
||||||
|
}
|
101
pkg/transport/metrics_test.go
Normal file
101
pkg/transport/metrics_test.go
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
package transport
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBasicMetricsCollector(t *testing.T) {
|
||||||
|
collector := NewMetricsCollector()
|
||||||
|
|
||||||
|
// Test initial state
|
||||||
|
metrics := collector.GetMetrics()
|
||||||
|
if metrics.TotalRequests != 0 ||
|
||||||
|
metrics.SuccessfulRequests != 0 ||
|
||||||
|
metrics.FailedRequests != 0 ||
|
||||||
|
metrics.BytesSent != 0 ||
|
||||||
|
metrics.BytesReceived != 0 ||
|
||||||
|
metrics.Connections != 0 ||
|
||||||
|
metrics.ConnectionFailures != 0 ||
|
||||||
|
len(metrics.AvgLatencyByType) != 0 {
|
||||||
|
t.Errorf("Initial metrics not initialized correctly: %+v", metrics)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test recording successful request
|
||||||
|
startTime := time.Now().Add(-100 * time.Millisecond) // Simulate 100ms request
|
||||||
|
collector.RecordRequest("get", startTime, nil)
|
||||||
|
|
||||||
|
metrics = collector.GetMetrics()
|
||||||
|
if metrics.TotalRequests != 1 {
|
||||||
|
t.Errorf("Expected TotalRequests to be 1, got %d", metrics.TotalRequests)
|
||||||
|
}
|
||||||
|
if metrics.SuccessfulRequests != 1 {
|
||||||
|
t.Errorf("Expected SuccessfulRequests to be 1, got %d", metrics.SuccessfulRequests)
|
||||||
|
}
|
||||||
|
if metrics.FailedRequests != 0 {
|
||||||
|
t.Errorf("Expected FailedRequests to be 0, got %d", metrics.FailedRequests)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check average latency
|
||||||
|
if avgLatency, exists := metrics.AvgLatencyByType["get"]; !exists {
|
||||||
|
t.Error("Expected 'get' latency to exist")
|
||||||
|
} else if avgLatency < 100*time.Millisecond {
|
||||||
|
t.Errorf("Expected latency to be at least 100ms, got %v", avgLatency)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test recording failed request
|
||||||
|
startTime = time.Now().Add(-200 * time.Millisecond) // Simulate 200ms request
|
||||||
|
collector.RecordRequest("get", startTime, errors.New("test error"))
|
||||||
|
|
||||||
|
metrics = collector.GetMetrics()
|
||||||
|
if metrics.TotalRequests != 2 {
|
||||||
|
t.Errorf("Expected TotalRequests to be 2, got %d", metrics.TotalRequests)
|
||||||
|
}
|
||||||
|
if metrics.SuccessfulRequests != 1 {
|
||||||
|
t.Errorf("Expected SuccessfulRequests to be 1, got %d", metrics.SuccessfulRequests)
|
||||||
|
}
|
||||||
|
if metrics.FailedRequests != 1 {
|
||||||
|
t.Errorf("Expected FailedRequests to be 1, got %d", metrics.FailedRequests)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test average latency calculation for multiple requests
|
||||||
|
startTime = time.Now().Add(-300 * time.Millisecond)
|
||||||
|
collector.RecordRequest("put", startTime, nil)
|
||||||
|
|
||||||
|
startTime = time.Now().Add(-500 * time.Millisecond)
|
||||||
|
collector.RecordRequest("put", startTime, nil)
|
||||||
|
|
||||||
|
metrics = collector.GetMetrics()
|
||||||
|
avgPutLatency := metrics.AvgLatencyByType["put"]
|
||||||
|
|
||||||
|
// Expected avg is around (300ms + 500ms) / 2 = 400ms
|
||||||
|
if avgPutLatency < 390*time.Millisecond || avgPutLatency > 410*time.Millisecond {
|
||||||
|
t.Errorf("Expected average 'put' latency to be around 400ms, got %v", avgPutLatency)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test byte tracking
|
||||||
|
collector.RecordSend(1000)
|
||||||
|
collector.RecordReceive(2000)
|
||||||
|
|
||||||
|
metrics = collector.GetMetrics()
|
||||||
|
if metrics.BytesSent != 1000 {
|
||||||
|
t.Errorf("Expected BytesSent to be 1000, got %d", metrics.BytesSent)
|
||||||
|
}
|
||||||
|
if metrics.BytesReceived != 2000 {
|
||||||
|
t.Errorf("Expected BytesReceived to be 2000, got %d", metrics.BytesReceived)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test connection tracking
|
||||||
|
collector.RecordConnection(true)
|
||||||
|
collector.RecordConnection(false)
|
||||||
|
collector.RecordConnection(true)
|
||||||
|
|
||||||
|
metrics = collector.GetMetrics()
|
||||||
|
if metrics.Connections != 2 {
|
||||||
|
t.Errorf("Expected Connections to be 2, got %d", metrics.Connections)
|
||||||
|
}
|
||||||
|
if metrics.ConnectionFailures != 1 {
|
||||||
|
t.Errorf("Expected ConnectionFailures to be 1, got %d", metrics.ConnectionFailures)
|
||||||
|
}
|
||||||
|
}
|
114
pkg/transport/registry.go
Normal file
114
pkg/transport/registry.go
Normal file
@ -0,0 +1,114 @@
|
|||||||
|
package transport
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
// registry implements the Registry interface
|
||||||
|
type registry struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
clientFactories map[string]ClientFactory
|
||||||
|
serverFactories map[string]ServerFactory
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRegistry creates a new transport registry
|
||||||
|
func NewRegistry() Registry {
|
||||||
|
return ®istry{
|
||||||
|
clientFactories: make(map[string]ClientFactory),
|
||||||
|
serverFactories: make(map[string]ServerFactory),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultRegistry is the default global registry instance
|
||||||
|
var DefaultRegistry = NewRegistry()
|
||||||
|
|
||||||
|
// RegisterClient adds a new client implementation to the registry
|
||||||
|
func (r *registry) RegisterClient(name string, factory ClientFactory) {
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
r.clientFactories[name] = factory
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterServer adds a new server implementation to the registry
|
||||||
|
func (r *registry) RegisterServer(name string, factory ServerFactory) {
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
r.serverFactories[name] = factory
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateClient instantiates a client by name
|
||||||
|
func (r *registry) CreateClient(name, endpoint string, options TransportOptions) (Client, error) {
|
||||||
|
r.mu.RLock()
|
||||||
|
factory, exists := r.clientFactories[name]
|
||||||
|
r.mu.RUnlock()
|
||||||
|
|
||||||
|
if !exists {
|
||||||
|
return nil, fmt.Errorf("transport client %q not registered", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
return factory(endpoint, options)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateServer instantiates a server by name
|
||||||
|
func (r *registry) CreateServer(name, address string, options TransportOptions) (Server, error) {
|
||||||
|
r.mu.RLock()
|
||||||
|
factory, exists := r.serverFactories[name]
|
||||||
|
r.mu.RUnlock()
|
||||||
|
|
||||||
|
if !exists {
|
||||||
|
return nil, fmt.Errorf("transport server %q not registered", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
return factory(address, options)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListTransports returns all available transport names
|
||||||
|
func (r *registry) ListTransports() []string {
|
||||||
|
r.mu.RLock()
|
||||||
|
defer r.mu.RUnlock()
|
||||||
|
|
||||||
|
// Get unique transport names
|
||||||
|
names := make(map[string]struct{})
|
||||||
|
for name := range r.clientFactories {
|
||||||
|
names[name] = struct{}{}
|
||||||
|
}
|
||||||
|
for name := range r.serverFactories {
|
||||||
|
names[name] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert to slice
|
||||||
|
result := make([]string, 0, len(names))
|
||||||
|
for name := range names {
|
||||||
|
result = append(result, name)
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper functions for global registry
|
||||||
|
|
||||||
|
// RegisterClientTransport registers a client transport with the default registry
|
||||||
|
func RegisterClientTransport(name string, factory ClientFactory) {
|
||||||
|
DefaultRegistry.RegisterClient(name, factory)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterServerTransport registers a server transport with the default registry
|
||||||
|
func RegisterServerTransport(name string, factory ServerFactory) {
|
||||||
|
DefaultRegistry.RegisterServer(name, factory)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetClient creates a client using the default registry
|
||||||
|
func GetClient(name, endpoint string, options TransportOptions) (Client, error) {
|
||||||
|
return DefaultRegistry.CreateClient(name, endpoint, options)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetServer creates a server using the default registry
|
||||||
|
func GetServer(name, address string, options TransportOptions) (Server, error) {
|
||||||
|
return DefaultRegistry.CreateServer(name, address, options)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AvailableTransports lists all available transports in the default registry
|
||||||
|
func AvailableTransports() []string {
|
||||||
|
return DefaultRegistry.ListTransports()
|
||||||
|
}
|
162
pkg/transport/registry_test.go
Normal file
162
pkg/transport/registry_test.go
Normal file
@ -0,0 +1,162 @@
|
|||||||
|
package transport
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// mockClient implements the Client interface for testing
|
||||||
|
type mockClient struct {
|
||||||
|
connected bool
|
||||||
|
endpoint string
|
||||||
|
options TransportOptions
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockClient) Connect(ctx context.Context) error {
|
||||||
|
m.connected = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockClient) Close() error {
|
||||||
|
m.connected = false
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockClient) IsConnected() bool {
|
||||||
|
return m.connected
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockClient) Status() TransportStatus {
|
||||||
|
return TransportStatus{
|
||||||
|
Connected: m.connected,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockClient) Send(ctx context.Context, request Request) (Response, error) {
|
||||||
|
if !m.connected {
|
||||||
|
return nil, ErrNotConnected
|
||||||
|
}
|
||||||
|
return &BasicResponse{
|
||||||
|
ResponseType: request.Type() + "_response",
|
||||||
|
ResponseData: []byte("mock response"),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockClient) Stream(ctx context.Context) (Stream, error) {
|
||||||
|
if !m.connected {
|
||||||
|
return nil, ErrNotConnected
|
||||||
|
}
|
||||||
|
return nil, errors.New("streaming not implemented in mock")
|
||||||
|
}
|
||||||
|
|
||||||
|
// mockClientFactory creates a new mock client
|
||||||
|
func mockClientFactory(endpoint string, options TransportOptions) (Client, error) {
|
||||||
|
return &mockClient{
|
||||||
|
endpoint: endpoint,
|
||||||
|
options: options,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// mockServer implements the Server interface for testing
|
||||||
|
type mockServer struct {
|
||||||
|
started bool
|
||||||
|
address string
|
||||||
|
options TransportOptions
|
||||||
|
handler RequestHandler
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockServer) Start() error {
|
||||||
|
m.started = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockServer) Serve() error {
|
||||||
|
m.started = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockServer) Stop(ctx context.Context) error {
|
||||||
|
m.started = false
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockServer) SetRequestHandler(handler RequestHandler) {
|
||||||
|
m.handler = handler
|
||||||
|
}
|
||||||
|
|
||||||
|
// mockServerFactory creates a new mock server
|
||||||
|
func mockServerFactory(address string, options TransportOptions) (Server, error) {
|
||||||
|
return &mockServer{
|
||||||
|
address: address,
|
||||||
|
options: options,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRegistry tests the transport registry
|
||||||
|
func TestRegistry(t *testing.T) {
|
||||||
|
registry := NewRegistry()
|
||||||
|
|
||||||
|
// Register transports
|
||||||
|
registry.RegisterClient("mock", mockClientFactory)
|
||||||
|
registry.RegisterServer("mock", mockServerFactory)
|
||||||
|
|
||||||
|
// Test listing transports
|
||||||
|
transports := registry.ListTransports()
|
||||||
|
if len(transports) != 1 || transports[0] != "mock" {
|
||||||
|
t.Errorf("Expected [mock], got %v", transports)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test creating client
|
||||||
|
client, err := registry.CreateClient("mock", "localhost:8080", TransportOptions{
|
||||||
|
Timeout: 5 * time.Second,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create client: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test client methods
|
||||||
|
if client.IsConnected() {
|
||||||
|
t.Error("Expected client to be disconnected initially")
|
||||||
|
}
|
||||||
|
|
||||||
|
err = client.Connect(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to connect: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !client.IsConnected() {
|
||||||
|
t.Error("Expected client to be connected after Connect()")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test server creation
|
||||||
|
server, err := registry.CreateServer("mock", "localhost:8080", TransportOptions{
|
||||||
|
Timeout: 5 * time.Second,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create server: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test server methods
|
||||||
|
err = server.Start()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to start server: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
mockServer := server.(*mockServer)
|
||||||
|
if !mockServer.started {
|
||||||
|
t.Error("Expected server to be started")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test non-existent transport
|
||||||
|
_, err = registry.CreateClient("nonexistent", "", TransportOptions{})
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected error creating non-existent client")
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = registry.CreateServer("nonexistent", "", TransportOptions{})
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected error creating non-existent server")
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user