Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BucketTask and CronTask definition #12

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions apis/workflows/v1/trigger.proto
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,13 @@ message BucketTriggers {
repeated BucketTrigger triggers = 1;
}

// BucketTask is a task that is triggered by a bucket trigger.
message BucketTask {
string bucket = 1; // The bucket that triggered the task
string object = 2; // The object that triggered the task
bytes args = 3; // Additional arguments for the task, to be deserialized by the task
}

// CronTrigger is a trigger that will trigger a task submission on a schedule.
message CronTrigger {
UUID id = 1; // Unique identifier for the trigger
Expand All @@ -71,6 +78,12 @@ message CronTriggers {
repeated CronTrigger triggers = 1;
}

// CronTask is a task that is triggered by a cron trigger.
message CronTask {
google.protobuf.Timestamp trigger_time = 1; // The time the cron trigger fired
bytes args = 2; // Additional arguments for the task, to be deserialized by the task
}

// TriggerService is a service for managing NRT triggers. Currently, we support two types of triggers:
// - Bucket triggers, which trigger on object uploads to a storage bucket
// - Cron triggers, which trigger on a schedule
Expand Down
412 changes: 284 additions & 128 deletions protogen/go/workflows/v1/trigger.pb.go

Large diffs are not rendered by default.

11 changes: 6 additions & 5 deletions workflows/v1/jobs.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func NewJobService(client workflowsv1connect.JobServiceClient) *JobService {
}
}

func (js *JobService) Submit(ctx context.Context, jobName, clusterSlug string, tasks ...Task) (*workflowsv1.Job, error) {
func (js *JobService) Submit(ctx context.Context, jobName, clusterSlug string, maxRetries int, tasks ...Task) (*workflowsv1.Job, error) {
if len(tasks) == 0 {
return nil, errors.New("no tasks to submit")
}
Expand Down Expand Up @@ -60,11 +60,12 @@ func (js *JobService) Submit(ctx context.Context, jobName, clusterSlug string, t
rootTasks = append(rootTasks, &workflowsv1.TaskSubmission{
ClusterSlug: clusterSlug,
Identifier: &workflowsv1.TaskIdentifier{
Name: identifier.Name,
Version: identifier.Version,
Name: identifier.Name(),
Version: identifier.Version(),
},
Input: subtaskInput,
Display: identifier.Name,
Input: subtaskInput,
Display: identifier.Display(),
MaxRetries: int64(maxRetries),
})
}

Expand Down
2 changes: 1 addition & 1 deletion workflows/v1/jobs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ func TestJobService_Submit(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
client := mockJobServiceClient{}
js := NewJobService(&client)
got, err := js.Submit(ctx, tt.args.jobName, tt.args.clusterSlug, tt.args.tasks...)
got, err := js.Submit(ctx, tt.args.jobName, tt.args.clusterSlug, 0, tt.args.tasks...)
if (err != nil) != tt.wantErr {
t.Errorf("Submit() error = %v, wantErr %v", err, tt.wantErr)
return
Expand Down
27 changes: 16 additions & 11 deletions workflows/v1/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,14 @@ const jitterInterval = 5 * time.Second

type TaskRunner struct {
Client workflowsv1connect.TaskServiceClient
taskDefinitions map[TaskIdentifier]ExecutableTask
taskDefinitions map[taskIdentifier]ExecutableTask
tracer trace.Tracer
}

func NewTaskRunner(client workflowsv1connect.TaskServiceClient) *TaskRunner {
return &TaskRunner{
Client: client,
taskDefinitions: make(map[TaskIdentifier]ExecutableTask),
taskDefinitions: make(map[taskIdentifier]ExecutableTask),
tracer: otel.Tracer("tilebox.com/observability"),
}
}
Expand All @@ -55,10 +55,15 @@ func (t *TaskRunner) RegisterTask(task ExecutableTask) error {
if err != nil {
return err
}
t.taskDefinitions[identifier] = task
t.taskDefinitions[taskIdentifier{name: identifier.Name(), version: identifier.Version()}] = task
return nil
}

func (t *TaskRunner) GetRegisteredTask(identifier TaskIdentifier) (ExecutableTask, bool) {
registeredTask, found := t.taskDefinitions[taskIdentifier{name: identifier.Name(), version: identifier.Version()}]
return registeredTask, found
}

func (t *TaskRunner) RegisterTasks(tasks ...ExecutableTask) error {
for _, task := range tasks {
err := t.RegisterTask(task)
Expand Down Expand Up @@ -101,8 +106,8 @@ func (t *TaskRunner) Run(ctx context.Context) {
for _, task := range t.taskDefinitions {
identifier := identifierFromTask(task)
identifiers = append(identifiers, &workflowsv1.TaskIdentifier{
Name: identifier.Name,
Version: identifier.Version,
Name: identifier.Name(),
Version: identifier.Version(),
})
}

Expand Down Expand Up @@ -212,13 +217,13 @@ func (t *TaskRunner) executeTask(ctx context.Context, task *workflowsv1.Task) (*
if task.GetIdentifier() == nil {
return nil, errors.New("task has no identifier")
}
identifier := TaskIdentifier{Name: task.GetIdentifier().GetName(), Version: task.GetIdentifier().GetVersion()}
taskPrototype, found := t.taskDefinitions[identifier]
identifier := NewTaskIdentifier(task.GetIdentifier().GetName(), task.GetIdentifier().GetVersion())
taskPrototype, found := t.GetRegisteredTask(identifier)
if !found {
return nil, fmt.Errorf("task %s is not registered on this runner", task.GetIdentifier().GetName())
}

return observability.StartJobSpan(ctx, t.tracer, fmt.Sprintf("task/%s", identifier.Name), task.GetJob(), func(ctx context.Context) (*taskExecutionContext, error) {
return observability.StartJobSpan(ctx, t.tracer, fmt.Sprintf("task/%s", identifier.Name()), task.GetJob(), func(ctx context.Context) (*taskExecutionContext, error) {
slog.DebugContext(ctx, "executing task", "task", identifier.Name, "version", identifier.Version)
taskStruct := reflect.New(reflect.ValueOf(taskPrototype).Elem().Type()).Interface().(ExecutableTask)

Expand Down Expand Up @@ -336,12 +341,12 @@ func SubmitSubtasks(ctx context.Context, tasks ...Task) error {
executionContext.Subtasks = append(executionContext.Subtasks, &workflowsv1.TaskSubmission{
ClusterSlug: DefaultClusterSlug,
Identifier: &workflowsv1.TaskIdentifier{
Name: identifier.Name,
Version: identifier.Version,
Name: identifier.Name(),
Version: identifier.Version(),
},
Input: subtaskInput,
Dependencies: nil,
Display: identifier.Name,
Display: identifier.Display(),
})
}

Expand Down
6 changes: 3 additions & 3 deletions workflows/v1/runner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ type badIdentifierTask struct {
}

func (t *badIdentifierTask) Identifier() TaskIdentifier {
return TaskIdentifier{Name: "", Version: ""}
return NewTaskIdentifier("", "")
}

func TestTaskRunner_RegisterTask(t *testing.T) {
Expand Down Expand Up @@ -62,7 +62,7 @@ func TestTaskRunner_RegisterTask(t *testing.T) {
}

identifier := identifierFromTask(tt.args.task)
_, ok := runner.taskDefinitions[identifier]
_, ok := runner.GetRegisteredTask(identifier)
if ok && tt.wantErr {
t1.Errorf("RegisterTask() task found in taskDefinitions")
}
Expand Down Expand Up @@ -131,7 +131,7 @@ func TestTaskRunner_RegisterTasks(t *testing.T) {
}
for _, task := range tt.wantTasks {
identifier := identifierFromTask(task)
_, ok := runner.taskDefinitions[identifier]
_, ok := runner.GetRegisteredTask(identifier)
if !ok {
t1.Errorf("RegisterTasks() task not found in taskDefinitions")
}
Expand Down
50 changes: 38 additions & 12 deletions workflows/v1/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,46 @@ import (
"strconv"
)

type TaskIdentifier interface {
Name() string
Version() string
Display() string
}

// TaskIdentifier is the struct that defines the unique identifier of a task.
// It is used to uniquely identify a task and specify its version.
type TaskIdentifier struct {
Name string
Version string
type taskIdentifier struct {
name string
version string
}

func NewTaskIdentifier(name, version string) TaskIdentifier {
return TaskIdentifier{
Name: name,
Version: version,
return taskIdentifier{
name: name,
version: version,
}
}

// Name returns the name of the task.
func (t taskIdentifier) Name() string {
return t.name
}

// Version returns the version of the task.
func (t taskIdentifier) Version() string {
return t.version
}

// Display returns a human-readable string representation of the task identifier, to be used in graph visualizations.
// Can be overridden during task execution to provide a more descriptive name.
func (t taskIdentifier) Display() string {
return t.name
}

func (t taskIdentifier) String() string {
return fmt.Sprintf("%s@%s", t.name, t.version)
}

// Task is the interface for a task that can be submitted to the workflow service.
// It doesn't need to be identifiable or executable, but it can be both.
type Task interface {
Expand All @@ -44,21 +70,21 @@ func identifierFromTask(task Task) TaskIdentifier {
if identifiableTask, ok := task.(ExplicitlyIdentifiableTask); ok {
return identifiableTask.Identifier()
}
return TaskIdentifier{
Name: getStructName(task),
Version: "v0.0", // default version if not otherwise specified
return &taskIdentifier{
name: getStructName(task),
version: "v0.0", // default version if not otherwise specified
}
}

// ValidateIdentifier performs client-side validation on a task identifier.
func ValidateIdentifier(identifier TaskIdentifier) error {
if identifier.Name == "" {
if identifier.Name() == "" {
return errors.New("task name is empty")
}
if len(identifier.Name) > 256 {
if len(identifier.Name()) > 256 {
return errors.New("task name is too long")
}
_, _, err := parseVersion(identifier.Version)
_, _, err := parseVersion(identifier.Version())
if err != nil {
return err
}
Expand Down
51 changes: 26 additions & 25 deletions workflows/v1/task_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package workflows

import (
"reflect"
"strings"
"testing"
)
Expand All @@ -22,16 +21,17 @@ func TestNewTaskIdentifier(t *testing.T) {
name: "test",
version: "v0.0",
},
want: TaskIdentifier{
Name: "test",
Version: "v0.0",
want: taskIdentifier{
name: "test",
version: "v0.0",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := NewTaskIdentifier(tt.args.name, tt.args.version); !reflect.DeepEqual(got, tt.want) {
t.Errorf("NewTaskIdentifier() = %v, wantMajor %v", got, tt.want)
got := NewTaskIdentifier(tt.args.name, tt.args.version)
if got.Name() != tt.want.Name() || got.Version() != tt.want.Version() {
t.Errorf("NewTaskIdentifier() = %v, want %v", got, tt.want)
}
})
}
Expand All @@ -42,9 +42,9 @@ type emptyTask struct{}
type identifiableTask struct{}

func (t *identifiableTask) Identifier() TaskIdentifier {
return TaskIdentifier{
Name: "myName",
Version: "v1.2",
return taskIdentifier{
name: "myName",
version: "v1.2",
}
}

Expand All @@ -62,25 +62,26 @@ func Test_identifierFromTask(t *testing.T) {
args: args{
task: &emptyTask{},
},
want: TaskIdentifier{
Name: "emptyTask",
Version: "v0.0",
want: taskIdentifier{
name: "emptyTask",
version: "v0.0",
},
},
{
name: "identifier identifiable task",
args: args{
task: &identifiableTask{},
},
want: TaskIdentifier{
Name: "myName",
Version: "v1.2",
want: taskIdentifier{
name: "myName",
version: "v1.2",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := identifierFromTask(tt.args.task); !reflect.DeepEqual(got, tt.want) {
got := identifierFromTask(tt.args.task)
if got.Name() != tt.want.Name() || got.Version() != tt.want.Version() {
t.Errorf("identifierFromTask() = %v, wantMajor %v", got, tt.want)
}
})
Expand All @@ -100,19 +101,19 @@ func TestValidateIdentifier(t *testing.T) {
{
name: "ValidateIdentifier",
args: args{
identifier: TaskIdentifier{
Name: "test",
Version: "v0.0",
identifier: taskIdentifier{
name: "test",
version: "v0.0",
},
},
wantErr: false,
},
{
name: "ValidateIdentifier name empty",
args: args{
identifier: TaskIdentifier{
Name: "",
Version: "v0.0",
identifier: taskIdentifier{
name: "",
version: "v0.0",
},
},
wantErr: true,
Expand All @@ -121,9 +122,9 @@ func TestValidateIdentifier(t *testing.T) {
{
name: "ValidateIdentifier name too long",
args: args{
identifier: TaskIdentifier{
Name: strings.Repeat("a", 257),
Version: "v0.0",
identifier: taskIdentifier{
name: strings.Repeat("a", 257),
version: "v0.0",
},
},
wantErr: true,
Expand Down
Loading