Skip to content

Commit

Permalink
TaskRunner config
Browse files Browse the repository at this point in the history
  • Loading branch information
lukasbindreiter committed May 13, 2024
1 parent 0558ab2 commit 116ff5d
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 54 deletions.
6 changes: 3 additions & 3 deletions workflows/v1/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,13 @@ func (o *connectOptions) applyToClient(config *clientConfig) {
config.connectOptions = append(config.connectOptions, o.options...)
}

func newClientConfig(options []ClientOption) clientConfig {
cfg := clientConfig{
func newClientConfig(options []ClientOption) *clientConfig {
cfg := &clientConfig{
httpClient: grpc.RetryHTTPClient(),
url: "https://api.tilebox.com",
}
for _, opt := range options {
opt.applyToClient(&cfg)
opt.applyToClient(cfg)
}
return cfg
}
Expand Down
132 changes: 101 additions & 31 deletions workflows/v1/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,20 +35,90 @@ const DefaultClusterSlug = "testing-4qgCk4qHH85qR7"
const pollingInterval = 5 * time.Second
const jitterInterval = 5 * time.Second

type taskRunnerConfig struct {
clusterSlug string
tracer trace.Tracer
log *slog.Logger
}

type TaskRunnerOption interface {
applyToRunner(config *taskRunnerConfig)
}

type withClusterSlogOption struct {
clusterSlug string
}

func WithCluster(clusterSlug string) TaskRunnerOption {
return &withClusterSlogOption{clusterSlug}
}

func (o *withClusterSlogOption) applyToRunner(config *taskRunnerConfig) {
config.clusterSlug = o.clusterSlug
}

type withTracerOption struct {
tracer trace.Tracer
}

func WithTracer(tracer trace.Tracer) TaskRunnerOption {
return &withTracerOption{tracer}
}

func (o *withTracerOption) applyToRunner(config *taskRunnerConfig) {
config.tracer = o.tracer
}

type withLoggerOption struct {
logger *slog.Logger
}

func WithLogger(logger *slog.Logger) TaskRunnerOption {
return &withLoggerOption{logger}
}

func (o *withLoggerOption) applyToRunner(config *taskRunnerConfig) {
config.log = o.logger
}

func newTaskRunnerConfig(options []TaskRunnerOption) (*taskRunnerConfig, error) {
cfg := &taskRunnerConfig{
tracer: otel.Tracer("tilebox.com/observability"),
log: slog.Default(),
}
for _, opt := range options {
opt.applyToRunner(cfg)
}

if cfg.clusterSlug == "" {
return nil, errors.New("cluster slug is required")
}

return cfg, nil
}

type TaskRunner struct {
Client workflowsv1connect.TaskServiceClient
client workflowsv1connect.TaskServiceClient
taskDefinitions map[taskIdentifier]ExecutableTask
tracer trace.Tracer
log *slog.Logger

cluster string
tracer trace.Tracer
logger *slog.Logger
}

func NewTaskRunner(client workflowsv1connect.TaskServiceClient) *TaskRunner {
func NewTaskRunner(client workflowsv1connect.TaskServiceClient, options ...TaskRunnerOption) (*TaskRunner, error) {
cfg, err := newTaskRunnerConfig(options)
if err != nil {
return nil, err
}
return &TaskRunner{
Client: client,
client: client,
taskDefinitions: make(map[taskIdentifier]ExecutableTask),
tracer: otel.Tracer("tilebox.com/observability"),
log: slog.Default(),
}

cluster: cfg.clusterSlug,
tracer: cfg.tracer,
logger: cfg.log,
}, nil
}

func (t *TaskRunner) RegisterTask(task ExecutableTask) error {
Expand Down Expand Up @@ -117,11 +187,11 @@ func (t *TaskRunner) Run(ctx context.Context) {

for {
if task == nil { // if we don't have a task, let's try work-stealing one
taskResponse, err := t.Client.NextTask(ctx, connect.NewRequest(&workflowsv1.NextTaskRequest{
NextTaskToRun: &workflowsv1.NextTaskToRun{ClusterSlug: DefaultClusterSlug, Identifiers: identifiers},
taskResponse, err := t.client.NextTask(ctx, connect.NewRequest(&workflowsv1.NextTaskRequest{
NextTaskToRun: &workflowsv1.NextTaskToRun{ClusterSlug: t.cluster, Identifiers: identifiers},
}))
if err != nil {
t.log.ErrorContext(ctx, "failed to work-steal a task", "error", err)
t.logger.ErrorContext(ctx, "failed to work-steal a task", "error", err)
// return // should we even try again, or just stop here?
} else {
task = taskResponse.Msg.GetNextTask()
Expand All @@ -130,7 +200,7 @@ func (t *TaskRunner) Run(ctx context.Context) {

if task != nil { // we have a task to execute
if isEmpty(task.GetId()) {
t.log.ErrorContext(ctx, "got a task without an ID - skipping to the next task")
t.logger.ErrorContext(ctx, "got a task without an ID - skipping to the next task")
task = nil
continue
}
Expand All @@ -144,7 +214,7 @@ func (t *TaskRunner) Run(ctx context.Context) {
if executionContext != nil && len(executionContext.Subtasks) > 0 {
computedTask.SubTasks = executionContext.Subtasks
}
nextTaskToRun := &workflowsv1.NextTaskToRun{ClusterSlug: DefaultClusterSlug, Identifiers: identifiers}
nextTaskToRun := &workflowsv1.NextTaskToRun{ClusterSlug: t.cluster, Identifiers: identifiers}
select {
case <-ctxSignal.Done():
// if we got a context cancellation, don't request a new task
Expand All @@ -155,37 +225,37 @@ func (t *TaskRunner) Run(ctx context.Context) {

task, err = retry.DoWithData(
func() (*workflowsv1.Task, error) {
taskResponse, err := t.Client.NextTask(ctx, connect.NewRequest(&workflowsv1.NextTaskRequest{
taskResponse, err := t.client.NextTask(ctx, connect.NewRequest(&workflowsv1.NextTaskRequest{
ComputedTask: computedTask, NextTaskToRun: nextTaskToRun,
}))
if err != nil {
t.log.ErrorContext(ctx, "failed to mark task as computed, retrying", "error", err)
t.logger.ErrorContext(ctx, "failed to mark task as computed, retrying", "error", err)
return nil, err
}
return taskResponse.Msg.GetNextTask(), nil
}, retry.Context(ctxSignal), retry.DelayType(retry.CombineDelay(retry.BackOffDelay, retry.RandomDelay)),
)
if err != nil {
t.log.ErrorContext(ctx, "failed to retry NextTask", "error", err)
t.logger.ErrorContext(ctx, "failed to retry NextTask", "error", err)
return // we got a cancellation signal, so let's just stop here
}
} else { // err != nil
t.log.ErrorContext(ctx, "task execution failed", "error", err)
t.logger.ErrorContext(ctx, "task execution failed", "error", err)
err = retry.Do(
func() error {
_, err := t.Client.TaskFailed(ctx, connect.NewRequest(&workflowsv1.TaskFailedRequest{
_, err := t.client.TaskFailed(ctx, connect.NewRequest(&workflowsv1.TaskFailedRequest{
TaskId: task.GetId(),
CancelJob: true,
}))
if err != nil {
t.log.ErrorContext(ctx, "failed to report task failure", "error", err)
t.logger.ErrorContext(ctx, "failed to report task failure", "error", err)
return err
}
return nil
}, retry.Context(ctxSignal), retry.DelayType(retry.CombineDelay(retry.BackOffDelay, retry.RandomDelay)),
)
if err != nil {
t.log.ErrorContext(ctx, "failed to retry TaskFailed", "error", err)
t.logger.ErrorContext(ctx, "failed to retry TaskFailed", "error", err)
return // we got a cancellation signal, so let's just stop here
}
task = nil // reported a task failure, let's work-steal again
Expand All @@ -195,7 +265,7 @@ func (t *TaskRunner) Run(ctx context.Context) {
}
} else {
// if we didn't get a task, let's wait for a bit and try work-stealing again
t.log.DebugContext(ctx, "no task to run")
t.logger.DebugContext(ctx, "no task to run")

// instead of time.Sleep we set a timer and select on it, so we still can catch signals like SIGINT
timer := time.NewTimer(pollingInterval + rand.N(jitterInterval))
Expand All @@ -212,7 +282,7 @@ func (t *TaskRunner) Run(ctx context.Context) {
func (t *TaskRunner) executeTask(ctx context.Context, task *workflowsv1.Task) (*taskExecutionContext, error) {
// start a goroutine to extend the lease of the task continuously until the task execution is finished
leaseCtx, stopLeaseExtensions := context.WithCancel(ctx)
go t.extendTaskLease(leaseCtx, t.Client, task.GetId(), task.GetLease().GetLease().AsDuration(), task.GetLease().GetRecommendedWaitUntilNextExtension().AsDuration())
go t.extendTaskLease(leaseCtx, t.client, task.GetId(), task.GetLease().GetLease().AsDuration(), task.GetLease().GetRecommendedWaitUntilNextExtension().AsDuration())
defer stopLeaseExtensions()

// actually execute the task
Expand All @@ -226,7 +296,7 @@ func (t *TaskRunner) executeTask(ctx context.Context, task *workflowsv1.Task) (*
}

return observability.StartJobSpan(ctx, t.tracer, fmt.Sprintf("task/%s", identifier.Name()), task.GetJob(), func(ctx context.Context) (*taskExecutionContext, error) {
t.log.DebugContext(ctx, "executing task", "task", identifier.Name, "version", identifier.Version)
t.logger.DebugContext(ctx, "executing task", "task", identifier.Name, "version", identifier.Version)
taskStruct := reflect.New(reflect.ValueOf(taskPrototype).Elem().Type()).Interface().(ExecutableTask)

_, isProtobuf := taskStruct.(proto.Message)
Expand All @@ -242,7 +312,7 @@ func (t *TaskRunner) executeTask(ctx context.Context, task *workflowsv1.Task) (*
}
}

executionContext := withTaskExecutionContext(ctx, t.Client, task)
executionContext := t.withTaskExecutionContext(ctx, task)
err := taskStruct.Execute(executionContext)
if r := recover(); r != nil {
// recover from panics during task executions, so we can still report the error to the server and continue
Expand Down Expand Up @@ -270,21 +340,21 @@ func (t *TaskRunner) extendTaskLease(ctx context.Context, client workflowsv1conn
return
case <-timer.C: // the timer expired, let's try to extend the lease
}
t.log.DebugContext(ctx, "extending task lease", "task_id", uuid.Must(uuid.FromBytes(taskID.GetUuid())), "lease", lease, "wait", wait)
t.logger.DebugContext(ctx, "extending task lease", "task_id", uuid.Must(uuid.FromBytes(taskID.GetUuid())), "lease", lease, "wait", wait)
req := &workflowsv1.TaskLeaseRequest{
TaskId: taskID,
RequestedLease: durationpb.New(2 * lease), // double the current lease duration for the next extension
}
extension, err := client.ExtendTaskLease(ctx, connect.NewRequest(req))
if err != nil {
t.log.ErrorContext(ctx, "failed to extend task lease", "error", err, "task_id", uuid.Must(uuid.FromBytes(taskID.GetUuid())))
t.logger.ErrorContext(ctx, "failed to extend task lease", "error", err, "task_id", uuid.Must(uuid.FromBytes(taskID.GetUuid())))
// The server probably has an internal error, but there is no point in trying to extend the lease again
// because it will be expired then, so let's just return
return
}
if extension.Msg.GetLease() == nil {
// the server did not return a lease extension, it means that there is no need in trying to extend the lease
t.log.DebugContext(ctx, "task lease extension not granted", "task_id", uuid.Must(uuid.FromBytes(taskID.GetUuid())))
t.logger.DebugContext(ctx, "task lease extension not granted", "task_id", uuid.Must(uuid.FromBytes(taskID.GetUuid())))
return
}
// will probably be double the previous lease (since we requested that) or capped by the server at maxLeaseDuration
Expand All @@ -295,14 +365,14 @@ func (t *TaskRunner) extendTaskLease(ctx context.Context, client workflowsv1conn

type taskExecutionContext struct {
CurrentTask *workflowsv1.Task
Client workflowsv1connect.TaskServiceClient
runner *TaskRunner
Subtasks []*workflowsv1.TaskSubmission
}

func withTaskExecutionContext(ctx context.Context, client workflowsv1connect.TaskServiceClient, task *workflowsv1.Task) context.Context {
func (t *TaskRunner) withTaskExecutionContext(ctx context.Context, task *workflowsv1.Task) context.Context {
return context.WithValue(ctx, ContextKeyTaskExecution, &taskExecutionContext{
CurrentTask: task,
Client: client,
runner: t,
Subtasks: make([]*workflowsv1.TaskSubmission, 0),
})
}
Expand Down Expand Up @@ -341,7 +411,7 @@ func SubmitSubtasks(ctx context.Context, tasks ...Task) error {
}

executionContext.Subtasks = append(executionContext.Subtasks, &workflowsv1.TaskSubmission{
ClusterSlug: DefaultClusterSlug,
ClusterSlug: executionContext.runner.cluster,
Identifier: &workflowsv1.TaskIdentifier{
Name: identifier.Name(),
Version: identifier.Version(),
Expand Down
Loading

0 comments on commit 116ff5d

Please sign in to comment.