kevo/pkg/replication/state_test.go

187 lines
5.3 KiB
Go

package replication
import (
"errors"
"testing"
"time"
)
func TestStateTracker(t *testing.T) {
// Create a new state tracker
tracker := NewStateTracker()
// Test initial state
if tracker.GetState() != StateConnecting {
t.Errorf("Expected initial state to be StateConnecting, got %s", tracker.GetState())
}
// Test valid state transition
err := tracker.SetState(StateStreamingEntries)
if err != nil {
t.Errorf("Unexpected error for valid transition: %v", err)
}
if tracker.GetState() != StateStreamingEntries {
t.Errorf("Expected state to be StateStreamingEntries, got %s", tracker.GetState())
}
// Test invalid state transition
err = tracker.SetState(StateAcknowledging)
if err == nil {
t.Errorf("Expected error for invalid transition, got nil")
}
if !errors.Is(err, ErrInvalidStateTransition) {
t.Errorf("Expected ErrInvalidStateTransition, got %v", err)
}
if tracker.GetState() != StateStreamingEntries {
t.Errorf("State should not change after invalid transition, got %s", tracker.GetState())
}
// Test complete valid path
validPath := []ReplicaState{
StateApplyingEntries,
StateFsyncPending,
StateAcknowledging,
StateWaitingForData,
StateStreamingEntries,
StateApplyingEntries,
StateFsyncPending,
StateAcknowledging,
StateStreamingEntries,
}
for i, state := range validPath {
err := tracker.SetState(state)
if err != nil {
t.Errorf("Unexpected error at step %d: %v", i, err)
}
if tracker.GetState() != state {
t.Errorf("Expected state to be %s at step %d, got %s", state, i, tracker.GetState())
}
}
// Test error state transition
err = tracker.SetError(errors.New("test error"))
if err != nil {
t.Errorf("Unexpected error setting error state: %v", err)
}
if tracker.GetState() != StateError {
t.Errorf("Expected state to be StateError, got %s", tracker.GetState())
}
if tracker.GetError() == nil {
t.Errorf("Expected error to be set, got nil")
}
if tracker.GetError().Error() != "test error" {
t.Errorf("Expected error message 'test error', got '%s'", tracker.GetError().Error())
}
// Test recovery from error
err = tracker.SetState(StateConnecting)
if err != nil {
t.Errorf("Unexpected error recovering from error state: %v", err)
}
if tracker.GetState() != StateConnecting {
t.Errorf("Expected state to be StateConnecting after recovery, got %s", tracker.GetState())
}
// Test transitions tracking
transitions := tracker.GetTransitions()
// Count the actual transitions we made
transitionCount := len(validPath) + 1 // +1 for error state
if len(transitions) < transitionCount {
t.Errorf("Expected at least %d transitions, got %d", transitionCount, len(transitions))
}
// Test reset
tracker.ResetState()
if tracker.GetState() != StateConnecting {
t.Errorf("Expected state to be StateConnecting after reset, got %s", tracker.GetState())
}
if tracker.GetError() != nil {
t.Errorf("Expected error to be nil after reset, got %v", tracker.GetError())
}
if len(tracker.GetTransitions()) != 0 {
t.Errorf("Expected 0 transitions after reset, got %d", len(tracker.GetTransitions()))
}
}
func TestStateDuration(t *testing.T) {
// Create a new state tracker
tracker := NewStateTracker()
// Initial state duration should be small
initialDuration := tracker.GetStateDuration()
if initialDuration > 100*time.Millisecond {
t.Errorf("Initial state duration too large: %v", initialDuration)
}
// Wait a bit
time.Sleep(200 * time.Millisecond)
// Duration should have increased
afterWaitDuration := tracker.GetStateDuration()
if afterWaitDuration < 200*time.Millisecond {
t.Errorf("Duration did not increase as expected: %v", afterWaitDuration)
}
// Transition to a new state
err := tracker.SetState(StateStreamingEntries)
if err != nil {
t.Fatalf("Unexpected error transitioning states: %v", err)
}
// New state duration should be small again
newStateDuration := tracker.GetStateDuration()
if newStateDuration > 100*time.Millisecond {
t.Errorf("New state duration too large: %v", newStateDuration)
}
}
func TestStateStringRepresentation(t *testing.T) {
testCases := []struct {
state ReplicaState
expected string
}{
{StateConnecting, "CONNECTING"},
{StateStreamingEntries, "STREAMING_ENTRIES"},
{StateApplyingEntries, "APPLYING_ENTRIES"},
{StateFsyncPending, "FSYNC_PENDING"},
{StateAcknowledging, "ACKNOWLEDGING"},
{StateWaitingForData, "WAITING_FOR_DATA"},
{StateError, "ERROR"},
{ReplicaState(999), "UNKNOWN(999)"},
}
for _, tc := range testCases {
t.Run(tc.expected, func(t *testing.T) {
if tc.state.String() != tc.expected {
t.Errorf("Expected state string %s, got %s", tc.expected, tc.state.String())
}
})
}
}
func TestGetStateString(t *testing.T) {
tracker := NewStateTracker()
// Test initial state string
if tracker.GetStateString() != "CONNECTING" {
t.Errorf("Expected state string CONNECTING, got %s", tracker.GetStateString())
}
// Change state and test string
err := tracker.SetState(StateStreamingEntries)
if err != nil {
t.Fatalf("Unexpected error transitioning states: %v", err)
}
if tracker.GetStateString() != "STREAMING_ENTRIES" {
t.Errorf("Expected state string STREAMING_ENTRIES, got %s", tracker.GetStateString())
}
// Set error state and test string
tracker.SetError(errors.New("test error"))
if tracker.GetStateString() != "ERROR" {
t.Errorf("Expected state string ERROR, got %s", tracker.GetStateString())
}
}