task/task.go
Jeremy Tregunna 990c12b3c1
All checks were successful
Go Tests / Run Tests (1.24.2) (push) Successful in 18s
feat: add context parameter to Task.Execute method
2025-04-18 15:46:43 -06:00

375 lines
10 KiB
Go

package task
import (
"context"
"log"
"sync"
"time"
)
type Task interface {
ID() string
Execute(ctx context.Context) error
Dependencies() []string
}
type TaskState int
const (
TaskStateReady TaskState = iota
TaskStatePending
TaskStateRunning
TaskStateCompleted
TaskStateFailed
)
// TaskStatus tracks the current state and execution history of a task
type TaskStatus struct {
taskID string
state TaskState
lastRunTime time.Time
error error
}
// ScheduledTask represents a task with its execution interval
type ScheduledTask struct {
task Task
interval time.Duration
lastRunTime time.Time
}
// TaskExecutor manages and executes scheduled tasks with dependencies
type TaskExecutor struct {
tasks []*ScheduledTask // All registered tasks
completedTasks map[string]*TaskStatus // Status tracking for all tasks
completedTasksMutex sync.RWMutex // For thread-safe status access
taskRegister map[string]*ScheduledTask // Quick lookup of tasks by ID
rateLimit chan struct{} // Semaphore for concurrent execution limit
taskChan chan *ScheduledTask // Channel for registering new tasks
readyQueue chan *ScheduledTask // Channel for tasks ready to execute
runOnceFlag map[string]bool // Flag to enforce single execution for tests
runOnceMutex sync.RWMutex // Mutex for runOnceFlag
}
// NewTaskExecutor creates a new task executor with the specified rate limit
func NewTaskExecutor(rateLimit int) *TaskExecutor {
return &TaskExecutor{
tasks: make([]*ScheduledTask, 0),
completedTasks: make(map[string]*TaskStatus),
taskRegister: make(map[string]*ScheduledTask),
rateLimit: make(chan struct{}, rateLimit),
taskChan: make(chan *ScheduledTask, 100),
readyQueue: make(chan *ScheduledTask, 100),
runOnceFlag: make(map[string]bool),
}
}
// Len returns the number of registered tasks
func (te *TaskExecutor) Len() int {
return len(te.tasks)
}
// AddTask adds a new task to the execution queue
func (te *TaskExecutor) AddTask(task Task, interval time.Duration) {
log.Printf("Adding task %T with interval %v\n", task, interval)
if interval < 0 {
log.Printf("Task %s has a negative interval, ignoring", task.ID())
return
}
st := &ScheduledTask{
task: task,
interval: interval,
lastRunTime: time.Now(),
}
// Register the task immediately in the task register
taskID := task.ID()
// Check for duplicate task ID
te.completedTasksMutex.Lock()
if _, exists := te.taskRegister[taskID]; exists {
log.Printf("Warning: Task with ID %s already exists, overwriting", taskID)
}
// Validate dependencies
missingDeps := []string{}
for _, depID := range task.Dependencies() {
// Check for self-dependency
if depID == taskID {
te.completedTasksMutex.Unlock()
log.Printf("Error: Task %s depends on itself, ignoring", taskID)
return
}
// Check that the dependency exists in the system
if _, exists := te.taskRegister[depID]; !exists {
missingDeps = append(missingDeps, depID)
}
}
// Dependencies aren't required, so we silently allow missing dependencies
// Check for circular dependencies (basic detection)
visited := make(map[string]bool)
if te.hasCircularDependency(taskID, task.Dependencies(), visited) {
te.completedTasksMutex.Unlock()
log.Printf("Error: Task %s has circular dependencies, ignoring", taskID)
return
}
te.taskRegister[taskID] = st
// Initialize the task status if it doesn't exist
if _, exists := te.completedTasks[taskID]; !exists {
te.completedTasks[taskID] = &TaskStatus{
taskID: taskID,
state: TaskStateReady,
}
}
te.completedTasksMutex.Unlock()
// Queue the task for processing
select {
case te.taskChan <- st:
log.Printf("Task %T queued up with interval %v\n", task, interval)
default:
log.Printf("Failed to add task %T with interval %v, channel full\n", task, interval)
}
}
// Start initiates task processing and scheduling
func (te *TaskExecutor) Start() {
// Launch the task processor goroutine to handle registration
go func() {
for {
select {
case st := <-te.taskChan:
te.tasks = append(te.tasks, st)
if st.interval == 0 {
log.Printf("Task %s has an interval of 0, queuing for immediate execution\n", st.task.ID())
st.lastRunTime = time.Now().Add(-24 * time.Hour) // Ensure it's ready to run
} else if st.interval < 0 {
log.Printf("Task %s has a negative interval, nonsensical, ignoring", st.task.ID())
}
}
}
}()
// Launch dependency checker goroutine
go func() {
ticker := time.NewTicker(100 * time.Millisecond)
defer ticker.Stop()
for range ticker.C {
te.checkDependenciesAndQueue()
}
}()
// Launch task executor goroutine - using a single goroutine here
// helps ensure more predictable execution order for dependencies
go func() {
for st := range te.readyQueue {
te.executeTask(st)
}
}()
// Process any tasks already in the channel before returning
// This ensures that when Start() returns, all queued tasks are in the tasks slice
time.Sleep(10 * time.Millisecond)
}
// shouldRun determines if a task is ready to run based on its interval
func (te *TaskExecutor) shouldRun(st *ScheduledTask, t time.Time) bool {
return t.Sub(st.lastRunTime) >= st.interval
}
// hasCircularDependency checks if a task has circular dependencies
// The completedTasksMutex must be held when calling this function
func (te *TaskExecutor) hasCircularDependency(taskID string, dependencies []string, visited map[string]bool) bool {
// If we've already visited this task in current path, we have a cycle
if visited[taskID] {
return true
}
// Mark this task as visited
visited[taskID] = true
// Check each dependency
for _, depID := range dependencies {
// Skip if already detected cycle
if visited[depID] {
return true
}
// Get the dependency task
depTask, exists := te.taskRegister[depID]
if !exists {
// Dependency doesn't exist yet, can't determine circular status
continue
}
// Recursively check the dependency's dependencies
if te.hasCircularDependency(depID, depTask.task.Dependencies(), visited) {
return true
}
}
// Remove this task from visited (backtrack)
visited[taskID] = false
return false
}
// checkDependenciesComplete verifies if all dependencies for a task are completed successfully
// Must be called with the completedTasksMutex already acquired in read mode
func (te *TaskExecutor) checkDependenciesComplete(task Task) bool {
for _, depID := range task.Dependencies() {
status, exists := te.completedTasks[depID]
if !exists || status.state != TaskStateCompleted {
return false
}
}
return true
}
// isTaskRunnable checks if a task is in a state where it can be executed
// Must be called with the completedTasksMutex already acquired in read mode
func (te *TaskExecutor) isTaskRunnable(taskID string) bool {
status, exists := te.completedTasks[taskID]
return !exists || (status.state != TaskStatePending && status.state != TaskStateRunning)
}
// checkDependenciesAndQueue evaluates all tasks and queues those that are ready to run
func (te *TaskExecutor) checkDependenciesAndQueue() {
now := time.Now()
// First pass: find all eligible tasks with a read lock
eligibleTasks := make([]*ScheduledTask, 0)
for _, st := range te.tasks {
if !te.shouldRun(st, now) {
continue
}
taskID := st.task.ID()
// Check if the task has already run once - this helps with test stability
te.runOnceMutex.RLock()
if te.runOnceFlag[taskID] && st.interval == 0 {
te.runOnceMutex.RUnlock()
continue
}
te.runOnceMutex.RUnlock()
// Check task state and dependencies
te.completedTasksMutex.RLock()
canRun := te.isTaskRunnable(taskID) && te.checkDependenciesComplete(st.task)
te.completedTasksMutex.RUnlock()
if canRun {
eligibleTasks = append(eligibleTasks, st)
}
}
// Second pass: queue eligible tasks with a write lock one by one
for _, st := range eligibleTasks {
taskID := st.task.ID()
// Get exclusive access for state update
te.completedTasksMutex.Lock()
// Check again to make sure the task is still eligible (state hasn't changed)
if !te.isTaskRunnable(taskID) {
te.completedTasksMutex.Unlock()
continue
}
// Mark as pending
te.completedTasks[taskID] = &TaskStatus{
taskID: taskID,
state: TaskStatePending,
lastRunTime: now,
}
te.completedTasksMutex.Unlock()
// Queue the task for execution
select {
case te.readyQueue <- st:
st.lastRunTime = now
log.Printf("Task %s queued for execution", taskID)
// Mark the task as having run once for test stability
if st.interval == 0 {
te.runOnceMutex.Lock()
te.runOnceFlag[taskID] = true
te.runOnceMutex.Unlock()
}
default:
// Queue full, revert the task state
te.completedTasksMutex.Lock()
te.completedTasks[taskID] = &TaskStatus{
taskID: taskID,
state: TaskStateReady,
}
te.completedTasksMutex.Unlock()
log.Printf("Task %s queue attempt failed, queue full", taskID)
}
}
}
// executeTask performs the actual execution of a task
func (te *TaskExecutor) executeTask(st *ScheduledTask) {
taskID := st.task.ID()
// Acquire rate limit token
te.rateLimit <- struct{}{}
defer func() {
<-te.rateLimit
}()
// Verify task is still in pending state before executing
te.completedTasksMutex.Lock()
status, exists := te.completedTasks[taskID]
if !exists || status.state != TaskStatePending {
te.completedTasksMutex.Unlock()
log.Printf("Task %s skipped execution - state changed", taskID)
return
}
// Update state to running
te.completedTasks[taskID] = &TaskStatus{
taskID: taskID,
state: TaskStateRunning,
lastRunTime: time.Now(),
}
te.completedTasksMutex.Unlock()
// Execute the task with a background context
ctx := context.Background()
err := st.task.Execute(ctx)
// Update final task status
te.completedTasksMutex.Lock()
finalState := TaskStateCompleted
if err != nil {
finalState = TaskStateFailed
}
te.completedTasks[taskID] = &TaskStatus{
taskID: taskID,
state: finalState,
error: err,
lastRunTime: time.Now(),
}
te.completedTasksMutex.Unlock()
// Log the result
if err != nil {
log.Printf("Task %s failed: %v", taskID, err)
} else {
log.Printf("Task %s completed successfully", taskID)
}
}