chore: formatting
All checks were successful
Go Tests / Run Tests (1.24.2) (push) Successful in 9m50s

This commit is contained in:
Jeremy Tregunna 2025-04-22 14:09:54 -06:00
parent e7974e008d
commit a0a1c0512f
Signed by: jer
GPG Key ID: 1278B36BA6F5D5E4
33 changed files with 440 additions and 440 deletions

View File

@ -81,14 +81,14 @@ Commands (interactive mode only):
// Config holds the application configuration
type Config struct {
ServerMode bool
DaemonMode bool
ListenAddr string
DBPath string
TLSEnabled bool
TLSCertFile string
TLSKeyFile string
TLSCAFile string
ServerMode bool
DaemonMode bool
ListenAddr string
DBPath string
TLSEnabled bool
TLSCertFile string
TLSKeyFile string
TLSCAFile string
}
func main() {
@ -98,7 +98,7 @@ func main() {
// Open database if path provided
var eng *engine.Engine
var err error
if config.DBPath != "" {
fmt.Printf("Opening database at %s\n", config.DBPath)
eng, err = engine.NewEngine(config.DBPath)
@ -108,18 +108,18 @@ func main() {
}
defer eng.Close()
}
// Check if we should run in server mode
if config.ServerMode {
if eng == nil {
fmt.Fprintf(os.Stderr, "Error: Server mode requires a database path\n")
os.Exit(1)
}
runServer(eng, config)
return
}
// Run in interactive mode
runInteractive(eng, config.DBPath)
}
@ -151,31 +151,31 @@ func parseFlags() Config {
serverMode := flag.Bool("server", false, "Run in server mode, exposing a gRPC API")
daemonMode := flag.Bool("daemon", false, "Run in daemon mode (detached from terminal)")
listenAddr := flag.String("address", "localhost:50051", "Address to listen on in server mode")
// TLS options
tlsEnabled := flag.Bool("tls", false, "Enable TLS for secure connections")
tlsCertFile := flag.String("cert", "", "TLS certificate file path")
tlsKeyFile := flag.String("key", "", "TLS private key file path")
tlsCAFile := flag.String("ca", "", "TLS CA certificate file for client verification")
// Parse flags
flag.Parse()
// Get database path from remaining arguments
var dbPath string
if flag.NArg() > 0 {
dbPath = flag.Arg(0)
}
return Config{
ServerMode: *serverMode,
DaemonMode: *daemonMode,
ListenAddr: *listenAddr,
DBPath: dbPath,
TLSEnabled: *tlsEnabled,
TLSCertFile: *tlsCertFile,
TLSKeyFile: *tlsKeyFile,
TLSCAFile: *tlsCAFile,
ServerMode: *serverMode,
DaemonMode: *daemonMode,
ListenAddr: *listenAddr,
DBPath: dbPath,
TLSEnabled: *tlsEnabled,
TLSCertFile: *tlsCertFile,
TLSKeyFile: *tlsKeyFile,
TLSCAFile: *tlsCAFile,
}
}
@ -185,10 +185,10 @@ func runServer(eng *engine.Engine, config Config) {
if config.DaemonMode {
setupDaemonMode()
}
// Create and start the server
server := NewServer(eng, config)
// Start the server (non-blocking)
if err := server.Start(); err != nil {
fmt.Fprintf(os.Stderr, "Error starting server: %v\n", err)
@ -196,10 +196,10 @@ func runServer(eng *engine.Engine, config Config) {
}
fmt.Printf("Kevo server started on %s\n", config.ListenAddr)
// Set up signal handling for graceful shutdown
setupGracefulShutdown(server, eng)
// Start serving (blocking)
if err := server.Serve(); err != nil {
fmt.Fprintf(os.Stderr, "Error serving: %v\n", err)
@ -214,29 +214,29 @@ func setupDaemonMode() {
if err != nil {
log.Fatalf("Failed to open /dev/null: %v", err)
}
// Redirect standard file descriptors to /dev/null
err = syscall.Dup2(int(null.Fd()), int(os.Stdin.Fd()))
if err != nil {
log.Fatalf("Failed to redirect stdin: %v", err)
}
err = syscall.Dup2(int(null.Fd()), int(os.Stdout.Fd()))
if err != nil {
log.Fatalf("Failed to redirect stdout: %v", err)
}
err = syscall.Dup2(int(null.Fd()), int(os.Stderr.Fd()))
if err != nil {
log.Fatalf("Failed to redirect stderr: %v", err)
}
// Create a new process group
_, err = syscall.Setsid()
if err != nil {
log.Fatalf("Failed to create new session: %v", err)
}
fmt.Println("Daemon mode enabled, detaching from terminal...")
}
@ -244,22 +244,22 @@ func setupDaemonMode() {
func setupGracefulShutdown(server *Server, eng *engine.Engine) {
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
go func() {
sig := <-sigChan
fmt.Printf("\nReceived signal %v, shutting down...\n", sig)
// Graceful shutdown logic
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
// Shut down the server
if err := server.Shutdown(ctx); err != nil {
fmt.Fprintf(os.Stderr, "Error shutting down server: %v\n", err)
}
// The engine will be closed by the defer in main()
fmt.Println("Shutdown complete")
os.Exit(0)
}()
@ -269,7 +269,7 @@ func setupGracefulShutdown(server *Server, eng *engine.Engine) {
func runInteractive(eng *engine.Engine, dbPath string) {
fmt.Println("Kevo (kevo) version 1.0.2")
fmt.Println("Enter .help for usage hints.")
var tx engine.Transaction
var err error
@ -412,7 +412,7 @@ func runInteractive(eng *engine.Engine, dbPath string) {
// Print statistics
stats := eng.GetStats()
// Format human-readable time for the last operation timestamps
var lastPutTime, lastGetTime, lastDeleteTime time.Time
if putTime, ok := stats["last_put_time"].(int64); ok && putTime > 0 {
@ -424,13 +424,13 @@ func runInteractive(eng *engine.Engine, dbPath string) {
if deleteTime, ok := stats["last_delete_time"].(int64); ok && deleteTime > 0 {
lastDeleteTime = time.Unix(0, deleteTime)
}
// Operations section
fmt.Println("📊 Operations:")
fmt.Printf(" • Puts: %d\n", stats["put_ops"])
fmt.Printf(" • Gets: %d (Hits: %d, Misses: %d)\n", stats["get_ops"], stats["get_hits"], stats["get_misses"])
fmt.Printf(" • Deletes: %d\n", stats["delete_ops"])
// Last Operation Times
fmt.Println("\n⏱ Last Operation Times:")
if !lastPutTime.IsZero() {
@ -448,25 +448,25 @@ func runInteractive(eng *engine.Engine, dbPath string) {
} else {
fmt.Printf(" • Last Delete: Never\n")
}
// Transactions
fmt.Println("\n💼 Transactions:")
fmt.Printf(" • Started: %d\n", stats["tx_started"])
fmt.Printf(" • Completed: %d\n", stats["tx_completed"])
fmt.Printf(" • Aborted: %d\n", stats["tx_aborted"])
// Storage metrics
fmt.Println("\n💾 Storage:")
fmt.Printf(" • Total Bytes Read: %d\n", stats["total_bytes_read"])
fmt.Printf(" • Total Bytes Written: %d\n", stats["total_bytes_written"])
fmt.Printf(" • Flush Count: %d\n", stats["flush_count"])
// Table stats
fmt.Println("\n📋 Tables:")
fmt.Printf(" • SSTable Count: %d\n", stats["sstable_count"])
fmt.Printf(" • Immutable MemTable Count: %d\n", stats["immutable_memtable_count"])
fmt.Printf(" • Current MemTable Size: %d bytes\n", stats["memtable_size"])
// WAL recovery stats
fmt.Println("\n🔄 WAL Recovery:")
fmt.Printf(" • Files Recovered: %d\n", stats["wal_files_recovered"])
@ -475,17 +475,17 @@ func runInteractive(eng *engine.Engine, dbPath string) {
if recoveryDuration, ok := stats["wal_recovery_duration_ms"]; ok {
fmt.Printf(" • Recovery Duration: %d ms\n", recoveryDuration)
}
// Error counts
fmt.Println("\n⚠ Errors:")
fmt.Printf(" • Read Errors: %d\n", stats["read_errors"])
fmt.Printf(" • Write Errors: %d\n", stats["write_errors"])
// Compaction stats (if available)
if compactionOutputCount, ok := stats["compaction_last_outputs_count"]; ok {
fmt.Println("\n🧹 Compaction:")
fmt.Printf(" • Last Output Files Count: %d\n", compactionOutputCount)
// Display other compaction stats as available
for key, value := range stats {
if strings.HasPrefix(key, "compaction_") && key != "compaction_last_outputs_count" && key != "compaction_last_outputs" {
@ -825,4 +825,4 @@ func toTitle(s string) string {
return r
},
s)
}
}

View File

@ -35,14 +35,14 @@ func (tr *TransactionRegistry) Begin(ctx context.Context, eng *engine.Engine, re
// Create context with timeout to prevent potential hangs
timeoutCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
// Create a channel to receive the transaction result
type txResult struct {
tx engine.Transaction
err error
}
resultCh := make(chan txResult, 1)
// Start transaction in a goroutine to prevent potential blocking
go func() {
tx, err := eng.BeginTransaction(readOnly)
@ -56,26 +56,26 @@ func (tr *TransactionRegistry) Begin(ctx context.Context, eng *engine.Engine, re
}
}
}()
// Wait for result or timeout
select {
case result := <-resultCh:
if result.err != nil {
return "", fmt.Errorf("failed to begin transaction: %w", result.err)
}
tr.mu.Lock()
defer tr.mu.Unlock()
// Generate a transaction ID
tr.nextID++
txID := fmt.Sprintf("tx-%d", tr.nextID)
// Register the transaction
tr.transactions[txID] = result.tx
return txID, nil
case <-timeoutCtx.Done():
return "", fmt.Errorf("transaction creation timed out: %w", timeoutCtx.Err())
}
@ -104,31 +104,31 @@ func (tr *TransactionRegistry) GracefulShutdown(ctx context.Context) error {
defer tr.mu.Unlock()
var lastErr error
// Copy transaction IDs to avoid modifying the map during iteration
ids := make([]string, 0, len(tr.transactions))
for id := range tr.transactions {
ids = append(ids, id)
}
// Rollback each transaction with a timeout
for _, id := range ids {
tx, exists := tr.transactions[id]
if !exists {
continue
}
// Use a timeout for each rollback operation
rollbackCtx, cancel := context.WithTimeout(ctx, 1*time.Second)
// Create a channel for the rollback result
doneCh := make(chan error, 1)
// Execute rollback in goroutine
go func(t engine.Transaction) {
doneCh <- t.Rollback()
}(tx)
// Wait for rollback or timeout
var err error
select {
@ -137,14 +137,14 @@ func (tr *TransactionRegistry) GracefulShutdown(ctx context.Context) error {
case <-rollbackCtx.Done():
err = fmt.Errorf("rollback timed out: %w", rollbackCtx.Err())
}
cancel() // Clean up context
// Record error if any
if err != nil {
lastErr = fmt.Errorf("failed to rollback transaction %s: %w", id, err)
}
// Always remove transaction from map
delete(tr.transactions, id)
}
@ -154,12 +154,12 @@ func (tr *TransactionRegistry) GracefulShutdown(ctx context.Context) error {
// Server represents the Kevo server
type Server struct {
eng *engine.Engine
txRegistry *TransactionRegistry
listener net.Listener
grpcServer *grpc.Server
eng *engine.Engine
txRegistry *TransactionRegistry
listener net.Listener
grpcServer *grpc.Server
kevoService *grpcservice.KevoServiceServer
config Config
config Config
}
// NewServer creates a new server instance
@ -249,14 +249,14 @@ func (s *Server) Shutdown(ctx context.Context) error {
// First, gracefully stop the gRPC server if it exists
if s.grpcServer != nil {
fmt.Println("Gracefully stopping gRPC server...")
// Create a channel to signal when the server has stopped
stopped := make(chan struct{})
go func() {
s.grpcServer.GracefulStop()
close(stopped)
}()
// Wait for graceful stop or context deadline
select {
case <-stopped:
@ -266,7 +266,7 @@ func (s *Server) Shutdown(ctx context.Context) error {
s.grpcServer.Stop()
}
}
// Shut down the listener if it's still open
if s.listener != nil {
if err := s.listener.Close(); err != nil {
@ -280,4 +280,4 @@ func (s *Server) Shutdown(ctx context.Context) error {
}
return nil
}
}

View File

@ -14,7 +14,7 @@ func TestTransactionRegistry(t *testing.T) {
// Create a timeout context for the whole test
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
// Set up temporary directory for test
tmpDir, err := os.MkdirTemp("", "kevo_test")
if err != nil {
@ -153,47 +153,47 @@ func TestGRPCServer(t *testing.T) {
t.Fatalf("Failed to create temporary directory: %v", err)
}
defer os.RemoveAll(tempDBPath)
// Create engine
eng, err := engine.NewEngine(tempDBPath)
if err != nil {
t.Fatalf("Failed to create engine: %v", err)
}
defer eng.Close()
// Create server configuration
config := Config{
ServerMode: true,
ListenAddr: "localhost:50052", // Use a different port for tests
DBPath: tempDBPath,
}
// Create and start the server
server := NewServer(eng, config)
if err := server.Start(); err != nil {
t.Fatalf("Failed to start server: %v", err)
}
// Run server in a goroutine
go func() {
if err := server.Serve(); err != nil {
t.Logf("Server stopped: %v", err)
}
}()
// Give the server a moment to start
time.Sleep(200 * time.Millisecond)
// Clean up at the end
defer func() {
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer shutdownCancel()
if err := server.Shutdown(shutdownCtx); err != nil {
t.Logf("Failed to shut down server: %v", err)
}
}()
// TODO: Add gRPC client tests here when client implementation is complete
t.Log("gRPC server integration test scaffolding added")
}
}

View File

@ -23,11 +23,11 @@ const (
// ClientOptions configures a Kevo client
type ClientOptions struct {
// Connection options
Endpoint string // Server address
ConnectTimeout time.Duration // Timeout for connection attempts
RequestTimeout time.Duration // Default timeout for requests
TransportType string // Transport type (e.g. "grpc")
PoolSize int // Connection pool size
Endpoint string // Server address
ConnectTimeout time.Duration // Timeout for connection attempts
RequestTimeout time.Duration // Default timeout for requests
TransportType string // Transport type (e.g. "grpc")
PoolSize int // Connection pool size
// Security options
TLSEnabled bool // Enable TLS
@ -50,19 +50,19 @@ type ClientOptions struct {
// DefaultClientOptions returns sensible default client options
func DefaultClientOptions() ClientOptions {
return ClientOptions{
Endpoint: "localhost:50051",
ConnectTimeout: time.Second * 5,
RequestTimeout: time.Second * 10,
TransportType: "grpc",
PoolSize: 5,
TLSEnabled: false,
MaxRetries: 3,
InitialBackoff: time.Millisecond * 100,
MaxBackoff: time.Second * 2,
BackoffFactor: 1.5,
RetryJitter: 0.2,
Compression: CompressionNone,
MaxMessageSize: 16 * 1024 * 1024, // 16MB
Endpoint: "localhost:50051",
ConnectTimeout: time.Second * 5,
RequestTimeout: time.Second * 10,
TransportType: "grpc",
PoolSize: 5,
TLSEnabled: false,
MaxRetries: 3,
InitialBackoff: time.Millisecond * 100,
MaxBackoff: time.Second * 2,
BackoffFactor: 1.5,
RetryJitter: 0.2,
Compression: CompressionNone,
MaxMessageSize: 16 * 1024 * 1024, // 16MB
}
}
@ -378,4 +378,4 @@ type Stats struct {
SstableCount int32
WriteAmplification float64
ReadAmplification float64
}
}

View File

@ -105,13 +105,13 @@ func TestClientConnect(t *testing.T) {
// Modify default options to use mock transport
options := DefaultClientOptions()
options.TransportType = "mock"
// Create a client with the mock transport
client, err := NewClient(options)
if err != nil {
t.Fatalf("Failed to create client: %v", err)
}
// Get the underlying mock client for test assertions
mock := client.client.(*mockClient)
@ -139,12 +139,12 @@ func TestClientGet(t *testing.T) {
// Create a client with the mock transport
options := DefaultClientOptions()
options.TransportType = "mock"
client, err := NewClient(options)
if err != nil {
t.Fatalf("Failed to create client: %v", err)
}
// Get the underlying mock client for test assertions
mock := client.client.(*mockClient)
mock.connected = true
@ -186,12 +186,12 @@ func TestClientPut(t *testing.T) {
// Create a client with the mock transport
options := DefaultClientOptions()
options.TransportType = "mock"
client, err := NewClient(options)
if err != nil {
t.Fatalf("Failed to create client: %v", err)
}
// Get the underlying mock client for test assertions
mock := client.client.(*mockClient)
mock.connected = true
@ -220,12 +220,12 @@ func TestClientDelete(t *testing.T) {
// Create a client with the mock transport
options := DefaultClientOptions()
options.TransportType = "mock"
client, err := NewClient(options)
if err != nil {
t.Fatalf("Failed to create client: %v", err)
}
// Get the underlying mock client for test assertions
mock := client.client.(*mockClient)
mock.connected = true
@ -254,12 +254,12 @@ func TestClientBatchWrite(t *testing.T) {
// Create a client with the mock transport
options := DefaultClientOptions()
options.TransportType = "mock"
client, err := NewClient(options)
if err != nil {
t.Fatalf("Failed to create client: %v", err)
}
// Get the underlying mock client for test assertions
mock := client.client.(*mockClient)
mock.connected = true
@ -295,12 +295,12 @@ func TestClientGetStats(t *testing.T) {
// Create a client with the mock transport
options := DefaultClientOptions()
options.TransportType = "mock"
client, err := NewClient(options)
if err != nil {
t.Fatalf("Failed to create client: %v", err)
}
// Get the underlying mock client for test assertions
mock := client.client.(*mockClient)
mock.connected = true
@ -317,12 +317,12 @@ func TestClientGetStats(t *testing.T) {
"read_amplification": 2.0
}`
mock.setResponse(transport.TypeGetStats, []byte(statsJSON))
stats, err := client.GetStats(ctx)
if err != nil {
t.Errorf("Expected successful get stats, got error: %v", err)
}
if stats.KeyCount != 1000 {
t.Errorf("Expected KeyCount 1000, got %d", stats.KeyCount)
}
@ -354,12 +354,12 @@ func TestClientCompact(t *testing.T) {
// Create a client with the mock transport
options := DefaultClientOptions()
options.TransportType = "mock"
client, err := NewClient(options)
if err != nil {
t.Fatalf("Failed to create client: %v", err)
}
// Get the underlying mock client for test assertions
mock := client.client.(*mockClient)
mock.connected = true
@ -386,7 +386,7 @@ func TestClientCompact(t *testing.T) {
func TestRetryWithBackoff(t *testing.T) {
ctx := context.Background()
// Test successful retry
attempts := 0
err := RetryWithBackoff(
@ -404,14 +404,14 @@ func TestRetryWithBackoff(t *testing.T) {
2.0, // backoffFactor
0.1, // jitter
)
if err != nil {
t.Errorf("Expected successful retry, got error: %v", err)
}
if attempts != 3 {
t.Errorf("Expected 3 attempts, got %d", attempts)
}
// Test max retries exceeded
attempts = 0
err = RetryWithBackoff(
@ -426,14 +426,14 @@ func TestRetryWithBackoff(t *testing.T) {
2.0, // backoffFactor
0.1, // jitter
)
if err == nil {
t.Error("Expected error after max retries, got nil")
}
if attempts != 4 { // Initial + 3 retries
t.Errorf("Expected 4 attempts, got %d", attempts)
}
// Test non-retryable error
attempts = 0
err = RetryWithBackoff(
@ -448,14 +448,14 @@ func TestRetryWithBackoff(t *testing.T) {
2.0, // backoffFactor
0.1, // jitter
)
if err == nil {
t.Error("Expected non-retryable error to be returned, got nil")
}
if attempts != 1 {
t.Errorf("Expected 1 attempt for non-retryable error, got %d", attempts)
}
// Test context cancellation
attempts = 0
cancelCtx, cancel := context.WithCancel(ctx)
@ -463,7 +463,7 @@ func TestRetryWithBackoff(t *testing.T) {
time.Sleep(20 * time.Millisecond)
cancel()
}()
err = RetryWithBackoff(
cancelCtx,
func() error {
@ -476,8 +476,8 @@ func TestRetryWithBackoff(t *testing.T) {
2.0, // backoffFactor
0.1, // jitter
)
if !errors.Is(err, context.Canceled) {
t.Errorf("Expected context.Canceled error, got: %v", err)
}
}
}

View File

@ -304,4 +304,4 @@ func (s *transactionScanIterator) Close() error {
s.closed = true
s.cancelFunc()
return s.stream.Close()
}
}

View File

@ -7,33 +7,33 @@ import (
func TestDefaultClientOptions(t *testing.T) {
options := DefaultClientOptions()
// Verify the default options have sensible values
if options.Endpoint != "localhost:50051" {
t.Errorf("Expected default endpoint to be localhost:50051, got %s", options.Endpoint)
}
if options.ConnectTimeout != 5*time.Second {
t.Errorf("Expected default connect timeout to be 5s, got %s", options.ConnectTimeout)
}
if options.RequestTimeout != 10*time.Second {
t.Errorf("Expected default request timeout to be 10s, got %s", options.RequestTimeout)
}
if options.TransportType != "grpc" {
t.Errorf("Expected default transport type to be grpc, got %s", options.TransportType)
}
if options.PoolSize != 5 {
t.Errorf("Expected default pool size to be 5, got %d", options.PoolSize)
}
if options.TLSEnabled != false {
t.Errorf("Expected default TLS enabled to be false")
}
if options.MaxRetries != 3 {
t.Errorf("Expected default max retries to be 3, got %d", options.MaxRetries)
}
}
}

View File

@ -17,19 +17,19 @@ func mockClientFactory(endpoint string, options transport.TransportOptions) (tra
func TestClientCreation(t *testing.T) {
// First, register our mock transport
transport.RegisterClientTransport("mock_test", mockClientFactory)
// Create client options using our mock transport
options := DefaultClientOptions()
options.TransportType = "mock_test"
// Create a client
client, err := NewClient(options)
if err != nil {
t.Fatalf("Failed to create client: %v", err)
}
// Verify the client was created
if client == nil {
t.Fatal("Client is nil")
}
}
}

View File

@ -12,11 +12,11 @@ import (
// Transaction represents a database transaction
type Transaction struct {
client *Client
id string
readOnly bool
closed bool
mu sync.RWMutex
client *Client
id string
readOnly bool
closed bool
mu sync.RWMutex
}
// ErrTransactionClosed is returned when attempting to use a closed transaction
@ -285,4 +285,4 @@ func (tx *Transaction) Delete(ctx context.Context, key []byte) (bool, error) {
}
return deleteResp.Success, nil
}
}

View File

@ -117,4 +117,4 @@ func CalculateExponentialBackoff(
}
return backoff
}
}

View File

@ -130,14 +130,14 @@ func (h *HierarchicalIterator) Seek(target []byte) bool {
if !iter.Valid() {
continue
}
// If a newer iterator has the same key, use its value
if bytes.Equal(iter.Key(), bestKey) {
bestValue = iter.Value()
break // Since iterators are in newest-to-oldest order, we can stop at the first match
break // Since iterators are in newest-to-oldest order, we can stop at the first match
}
}
// Set the found key/value
h.key = bestKey
h.value = bestValue
@ -253,7 +253,7 @@ func (h *HierarchicalIterator) findNextUniqueKey(prevKey []byte) bool {
// Get the current key
key := iter.Key()
// If we haven't found a valid key yet, or this key is smaller than the current best key
if bestIterIdx == -1 || bytes.Compare(key, bestKey) < 0 {
// This becomes our best candidate so far
@ -271,14 +271,14 @@ func (h *HierarchicalIterator) findNextUniqueKey(prevKey []byte) bool {
if !iter.Valid() {
continue
}
// If a newer iterator has the same key, use its value
if bytes.Equal(iter.Key(), bestKey) {
bestValue = iter.Value()
break // Since iterators are in newest-to-oldest order, we can stop at the first match
break // Since iterators are in newest-to-oldest order, we can stop at the first match
}
}
// Set the found key/value
h.key = bestKey
h.value = bestValue

View File

@ -62,7 +62,7 @@ type EngineStats struct {
TxAborted atomic.Uint64
// Recovery stats
WALFilesRecovered atomic.Uint64
WALFilesRecovered atomic.Uint64
WALEntriesRecovered atomic.Uint64
WALCorruptedEntries atomic.Uint64
WALRecoveryDuration atomic.Int64 // nanoseconds
@ -524,15 +524,15 @@ func (e *Engine) flushMemTable(mem *memtable.MemTable) error {
for iter.SeekToFirst(); iter.Valid(); iter.Next() {
key := iter.Key()
keyStr := string(key) // Use as map key
// Skip keys we've already processed (including tombstones)
if _, seen := processedKeys[keyStr]; seen {
continue
}
// Mark this key as processed regardless of whether it's a value or tombstone
processedKeys[keyStr] = struct{}{}
// Only write non-tombstone entries to the SSTable
if value := iter.Value(); value != nil {
bytesWritten += uint64(len(key) + len(value))
@ -673,7 +673,7 @@ func (e *Engine) loadSSTables() error {
// recoverFromWAL recovers memtables from existing WAL files
func (e *Engine) recoverFromWAL() error {
startTime := time.Now()
// Check if WAL directory exists
if _, err := os.Stat(e.walDir); os.IsNotExist(err) {
return nil // No WAL directory, nothing to recover
@ -685,7 +685,7 @@ func (e *Engine) recoverFromWAL() error {
e.stats.ReadErrors.Add(1)
return fmt.Errorf("error listing WAL files: %w", err)
}
if len(walFiles) > 0 {
e.stats.WALFilesRecovered.Add(uint64(len(walFiles)))
}
@ -698,7 +698,7 @@ func (e *Engine) recoverFromWAL() error {
if err != nil {
// If recovery fails, let's try cleaning up WAL files
e.stats.ReadErrors.Add(1)
// Create a backup directory
backupDir := filepath.Join(e.walDir, "backup_"+time.Now().Format("20060102_150405"))
if err := os.MkdirAll(backupDir, 0755); err != nil {
@ -724,14 +724,14 @@ func (e *Engine) recoverFromWAL() error {
e.stats.WALRecoveryDuration.Store(time.Since(startTime).Nanoseconds())
return nil
}
// Update recovery statistics based on actual entries recovered
if len(walFiles) > 0 {
// Use WALDir function directly to get stats
recoveryStats, statErr := wal.ReplayWALDir(e.cfg.WALDir, func(entry *wal.Entry) error {
return nil // Just counting, not processing
})
if statErr == nil && recoveryStats != nil {
e.stats.WALEntriesRecovered.Add(recoveryStats.EntriesProcessed)
e.stats.WALCorruptedEntries.Add(recoveryStats.EntriesSkipped)
@ -767,7 +767,7 @@ func (e *Engine) recoverFromWAL() error {
// Record recovery stats
e.stats.WALRecoveryDuration.Store(time.Since(startTime).Nanoseconds())
return nil
}
@ -933,7 +933,7 @@ func (e *Engine) GetStats() map[string]interface{} {
// Add error statistics
stats["read_errors"] = e.stats.ReadErrors.Load()
stats["write_errors"] = e.stats.WriteErrors.Load()
// Add WAL recovery statistics
stats["wal_files_recovered"] = e.stats.WALFilesRecovered.Load()
stats["wal_entries_recovered"] = e.stats.WALEntriesRecovered.Load()

View File

@ -81,8 +81,8 @@ func TestEngine_SameKeyMultipleOperationsFlush(t *testing.T) {
// Simulate exactly the bug scenario from the CLI
// Add the same key multiple times with different values
key := []byte("foo")
// First add
// First add
if err := engine.Put(key, []byte("23")); err != nil {
t.Fatalf("Failed to put first value: %v", err)
}
@ -91,17 +91,17 @@ func TestEngine_SameKeyMultipleOperationsFlush(t *testing.T) {
if err := engine.Delete(key); err != nil {
t.Fatalf("Failed to delete key: %v", err)
}
// Add it again with different value
if err := engine.Put(key, []byte("42")); err != nil {
t.Fatalf("Failed to re-add key: %v", err)
}
// Add another key
if err := engine.Put([]byte("bar"), []byte("23")); err != nil {
t.Fatalf("Failed to add another key: %v", err)
}
// Add another key
if err := engine.Put([]byte("user:1"), []byte(`{"name":"John"}`)); err != nil {
t.Fatalf("Failed to add another key: %v", err)
@ -130,7 +130,7 @@ func TestEngine_SameKeyMultipleOperationsFlush(t *testing.T) {
if !bytes.Equal(value, []byte("42")) {
t.Errorf("Got incorrect value after flush. Expected: %s, Got: %s", "42", string(value))
}
value, err = engine.Get([]byte("bar"))
if err != nil {
t.Fatalf("Failed to get 'bar' after flush: %v", err)
@ -138,7 +138,7 @@ func TestEngine_SameKeyMultipleOperationsFlush(t *testing.T) {
if !bytes.Equal(value, []byte("23")) {
t.Errorf("Got incorrect value for 'bar' after flush. Expected: %s, Got: %s", "23", string(value))
}
value, err = engine.Get([]byte("user:1"))
if err != nil {
t.Fatalf("Failed to get 'user:1' after flush: %v", err)
@ -154,7 +154,7 @@ func TestEngine_DuplicateKeysFlush(t *testing.T) {
// Test with a key that will be deleted and re-added multiple times
key := []byte("foo")
// Add the key
if err := engine.Put(key, []byte("42")); err != nil {
t.Fatalf("Failed to put initial value: %v", err)

View File

@ -442,13 +442,13 @@ func (c *chainedIterator) SeekToFirst() {
// Find the iterator with the smallest key from the newest source
c.current = -1
// Find the smallest valid key
for i, iter := range c.iterators {
if !iter.Valid() {
continue
}
// If we haven't found a key yet, or this key is smaller than the current smallest
if c.current == -1 || bytes.Compare(iter.Key(), c.iterators[c.current].Key()) < 0 {
c.current = i
@ -499,13 +499,13 @@ func (c *chainedIterator) Seek(target []byte) bool {
// Find the iterator with the smallest key from the newest source
c.current = -1
// Find the smallest valid key
for i, iter := range c.iterators {
if !iter.Valid() {
continue
}
// If we haven't found a key yet, or this key is smaller than the current smallest
if c.current == -1 || bytes.Compare(iter.Key(), c.iterators[c.current].Key()) < 0 {
c.current = i
@ -537,18 +537,18 @@ func (c *chainedIterator) Next() bool {
// Find the iterator with the smallest key from the newest source
c.current = -1
// Find the smallest valid key that is greater than the current key
for i, iter := range c.iterators {
if !iter.Valid() {
continue
}
// Skip if the key is the same as the current key (we've already advanced past it)
if bytes.Equal(iter.Key(), currentKey) {
continue
}
// If we haven't found a key yet, or this key is smaller than the current smallest
if c.current == -1 || bytes.Compare(iter.Key(), c.iterators[c.current].Key()) < 0 {
c.current = i

View File

@ -109,7 +109,7 @@ func (s *KevoServiceServer) BatchWrite(ctx context.Context, req *pb.BatchWriteRe
if err != nil {
return &pb.BatchWriteResponse{Success: false}, fmt.Errorf("failed to start transaction: %w", err)
}
// Ensure we either commit or rollback
defer func() {
if err != nil {
@ -182,7 +182,7 @@ func (s *KevoServiceServer) Scan(req *pb.ScanRequest, stream pb.KevoService_Scan
count := int32(0)
// Position iterator at the first entry
iter.SeekToFirst()
// Iterate through all valid entries
for iter.Valid() {
if limit > 0 && count >= limit {
@ -199,7 +199,7 @@ func (s *KevoServiceServer) Scan(req *pb.ScanRequest, stream pb.KevoService_Scan
}
count++
}
// Move to the next entry
iter.Next()
}
@ -225,8 +225,8 @@ func (pi *prefixIterator) Next() bool {
for pi.iter.Next() {
// Check if current key has the prefix
key := pi.iter.Key()
if len(key) >= len(pi.prefix) &&
equalByteSlice(key[:len(pi.prefix)], pi.prefix) {
if len(key) >= len(pi.prefix) &&
equalByteSlice(key[:len(pi.prefix)], pi.prefix) {
return true
}
}
@ -415,7 +415,7 @@ func (s *KevoServiceServer) TxScan(req *pb.TxScanRequest, stream pb.KevoService_
count := int32(0)
// Position iterator at the first entry
iter.SeekToFirst()
// Iterate through all valid entries
for iter.Valid() {
if limit > 0 && count >= limit {
@ -432,7 +432,7 @@ func (s *KevoServiceServer) TxScan(req *pb.TxScanRequest, stream pb.KevoService_
}
count++
}
// Move to the next entry
iter.Next()
}
@ -446,17 +446,17 @@ func (s *KevoServiceServer) GetStats(ctx context.Context, req *pb.GetStatsReques
keyCount := int64(0)
sstableCount := int32(0)
memtableCount := int32(1) // At least 1 active memtable
// Create a read-only transaction to count keys
tx, err := s.engine.BeginTransaction(true)
if err != nil {
return nil, fmt.Errorf("failed to begin transaction for stats: %w", err)
}
defer tx.Rollback()
// Use an iterator to count keys
iter := tx.NewIterator()
// Count keys and estimate size
var totalSize int64
for iter.Next() {
@ -492,7 +492,7 @@ func (s *KevoServiceServer) Compact(ctx context.Context, req *pb.CompactRequest)
if err != nil {
return &pb.CompactResponse{Success: false}, err
}
// Do a dummy write to force a flush
if req.Force {
err = tx.Put([]byte("__compact_marker__"), []byte("force"))
@ -501,11 +501,11 @@ func (s *KevoServiceServer) Compact(ctx context.Context, req *pb.CompactRequest)
return &pb.CompactResponse{Success: false}, err
}
}
err = tx.Commit()
if err != nil {
return &pb.CompactResponse{Success: false}, err
}
return &pb.CompactResponse{Success: true}, nil
}
}

View File

@ -56,4 +56,4 @@ func sortDurations(durations []time.Duration) {
}
}
}
}
}

View File

@ -9,8 +9,8 @@ import (
"sync"
"time"
pb "github.com/KevoDB/kevo/proto/kevo"
"github.com/KevoDB/kevo/pkg/transport"
pb "github.com/KevoDB/kevo/proto/kevo"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
@ -19,13 +19,13 @@ import (
// GRPCClient implements the transport.Client interface for gRPC
type GRPCClient struct {
endpoint string
options transport.TransportOptions
conn *grpc.ClientConn
client pb.KevoServiceClient
status transport.TransportStatus
statusMu sync.RWMutex
metrics transport.MetricsCollector
endpoint string
options transport.TransportOptions
conn *grpc.ClientConn
client pb.KevoServiceClient
status transport.TransportStatus
statusMu sync.RWMutex
metrics transport.MetricsCollector
}
// NewGRPCClient creates a new gRPC client
@ -123,10 +123,10 @@ func (c *GRPCClient) Status() transport.TransportStatus {
func (c *GRPCClient) setStatus(connected bool, err error) {
c.statusMu.Lock()
defer c.statusMu.Unlock()
c.status.Connected = connected
c.status.LastError = err
if connected {
c.status.LastConnected = time.Now()
}
@ -141,11 +141,11 @@ func (c *GRPCClient) Send(ctx context.Context, request transport.Request) (trans
// Record request metrics
startTime := time.Now()
requestType := request.Type()
// Record bytes sent
requestPayload := request.Payload()
c.metrics.RecordSend(len(requestPayload))
var resp transport.Response
var err error
@ -182,12 +182,12 @@ func (c *GRPCClient) Send(ctx context.Context, request transport.Request) (trans
// Record metrics for the request
c.metrics.RecordRequest(requestType, startTime, err)
// If we got a response, record received bytes
if resp != nil {
c.metrics.RecordReceive(len(resp.Payload()))
}
return resp, err
}
@ -206,20 +206,20 @@ func (c *GRPCClient) handleGet(ctx context.Context, payload []byte) (transport.R
var req struct {
Key []byte `json:"key"`
}
if err := json.Unmarshal(payload, &req); err != nil {
return transport.NewErrorResponse(fmt.Errorf("invalid get request payload: %w", err)), err
}
grpcReq := &pb.GetRequest{
Key: req.Key,
}
grpcResp, err := c.client.Get(ctx, grpcReq)
if err != nil {
return transport.NewErrorResponse(err), err
}
resp := struct {
Value []byte `json:"value"`
Found bool `json:"found"`
@ -227,12 +227,12 @@ func (c *GRPCClient) handleGet(ctx context.Context, payload []byte) (transport.R
Value: grpcResp.Value,
Found: grpcResp.Found,
}
respData, err := json.Marshal(resp)
if err != nil {
return transport.NewErrorResponse(err), err
}
return transport.NewResponse(transport.TypeGet, respData, nil), nil
}
@ -242,33 +242,33 @@ func (c *GRPCClient) handlePut(ctx context.Context, payload []byte) (transport.R
Value []byte `json:"value"`
Sync bool `json:"sync"`
}
if err := json.Unmarshal(payload, &req); err != nil {
return transport.NewErrorResponse(fmt.Errorf("invalid put request payload: %w", err)), err
}
grpcReq := &pb.PutRequest{
Key: req.Key,
Value: req.Value,
Sync: req.Sync,
}
grpcResp, err := c.client.Put(ctx, grpcReq)
if err != nil {
return transport.NewErrorResponse(err), err
}
resp := struct {
Success bool `json:"success"`
}{
Success: grpcResp.Success,
}
respData, err := json.Marshal(resp)
if err != nil {
return transport.NewErrorResponse(err), err
}
return transport.NewResponse(transport.TypePut, respData, nil), nil
}
@ -277,32 +277,32 @@ func (c *GRPCClient) handleDelete(ctx context.Context, payload []byte) (transpor
Key []byte `json:"key"`
Sync bool `json:"sync"`
}
if err := json.Unmarshal(payload, &req); err != nil {
return transport.NewErrorResponse(fmt.Errorf("invalid delete request payload: %w", err)), err
}
grpcReq := &pb.DeleteRequest{
Key: req.Key,
Sync: req.Sync,
}
grpcResp, err := c.client.Delete(ctx, grpcReq)
if err != nil {
return transport.NewErrorResponse(err), err
}
resp := struct {
Success bool `json:"success"`
}{
Success: grpcResp.Success,
}
respData, err := json.Marshal(resp)
if err != nil {
return transport.NewErrorResponse(err), err
}
return transport.NewResponse(transport.TypeDelete, respData, nil), nil
}
@ -315,18 +315,18 @@ func (c *GRPCClient) handleBatchWrite(ctx context.Context, payload []byte) (tran
} `json:"operations"`
Sync bool `json:"sync"`
}
if err := json.Unmarshal(payload, &req); err != nil {
return transport.NewErrorResponse(fmt.Errorf("invalid batch write request payload: %w", err)), err
}
operations := make([]*pb.Operation, len(req.Operations))
for i, op := range req.Operations {
pbOp := &pb.Operation{
Key: op.Key,
Value: op.Value,
}
switch op.Type {
case "put":
pbOp.Type = pb.Operation_PUT
@ -335,31 +335,31 @@ func (c *GRPCClient) handleBatchWrite(ctx context.Context, payload []byte) (tran
default:
return transport.NewErrorResponse(fmt.Errorf("invalid operation type: %s", op.Type)), fmt.Errorf("invalid operation type: %s", op.Type)
}
operations[i] = pbOp
}
grpcReq := &pb.BatchWriteRequest{
Operations: operations,
Sync: req.Sync,
}
grpcResp, err := c.client.BatchWrite(ctx, grpcReq)
if err != nil {
return transport.NewErrorResponse(err), err
}
resp := struct {
Success bool `json:"success"`
}{
Success: grpcResp.Success,
}
respData, err := json.Marshal(resp)
if err != nil {
return transport.NewErrorResponse(err), err
}
return transport.NewResponse(transport.TypeBatchWrite, respData, nil), nil
}
@ -367,31 +367,31 @@ func (c *GRPCClient) handleBeginTransaction(ctx context.Context, payload []byte)
var req struct {
ReadOnly bool `json:"read_only"`
}
if err := json.Unmarshal(payload, &req); err != nil {
return transport.NewErrorResponse(fmt.Errorf("invalid begin transaction request payload: %w", err)), err
}
grpcReq := &pb.BeginTransactionRequest{
ReadOnly: req.ReadOnly,
}
grpcResp, err := c.client.BeginTransaction(ctx, grpcReq)
if err != nil {
return transport.NewErrorResponse(err), err
}
resp := struct {
TransactionID string `json:"transaction_id"`
}{
TransactionID: grpcResp.TransactionId,
}
respData, err := json.Marshal(resp)
if err != nil {
return transport.NewErrorResponse(err), err
}
return transport.NewResponse(transport.TypeBeginTx, respData, nil), nil
}
@ -399,31 +399,31 @@ func (c *GRPCClient) handleCommitTransaction(ctx context.Context, payload []byte
var req struct {
TransactionID string `json:"transaction_id"`
}
if err := json.Unmarshal(payload, &req); err != nil {
return transport.NewErrorResponse(fmt.Errorf("invalid commit transaction request payload: %w", err)), err
}
grpcReq := &pb.CommitTransactionRequest{
TransactionId: req.TransactionID,
}
grpcResp, err := c.client.CommitTransaction(ctx, grpcReq)
if err != nil {
return transport.NewErrorResponse(err), err
}
resp := struct {
Success bool `json:"success"`
}{
Success: grpcResp.Success,
}
respData, err := json.Marshal(resp)
if err != nil {
return transport.NewErrorResponse(err), err
}
return transport.NewResponse(transport.TypeCommitTx, respData, nil), nil
}
@ -431,31 +431,31 @@ func (c *GRPCClient) handleRollbackTransaction(ctx context.Context, payload []by
var req struct {
TransactionID string `json:"transaction_id"`
}
if err := json.Unmarshal(payload, &req); err != nil {
return transport.NewErrorResponse(fmt.Errorf("invalid rollback transaction request payload: %w", err)), err
}
grpcReq := &pb.RollbackTransactionRequest{
TransactionId: req.TransactionID,
}
grpcResp, err := c.client.RollbackTransaction(ctx, grpcReq)
if err != nil {
return transport.NewErrorResponse(err), err
}
resp := struct {
Success bool `json:"success"`
}{
Success: grpcResp.Success,
}
respData, err := json.Marshal(resp)
if err != nil {
return transport.NewErrorResponse(err), err
}
return transport.NewResponse(transport.TypeRollbackTx, respData, nil), nil
}
@ -464,21 +464,21 @@ func (c *GRPCClient) handleTxGet(ctx context.Context, payload []byte) (transport
TransactionID string `json:"transaction_id"`
Key []byte `json:"key"`
}
if err := json.Unmarshal(payload, &req); err != nil {
return transport.NewErrorResponse(fmt.Errorf("invalid tx get request payload: %w", err)), err
}
grpcReq := &pb.TxGetRequest{
TransactionId: req.TransactionID,
Key: req.Key,
}
grpcResp, err := c.client.TxGet(ctx, grpcReq)
if err != nil {
return transport.NewErrorResponse(err), err
}
resp := struct {
Value []byte `json:"value"`
Found bool `json:"found"`
@ -486,12 +486,12 @@ func (c *GRPCClient) handleTxGet(ctx context.Context, payload []byte) (transport
Value: grpcResp.Value,
Found: grpcResp.Found,
}
respData, err := json.Marshal(resp)
if err != nil {
return transport.NewErrorResponse(err), err
}
return transport.NewResponse(transport.TypeTxGet, respData, nil), nil
}
@ -501,33 +501,33 @@ func (c *GRPCClient) handleTxPut(ctx context.Context, payload []byte) (transport
Key []byte `json:"key"`
Value []byte `json:"value"`
}
if err := json.Unmarshal(payload, &req); err != nil {
return transport.NewErrorResponse(fmt.Errorf("invalid tx put request payload: %w", err)), err
}
grpcReq := &pb.TxPutRequest{
TransactionId: req.TransactionID,
Key: req.Key,
Value: req.Value,
}
grpcResp, err := c.client.TxPut(ctx, grpcReq)
if err != nil {
return transport.NewErrorResponse(err), err
}
resp := struct {
Success bool `json:"success"`
}{
Success: grpcResp.Success,
}
respData, err := json.Marshal(resp)
if err != nil {
return transport.NewErrorResponse(err), err
}
return transport.NewResponse(transport.TypeTxPut, respData, nil), nil
}
@ -536,43 +536,43 @@ func (c *GRPCClient) handleTxDelete(ctx context.Context, payload []byte) (transp
TransactionID string `json:"transaction_id"`
Key []byte `json:"key"`
}
if err := json.Unmarshal(payload, &req); err != nil {
return transport.NewErrorResponse(fmt.Errorf("invalid tx delete request payload: %w", err)), err
}
grpcReq := &pb.TxDeleteRequest{
TransactionId: req.TransactionID,
Key: req.Key,
}
grpcResp, err := c.client.TxDelete(ctx, grpcReq)
if err != nil {
return transport.NewErrorResponse(err), err
}
resp := struct {
Success bool `json:"success"`
}{
Success: grpcResp.Success,
}
respData, err := json.Marshal(resp)
if err != nil {
return transport.NewErrorResponse(err), err
}
return transport.NewResponse(transport.TypeTxDelete, respData, nil), nil
}
func (c *GRPCClient) handleGetStats(ctx context.Context, payload []byte) (transport.Response, error) {
grpcReq := &pb.GetStatsRequest{}
grpcResp, err := c.client.GetStats(ctx, grpcReq)
if err != nil {
return transport.NewErrorResponse(err), err
}
resp := struct {
KeyCount int64 `json:"key_count"`
StorageSize int64 `json:"storage_size"`
@ -588,12 +588,12 @@ func (c *GRPCClient) handleGetStats(ctx context.Context, payload []byte) (transp
WriteAmplification: grpcResp.WriteAmplification,
ReadAmplification: grpcResp.ReadAmplification,
}
respData, err := json.Marshal(resp)
if err != nil {
return transport.NewErrorResponse(err), err
}
return transport.NewResponse(transport.TypeGetStats, respData, nil), nil
}
@ -601,31 +601,31 @@ func (c *GRPCClient) handleCompact(ctx context.Context, payload []byte) (transpo
var req struct {
Force bool `json:"force"`
}
if err := json.Unmarshal(payload, &req); err != nil {
return transport.NewErrorResponse(fmt.Errorf("invalid compact request payload: %w", err)), err
}
grpcReq := &pb.CompactRequest{
Force: req.Force,
}
grpcResp, err := c.client.Compact(ctx, grpcReq)
if err != nil {
return transport.NewErrorResponse(err), err
}
resp := struct {
Success bool `json:"success"`
}{
Success: grpcResp.Success,
}
respData, err := json.Marshal(resp)
if err != nil {
return transport.NewErrorResponse(err), err
}
return transport.NewResponse(transport.TypeCompact, respData, nil), nil
}
@ -650,7 +650,7 @@ func (s *GRPCScanStream) Recv() (transport.Response, error) {
}
return transport.NewErrorResponse(err), err
}
// Build response based on scan type
scanResp := struct {
Key []byte `json:"key"`
@ -659,12 +659,12 @@ func (s *GRPCScanStream) Recv() (transport.Response, error) {
Key: resp.Key,
Value: resp.Value,
}
respData, err := json.Marshal(scanResp)
if err != nil {
return transport.NewErrorResponse(err), err
}
s.client.metrics.RecordReceive(len(respData))
return transport.NewResponse(s.streamType, respData, nil), nil
}
@ -672,4 +672,4 @@ func (s *GRPCScanStream) Recv() (transport.Response, error) {
func (s *GRPCScanStream) Close() error {
s.cancel()
return nil
}
}

View File

@ -7,8 +7,8 @@ import (
"sync"
"time"
pb "github.com/KevoDB/kevo/proto/kevo"
"github.com/KevoDB/kevo/pkg/transport"
pb "github.com/KevoDB/kevo/proto/kevo"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
@ -72,7 +72,7 @@ func (g *GRPCTransportManager) Serve() error {
if err := g.Start(); err != nil {
return err
}
// Block until server is stopped
<-ctx.Done()
return nil
@ -284,4 +284,4 @@ func init() {
transport.RegisterClientTransport("grpc", func(endpoint string, options transport.TransportOptions) (transport.Client, error) {
return NewGRPCClient(endpoint, options)
})
}
}

View File

@ -7,15 +7,15 @@ import (
// Simple smoke test for the gRPC transport
func TestNewGRPCTransportManager(t *testing.T) {
opts := DefaultGRPCTransportOptions()
// Override the listen address to avoid port conflicts
opts.ListenAddr = ":0" // use random available port
manager, err := NewGRPCTransportManager(opts)
if err != nil {
t.Fatalf("Failed to create transport manager: %v", err)
}
// Verify the manager was created
if manager == nil {
t.Fatal("Expected non-nil manager")
@ -60,4 +60,4 @@ func TestLoadClientTLSConfigFromStruct(t *testing.T) {
if !config.InsecureSkipVerify {
t.Fatal("Expected InsecureSkipVerify to be true")
}
}
}

View File

@ -207,4 +207,4 @@ func (m *ConnectionPoolManager) CloseAll() {
m.pools.Delete(key)
return true
})
}
}

View File

@ -7,8 +7,8 @@ import (
"sync"
"time"
pb "github.com/KevoDB/kevo/proto/kevo"
"github.com/KevoDB/kevo/pkg/transport"
pb "github.com/KevoDB/kevo/proto/kevo"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/keepalive"
@ -16,30 +16,30 @@ import (
// GRPCServer implements the transport.Server interface for gRPC
type GRPCServer struct {
address string
tlsConfig *tls.Config
server *grpc.Server
address string
tlsConfig *tls.Config
server *grpc.Server
requestHandler transport.RequestHandler
started bool
mu sync.Mutex
metrics *transport.ExtendedMetricsCollector
started bool
mu sync.Mutex
metrics *transport.ExtendedMetricsCollector
}
// NewGRPCServer creates a new gRPC server
func NewGRPCServer(address string, options transport.TransportOptions) (transport.Server, error) {
// Create server options
var serverOpts []grpc.ServerOption
// Configure TLS if enabled
if options.TLSEnabled {
tlsConfig, err := LoadServerTLSConfig(options.CertFile, options.KeyFile, options.CAFile)
if err != nil {
return nil, fmt.Errorf("failed to load TLS config: %w", err)
}
serverOpts = append(serverOpts, grpc.Creds(credentials.NewTLS(tlsConfig)))
}
// Configure keepalive parameters
kaProps := keepalive.ServerParameters{
MaxConnectionIdle: 30 * time.Minute,
@ -47,20 +47,20 @@ func NewGRPCServer(address string, options transport.TransportOptions) (transpor
Time: 15 * time.Second,
Timeout: 5 * time.Second,
}
kaPolicy := keepalive.EnforcementPolicy{
MinTime: 10 * time.Second,
PermitWithoutStream: true,
}
serverOpts = append(serverOpts,
grpc.KeepaliveParams(kaProps),
grpc.KeepaliveEnforcementPolicy(kaPolicy),
)
// Create the server
server := grpc.NewServer(serverOpts...)
return &GRPCServer{
address: address,
server: server,
@ -72,18 +72,18 @@ func NewGRPCServer(address string, options transport.TransportOptions) (transpor
func (s *GRPCServer) Start() error {
s.mu.Lock()
defer s.mu.Unlock()
if s.started {
return fmt.Errorf("server already started")
}
// Start the server in a goroutine
go func() {
if err := s.Serve(); err != nil {
fmt.Printf("gRPC server error: %v\n", err)
}
}()
s.started = true
return nil
}
@ -93,31 +93,31 @@ func (s *GRPCServer) Serve() error {
if s.requestHandler == nil {
return fmt.Errorf("no request handler set")
}
// Create the service implementation
service := &kevoServiceServer{
handler: s.requestHandler,
}
// Register the service
pb.RegisterKevoServiceServer(s.server, service)
// Start listening
listener, err := transport.CreateListener("tcp", s.address, s.tlsConfig)
if err != nil {
return fmt.Errorf("failed to listen on %s: %w", s.address, err)
}
s.metrics.ServerStarted()
// Serve requests
err = s.server.Serve(listener)
if err != nil {
s.metrics.ServerErrored()
return fmt.Errorf("failed to serve: %w", err)
}
s.metrics.ServerStopped()
return nil
}
@ -126,14 +126,14 @@ func (s *GRPCServer) Serve() error {
func (s *GRPCServer) Stop(ctx context.Context) error {
s.mu.Lock()
defer s.mu.Unlock()
if !s.started {
return nil
}
s.server.GracefulStop()
s.started = false
return nil
}
@ -141,7 +141,7 @@ func (s *GRPCServer) Stop(ctx context.Context) error {
func (s *GRPCServer) SetRequestHandler(handler transport.RequestHandler) {
s.mu.Lock()
defer s.mu.Unlock()
s.requestHandler = handler
}
@ -151,4 +151,4 @@ type kevoServiceServer struct {
handler transport.RequestHandler
}
// TODO: Implement service methods
// TODO: Implement service methods

View File

@ -92,4 +92,4 @@ func LoadClientTLSConfigFromStruct(config *TLSConfig) (*tls.Config, error) {
return &tls.Config{MinVersion: tls.VersionTLS12}, nil
}
return LoadClientTLSConfig(config.CertFile, config.KeyFile, config.CAFile, config.SkipVerify)
}
}

View File

@ -5,8 +5,8 @@ import (
"sync"
"time"
pb "github.com/KevoDB/kevo/proto/kevo"
"github.com/KevoDB/kevo/pkg/transport"
pb "github.com/KevoDB/kevo/proto/kevo"
"google.golang.org/grpc"
)
@ -30,17 +30,17 @@ func (c *GRPCConnection) Execute(fn func(interface{}) error) error {
// Create a new client from the connection
client := pb.NewKevoServiceClient(c.conn)
// Execute the provided function with the client
err := fn(client)
// Update metrics if there was an error
if err != nil {
c.mu.Lock()
c.errCount++
c.mu.Unlock()
}
return err
}
@ -58,10 +58,10 @@ func (c *GRPCConnection) Address() string {
func (c *GRPCConnection) Status() transport.ConnectionStatus {
c.mu.RLock()
defer c.mu.RUnlock()
// Check the connection state
isConnected := c.conn != nil
return transport.ConnectionStatus{
Connected: isConnected,
LastActivity: c.lastUsed,
@ -81,4 +81,4 @@ type GRPCTransportOptions struct {
MaxConnectionIdle time.Duration
MaxConnectionAge time.Duration
MaxPoolConnections int
}
}

View File

@ -6,21 +6,21 @@ import (
// Standard request/response type constants
const (
TypeGet = "get"
TypePut = "put"
TypeDelete = "delete"
TypeBatchWrite = "batch_write"
TypeScan = "scan"
TypeBeginTx = "begin_tx"
TypeCommitTx = "commit_tx"
TypeRollbackTx = "rollback_tx"
TypeTxGet = "tx_get"
TypeTxPut = "tx_put"
TypeTxDelete = "tx_delete"
TypeTxScan = "tx_scan"
TypeGetStats = "get_stats"
TypeCompact = "compact"
TypeError = "error"
TypeGet = "get"
TypePut = "put"
TypeDelete = "delete"
TypeBatchWrite = "batch_write"
TypeScan = "scan"
TypeBeginTx = "begin_tx"
TypeCommitTx = "commit_tx"
TypeRollbackTx = "rollback_tx"
TypeTxGet = "tx_get"
TypeTxPut = "tx_put"
TypeTxDelete = "tx_delete"
TypeTxScan = "tx_scan"
TypeGetStats = "get_stats"
TypeCompact = "compact"
TypeError = "error"
)
// Common errors
@ -97,4 +97,4 @@ func NewErrorResponse(err error) Response {
ResponseData: msg,
ResponseErr: err,
}
}
}

View File

@ -9,12 +9,12 @@ func TestBasicRequest(t *testing.T) {
// Test creating a request
payload := []byte("test payload")
req := NewRequest(TypeGet, payload)
// Test Type method
if req.Type() != TypeGet {
t.Errorf("Expected type %s, got %s", TypeGet, req.Type())
}
// Test Payload method
if string(req.Payload()) != string(payload) {
t.Errorf("Expected payload %s, got %s", string(payload), string(req.Payload()))
@ -25,26 +25,26 @@ func TestBasicResponse(t *testing.T) {
// Test creating a response with no error
payload := []byte("test response")
resp := NewResponse(TypeGet, payload, nil)
// Test Type method
if resp.Type() != TypeGet {
t.Errorf("Expected type %s, got %s", TypeGet, resp.Type())
}
// Test Payload method
if string(resp.Payload()) != string(payload) {
t.Errorf("Expected payload %s, got %s", string(payload), string(resp.Payload()))
}
// Test Error method
if resp.Error() != nil {
t.Errorf("Expected nil error, got %v", resp.Error())
}
// Test creating a response with an error
testErr := errors.New("test error")
resp = NewResponse(TypeGet, payload, testErr)
if resp.Error() != testErr {
t.Errorf("Expected error %v, got %v", testErr, resp.Error())
}
@ -54,34 +54,34 @@ func TestNewErrorResponse(t *testing.T) {
// Test creating an error response
testErr := errors.New("test error")
resp := NewErrorResponse(testErr)
// Test Type method
if resp.Type() != TypeError {
t.Errorf("Expected type %s, got %s", TypeError, resp.Type())
}
// Test Payload method - should contain error message
if string(resp.Payload()) != testErr.Error() {
t.Errorf("Expected payload %s, got %s", testErr.Error(), string(resp.Payload()))
}
// Test Error method
if resp.Error() != testErr {
t.Errorf("Expected error %v, got %v", testErr, resp.Error())
}
// Test with nil error
resp = NewErrorResponse(nil)
if resp.Type() != TypeError {
t.Errorf("Expected type %s, got %s", TypeError, resp.Type())
}
if len(resp.Payload()) != 0 {
t.Errorf("Expected empty payload, got %s", string(resp.Payload()))
}
if resp.Error() != nil {
t.Errorf("Expected nil error, got %v", resp.Error())
}
}
}

View File

@ -50,7 +50,7 @@ type TransportStatus struct {
type Request interface {
// Type returns the type of request
Type() string
// Payload returns the payload of the request
Payload() []byte
}
@ -59,10 +59,10 @@ type Request interface {
type Response interface {
// Type returns the type of response
Type() string
// Payload returns the payload of the response
Payload() []byte
// Error returns any error associated with the response
Error() error
}
@ -71,10 +71,10 @@ type Response interface {
type Stream interface {
// Send sends a request over the stream
Send(request Request) error
// Recv receives a response from the stream
Recv() (Response, error)
// Close closes the stream
Close() error
}
@ -83,19 +83,19 @@ type Stream interface {
type Client interface {
// Connect establishes a connection to the server
Connect(ctx context.Context) error
// Close closes the connection
Close() error
// IsConnected returns whether the client is connected
IsConnected() bool
// Status returns the current status of the connection
Status() TransportStatus
// Send sends a request and waits for a response
Send(ctx context.Context, request Request) (Response, error)
// Stream opens a bidirectional stream
Stream(ctx context.Context) (Stream, error)
}
@ -104,7 +104,7 @@ type Client interface {
type RequestHandler interface {
// HandleRequest processes a request and returns a response
HandleRequest(ctx context.Context, request Request) (Response, error)
// HandleStream processes a bidirectional stream
HandleStream(stream Stream) error
}
@ -113,13 +113,13 @@ type RequestHandler interface {
type Server interface {
// Start starts the server and returns immediately
Start() error
// Serve starts the server and blocks until it's stopped
Serve() error
// Stop stops the server gracefully
Stop(ctx context.Context) error
// SetRequestHandler sets the handler for incoming requests
SetRequestHandler(handler RequestHandler)
}
@ -134,16 +134,16 @@ type ServerFactory func(address string, options TransportOptions) (Server, error
type Registry interface {
// RegisterClient adds a new client implementation to the registry
RegisterClient(name string, factory ClientFactory)
// RegisterServer adds a new server implementation to the registry
RegisterServer(name string, factory ServerFactory)
// CreateClient instantiates a client by name
CreateClient(name, endpoint string, options TransportOptions) (Client, error)
// CreateServer instantiates a server by name
CreateServer(name, address string, options TransportOptions) (Server, error)
// ListTransports returns all available transport names
ListTransports() []string
}
}

View File

@ -10,16 +10,16 @@ import (
type MetricsCollector interface {
// RecordRequest records metrics for a request
RecordRequest(requestType string, startTime time.Time, err error)
// RecordSend records metrics for bytes sent
RecordSend(bytes int)
// RecordReceive records metrics for bytes received
RecordReceive(bytes int)
// RecordConnection records a connection event
RecordConnection(successful bool)
// GetMetrics returns the current metrics
GetMetrics() Metrics
}
@ -46,7 +46,7 @@ type BasicMetricsCollector struct {
bytesReceived uint64
connections uint64
connectionFailures uint64
// Track average latency and count for each request type
avgLatencyByType map[string]time.Duration
requestCountByType map[string]uint64
@ -63,26 +63,26 @@ func NewMetricsCollector() MetricsCollector {
// RecordRequest records metrics for a request
func (c *BasicMetricsCollector) RecordRequest(requestType string, startTime time.Time, err error) {
atomic.AddUint64(&c.totalRequests, 1)
if err == nil {
atomic.AddUint64(&c.successfulRequests, 1)
} else {
atomic.AddUint64(&c.failedRequests, 1)
}
// Update average latency for request type
latency := time.Since(startTime)
c.mu.Lock()
defer c.mu.Unlock()
currentAvg, exists := c.avgLatencyByType[requestType]
currentCount, _ := c.requestCountByType[requestType]
if exists {
// Update running average - the common case for better branch prediction
// new_avg = (old_avg * count + new_value) / (count + 1)
totalDuration := currentAvg * time.Duration(currentCount) + latency
totalDuration := currentAvg*time.Duration(currentCount) + latency
newCount := currentCount + 1
c.avgLatencyByType[requestType] = totalDuration / time.Duration(newCount)
c.requestCountByType[requestType] = newCount
@ -116,13 +116,13 @@ func (c *BasicMetricsCollector) RecordConnection(successful bool) {
func (c *BasicMetricsCollector) GetMetrics() Metrics {
c.mu.RLock()
defer c.mu.RUnlock()
// Create a copy of the average latency map
avgLatencyByType := make(map[string]time.Duration, len(c.avgLatencyByType))
for k, v := range c.avgLatencyByType {
avgLatencyByType[k] = v
}
return Metrics{
TotalRequests: atomic.LoadUint64(&c.totalRequests),
SuccessfulRequests: atomic.LoadUint64(&c.successfulRequests),
@ -133,4 +133,4 @@ func (c *BasicMetricsCollector) GetMetrics() Metrics {
ConnectionFailures: atomic.LoadUint64(&c.connectionFailures),
AvgLatencyByType: avgLatencyByType,
}
}
}

View File

@ -9,9 +9,9 @@ import (
// Metrics struct extensions for server metrics
type ServerMetrics struct {
Metrics
ServerStarted uint64
ServerErrored uint64
ServerStopped uint64
ServerStarted uint64
ServerErrored uint64
ServerStopped uint64
}
// Connection represents a connection to a remote endpoint
@ -31,11 +31,11 @@ type Connection interface {
// ConnectionStatus represents the status of a connection
type ConnectionStatus struct {
Connected bool
LastActivity time.Time
ErrorCount int
RequestCount int
LatencyAvg time.Duration
Connected bool
LastActivity time.Time
ErrorCount int
RequestCount int
LatencyAvg time.Duration
}
// TransportManager is an interface for managing transport layer operations

View File

@ -8,7 +8,7 @@ import (
func TestBasicMetricsCollector(t *testing.T) {
collector := NewMetricsCollector()
// Test initial state
metrics := collector.GetMetrics()
if metrics.TotalRequests != 0 ||
@ -21,11 +21,11 @@ func TestBasicMetricsCollector(t *testing.T) {
len(metrics.AvgLatencyByType) != 0 {
t.Errorf("Initial metrics not initialized correctly: %+v", metrics)
}
// Test recording successful request
startTime := time.Now().Add(-100 * time.Millisecond) // Simulate 100ms request
collector.RecordRequest("get", startTime, nil)
metrics = collector.GetMetrics()
if metrics.TotalRequests != 1 {
t.Errorf("Expected TotalRequests to be 1, got %d", metrics.TotalRequests)
@ -36,18 +36,18 @@ func TestBasicMetricsCollector(t *testing.T) {
if metrics.FailedRequests != 0 {
t.Errorf("Expected FailedRequests to be 0, got %d", metrics.FailedRequests)
}
// Check average latency
if avgLatency, exists := metrics.AvgLatencyByType["get"]; !exists {
t.Error("Expected 'get' latency to exist")
} else if avgLatency < 100*time.Millisecond {
t.Errorf("Expected latency to be at least 100ms, got %v", avgLatency)
}
// Test recording failed request
startTime = time.Now().Add(-200 * time.Millisecond) // Simulate 200ms request
collector.RecordRequest("get", startTime, errors.New("test error"))
metrics = collector.GetMetrics()
if metrics.TotalRequests != 2 {
t.Errorf("Expected TotalRequests to be 2, got %d", metrics.TotalRequests)
@ -58,26 +58,26 @@ func TestBasicMetricsCollector(t *testing.T) {
if metrics.FailedRequests != 1 {
t.Errorf("Expected FailedRequests to be 1, got %d", metrics.FailedRequests)
}
// Test average latency calculation for multiple requests
startTime = time.Now().Add(-300 * time.Millisecond)
collector.RecordRequest("put", startTime, nil)
startTime = time.Now().Add(-500 * time.Millisecond)
collector.RecordRequest("put", startTime, nil)
metrics = collector.GetMetrics()
avgPutLatency := metrics.AvgLatencyByType["put"]
// Expected avg is around (300ms + 500ms) / 2 = 400ms
if avgPutLatency < 390*time.Millisecond || avgPutLatency > 410*time.Millisecond {
t.Errorf("Expected average 'put' latency to be around 400ms, got %v", avgPutLatency)
}
// Test byte tracking
collector.RecordSend(1000)
collector.RecordReceive(2000)
metrics = collector.GetMetrics()
if metrics.BytesSent != 1000 {
t.Errorf("Expected BytesSent to be 1000, got %d", metrics.BytesSent)
@ -85,12 +85,12 @@ func TestBasicMetricsCollector(t *testing.T) {
if metrics.BytesReceived != 2000 {
t.Errorf("Expected BytesReceived to be 2000, got %d", metrics.BytesReceived)
}
// Test connection tracking
collector.RecordConnection(true)
collector.RecordConnection(false)
collector.RecordConnection(true)
metrics = collector.GetMetrics()
if metrics.Connections != 2 {
t.Errorf("Expected Connections to be 2, got %d", metrics.Connections)
@ -98,4 +98,4 @@ func TestBasicMetricsCollector(t *testing.T) {
if metrics.ConnectionFailures != 1 {
t.Errorf("Expected ConnectionFailures to be 1, got %d", metrics.ConnectionFailures)
}
}
}

View File

@ -12,11 +12,11 @@ func CreateListener(network, address string, tlsConfig *tls.Config) (net.Listene
if err != nil {
return nil, err
}
// If TLS is configured, wrap the listener
if tlsConfig != nil {
listener = tls.NewListener(listener, tlsConfig)
}
return listener, nil
}
}

View File

@ -7,7 +7,7 @@ import (
// registry implements the Registry interface
type registry struct {
mu sync.RWMutex
mu sync.RWMutex
clientFactories map[string]ClientFactory
serverFactories map[string]ServerFactory
}
@ -111,4 +111,4 @@ func GetServer(name, address string, options TransportOptions) (Server, error) {
// AvailableTransports lists all available transports in the default registry
func AvailableTransports() []string {
return DefaultRegistry.ListTransports()
}
}

View File

@ -97,17 +97,17 @@ func mockServerFactory(address string, options TransportOptions) (Server, error)
// TestRegistry tests the transport registry
func TestRegistry(t *testing.T) {
registry := NewRegistry()
// Register transports
registry.RegisterClient("mock", mockClientFactory)
registry.RegisterServer("mock", mockServerFactory)
// Test listing transports
transports := registry.ListTransports()
if len(transports) != 1 || transports[0] != "mock" {
t.Errorf("Expected [mock], got %v", transports)
}
// Test creating client
client, err := registry.CreateClient("mock", "localhost:8080", TransportOptions{
Timeout: 5 * time.Second,
@ -115,21 +115,21 @@ func TestRegistry(t *testing.T) {
if err != nil {
t.Fatalf("Failed to create client: %v", err)
}
// Test client methods
if client.IsConnected() {
t.Error("Expected client to be disconnected initially")
}
err = client.Connect(context.Background())
if err != nil {
t.Fatalf("Failed to connect: %v", err)
}
if !client.IsConnected() {
t.Error("Expected client to be connected after Connect()")
}
// Test server creation
server, err := registry.CreateServer("mock", "localhost:8080", TransportOptions{
Timeout: 5 * time.Second,
@ -137,26 +137,26 @@ func TestRegistry(t *testing.T) {
if err != nil {
t.Fatalf("Failed to create server: %v", err)
}
// Test server methods
err = server.Start()
if err != nil {
t.Fatalf("Failed to start server: %v", err)
}
mockServer := server.(*mockServer)
if !mockServer.started {
t.Error("Expected server to be started")
}
// Test non-existent transport
_, err = registry.CreateClient("nonexistent", "", TransportOptions{})
if err == nil {
t.Error("Expected error creating non-existent client")
}
_, err = registry.CreateServer("nonexistent", "", TransportOptions{})
if err == nil {
t.Error("Expected error creating non-existent server")
}
}
}

View File

@ -366,7 +366,7 @@ func ReplayWALDir(dir string, handler EntryHandler) (*RecoveryStats, error) {
// Track overall recovery stats
totalStats := NewRecoveryStats()
// Track number of files processed successfully
successfulFiles := 0
var lastErr error
@ -397,7 +397,7 @@ func ReplayWALDir(dir string, handler EntryHandler) (*RecoveryStats, error) {
// Add stats from this file to our totals
totalStats.EntriesProcessed += fileStats.EntriesProcessed
totalStats.EntriesSkipped += fileStats.EntriesSkipped
successfulFiles++
}