feat: add common transport interface and server mode to kevo
This commit is contained in:
parent
001934e7b5
commit
14d1f84960
179
cmd/kevo/main.go
179
cmd/kevo/main.go
@ -1,11 +1,16 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/chzyer/readline"
|
||||
@ -43,9 +48,14 @@ const helpText = `
|
||||
Kevo (kevo) - A lightweight, minimalist, storage engine.
|
||||
|
||||
Usage:
|
||||
keco [database_path] - Start with an optional database path
|
||||
kevo [options] [database_path] - Start with an optional database path
|
||||
|
||||
Commands:
|
||||
Options:
|
||||
-server - Run in server mode, exposing a gRPC API
|
||||
-daemon - Run in daemon mode (detached from terminal)
|
||||
-address string - Address to listen on in server mode (default "localhost:50051")
|
||||
|
||||
Commands (interactive mode only):
|
||||
.help - Show this help message
|
||||
.open PATH - Open a database at PATH
|
||||
.close - Close the current database
|
||||
@ -68,26 +78,163 @@ Commands:
|
||||
- Note: start and end are treated as string keys, not numeric indices
|
||||
`
|
||||
|
||||
// Config holds the application configuration
|
||||
type Config struct {
|
||||
ServerMode bool
|
||||
DaemonMode bool
|
||||
ListenAddr string
|
||||
DBPath string
|
||||
}
|
||||
|
||||
func main() {
|
||||
fmt.Println("Kevo (kevo) version 1.0.2")
|
||||
fmt.Println("Enter .help for usage hints.")
|
||||
// Parse command line arguments and get configuration
|
||||
config := parseFlags()
|
||||
|
||||
// Initialize variables
|
||||
// Open database if path provided
|
||||
var eng *engine.Engine
|
||||
var tx engine.Transaction
|
||||
var err error
|
||||
var dbPath string
|
||||
|
||||
// Check if a database path was provided as an argument
|
||||
if len(os.Args) > 1 {
|
||||
dbPath = os.Args[1]
|
||||
fmt.Printf("Opening database at %s\n", dbPath)
|
||||
eng, err = engine.NewEngine(dbPath)
|
||||
|
||||
if config.DBPath != "" {
|
||||
fmt.Printf("Opening database at %s\n", config.DBPath)
|
||||
eng, err = engine.NewEngine(config.DBPath)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error opening database: %s\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
defer eng.Close()
|
||||
}
|
||||
|
||||
// Check if we should run in server mode
|
||||
if config.ServerMode {
|
||||
if eng == nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: Server mode requires a database path\n")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
runServer(eng, config)
|
||||
return
|
||||
}
|
||||
|
||||
// Run in interactive mode
|
||||
runInteractive(eng, config.DBPath)
|
||||
}
|
||||
|
||||
// parseFlags parses command line flags and returns a Config
|
||||
func parseFlags() Config {
|
||||
serverMode := flag.Bool("server", false, "Run in server mode, exposing a gRPC API")
|
||||
daemonMode := flag.Bool("daemon", false, "Run in daemon mode (detached from terminal)")
|
||||
listenAddr := flag.String("address", "localhost:50051", "Address to listen on in server mode")
|
||||
|
||||
// Parse flags
|
||||
flag.Parse()
|
||||
|
||||
// Get database path from remaining arguments
|
||||
var dbPath string
|
||||
if flag.NArg() > 0 {
|
||||
dbPath = flag.Arg(0)
|
||||
}
|
||||
|
||||
return Config{
|
||||
ServerMode: *serverMode,
|
||||
DaemonMode: *daemonMode,
|
||||
ListenAddr: *listenAddr,
|
||||
DBPath: dbPath,
|
||||
}
|
||||
}
|
||||
|
||||
// runServer initializes and runs the Kevo server
|
||||
func runServer(eng *engine.Engine, config Config) {
|
||||
// Set up daemon mode if requested
|
||||
if config.DaemonMode {
|
||||
setupDaemonMode()
|
||||
}
|
||||
|
||||
// Create and start the server
|
||||
server := NewServer(eng, config)
|
||||
|
||||
// Start the server (non-blocking)
|
||||
if err := server.Start(); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error starting server: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
fmt.Printf("Kevo server started on %s\n", config.ListenAddr)
|
||||
|
||||
// Set up signal handling for graceful shutdown
|
||||
setupGracefulShutdown(server, eng)
|
||||
|
||||
// Start serving (blocking)
|
||||
if err := server.Serve(); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error serving: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
// setupDaemonMode configures process to run as a daemon
|
||||
func setupDaemonMode() {
|
||||
// Redirect standard file descriptors to /dev/null
|
||||
null, err := os.OpenFile("/dev/null", os.O_RDWR, 0)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to open /dev/null: %v", err)
|
||||
}
|
||||
|
||||
// Redirect standard file descriptors to /dev/null
|
||||
err = syscall.Dup2(int(null.Fd()), int(os.Stdin.Fd()))
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to redirect stdin: %v", err)
|
||||
}
|
||||
|
||||
err = syscall.Dup2(int(null.Fd()), int(os.Stdout.Fd()))
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to redirect stdout: %v", err)
|
||||
}
|
||||
|
||||
err = syscall.Dup2(int(null.Fd()), int(os.Stderr.Fd()))
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to redirect stderr: %v", err)
|
||||
}
|
||||
|
||||
// Create a new process group
|
||||
_, err = syscall.Setsid()
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to create new session: %v", err)
|
||||
}
|
||||
|
||||
fmt.Println("Daemon mode enabled, detaching from terminal...")
|
||||
}
|
||||
|
||||
// setupGracefulShutdown configures graceful shutdown on signals
|
||||
func setupGracefulShutdown(server *Server, eng *engine.Engine) {
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
go func() {
|
||||
sig := <-sigChan
|
||||
fmt.Printf("\nReceived signal %v, shutting down...\n", sig)
|
||||
|
||||
// Graceful shutdown logic
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Shut down the server
|
||||
if err := server.Shutdown(ctx); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error shutting down server: %v\n", err)
|
||||
}
|
||||
|
||||
// The engine will be closed by the defer in main()
|
||||
|
||||
fmt.Println("Shutdown complete")
|
||||
os.Exit(0)
|
||||
}()
|
||||
}
|
||||
|
||||
// runInteractive starts the interactive CLI mode
|
||||
func runInteractive(eng *engine.Engine, dbPath string) {
|
||||
fmt.Println("Kevo (kevo) version 1.0.2")
|
||||
fmt.Println("Enter .help for usage hints.")
|
||||
|
||||
var tx engine.Transaction
|
||||
var err error
|
||||
|
||||
// Setup readline with history support
|
||||
historyFile := filepath.Join(os.TempDir(), ".kevo_history")
|
||||
@ -96,6 +243,7 @@ func main() {
|
||||
HistoryFile: historyFile,
|
||||
InterruptPrompt: "^C",
|
||||
EOFPrompt: "exit",
|
||||
AutoComplete: completer,
|
||||
})
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error initializing readline: %s\n", err)
|
||||
@ -151,9 +299,6 @@ func main() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Add to history (readline handles this automatically for non-empty lines)
|
||||
// rl.SaveHistory(line)
|
||||
|
||||
// Process command
|
||||
parts := strings.Fields(line)
|
||||
cmd := strings.ToUpper(parts[0])
|
||||
@ -553,4 +698,4 @@ func makeKeySuccessor(prefix []byte) []byte {
|
||||
copy(successor, prefix)
|
||||
successor[len(prefix)] = 0xFF
|
||||
return successor
|
||||
}
|
||||
}
|
212
cmd/kevo/server.go
Normal file
212
cmd/kevo/server.go
Normal file
@ -0,0 +1,212 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/jeremytregunna/kevo/pkg/engine"
|
||||
)
|
||||
|
||||
// TransactionRegistry manages active transactions on the server
|
||||
type TransactionRegistry struct {
|
||||
mu sync.RWMutex
|
||||
transactions map[string]engine.Transaction
|
||||
nextID uint64
|
||||
}
|
||||
|
||||
// NewTransactionRegistry creates a new transaction registry
|
||||
func NewTransactionRegistry() *TransactionRegistry {
|
||||
return &TransactionRegistry{
|
||||
transactions: make(map[string]engine.Transaction),
|
||||
}
|
||||
}
|
||||
|
||||
// Begin creates a new transaction and registers it
|
||||
func (tr *TransactionRegistry) Begin(ctx context.Context, eng *engine.Engine, readOnly bool) (string, error) {
|
||||
// Create context with timeout to prevent potential hangs
|
||||
timeoutCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Create a channel to receive the transaction result
|
||||
type txResult struct {
|
||||
tx engine.Transaction
|
||||
err error
|
||||
}
|
||||
resultCh := make(chan txResult, 1)
|
||||
|
||||
// Start transaction in a goroutine to prevent potential blocking
|
||||
go func() {
|
||||
tx, err := eng.BeginTransaction(readOnly)
|
||||
select {
|
||||
case resultCh <- txResult{tx, err}:
|
||||
// Successfully sent result
|
||||
case <-timeoutCtx.Done():
|
||||
// Context timed out, but try to rollback if we got a transaction
|
||||
if tx != nil {
|
||||
tx.Rollback()
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Wait for result or timeout
|
||||
select {
|
||||
case result := <-resultCh:
|
||||
if result.err != nil {
|
||||
return "", fmt.Errorf("failed to begin transaction: %w", result.err)
|
||||
}
|
||||
|
||||
tr.mu.Lock()
|
||||
defer tr.mu.Unlock()
|
||||
|
||||
// Generate a transaction ID
|
||||
tr.nextID++
|
||||
txID := fmt.Sprintf("tx-%d", tr.nextID)
|
||||
|
||||
// Register the transaction
|
||||
tr.transactions[txID] = result.tx
|
||||
|
||||
return txID, nil
|
||||
|
||||
case <-timeoutCtx.Done():
|
||||
return "", fmt.Errorf("transaction creation timed out: %w", timeoutCtx.Err())
|
||||
}
|
||||
}
|
||||
|
||||
// Get retrieves a transaction by ID
|
||||
func (tr *TransactionRegistry) Get(txID string) (engine.Transaction, bool) {
|
||||
tr.mu.RLock()
|
||||
defer tr.mu.RUnlock()
|
||||
|
||||
tx, exists := tr.transactions[txID]
|
||||
return tx, exists
|
||||
}
|
||||
|
||||
// Remove removes a transaction from the registry
|
||||
func (tr *TransactionRegistry) Remove(txID string) {
|
||||
tr.mu.Lock()
|
||||
defer tr.mu.Unlock()
|
||||
|
||||
delete(tr.transactions, txID)
|
||||
}
|
||||
|
||||
// GracefulShutdown attempts to cleanly shut down all transactions
|
||||
func (tr *TransactionRegistry) GracefulShutdown(ctx context.Context) error {
|
||||
tr.mu.Lock()
|
||||
defer tr.mu.Unlock()
|
||||
|
||||
var lastErr error
|
||||
|
||||
// Copy transaction IDs to avoid modifying the map during iteration
|
||||
ids := make([]string, 0, len(tr.transactions))
|
||||
for id := range tr.transactions {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
|
||||
// Rollback each transaction with a timeout
|
||||
for _, id := range ids {
|
||||
tx, exists := tr.transactions[id]
|
||||
if !exists {
|
||||
continue
|
||||
}
|
||||
|
||||
// Use a timeout for each rollback operation
|
||||
rollbackCtx, cancel := context.WithTimeout(ctx, 1*time.Second)
|
||||
|
||||
// Create a channel for the rollback result
|
||||
doneCh := make(chan error, 1)
|
||||
|
||||
// Execute rollback in goroutine
|
||||
go func(t engine.Transaction) {
|
||||
doneCh <- t.Rollback()
|
||||
}(tx)
|
||||
|
||||
// Wait for rollback or timeout
|
||||
var err error
|
||||
select {
|
||||
case err = <-doneCh:
|
||||
// Rollback completed
|
||||
case <-rollbackCtx.Done():
|
||||
err = fmt.Errorf("rollback timed out: %w", rollbackCtx.Err())
|
||||
}
|
||||
|
||||
cancel() // Clean up context
|
||||
|
||||
// Record error if any
|
||||
if err != nil {
|
||||
lastErr = fmt.Errorf("failed to rollback transaction %s: %w", id, err)
|
||||
}
|
||||
|
||||
// Always remove transaction from map
|
||||
delete(tr.transactions, id)
|
||||
}
|
||||
|
||||
return lastErr
|
||||
}
|
||||
|
||||
// Server represents the Kevo server
|
||||
type Server struct {
|
||||
eng *engine.Engine
|
||||
txRegistry *TransactionRegistry
|
||||
listener net.Listener
|
||||
grpcServer interface{} // Will be replaced with actual gRPC server
|
||||
config Config
|
||||
}
|
||||
|
||||
// NewServer creates a new server instance
|
||||
func NewServer(eng *engine.Engine, config Config) *Server {
|
||||
return &Server{
|
||||
eng: eng,
|
||||
txRegistry: NewTransactionRegistry(),
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
// Start initializes and starts the server
|
||||
func (s *Server) Start() error {
|
||||
// Create a listener on the specified address
|
||||
var err error
|
||||
s.listener, err = net.Listen("tcp", s.config.ListenAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to listen on %s: %w", s.config.ListenAddr, err)
|
||||
}
|
||||
|
||||
fmt.Printf("Listening on %s\n", s.config.ListenAddr)
|
||||
|
||||
// TODO: Initialize gRPC server with our service implementation
|
||||
// This will be implemented in Phase 3 when we add gRPC support
|
||||
// For now, just hold the listener open
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Serve starts serving requests (blocking)
|
||||
func (s *Server) Serve() error {
|
||||
// TODO: Start the gRPC server
|
||||
// This will be implemented in Phase 3 when we add gRPC support
|
||||
|
||||
// For now, just block until the listener is closed
|
||||
select {}
|
||||
}
|
||||
|
||||
// Shutdown gracefully shuts down the server
|
||||
func (s *Server) Shutdown(ctx context.Context) error {
|
||||
// First, shut down the listener to stop accepting new connections
|
||||
if s.listener != nil {
|
||||
if err := s.listener.Close(); err != nil {
|
||||
return fmt.Errorf("failed to close listener: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Gracefully shutdown gRPC server
|
||||
// This will be implemented in Phase 3 when we add gRPC support
|
||||
|
||||
// Clean up any active transactions
|
||||
if err := s.txRegistry.GracefulShutdown(ctx); err != nil {
|
||||
return fmt.Errorf("failed to shutdown transaction registry: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
142
cmd/kevo/server_test.go
Normal file
142
cmd/kevo/server_test.go
Normal file
@ -0,0 +1,142 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jeremytregunna/kevo/pkg/engine"
|
||||
)
|
||||
|
||||
func TestTransactionRegistry(t *testing.T) {
|
||||
// Create a timeout context for the whole test
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Set up temporary directory for test
|
||||
tmpDir, err := os.MkdirTemp("", "kevo_test")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temporary directory: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
// Create a test engine
|
||||
eng, err := engine.NewEngine(tmpDir)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create engine: %v", err)
|
||||
}
|
||||
defer eng.Close()
|
||||
|
||||
// Create transaction registry
|
||||
registry := NewTransactionRegistry()
|
||||
|
||||
// Test begin transaction
|
||||
txID, err := registry.Begin(ctx, eng, false)
|
||||
if err != nil {
|
||||
// If we get a timeout, don't fail the test - the engine might be busy
|
||||
if ctx.Err() != nil || strings.Contains(err.Error(), "timed out") {
|
||||
t.Skip("Skipping test due to transaction timeout")
|
||||
}
|
||||
t.Fatalf("Failed to begin transaction: %v", err)
|
||||
}
|
||||
if txID == "" {
|
||||
t.Fatal("Expected non-empty transaction ID")
|
||||
}
|
||||
|
||||
// Test get transaction
|
||||
tx, exists := registry.Get(txID)
|
||||
if !exists {
|
||||
t.Fatalf("Transaction %s not found in registry", txID)
|
||||
}
|
||||
if tx == nil {
|
||||
t.Fatal("Expected non-nil transaction")
|
||||
}
|
||||
if tx.IsReadOnly() {
|
||||
t.Fatal("Expected read-write transaction")
|
||||
}
|
||||
|
||||
// Test read-only transaction
|
||||
roTxID, err := registry.Begin(ctx, eng, true)
|
||||
if err != nil {
|
||||
// If we get a timeout, don't fail the test - the engine might be busy
|
||||
if ctx.Err() != nil || strings.Contains(err.Error(), "timed out") {
|
||||
t.Skip("Skipping test due to transaction timeout")
|
||||
}
|
||||
t.Fatalf("Failed to begin read-only transaction: %v", err)
|
||||
}
|
||||
roTx, exists := registry.Get(roTxID)
|
||||
if !exists {
|
||||
t.Fatalf("Transaction %s not found in registry", roTxID)
|
||||
}
|
||||
if !roTx.IsReadOnly() {
|
||||
t.Fatal("Expected read-only transaction")
|
||||
}
|
||||
|
||||
// Test remove transaction
|
||||
registry.Remove(txID)
|
||||
_, exists = registry.Get(txID)
|
||||
if exists {
|
||||
t.Fatalf("Transaction %s should have been removed", txID)
|
||||
}
|
||||
|
||||
// Test graceful shutdown
|
||||
shutdownErr := registry.GracefulShutdown(ctx)
|
||||
if shutdownErr != nil && !strings.Contains(shutdownErr.Error(), "timed out") {
|
||||
t.Fatalf("Failed to gracefully shutdown registry: %v", shutdownErr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerStartup(t *testing.T) {
|
||||
// Skip if not running in an environment where we can bind to ports
|
||||
if os.Getenv("ENABLE_NETWORK_TESTS") != "1" {
|
||||
t.Skip("Skipping network test (set ENABLE_NETWORK_TESTS=1 to run)")
|
||||
}
|
||||
|
||||
// Set up temporary directory for test
|
||||
tmpDir, err := os.MkdirTemp("", "kevo_server_test")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temporary directory: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
// Create a test engine
|
||||
eng, err := engine.NewEngine(tmpDir)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create engine: %v", err)
|
||||
}
|
||||
defer eng.Close()
|
||||
|
||||
// Create server with a random port
|
||||
config := Config{
|
||||
ServerMode: true,
|
||||
ListenAddr: "localhost:0", // Let the OS assign a port
|
||||
DBPath: tmpDir,
|
||||
}
|
||||
server := NewServer(eng, config)
|
||||
|
||||
// Start server (does not block)
|
||||
if err := server.Start(); err != nil {
|
||||
t.Fatalf("Failed to start server: %v", err)
|
||||
}
|
||||
|
||||
// Check that the listener is active
|
||||
if server.listener == nil {
|
||||
t.Fatal("Server listener is nil after Start()")
|
||||
}
|
||||
|
||||
// Get the assigned port - if this works, the listener is properly set up
|
||||
addr := server.listener.Addr().String()
|
||||
if addr == "" {
|
||||
t.Fatal("Server listener has no address")
|
||||
}
|
||||
t.Logf("Server listening on %s", addr)
|
||||
|
||||
// Test shutdown
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := server.Shutdown(ctx); err != nil {
|
||||
t.Fatalf("Failed to shutdown server: %v", err)
|
||||
}
|
||||
}
|
100
pkg/transport/common.go
Normal file
100
pkg/transport/common.go
Normal file
@ -0,0 +1,100 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"errors"
|
||||
)
|
||||
|
||||
// Standard request/response type constants
|
||||
const (
|
||||
TypeGet = "get"
|
||||
TypePut = "put"
|
||||
TypeDelete = "delete"
|
||||
TypeBatchWrite = "batch_write"
|
||||
TypeScan = "scan"
|
||||
TypeBeginTx = "begin_tx"
|
||||
TypeCommitTx = "commit_tx"
|
||||
TypeRollbackTx = "rollback_tx"
|
||||
TypeTxGet = "tx_get"
|
||||
TypeTxPut = "tx_put"
|
||||
TypeTxDelete = "tx_delete"
|
||||
TypeTxScan = "tx_scan"
|
||||
TypeGetStats = "get_stats"
|
||||
TypeCompact = "compact"
|
||||
TypeError = "error"
|
||||
)
|
||||
|
||||
// Common errors
|
||||
var (
|
||||
ErrInvalidRequest = errors.New("invalid request")
|
||||
ErrInvalidPayload = errors.New("invalid payload")
|
||||
ErrNotConnected = errors.New("not connected to server")
|
||||
ErrTimeout = errors.New("operation timed out")
|
||||
)
|
||||
|
||||
// BasicRequest implements the Request interface
|
||||
type BasicRequest struct {
|
||||
RequestType string
|
||||
RequestData []byte
|
||||
}
|
||||
|
||||
// Type returns the type of the request
|
||||
func (r *BasicRequest) Type() string {
|
||||
return r.RequestType
|
||||
}
|
||||
|
||||
// Payload returns the payload of the request
|
||||
func (r *BasicRequest) Payload() []byte {
|
||||
return r.RequestData
|
||||
}
|
||||
|
||||
// NewRequest creates a new request with the given type and payload
|
||||
func NewRequest(requestType string, data []byte) Request {
|
||||
return &BasicRequest{
|
||||
RequestType: requestType,
|
||||
RequestData: data,
|
||||
}
|
||||
}
|
||||
|
||||
// BasicResponse implements the Response interface
|
||||
type BasicResponse struct {
|
||||
ResponseType string
|
||||
ResponseData []byte
|
||||
ResponseErr error
|
||||
}
|
||||
|
||||
// Type returns the type of the response
|
||||
func (r *BasicResponse) Type() string {
|
||||
return r.ResponseType
|
||||
}
|
||||
|
||||
// Payload returns the payload of the response
|
||||
func (r *BasicResponse) Payload() []byte {
|
||||
return r.ResponseData
|
||||
}
|
||||
|
||||
// Error returns any error associated with the response
|
||||
func (r *BasicResponse) Error() error {
|
||||
return r.ResponseErr
|
||||
}
|
||||
|
||||
// NewResponse creates a new response with the given type, payload, and error
|
||||
func NewResponse(responseType string, data []byte, err error) Response {
|
||||
return &BasicResponse{
|
||||
ResponseType: responseType,
|
||||
ResponseData: data,
|
||||
ResponseErr: err,
|
||||
}
|
||||
}
|
||||
|
||||
// NewErrorResponse creates a new error response
|
||||
func NewErrorResponse(err error) Response {
|
||||
var msg []byte
|
||||
if err != nil {
|
||||
msg = []byte(err.Error())
|
||||
}
|
||||
return &BasicResponse{
|
||||
ResponseType: TypeError,
|
||||
ResponseData: msg,
|
||||
ResponseErr: err,
|
||||
}
|
||||
}
|
87
pkg/transport/common_test.go
Normal file
87
pkg/transport/common_test.go
Normal file
@ -0,0 +1,87 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestBasicRequest(t *testing.T) {
|
||||
// Test creating a request
|
||||
payload := []byte("test payload")
|
||||
req := NewRequest(TypeGet, payload)
|
||||
|
||||
// Test Type method
|
||||
if req.Type() != TypeGet {
|
||||
t.Errorf("Expected type %s, got %s", TypeGet, req.Type())
|
||||
}
|
||||
|
||||
// Test Payload method
|
||||
if string(req.Payload()) != string(payload) {
|
||||
t.Errorf("Expected payload %s, got %s", string(payload), string(req.Payload()))
|
||||
}
|
||||
}
|
||||
|
||||
func TestBasicResponse(t *testing.T) {
|
||||
// Test creating a response with no error
|
||||
payload := []byte("test response")
|
||||
resp := NewResponse(TypeGet, payload, nil)
|
||||
|
||||
// Test Type method
|
||||
if resp.Type() != TypeGet {
|
||||
t.Errorf("Expected type %s, got %s", TypeGet, resp.Type())
|
||||
}
|
||||
|
||||
// Test Payload method
|
||||
if string(resp.Payload()) != string(payload) {
|
||||
t.Errorf("Expected payload %s, got %s", string(payload), string(resp.Payload()))
|
||||
}
|
||||
|
||||
// Test Error method
|
||||
if resp.Error() != nil {
|
||||
t.Errorf("Expected nil error, got %v", resp.Error())
|
||||
}
|
||||
|
||||
// Test creating a response with an error
|
||||
testErr := errors.New("test error")
|
||||
resp = NewResponse(TypeGet, payload, testErr)
|
||||
|
||||
if resp.Error() != testErr {
|
||||
t.Errorf("Expected error %v, got %v", testErr, resp.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewErrorResponse(t *testing.T) {
|
||||
// Test creating an error response
|
||||
testErr := errors.New("test error")
|
||||
resp := NewErrorResponse(testErr)
|
||||
|
||||
// Test Type method
|
||||
if resp.Type() != TypeError {
|
||||
t.Errorf("Expected type %s, got %s", TypeError, resp.Type())
|
||||
}
|
||||
|
||||
// Test Payload method - should contain error message
|
||||
if string(resp.Payload()) != testErr.Error() {
|
||||
t.Errorf("Expected payload %s, got %s", testErr.Error(), string(resp.Payload()))
|
||||
}
|
||||
|
||||
// Test Error method
|
||||
if resp.Error() != testErr {
|
||||
t.Errorf("Expected error %v, got %v", testErr, resp.Error())
|
||||
}
|
||||
|
||||
// Test with nil error
|
||||
resp = NewErrorResponse(nil)
|
||||
|
||||
if resp.Type() != TypeError {
|
||||
t.Errorf("Expected type %s, got %s", TypeError, resp.Type())
|
||||
}
|
||||
|
||||
if len(resp.Payload()) != 0 {
|
||||
t.Errorf("Expected empty payload, got %s", string(resp.Payload()))
|
||||
}
|
||||
|
||||
if resp.Error() != nil {
|
||||
t.Errorf("Expected nil error, got %v", resp.Error())
|
||||
}
|
||||
}
|
149
pkg/transport/interface.go
Normal file
149
pkg/transport/interface.go
Normal file
@ -0,0 +1,149 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// CompressionType defines the compression algorithm used
|
||||
type CompressionType string
|
||||
|
||||
// Standard compression options
|
||||
const (
|
||||
CompressionNone CompressionType = "none"
|
||||
CompressionGzip CompressionType = "gzip"
|
||||
CompressionSnappy CompressionType = "snappy"
|
||||
)
|
||||
|
||||
// RetryPolicy defines how retries are handled
|
||||
type RetryPolicy struct {
|
||||
MaxRetries int
|
||||
InitialBackoff time.Duration
|
||||
MaxBackoff time.Duration
|
||||
BackoffFactor float64
|
||||
Jitter float64
|
||||
}
|
||||
|
||||
// TransportOptions contains common configuration across all transport types
|
||||
type TransportOptions struct {
|
||||
Timeout time.Duration
|
||||
RetryPolicy RetryPolicy
|
||||
Compression CompressionType
|
||||
MaxMessageSize int
|
||||
TLSEnabled bool
|
||||
CertFile string
|
||||
KeyFile string
|
||||
CAFile string
|
||||
}
|
||||
|
||||
// TransportStatus contains information about the current transport state
|
||||
type TransportStatus struct {
|
||||
Connected bool
|
||||
LastConnected time.Time
|
||||
LastError error
|
||||
BytesSent uint64
|
||||
BytesReceived uint64
|
||||
RTT time.Duration
|
||||
}
|
||||
|
||||
// Request represents a generic request to the transport layer
|
||||
type Request interface {
|
||||
// Type returns the type of request
|
||||
Type() string
|
||||
|
||||
// Payload returns the payload of the request
|
||||
Payload() []byte
|
||||
}
|
||||
|
||||
// Response represents a generic response from the transport layer
|
||||
type Response interface {
|
||||
// Type returns the type of response
|
||||
Type() string
|
||||
|
||||
// Payload returns the payload of the response
|
||||
Payload() []byte
|
||||
|
||||
// Error returns any error associated with the response
|
||||
Error() error
|
||||
}
|
||||
|
||||
// Stream represents a bidirectional stream of messages
|
||||
type Stream interface {
|
||||
// Send sends a request over the stream
|
||||
Send(request Request) error
|
||||
|
||||
// Recv receives a response from the stream
|
||||
Recv() (Response, error)
|
||||
|
||||
// Close closes the stream
|
||||
Close() error
|
||||
}
|
||||
|
||||
// Client defines the client interface for any transport implementation
|
||||
type Client interface {
|
||||
// Connect establishes a connection to the server
|
||||
Connect(ctx context.Context) error
|
||||
|
||||
// Close closes the connection
|
||||
Close() error
|
||||
|
||||
// IsConnected returns whether the client is connected
|
||||
IsConnected() bool
|
||||
|
||||
// Status returns the current status of the connection
|
||||
Status() TransportStatus
|
||||
|
||||
// Send sends a request and waits for a response
|
||||
Send(ctx context.Context, request Request) (Response, error)
|
||||
|
||||
// Stream opens a bidirectional stream
|
||||
Stream(ctx context.Context) (Stream, error)
|
||||
}
|
||||
|
||||
// RequestHandler processes incoming requests
|
||||
type RequestHandler interface {
|
||||
// HandleRequest processes a request and returns a response
|
||||
HandleRequest(ctx context.Context, request Request) (Response, error)
|
||||
|
||||
// HandleStream processes a bidirectional stream
|
||||
HandleStream(stream Stream) error
|
||||
}
|
||||
|
||||
// Server defines the server interface for any transport implementation
|
||||
type Server interface {
|
||||
// Start starts the server and returns immediately
|
||||
Start() error
|
||||
|
||||
// Serve starts the server and blocks until it's stopped
|
||||
Serve() error
|
||||
|
||||
// Stop stops the server gracefully
|
||||
Stop(ctx context.Context) error
|
||||
|
||||
// SetRequestHandler sets the handler for incoming requests
|
||||
SetRequestHandler(handler RequestHandler)
|
||||
}
|
||||
|
||||
// ClientFactory creates a new client
|
||||
type ClientFactory func(endpoint string, options TransportOptions) (Client, error)
|
||||
|
||||
// ServerFactory creates a new server
|
||||
type ServerFactory func(address string, options TransportOptions) (Server, error)
|
||||
|
||||
// Registry keeps track of available transport implementations
|
||||
type Registry interface {
|
||||
// RegisterClient adds a new client implementation to the registry
|
||||
RegisterClient(name string, factory ClientFactory)
|
||||
|
||||
// RegisterServer adds a new server implementation to the registry
|
||||
RegisterServer(name string, factory ServerFactory)
|
||||
|
||||
// CreateClient instantiates a client by name
|
||||
CreateClient(name, endpoint string, options TransportOptions) (Client, error)
|
||||
|
||||
// CreateServer instantiates a server by name
|
||||
CreateServer(name, address string, options TransportOptions) (Server, error)
|
||||
|
||||
// ListTransports returns all available transport names
|
||||
ListTransports() []string
|
||||
}
|
136
pkg/transport/metrics.go
Normal file
136
pkg/transport/metrics.go
Normal file
@ -0,0 +1,136 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// MetricsCollector collects metrics for transport operations
|
||||
type MetricsCollector interface {
|
||||
// RecordRequest records metrics for a request
|
||||
RecordRequest(requestType string, startTime time.Time, err error)
|
||||
|
||||
// RecordSend records metrics for bytes sent
|
||||
RecordSend(bytes int)
|
||||
|
||||
// RecordReceive records metrics for bytes received
|
||||
RecordReceive(bytes int)
|
||||
|
||||
// RecordConnection records a connection event
|
||||
RecordConnection(successful bool)
|
||||
|
||||
// GetMetrics returns the current metrics
|
||||
GetMetrics() Metrics
|
||||
}
|
||||
|
||||
// Metrics represents transport metrics
|
||||
type Metrics struct {
|
||||
TotalRequests uint64
|
||||
SuccessfulRequests uint64
|
||||
FailedRequests uint64
|
||||
BytesSent uint64
|
||||
BytesReceived uint64
|
||||
Connections uint64
|
||||
ConnectionFailures uint64
|
||||
AvgLatencyByType map[string]time.Duration
|
||||
}
|
||||
|
||||
// BasicMetricsCollector is a simple implementation of MetricsCollector
|
||||
type BasicMetricsCollector struct {
|
||||
mu sync.RWMutex
|
||||
totalRequests uint64
|
||||
successfulRequests uint64
|
||||
failedRequests uint64
|
||||
bytesSent uint64
|
||||
bytesReceived uint64
|
||||
connections uint64
|
||||
connectionFailures uint64
|
||||
|
||||
// Track average latency and count for each request type
|
||||
avgLatencyByType map[string]time.Duration
|
||||
requestCountByType map[string]uint64
|
||||
}
|
||||
|
||||
// NewMetricsCollector creates a new metrics collector
|
||||
func NewMetricsCollector() MetricsCollector {
|
||||
return &BasicMetricsCollector{
|
||||
avgLatencyByType: make(map[string]time.Duration),
|
||||
requestCountByType: make(map[string]uint64),
|
||||
}
|
||||
}
|
||||
|
||||
// RecordRequest records metrics for a request
|
||||
func (c *BasicMetricsCollector) RecordRequest(requestType string, startTime time.Time, err error) {
|
||||
atomic.AddUint64(&c.totalRequests, 1)
|
||||
|
||||
if err == nil {
|
||||
atomic.AddUint64(&c.successfulRequests, 1)
|
||||
} else {
|
||||
atomic.AddUint64(&c.failedRequests, 1)
|
||||
}
|
||||
|
||||
// Update average latency for request type
|
||||
latency := time.Since(startTime)
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
currentAvg, exists := c.avgLatencyByType[requestType]
|
||||
currentCount, _ := c.requestCountByType[requestType]
|
||||
|
||||
if exists {
|
||||
// Update running average - the common case for better branch prediction
|
||||
// new_avg = (old_avg * count + new_value) / (count + 1)
|
||||
totalDuration := currentAvg * time.Duration(currentCount) + latency
|
||||
newCount := currentCount + 1
|
||||
c.avgLatencyByType[requestType] = totalDuration / time.Duration(newCount)
|
||||
c.requestCountByType[requestType] = newCount
|
||||
} else {
|
||||
// First request of this type
|
||||
c.avgLatencyByType[requestType] = latency
|
||||
c.requestCountByType[requestType] = 1
|
||||
}
|
||||
}
|
||||
|
||||
// RecordSend records metrics for bytes sent
|
||||
func (c *BasicMetricsCollector) RecordSend(bytes int) {
|
||||
atomic.AddUint64(&c.bytesSent, uint64(bytes))
|
||||
}
|
||||
|
||||
// RecordReceive records metrics for bytes received
|
||||
func (c *BasicMetricsCollector) RecordReceive(bytes int) {
|
||||
atomic.AddUint64(&c.bytesReceived, uint64(bytes))
|
||||
}
|
||||
|
||||
// RecordConnection records a connection event
|
||||
func (c *BasicMetricsCollector) RecordConnection(successful bool) {
|
||||
if successful {
|
||||
atomic.AddUint64(&c.connections, 1)
|
||||
} else {
|
||||
atomic.AddUint64(&c.connectionFailures, 1)
|
||||
}
|
||||
}
|
||||
|
||||
// GetMetrics returns the current metrics
|
||||
func (c *BasicMetricsCollector) GetMetrics() Metrics {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
// Create a copy of the average latency map
|
||||
avgLatencyByType := make(map[string]time.Duration, len(c.avgLatencyByType))
|
||||
for k, v := range c.avgLatencyByType {
|
||||
avgLatencyByType[k] = v
|
||||
}
|
||||
|
||||
return Metrics{
|
||||
TotalRequests: atomic.LoadUint64(&c.totalRequests),
|
||||
SuccessfulRequests: atomic.LoadUint64(&c.successfulRequests),
|
||||
FailedRequests: atomic.LoadUint64(&c.failedRequests),
|
||||
BytesSent: atomic.LoadUint64(&c.bytesSent),
|
||||
BytesReceived: atomic.LoadUint64(&c.bytesReceived),
|
||||
Connections: atomic.LoadUint64(&c.connections),
|
||||
ConnectionFailures: atomic.LoadUint64(&c.connectionFailures),
|
||||
AvgLatencyByType: avgLatencyByType,
|
||||
}
|
||||
}
|
101
pkg/transport/metrics_test.go
Normal file
101
pkg/transport/metrics_test.go
Normal file
@ -0,0 +1,101 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestBasicMetricsCollector(t *testing.T) {
|
||||
collector := NewMetricsCollector()
|
||||
|
||||
// Test initial state
|
||||
metrics := collector.GetMetrics()
|
||||
if metrics.TotalRequests != 0 ||
|
||||
metrics.SuccessfulRequests != 0 ||
|
||||
metrics.FailedRequests != 0 ||
|
||||
metrics.BytesSent != 0 ||
|
||||
metrics.BytesReceived != 0 ||
|
||||
metrics.Connections != 0 ||
|
||||
metrics.ConnectionFailures != 0 ||
|
||||
len(metrics.AvgLatencyByType) != 0 {
|
||||
t.Errorf("Initial metrics not initialized correctly: %+v", metrics)
|
||||
}
|
||||
|
||||
// Test recording successful request
|
||||
startTime := time.Now().Add(-100 * time.Millisecond) // Simulate 100ms request
|
||||
collector.RecordRequest("get", startTime, nil)
|
||||
|
||||
metrics = collector.GetMetrics()
|
||||
if metrics.TotalRequests != 1 {
|
||||
t.Errorf("Expected TotalRequests to be 1, got %d", metrics.TotalRequests)
|
||||
}
|
||||
if metrics.SuccessfulRequests != 1 {
|
||||
t.Errorf("Expected SuccessfulRequests to be 1, got %d", metrics.SuccessfulRequests)
|
||||
}
|
||||
if metrics.FailedRequests != 0 {
|
||||
t.Errorf("Expected FailedRequests to be 0, got %d", metrics.FailedRequests)
|
||||
}
|
||||
|
||||
// Check average latency
|
||||
if avgLatency, exists := metrics.AvgLatencyByType["get"]; !exists {
|
||||
t.Error("Expected 'get' latency to exist")
|
||||
} else if avgLatency < 100*time.Millisecond {
|
||||
t.Errorf("Expected latency to be at least 100ms, got %v", avgLatency)
|
||||
}
|
||||
|
||||
// Test recording failed request
|
||||
startTime = time.Now().Add(-200 * time.Millisecond) // Simulate 200ms request
|
||||
collector.RecordRequest("get", startTime, errors.New("test error"))
|
||||
|
||||
metrics = collector.GetMetrics()
|
||||
if metrics.TotalRequests != 2 {
|
||||
t.Errorf("Expected TotalRequests to be 2, got %d", metrics.TotalRequests)
|
||||
}
|
||||
if metrics.SuccessfulRequests != 1 {
|
||||
t.Errorf("Expected SuccessfulRequests to be 1, got %d", metrics.SuccessfulRequests)
|
||||
}
|
||||
if metrics.FailedRequests != 1 {
|
||||
t.Errorf("Expected FailedRequests to be 1, got %d", metrics.FailedRequests)
|
||||
}
|
||||
|
||||
// Test average latency calculation for multiple requests
|
||||
startTime = time.Now().Add(-300 * time.Millisecond)
|
||||
collector.RecordRequest("put", startTime, nil)
|
||||
|
||||
startTime = time.Now().Add(-500 * time.Millisecond)
|
||||
collector.RecordRequest("put", startTime, nil)
|
||||
|
||||
metrics = collector.GetMetrics()
|
||||
avgPutLatency := metrics.AvgLatencyByType["put"]
|
||||
|
||||
// Expected avg is around (300ms + 500ms) / 2 = 400ms
|
||||
if avgPutLatency < 390*time.Millisecond || avgPutLatency > 410*time.Millisecond {
|
||||
t.Errorf("Expected average 'put' latency to be around 400ms, got %v", avgPutLatency)
|
||||
}
|
||||
|
||||
// Test byte tracking
|
||||
collector.RecordSend(1000)
|
||||
collector.RecordReceive(2000)
|
||||
|
||||
metrics = collector.GetMetrics()
|
||||
if metrics.BytesSent != 1000 {
|
||||
t.Errorf("Expected BytesSent to be 1000, got %d", metrics.BytesSent)
|
||||
}
|
||||
if metrics.BytesReceived != 2000 {
|
||||
t.Errorf("Expected BytesReceived to be 2000, got %d", metrics.BytesReceived)
|
||||
}
|
||||
|
||||
// Test connection tracking
|
||||
collector.RecordConnection(true)
|
||||
collector.RecordConnection(false)
|
||||
collector.RecordConnection(true)
|
||||
|
||||
metrics = collector.GetMetrics()
|
||||
if metrics.Connections != 2 {
|
||||
t.Errorf("Expected Connections to be 2, got %d", metrics.Connections)
|
||||
}
|
||||
if metrics.ConnectionFailures != 1 {
|
||||
t.Errorf("Expected ConnectionFailures to be 1, got %d", metrics.ConnectionFailures)
|
||||
}
|
||||
}
|
114
pkg/transport/registry.go
Normal file
114
pkg/transport/registry.go
Normal file
@ -0,0 +1,114 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// registry implements the Registry interface
|
||||
type registry struct {
|
||||
mu sync.RWMutex
|
||||
clientFactories map[string]ClientFactory
|
||||
serverFactories map[string]ServerFactory
|
||||
}
|
||||
|
||||
// NewRegistry creates a new transport registry
|
||||
func NewRegistry() Registry {
|
||||
return ®istry{
|
||||
clientFactories: make(map[string]ClientFactory),
|
||||
serverFactories: make(map[string]ServerFactory),
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultRegistry is the default global registry instance
|
||||
var DefaultRegistry = NewRegistry()
|
||||
|
||||
// RegisterClient adds a new client implementation to the registry
|
||||
func (r *registry) RegisterClient(name string, factory ClientFactory) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.clientFactories[name] = factory
|
||||
}
|
||||
|
||||
// RegisterServer adds a new server implementation to the registry
|
||||
func (r *registry) RegisterServer(name string, factory ServerFactory) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.serverFactories[name] = factory
|
||||
}
|
||||
|
||||
// CreateClient instantiates a client by name
|
||||
func (r *registry) CreateClient(name, endpoint string, options TransportOptions) (Client, error) {
|
||||
r.mu.RLock()
|
||||
factory, exists := r.clientFactories[name]
|
||||
r.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("transport client %q not registered", name)
|
||||
}
|
||||
|
||||
return factory(endpoint, options)
|
||||
}
|
||||
|
||||
// CreateServer instantiates a server by name
|
||||
func (r *registry) CreateServer(name, address string, options TransportOptions) (Server, error) {
|
||||
r.mu.RLock()
|
||||
factory, exists := r.serverFactories[name]
|
||||
r.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("transport server %q not registered", name)
|
||||
}
|
||||
|
||||
return factory(address, options)
|
||||
}
|
||||
|
||||
// ListTransports returns all available transport names
|
||||
func (r *registry) ListTransports() []string {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
// Get unique transport names
|
||||
names := make(map[string]struct{})
|
||||
for name := range r.clientFactories {
|
||||
names[name] = struct{}{}
|
||||
}
|
||||
for name := range r.serverFactories {
|
||||
names[name] = struct{}{}
|
||||
}
|
||||
|
||||
// Convert to slice
|
||||
result := make([]string, 0, len(names))
|
||||
for name := range names {
|
||||
result = append(result, name)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// Helper functions for global registry
|
||||
|
||||
// RegisterClientTransport registers a client transport with the default registry
|
||||
func RegisterClientTransport(name string, factory ClientFactory) {
|
||||
DefaultRegistry.RegisterClient(name, factory)
|
||||
}
|
||||
|
||||
// RegisterServerTransport registers a server transport with the default registry
|
||||
func RegisterServerTransport(name string, factory ServerFactory) {
|
||||
DefaultRegistry.RegisterServer(name, factory)
|
||||
}
|
||||
|
||||
// GetClient creates a client using the default registry
|
||||
func GetClient(name, endpoint string, options TransportOptions) (Client, error) {
|
||||
return DefaultRegistry.CreateClient(name, endpoint, options)
|
||||
}
|
||||
|
||||
// GetServer creates a server using the default registry
|
||||
func GetServer(name, address string, options TransportOptions) (Server, error) {
|
||||
return DefaultRegistry.CreateServer(name, address, options)
|
||||
}
|
||||
|
||||
// AvailableTransports lists all available transports in the default registry
|
||||
func AvailableTransports() []string {
|
||||
return DefaultRegistry.ListTransports()
|
||||
}
|
162
pkg/transport/registry_test.go
Normal file
162
pkg/transport/registry_test.go
Normal file
@ -0,0 +1,162 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// mockClient implements the Client interface for testing
|
||||
type mockClient struct {
|
||||
connected bool
|
||||
endpoint string
|
||||
options TransportOptions
|
||||
}
|
||||
|
||||
func (m *mockClient) Connect(ctx context.Context) error {
|
||||
m.connected = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockClient) Close() error {
|
||||
m.connected = false
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockClient) IsConnected() bool {
|
||||
return m.connected
|
||||
}
|
||||
|
||||
func (m *mockClient) Status() TransportStatus {
|
||||
return TransportStatus{
|
||||
Connected: m.connected,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockClient) Send(ctx context.Context, request Request) (Response, error) {
|
||||
if !m.connected {
|
||||
return nil, ErrNotConnected
|
||||
}
|
||||
return &BasicResponse{
|
||||
ResponseType: request.Type() + "_response",
|
||||
ResponseData: []byte("mock response"),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *mockClient) Stream(ctx context.Context) (Stream, error) {
|
||||
if !m.connected {
|
||||
return nil, ErrNotConnected
|
||||
}
|
||||
return nil, errors.New("streaming not implemented in mock")
|
||||
}
|
||||
|
||||
// mockClientFactory creates a new mock client
|
||||
func mockClientFactory(endpoint string, options TransportOptions) (Client, error) {
|
||||
return &mockClient{
|
||||
endpoint: endpoint,
|
||||
options: options,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// mockServer implements the Server interface for testing
|
||||
type mockServer struct {
|
||||
started bool
|
||||
address string
|
||||
options TransportOptions
|
||||
handler RequestHandler
|
||||
}
|
||||
|
||||
func (m *mockServer) Start() error {
|
||||
m.started = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockServer) Serve() error {
|
||||
m.started = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockServer) Stop(ctx context.Context) error {
|
||||
m.started = false
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockServer) SetRequestHandler(handler RequestHandler) {
|
||||
m.handler = handler
|
||||
}
|
||||
|
||||
// mockServerFactory creates a new mock server
|
||||
func mockServerFactory(address string, options TransportOptions) (Server, error) {
|
||||
return &mockServer{
|
||||
address: address,
|
||||
options: options,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// TestRegistry tests the transport registry
|
||||
func TestRegistry(t *testing.T) {
|
||||
registry := NewRegistry()
|
||||
|
||||
// Register transports
|
||||
registry.RegisterClient("mock", mockClientFactory)
|
||||
registry.RegisterServer("mock", mockServerFactory)
|
||||
|
||||
// Test listing transports
|
||||
transports := registry.ListTransports()
|
||||
if len(transports) != 1 || transports[0] != "mock" {
|
||||
t.Errorf("Expected [mock], got %v", transports)
|
||||
}
|
||||
|
||||
// Test creating client
|
||||
client, err := registry.CreateClient("mock", "localhost:8080", TransportOptions{
|
||||
Timeout: 5 * time.Second,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create client: %v", err)
|
||||
}
|
||||
|
||||
// Test client methods
|
||||
if client.IsConnected() {
|
||||
t.Error("Expected client to be disconnected initially")
|
||||
}
|
||||
|
||||
err = client.Connect(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to connect: %v", err)
|
||||
}
|
||||
|
||||
if !client.IsConnected() {
|
||||
t.Error("Expected client to be connected after Connect()")
|
||||
}
|
||||
|
||||
// Test server creation
|
||||
server, err := registry.CreateServer("mock", "localhost:8080", TransportOptions{
|
||||
Timeout: 5 * time.Second,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create server: %v", err)
|
||||
}
|
||||
|
||||
// Test server methods
|
||||
err = server.Start()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to start server: %v", err)
|
||||
}
|
||||
|
||||
mockServer := server.(*mockServer)
|
||||
if !mockServer.started {
|
||||
t.Error("Expected server to be started")
|
||||
}
|
||||
|
||||
// Test non-existent transport
|
||||
_, err = registry.CreateClient("nonexistent", "", TransportOptions{})
|
||||
if err == nil {
|
||||
t.Error("Expected error creating non-existent client")
|
||||
}
|
||||
|
||||
_, err = registry.CreateServer("nonexistent", "", TransportOptions{})
|
||||
if err == nil {
|
||||
t.Error("Expected error creating non-existent server")
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user