From c7a7cbcaf05d0c10b06823bcadab7b9d01c5727b Mon Sep 17 00:00:00 2001 From: Corentin <33163342+corentinmusard@users.noreply.github.com> Date: Thu, 11 Apr 2024 10:23:54 +0200 Subject: [PATCH] Add workflows tests (#6) * Add workflows tests * test golangci-lint * Configure golangci-lint --- .github/workflows/main.yml | 21 ++- .golangci.yaml | 76 +++++++++ generate.go | 2 +- grpc/client_interceptor.go | 3 +- grpc/grpc_connect.go | 7 +- observability/observability.go | 15 +- workflows/v1/jobs.go | 8 +- workflows/v1/jobs_test.go | 89 ++++++++++ workflows/v1/runner.go | 77 ++++----- workflows/v1/runner_test.go | 297 +++++++++++++++++++++++++++++++++ workflows/v1/task.go | 13 +- workflows/v1/task_test.go | 216 ++++++++++++++++++++++++ 12 files changed, 758 insertions(+), 66 deletions(-) create mode 100644 .golangci.yaml create mode 100644 workflows/v1/jobs_test.go create mode 100644 workflows/v1/runner_test.go create mode 100644 workflows/v1/task_test.go diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 0aa88ab..2c8dc37 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -7,6 +7,9 @@ on: merge_group: branches: ["**"] +permissions: + contents: read # for golangci-lint-action + jobs: tests: name: Run tests @@ -29,10 +32,14 @@ jobs: go install github.com/jstemmer/go-junit-report@latest - name: Build run: go build -v ./... - #- name: Run Tests - # run: go test -v ./... | go-junit-report -set-exit-code > test-report.xml - #- name: Test Summary - # uses: test-summary/action@v2 - # with: - # paths: "test-report.xml" - # if: always() + - name: Run Tests + run: go test -v ./... | go-junit-report -set-exit-code > test-report.xml + - name: Test Summary + uses: test-summary/action@v2 + with: + paths: "test-report.xml" + if: always() + - name: Lint + uses: golangci/golangci-lint-action@v4 + with: + version: v1.57 diff --git a/.golangci.yaml b/.golangci.yaml new file mode 100644 index 0000000..a5df0a1 --- /dev/null +++ b/.golangci.yaml @@ -0,0 +1,76 @@ +# yaml-language-server: $schema=https://golangci-lint.run/jsonschema/golangci.jsonschema.json +# Inspired from: (MIT license) https://gist.github.com/maratori/47a4d00457a92aa426dbd48a18776322 +--- +run: + timeout: '5m' +linters-settings: + govet: + enable-all: true + disable: + - fieldalignment # too strict + - shadow # too strict + perfsprint: + strconcat: false +linters: + enable: + - asasalint # checks for pass []any as any in variadic func(...any) + - asciicheck # checks that your code does not contain non-ASCII identifiers + - bidichk # checks for dangerous unicode character sequences + - bodyclose # checks whether HTTP response body is closed successfully + - copyloopvar # detects places where loop variables are copied + #- cyclop # checks function and package cyclomatic complexity + - dupl # tool for code clone detection + - durationcheck # checks for two durations multiplied together + - errname # checks that sentinel errors are prefixed with the Err and error types are suffixed with the Error + - errorlint # finds code that will cause problems with the error wrapping scheme introduced in Go 1.13 + - execinquery # checks query string in Query function which reads your Go src files and warning it finds + - exhaustive # checks exhaustiveness of enum switch statements + - exportloopref # checks for pointers to enclosing loop variables + - forbidigo # forbids identifiers + - gocheckcompilerdirectives # validates go compiler directive comments (//go:) + #- gochecknoglobals # checks that no global variables exist + - gochecknoinits # checks that no init functions are present in Go code + - gochecksumtype # checks exhaustiveness on Go "sum types" + - goconst # finds repeated strings that could be replaced by a constant + - gocritic # provides diagnostics that check for bugs, performance and style issues + - goimports # in addition to fixing imports, goimports also formats your code in the same style as gofmt + # - gomnd # detects magic numbers + - gomoddirectives # manages the use of 'replace', 'retract', and 'excludes' directives in go.mod + - gomodguard # allow and block lists linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations + - goprintffuncname # checks that printf-like functions are named with f at the end + - gosec # inspects source code for security problems + - intrange # finds places where for loops could make use of an integer range + - loggercheck # checks key value pairs for common logger libraries (kitlog,klog,logr,zap) + - makezero # finds slice declarations with non-zero initial length + - mirror # reports wrong mirror patterns of bytes/strings usage + - musttag # enforces field tags in (un)marshaled structs + - nakedret # finds naked returns in functions greater than a specified function length + - nilerr # finds the code that returns nil even if it checks that the error is not nil + - nilnil # checks that there is no simultaneous return of nil error and an invalid value + - noctx # finds sending http request without context.Context + - nolintlint # reports ill-formed or insufficient nolint directives + - nonamedreturns # reports all named returns + - nosprintfhostport # checks for misuse of Sprintf to construct a host with port in a URL + - perfsprint # checks that fmt.Sprintf can be replaced with a faster alternative + - prealloc # finds slice declarations that could potentially be preallocated + - predeclared # finds code that shadows one of Go's predeclared identifiers + - promlinter # checks Prometheus metrics naming via promlint + - protogetter # reports direct reads from proto message fields when getters should be used + - reassign # checks that package variables are not reassigned + - revive # fast, configurable, extensible, flexible, and beautiful linter for Go, drop-in replacement of golint + - rowserrcheck # checks whether Err of rows is checked successfully + - sloglint # ensure consistent code style when using log/slog + - spancheck # checks for mistakes with OpenTelemetry/Census spans + - sqlclosecheck # checks that sql.Rows and sql.Stmt are closed + - stylecheck # is a replacement for golint + - tagalign # checks that struct tags are well aligned + - tenv # detects using os.Setenv instead of t.Setenv since Go1.17 + - testableexamples # checks if examples are testable (have an expected output) + - testifylint # checks usage of github.com/stretchr/testify + #- testpackage # makes you use a separate _test package + - tparallel # detects inappropriate usage of t.Parallel() method in your Go test codes + - unconvert # removes unnecessary type conversions + - unparam # reports unused function parameters + - usestdlibvars # detects the possibility to use variables/constants from the Go standard library + - wastedassign # finds wasted assignment statements + - whitespace # detects leading and trailing whitespace diff --git a/generate.go b/generate.go index 72b20e0..637225e 100644 --- a/generate.go +++ b/generate.go @@ -1,3 +1,3 @@ -package tilebox_go +package tilebox //go:generate go run -mod=mod github.com/bufbuild/buf/cmd/buf generate diff --git a/grpc/client_interceptor.go b/grpc/client_interceptor.go index 0e224ef..dfedb76 100644 --- a/grpc/client_interceptor.go +++ b/grpc/client_interceptor.go @@ -1,8 +1,9 @@ package grpc import ( - "connectrpc.com/connect" "context" + + "connectrpc.com/connect" ) type addAuthTokenInterceptor struct { diff --git a/grpc/grpc_connect.go b/grpc/grpc_connect.go index 0014f03..ae2ce2d 100644 --- a/grpc/grpc_connect.go +++ b/grpc/grpc_connect.go @@ -1,15 +1,16 @@ package grpc import ( - "connectrpc.com/connect" "context" "errors" - "github.com/hashicorp/go-retryablehttp" "log/slog" "net/http" "net/url" "strings" "time" + + "connectrpc.com/connect" + "github.com/hashicorp/go-retryablehttp" ) // RetryOnStatusUnavailable provides a retry policy for retrying requests if the server is unavailable. @@ -43,7 +44,7 @@ func RetryOnStatusUnavailable(ctx context.Context, resp *http.Response, err erro return false, err } -func RetryHttpClient() connect.HTTPClient { +func RetryHTTPClient() connect.HTTPClient { retryClient := retryablehttp.NewClient() retryClient.RetryWaitMin = 20 * time.Millisecond retryClient.RetryWaitMax = 5 * time.Second diff --git a/observability/observability.go b/observability/observability.go index e968027..1b1c5ad 100644 --- a/observability/observability.go +++ b/observability/observability.go @@ -2,6 +2,8 @@ package observability import ( "context" + "log/slog" + adapter "github.com/axiomhq/axiom-go/adapters/slog" "github.com/axiomhq/axiom-go/axiom" axiotel "github.com/axiomhq/axiom-go/axiom/otel" @@ -13,7 +15,6 @@ import ( "go.opentelemetry.io/otel/sdk/trace" semconv "go.opentelemetry.io/otel/semconv/v1.21.0" oteltrace "go.opentelemetry.io/otel/trace" - "log/slog" ) var propagator = propagation.TraceContext{} @@ -37,7 +38,7 @@ func AxiomTraceExporter(ctx context.Context, dataset, token string) (trace.SpanE return axiotel.TraceExporter(ctx, dataset, axiotel.SetToken(token)) } -func SetupOtelTracing(serviceName, serviceVersion string, exporters ...trace.SpanExporter) (shutdown func(ctx context.Context), err error) { +func SetupOtelTracing(serviceName, serviceVersion string, exporters ...trace.SpanExporter) func(ctx context.Context) { tp := tracerProvider(serviceName, serviceVersion, exporters) otel.SetTracerProvider(tp) @@ -45,7 +46,7 @@ func SetupOtelTracing(serviceName, serviceVersion string, exporters ...trace.Spa _ = tp.Shutdown(ctx) } - return shutDownFunc, err + return shutDownFunc } // tracerProvider configures and returns a new OpenTelemetry tracer provider. @@ -75,15 +76,15 @@ func GetTraceParentOfCurrentSpan(ctx context.Context) string { return carrier.Get("traceparent") } -func StartJobSpan[Result any](tracer oteltrace.Tracer, ctx context.Context, spanName string, job *workflowsv1.Job, f func(ctx context.Context) (Result, error)) (Result, error) { - carrier := propagation.MapCarrier{"traceparent": job.TraceParent} +func StartJobSpan[Result any](ctx context.Context, tracer oteltrace.Tracer, spanName string, job *workflowsv1.Job, f func(ctx context.Context) (Result, error)) (Result, error) { + carrier := propagation.MapCarrier{"traceparent": job.GetTraceParent()} ctx = propagator.Extract(ctx, carrier) - return WithSpanResult(tracer, ctx, spanName, f) + return WithSpanResult(ctx, tracer, spanName, f) } // WithSpanResult wraps a function call that returns a result and an error with a tracing span of the given name -func WithSpanResult[Result any](tracer oteltrace.Tracer, ctx context.Context, name string, f func(ctx context.Context) (Result, error)) (Result, error) { +func WithSpanResult[Result any](ctx context.Context, tracer oteltrace.Tracer, name string, f func(ctx context.Context) (Result, error)) (Result, error) { ctx, span := tracer.Start(ctx, name) defer span.End() diff --git a/workflows/v1/jobs.go b/workflows/v1/jobs.go index 6639eba..7641535 100644 --- a/workflows/v1/jobs.go +++ b/workflows/v1/jobs.go @@ -1,10 +1,12 @@ package workflows import ( - "connectrpc.com/connect" "context" "encoding/json" + "errors" "fmt" + + "connectrpc.com/connect" "github.com/tilebox/tilebox-go/observability" workflowsv1 "github.com/tilebox/tilebox-go/protogen/go/workflows/v1" "github.com/tilebox/tilebox-go/protogen/go/workflows/v1/workflowsv1connect" @@ -27,7 +29,7 @@ func NewJobService(client workflowsv1connect.JobServiceClient) *JobService { func (js *JobService) Submit(ctx context.Context, jobName, clusterSlug string, tasks ...Task) (*workflowsv1.Job, error) { if len(tasks) == 0 { - return nil, fmt.Errorf("no tasks to submit") + return nil, errors.New("no tasks to submit") } rootTasks := make([]*workflowsv1.TaskSubmission, 0) @@ -66,7 +68,7 @@ func (js *JobService) Submit(ctx context.Context, jobName, clusterSlug string, t }) } - return observability.WithSpanResult(js.tracer, ctx, fmt.Sprintf("job/%s", jobName), func(ctx context.Context) (*workflowsv1.Job, error) { + return observability.WithSpanResult(ctx, js.tracer, fmt.Sprintf("job/%s", jobName), func(ctx context.Context) (*workflowsv1.Job, error) { traceParent := observability.GetTraceParentOfCurrentSpan(ctx) job, err := js.client.SubmitJob(ctx, connect.NewRequest( diff --git a/workflows/v1/jobs_test.go b/workflows/v1/jobs_test.go new file mode 100644 index 0000000..e8a2b8b --- /dev/null +++ b/workflows/v1/jobs_test.go @@ -0,0 +1,89 @@ +package workflows + +import ( + "context" + "reflect" + "testing" + + "connectrpc.com/connect" + workflowsv1 "github.com/tilebox/tilebox-go/protogen/go/workflows/v1" + "github.com/tilebox/tilebox-go/protogen/go/workflows/v1/workflowsv1connect" +) + +type mockJobServiceClient struct { + workflowsv1connect.JobServiceClient + reqs []*workflowsv1.SubmitJobRequest +} + +func (m *mockJobServiceClient) SubmitJob(_ context.Context, req *connect.Request[workflowsv1.SubmitJobRequest]) (*connect.Response[workflowsv1.Job], error) { + m.reqs = append(m.reqs, req.Msg) + + return connect.NewResponse(&workflowsv1.Job{ + Name: req.Msg.GetJobName(), + }), nil +} + +func TestJobService_Submit(t *testing.T) { + ctx := context.Background() + + type args struct { + jobName string + clusterSlug string + tasks []Task + } + tests := []struct { + name string + args args + want *workflowsv1.Job + wantReq *workflowsv1.SubmitJobRequest + wantErr bool + }{ + { + name: "Submit Job", + args: args{ + jobName: "test-job", + clusterSlug: "test-cluster", + tasks: []Task{&testTask1{}}, + }, + want: &workflowsv1.Job{ + Name: "test-job", + }, + wantReq: &workflowsv1.SubmitJobRequest{ + Tasks: []*workflowsv1.TaskSubmission{ + { + ClusterSlug: "test-cluster", + Identifier: &workflowsv1.TaskIdentifier{ + Name: "testTask1", + Version: "v0.0", + }, + Input: []byte("{\"ExecutableTask\":null}"), + Display: "testTask1", + }, + }, + JobName: "test-job", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + client := mockJobServiceClient{} + js := NewJobService(&client) + got, err := js.Submit(ctx, tt.args.jobName, tt.args.clusterSlug, tt.args.tasks...) + if (err != nil) != tt.wantErr { + t.Errorf("Submit() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Submit() got = %v, want %v", got, tt.want) + } + + // Verify the submitted request + if len(client.reqs) != 1 { + t.Fatalf("Submit() expected 1 request, got %d", len(client.reqs)) + } + if !reflect.DeepEqual(client.reqs[0], tt.wantReq) { + t.Errorf("Submit() request = %v, want %v", client.reqs[0], tt.wantReq) + } + }) + } +} diff --git a/workflows/v1/runner.go b/workflows/v1/runner.go index 28e05fd..d075740 100644 --- a/workflows/v1/runner.go +++ b/workflows/v1/runner.go @@ -1,10 +1,18 @@ package workflows import ( - "connectrpc.com/connect" "context" "encoding/json" + "errors" "fmt" + "log/slog" + "math/rand/v2" + "os/signal" + "reflect" + "syscall" + "time" + + "connectrpc.com/connect" "github.com/avast/retry-go/v4" "github.com/google/uuid" "github.com/tilebox/tilebox-go/observability" @@ -14,12 +22,6 @@ import ( "go.opentelemetry.io/otel/trace" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/durationpb" - "log/slog" - "math/rand/v2" - "os/signal" - "reflect" - "syscall" - "time" ) type ContextKeyTaskExecutionType string @@ -28,7 +30,7 @@ const ContextKeyTaskExecution ContextKeyTaskExecutionType = "x-tilebox-task-exec const DefaultClusterSlug = "testing-4qgCk4qHH85qR7" -//const DefaultClusterSlug = "workflow-dev-EifhUozDpwAJDL" +// const DefaultClusterSlug = "workflow-dev-EifhUozDpwAJDL" const pollingInterval = 5 * time.Second const jitterInterval = 5 * time.Second @@ -57,8 +59,8 @@ func (t *TaskRunner) RegisterTask(task ExecutableTask) error { return nil } -func (t *TaskRunner) RegisterTasks(task ...ExecutableTask) error { - for _, task := range task { +func (t *TaskRunner) RegisterTasks(tasks ...ExecutableTask) error { + for _, task := range tasks { err := t.RegisterTask(task) if err != nil { return err @@ -67,12 +69,12 @@ func (t *TaskRunner) RegisterTasks(task ...ExecutableTask) error { return nil } -func protobufToUuid(id *workflowsv1.UUID) (uuid.UUID, error) { - if id == nil || len(id.Uuid) == 0 { +func protobufToUUID(id *workflowsv1.UUID) (uuid.UUID, error) { + if id == nil || len(id.GetUuid()) == 0 { return uuid.Nil, nil } - bytes, err := uuid.FromBytes(id.Uuid) + bytes, err := uuid.FromBytes(id.GetUuid()) if err != nil { return uuid.Nil, err } @@ -81,11 +83,11 @@ func protobufToUuid(id *workflowsv1.UUID) (uuid.UUID, error) { } func isEmpty(id *workflowsv1.UUID) bool { - taskId, err := protobufToUuid(id) + taskID, err := protobufToUUID(id) if err != nil { return false } - return taskId == uuid.Nil + return taskID == uuid.Nil } // Run runs the task runner forever, looking for new tasks to run and polling for new tasks when idle. @@ -115,12 +117,12 @@ func (t *TaskRunner) Run(ctx context.Context) { slog.ErrorContext(ctx, "failed to work-steal a task", "error", err) // return // should we even try again, or just stop here? } else { - task = taskResponse.Msg.NextTask + task = taskResponse.Msg.GetNextTask() } } if task != nil { // we have a task to execute - if isEmpty(task.Id) { + if isEmpty(task.GetId()) { slog.ErrorContext(ctx, "got a task without an ID - skipping to the next task") task = nil continue @@ -129,7 +131,7 @@ func (t *TaskRunner) Run(ctx context.Context) { stopExecution := false if err == nil { // in case we got no error, let's mark the task as computed and get the next one computedTask := &workflowsv1.ComputedTask{ - Id: task.Id, + Id: task.GetId(), SubTasks: nil, } if executionContext != nil && len(executionContext.Subtasks) > 0 { @@ -153,7 +155,7 @@ func (t *TaskRunner) Run(ctx context.Context) { slog.ErrorContext(ctx, "failed to mark task as computed, retrying", "error", err) return nil, err } - return taskResponse.Msg.NextTask, nil + return taskResponse.Msg.GetNextTask(), nil }, retry.Context(ctxSignal), retry.DelayType(retry.CombineDelay(retry.BackOffDelay, retry.RandomDelay)), ) if err != nil { @@ -165,7 +167,7 @@ func (t *TaskRunner) Run(ctx context.Context) { err = retry.Do( func() error { _, err := t.Client.TaskFailed(ctx, connect.NewRequest(&workflowsv1.TaskFailedRequest{ - TaskId: task.Id, + TaskId: task.GetId(), CancelJob: true, })) if err != nil { @@ -197,38 +199,37 @@ func (t *TaskRunner) Run(ctx context.Context) { case <-timer.C: // the timer expired, let's try to work-steal a task again } } - } } func (t *TaskRunner) executeTask(ctx context.Context, task *workflowsv1.Task) (*taskExecutionContext, error) { // start a goroutine to extend the lease of the task continuously until the task execution is finished leaseCtx, stopLeaseExtensions := context.WithCancel(ctx) - go extendTaskLease(leaseCtx, t.Client, task.Id, task.Lease.Lease.AsDuration(), task.Lease.RecommendedWaitUntilNextExtension.AsDuration()) + go extendTaskLease(leaseCtx, t.Client, task.GetId(), task.GetLease().GetLease().AsDuration(), task.GetLease().GetRecommendedWaitUntilNextExtension().AsDuration()) defer stopLeaseExtensions() // actually execute the task - if task.Identifier == nil { - return nil, fmt.Errorf("task has no identifier") + if task.GetIdentifier() == nil { + return nil, errors.New("task has no identifier") } - identifier := TaskIdentifier{Name: task.Identifier.Name, Version: task.Identifier.Version} + identifier := TaskIdentifier{Name: task.GetIdentifier().GetName(), Version: task.GetIdentifier().GetVersion()} taskPrototype, found := t.taskDefinitions[identifier] if !found { - return nil, fmt.Errorf("task %s is not registered on this runner", task.Identifier.Name) + return nil, fmt.Errorf("task %s is not registered on this runner", task.GetIdentifier().GetName()) } - return observability.StartJobSpan(t.tracer, ctx, fmt.Sprintf("task/%s", identifier.Name), task.GetJob(), func(ctx context.Context) (*taskExecutionContext, error) { + return observability.StartJobSpan(ctx, t.tracer, fmt.Sprintf("task/%s", identifier.Name), task.GetJob(), func(ctx context.Context) (*taskExecutionContext, error) { slog.DebugContext(ctx, "executing task", "task", identifier.Name, "version", identifier.Version) taskStruct := reflect.New(reflect.ValueOf(taskPrototype).Elem().Type()).Interface().(ExecutableTask) _, isProtobuf := taskStruct.(proto.Message) if isProtobuf { - err := proto.Unmarshal(task.Input, taskStruct.(proto.Message)) + err := proto.Unmarshal(task.GetInput(), taskStruct.(proto.Message)) if err != nil { return nil, fmt.Errorf("failed to unmarshal protobuf task: %w", err) } } else { - err := json.Unmarshal(task.Input, taskStruct) + err := json.Unmarshal(task.GetInput(), taskStruct) if err != nil { return nil, fmt.Errorf("failed to unmarshal json task: %w", err) } @@ -251,7 +252,7 @@ func (t *TaskRunner) executeTask(ctx context.Context, task *workflowsv1.Task) (* // extendTaskLease is a function designed to be run as a goroutine, extending the lease of a task continuously until the // context is cancelled, which indicates that the execution of the task is finished. -func extendTaskLease(ctx context.Context, client workflowsv1connect.TaskServiceClient, taskId *workflowsv1.UUID, initialLease, initialWait time.Duration) { +func extendTaskLease(ctx context.Context, client workflowsv1connect.TaskServiceClient, taskID *workflowsv1.UUID, initialLease, initialWait time.Duration) { wait := initialWait lease := initialLease for { @@ -262,26 +263,26 @@ func extendTaskLease(ctx context.Context, client workflowsv1connect.TaskServiceC return case <-timer.C: // the timer expired, let's try to extend the lease } - slog.DebugContext(ctx, "extending task lease", "task_id", uuid.Must(uuid.FromBytes(taskId.Uuid)), "lease", lease, "wait", wait) + slog.DebugContext(ctx, "extending task lease", "task_id", uuid.Must(uuid.FromBytes(taskID.GetUuid())), "lease", lease, "wait", wait) req := &workflowsv1.TaskLeaseRequest{ - TaskId: taskId, + TaskId: taskID, RequestedLease: durationpb.New(2 * lease), // double the current lease duration for the next extension } extension, err := client.ExtendTaskLease(ctx, connect.NewRequest(req)) if err != nil { - slog.ErrorContext(ctx, "failed to extend task lease", "error", err, "task_id", uuid.Must(uuid.FromBytes(taskId.Uuid))) + slog.ErrorContext(ctx, "failed to extend task lease", "error", err, "task_id", uuid.Must(uuid.FromBytes(taskID.GetUuid()))) // The server probably has an internal error, but there is no point in trying to extend the lease again // because it will be expired then, so let's just return return } - if extension.Msg.Lease == nil { + if extension.Msg.GetLease() == nil { // the server did not return a lease extension, it means that there is no need in trying to extend the lease - slog.DebugContext(ctx, "task lease extension not granted", "task_id", uuid.Must(uuid.FromBytes(taskId.Uuid))) + slog.DebugContext(ctx, "task lease extension not granted", "task_id", uuid.Must(uuid.FromBytes(taskID.GetUuid()))) return } // will probably be double the previous lease (since we requested that) or capped by the server at maxLeaseDuration - lease = extension.Msg.Lease.AsDuration() - wait = extension.Msg.RecommendedWaitUntilNextExtension.AsDuration() + lease = extension.Msg.GetLease().AsDuration() + wait = extension.Msg.GetRecommendedWaitUntilNextExtension().AsDuration() } } @@ -306,7 +307,7 @@ func getTaskExecutionContext(ctx context.Context) *taskExecutionContext { func SubmitSubtasks(ctx context.Context, tasks ...Task) error { executionContext := getTaskExecutionContext(ctx) if executionContext == nil { - return fmt.Errorf("cannot submit subtask without task execution context") + return errors.New("cannot submit subtask without task execution context") } for _, task := range tasks { diff --git a/workflows/v1/runner_test.go b/workflows/v1/runner_test.go new file mode 100644 index 0000000..9f8f946 --- /dev/null +++ b/workflows/v1/runner_test.go @@ -0,0 +1,297 @@ +package workflows + +import ( + "context" + "reflect" + "testing" + + workflowsv1 "github.com/tilebox/tilebox-go/protogen/go/workflows/v1" + "github.com/tilebox/tilebox-go/protogen/go/workflows/v1/workflowsv1connect" +) + +type mockClient struct { + workflowsv1connect.TaskServiceClient +} + +type testTask1 struct { + ExecutableTask +} + +type testTask2 struct { + ExecutableTask +} + +type badIdentifierTask struct { + ExecutableTask +} + +func (t *badIdentifierTask) Identifier() TaskIdentifier { + return TaskIdentifier{Name: "", Version: ""} +} + +func TestTaskRunner_RegisterTask(t *testing.T) { + type args struct { + task ExecutableTask + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "Register Task", + args: args{ + task: &testTask1{}, + }, + }, + { + name: "Register Task bad identifier", + args: args{ + task: &badIdentifierTask{}, + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t1 *testing.T) { + runner := NewTaskRunner(mockClient{}) + + err := runner.RegisterTask(tt.args.task) + if (err != nil) != tt.wantErr { + t1.Errorf("RegisterTask() error = %v, wantErr %v", err, tt.wantErr) + } + + identifier := identifierFromTask(tt.args.task) + _, ok := runner.taskDefinitions[identifier] + if ok && tt.wantErr { + t1.Errorf("RegisterTask() task found in taskDefinitions") + } + if !ok && !tt.wantErr { + t1.Errorf("RegisterTask() task not found in taskDefinitions") + } + }) + } +} + +func TestTaskRunner_RegisterTasks(t *testing.T) { + type args struct { + tasks []ExecutableTask + } + tests := []struct { + name string + args args + wantErr bool + wantTasks []ExecutableTask + }{ + { + name: "Register Tasks no tasks", + args: args{ + tasks: []ExecutableTask{}, + }, + wantTasks: []ExecutableTask{}, + }, + { + name: "Register Tasks duplicated task", + args: args{ + tasks: []ExecutableTask{ + &testTask1{}, + &testTask1{}, + }, + }, + wantTasks: []ExecutableTask{ + &testTask1{}, + }, + }, + { + name: "Register Tasks bad identifier", + args: args{ + tasks: []ExecutableTask{ + &testTask1{}, + &badIdentifierTask{}, + &testTask2{}, + }, + }, + wantErr: true, + wantTasks: []ExecutableTask{ + &testTask1{}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t1 *testing.T) { + runner := NewTaskRunner(mockClient{}) + + err := runner.RegisterTasks(tt.args.tasks...) + if (err != nil) != tt.wantErr { + t1.Errorf("RegisterTasks() error = %v, wantErr %v", err, tt.wantErr) + } + + if len(runner.taskDefinitions) != len(tt.wantTasks) { + t1.Errorf("RegisterTasks() taskDefinitions length = %v, want %v", len(runner.taskDefinitions), len(tt.wantTasks)) + } + for _, task := range tt.wantTasks { + identifier := identifierFromTask(task) + _, ok := runner.taskDefinitions[identifier] + if !ok { + t1.Errorf("RegisterTasks() task not found in taskDefinitions") + } + } + }) + } +} + +func Test_isEmpty(t *testing.T) { + type args struct { + id *workflowsv1.UUID + } + tests := []struct { + name string + args args + want bool + }{ + { + name: "isEmpty nil", + args: args{ + id: nil, + }, + want: true, + }, + { + name: "isEmpty not nil but invalid id", + args: args{ + id: &workflowsv1.UUID{Uuid: []byte{1, 2, 3}}, + }, + want: false, + }, + { + name: "isEmpty not nil but valid id", + args: args{ + id: &workflowsv1.UUID{Uuid: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6}}, + }, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := isEmpty(tt.args.id); got != tt.want { + t.Errorf("isEmpty() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_withTaskExecutionContextRoundtrip(t *testing.T) { + type args struct { + ctx context.Context + client workflowsv1connect.TaskServiceClient + task *workflowsv1.Task + } + tests := []struct { + name string + args args + want *taskExecutionContext + }{ + { + name: "withTaskExecutionContext", + args: args{ + ctx: context.Background(), + client: mockClient{}, + task: nil, + }, + want: &taskExecutionContext{ + CurrentTask: nil, + Client: mockClient{}, + Subtasks: []*workflowsv1.TaskSubmission{}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + updatedCtx := withTaskExecutionContext(tt.args.ctx, tt.args.client, tt.args.task) + got := getTaskExecutionContext(updatedCtx) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("withTaskExecutionContext() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestSubmitSubtasks(t *testing.T) { + type args struct { + tasks []Task + } + tests := []struct { + name string + args args + wantErr bool + wantSubtasks []*workflowsv1.TaskSubmission + }{ + { + name: "Submit no Subtasks", + args: args{ + tasks: []Task{}, + }, + wantSubtasks: []*workflowsv1.TaskSubmission{}, + }, + { + name: "Submit one Subtasks", + args: args{ + tasks: []Task{ + &testTask1{}, + }, + }, + wantSubtasks: []*workflowsv1.TaskSubmission{ + { + ClusterSlug: DefaultClusterSlug, + Identifier: &workflowsv1.TaskIdentifier{Name: "testTask1", Version: "v0.0"}, + Input: []byte("{\"ExecutableTask\":null}"), + Display: "testTask1", + Dependencies: nil, + MaxRetries: 0, + }, + }, + }, + { + name: "Submit bad identifier Subtasks", + args: args{ + tasks: []Task{ + &testTask1{}, + &badIdentifierTask{}, + &testTask2{}, + }, + }, + wantErr: true, + wantSubtasks: []*workflowsv1.TaskSubmission{ + { + ClusterSlug: DefaultClusterSlug, + Identifier: &workflowsv1.TaskIdentifier{Name: "testTask1", Version: "v0.0"}, + Input: []byte("{\"ExecutableTask\":null}"), + Display: "testTask1", + Dependencies: nil, + MaxRetries: 0, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := withTaskExecutionContext(context.Background(), nil, nil) + + err := SubmitSubtasks(ctx, tt.args.tasks...) + if (err != nil) != tt.wantErr { + t.Errorf("SubmitSubtasks() error = %v, wantErr %v", err, tt.wantErr) + } + + te := getTaskExecutionContext(ctx) + if len(te.Subtasks) != len(tt.wantSubtasks) { + t.Errorf("SubmitSubtasks() Subtasks length = %v, want %v", len(te.Subtasks), len(tt.wantSubtasks)) + } + + for i, task := range tt.wantSubtasks { + if !reflect.DeepEqual(te.Subtasks[i], task) { + t.Errorf("SubmitSubtasks() Subtask %v = %v, want %v", i, te.Subtasks[i], task) + } + } + }) + } +} diff --git a/workflows/v1/task.go b/workflows/v1/task.go index bd16bf4..946c277 100644 --- a/workflows/v1/task.go +++ b/workflows/v1/task.go @@ -2,6 +2,7 @@ package workflows import ( "context" + "errors" "fmt" "reflect" "regexp" @@ -36,7 +37,7 @@ type ExplicitlyIdentifiableTask interface { // ExecutableTask is the interface for a task that can be executed, and therefore be registered with a task runner. type ExecutableTask interface { - Execute(context.Context) error + Execute(ctx context.Context) error } func identifierFromTask(task Task) TaskIdentifier { @@ -52,10 +53,10 @@ func identifierFromTask(task Task) TaskIdentifier { // ValidateIdentifier performs client-side validation on a task identifier. func ValidateIdentifier(identifier TaskIdentifier) error { if identifier.Name == "" { - return fmt.Errorf("task name is empty") + return errors.New("task name is empty") } if len(identifier.Name) > 256 { - return fmt.Errorf("task name is too long") + return errors.New("task name is too long") } _, _, err := parseVersion(identifier.Version) if err != nil { @@ -87,9 +88,9 @@ func parseVersion(version string) (int64, int64, error) { // getStructName returns the name of the struct type of a task. If the task is a pointer, the name of the pointed-to type is returned. // This function is used to generate a default identifier name for a task if it doesn't provide an explicit identifier. func getStructName(task interface{}) string { - if t := reflect.TypeOf(task); t.Kind() == reflect.Ptr { + t := reflect.TypeOf(task) + if t.Kind() == reflect.Ptr { return t.Elem().Name() - } else { - return t.Name() } + return t.Name() } diff --git a/workflows/v1/task_test.go b/workflows/v1/task_test.go new file mode 100644 index 0000000..d7d524c --- /dev/null +++ b/workflows/v1/task_test.go @@ -0,0 +1,216 @@ +package workflows + +import ( + "reflect" + "strings" + "testing" +) + +func TestNewTaskIdentifier(t *testing.T) { + type args struct { + name string + version string + } + tests := []struct { + name string + args args + want TaskIdentifier + }{ + { + name: "NewTaskIdentifier", + args: args{ + name: "test", + version: "v0.0", + }, + want: TaskIdentifier{ + Name: "test", + Version: "v0.0", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := NewTaskIdentifier(tt.args.name, tt.args.version); !reflect.DeepEqual(got, tt.want) { + t.Errorf("NewTaskIdentifier() = %v, wantMajor %v", got, tt.want) + } + }) + } +} + +type emptyTask struct{} + +type identifiableTask struct{} + +func (t *identifiableTask) Identifier() TaskIdentifier { + return TaskIdentifier{ + Name: "myName", + Version: "v1.2", + } +} + +func Test_identifierFromTask(t *testing.T) { + type args struct { + task Task + } + tests := []struct { + name string + args args + want TaskIdentifier + }{ + { + name: "identifier empty task", + args: args{ + task: &emptyTask{}, + }, + want: TaskIdentifier{ + Name: "emptyTask", + Version: "v0.0", + }, + }, + { + name: "identifier identifiable task", + args: args{ + task: &identifiableTask{}, + }, + want: TaskIdentifier{ + Name: "myName", + Version: "v1.2", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := identifierFromTask(tt.args.task); !reflect.DeepEqual(got, tt.want) { + t.Errorf("identifierFromTask() = %v, wantMajor %v", got, tt.want) + } + }) + } +} + +func TestValidateIdentifier(t *testing.T) { + type args struct { + identifier TaskIdentifier + } + tests := []struct { + name string + args args + wantErr bool + wantErrMessage string + }{ + { + name: "ValidateIdentifier", + args: args{ + identifier: TaskIdentifier{ + Name: "test", + Version: "v0.0", + }, + }, + wantErr: false, + }, + { + name: "ValidateIdentifier name empty", + args: args{ + identifier: TaskIdentifier{ + Name: "", + Version: "v0.0", + }, + }, + wantErr: true, + wantErrMessage: "task name is empty", + }, + { + name: "ValidateIdentifier name too long", + args: args{ + identifier: TaskIdentifier{ + Name: strings.Repeat("a", 257), + Version: "v0.0", + }, + }, + wantErr: true, + wantErrMessage: "task name is too long", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateIdentifier(tt.args.identifier) + if (err != nil) != tt.wantErr { + t.Errorf("ValidateIdentifier() error = %v, wantErr %v", err, tt.wantErr) + } + + if tt.wantErr { + if !strings.Contains(err.Error(), tt.wantErrMessage) { + t.Errorf("CreateCluster() error = %v, wantErrMessage %v", err, tt.wantErrMessage) + } + return + } + }) + } +} + +func Test_parseVersion(t *testing.T) { + type args struct { + version string + } + tests := []struct { + name string + args args + wantMajor int64 + wantMinor int64 + wantErr bool + }{ + { + name: "parseVersion v0.0", + args: args{ + version: "v0.0", + }, + wantMajor: 0, + wantMinor: 0, + wantErr: false, + }, + { + name: "parseVersion v2.3", + args: args{ + version: "v2.3", + }, + wantMajor: 2, + wantMinor: 3, + wantErr: false, + }, + { + name: "parseVersion wrong format", + args: args{ + version: "2.3", + }, + wantErr: true, + }, + { + name: "parseVersion wrong major", + args: args{ + version: "vA.3", + }, + wantErr: true, + }, + { + name: "parseVersion wrong minor", + args: args{ + version: "v2.A", + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotMajor, gotMinor, err := parseVersion(tt.args.version) + if (err != nil) != tt.wantErr { + t.Errorf("parseVersion() error = %v, wantErr %v", err, tt.wantErr) + return + } + if gotMajor != tt.wantMajor { + t.Errorf("parseVersion() gotMajor = %v, wantMajor %v", gotMajor, tt.wantMajor) + } + if gotMinor != tt.wantMinor { + t.Errorf("parseVersion() gotMinor = %v, wantMajor %v", gotMinor, tt.wantMinor) + } + }) + } +}