diff --git a/README.md b/README.md index bd2b160..21a5e50 100644 --- a/README.md +++ b/README.md @@ -56,15 +56,10 @@ import ( "log/slog" "os" - "connectrpc.com/connect" "github.com/google/uuid" - "github.com/tilebox/tilebox-go/grpc" - "github.com/tilebox/tilebox-go/protogen/go/workflows/v1/workflowsv1connect" "github.com/tilebox/tilebox-go/workflows/v1" ) -const serverURL = "https://api.tilebox.com" - type HelloTask struct { Name string } @@ -72,10 +67,13 @@ type HelloTask struct { func main() { ctx := context.Background() - jobsClient := clientFromConfig(serverURL, os.Getenv("TILEBOX_API_KEY")) - jobs := workflows.NewJobService(jobsClient) + jobs := workflows.NewJobService( + workflows.NewJobClient( + workflows.WithAPIKey(os.Getenv("TILEBOX_API_KEY")), + ), + ) - job, err := jobs.Submit(ctx, "hello-world", workflows.DefaultClusterSlug, + job, err := jobs.Submit(ctx, "hello-world", "testing-4qgCk4qHH85qR7", 0, &HelloTask{ Name: "Tilebox", }, @@ -87,16 +85,6 @@ func main() { slog.InfoContext(ctx, "Job submitted", "job_id", uuid.Must(uuid.FromBytes(job.GetId().GetUuid()))) } - -func clientFromConfig(serverURL, authToken string) workflowsv1connect.JobServiceClient { - return workflowsv1connect.NewJobServiceClient( - grpc.RetryHTTPClient(), serverURL, connect.WithInterceptors( - grpc.NewAddAuthTokenInterceptor(func() string { - return authToken - })), - ) -} - ``` ### Running a Worker @@ -111,14 +99,9 @@ import ( "log/slog" "os" - "connectrpc.com/connect" - "github.com/tilebox/tilebox-go/grpc" - "github.com/tilebox/tilebox-go/protogen/go/workflows/v1/workflowsv1connect" "github.com/tilebox/tilebox-go/workflows/v1" ) -const serverURL = "https://api.tilebox.com" - type HelloTask struct { Name string } @@ -130,10 +113,18 @@ func (t *HelloTask) Execute(context.Context) error { } func main() { - taskClient := clientFromConfig(serverURL, os.Getenv("TILEBOX_API_KEY")) - runner := workflows.NewTaskRunner(taskClient) + runner, err := workflows.NewTaskRunner( + workflows.NewTaskClient( + workflows.WithAPIKey(os.Getenv("TILEBOX_API_KEY")), + ), + workflows.WithCluster("testing-4qgCk4qHH85qR7"), + ) + if err != nil { + slog.Error("failed to create task runner", "error", err) + return + } - err := runner.RegisterTasks( + err = runner.RegisterTasks( &HelloTask{}, ) if err != nil { @@ -143,13 +134,4 @@ func main() { runner.Run(context.Background()) } - -func clientFromConfig(serverURL, authToken string) workflowsv1connect.TaskServiceClient { - return workflowsv1connect.NewTaskServiceClient( - grpc.RetryHTTPClient(), serverURL, connect.WithInterceptors( - grpc.NewAddAuthTokenInterceptor(func() string { - return authToken - })), - ) -} ``` diff --git a/go.mod b/go.mod index 324cb69..3c19e6e 100644 --- a/go.mod +++ b/go.mod @@ -25,6 +25,7 @@ require ( github.com/grpc-ecosystem/grpc-gateway/v2 v2.19.0 // indirect github.com/hashicorp/go-cleanhttp v0.5.2 // indirect github.com/klauspost/compress v1.17.7 // indirect + github.com/remychantenay/slog-otel v1.3.0 // indirect github.com/stretchr/testify v1.9.0 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/httptrace/otelhttptrace v0.48.0 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 // indirect diff --git a/go.sum b/go.sum index ee0a61d..13b7a65 100644 --- a/go.sum +++ b/go.sum @@ -102,6 +102,8 @@ github.com/pkg/profile v1.7.0 h1:hnbDkaNWPCLMO9wGLdBFTIZvzDrDfBM2072E1S9gJkA= github.com/pkg/profile v1.7.0/go.mod h1:8Uer0jas47ZQMJ7VD+OHknK4YDY07LPUC6dEvqDjvNo= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/remychantenay/slog-otel v1.3.0 h1:mppL97agkmwR416lKzltRQ9QRhrPdxwVidt0AnI3Ts4= +github.com/remychantenay/slog-otel v1.3.0/go.mod h1:L2VAe6WOMAk/kRzzuv2B/rWe/IDXAhUNae0919b4kHU= github.com/rs/cors v1.10.1 h1:L0uuZVXIKlI1SShY2nhFfo44TYvDPQ1w4oFkUJNfhyo= github.com/rs/cors v1.10.1/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU= github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk= diff --git a/observability/observability.go b/observability/observability.go index d601bd1..4283594 100644 --- a/observability/observability.go +++ b/observability/observability.go @@ -9,8 +9,8 @@ import ( adapter "github.com/axiomhq/axiom-go/adapters/slog" "github.com/axiomhq/axiom-go/axiom" axiotel "github.com/axiomhq/axiom-go/axiom/otel" + slogotel "github.com/remychantenay/slog-otel" workflowsv1 "github.com/tilebox/tilebox-go/protogen/go/workflows/v1" - "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/codes" "go.opentelemetry.io/otel/propagation" "go.opentelemetry.io/otel/sdk/resource" @@ -21,53 +21,49 @@ import ( var propagator = propagation.TraceContext{} -// AxiomLogHandler returns an Axiom handler for slog. -func AxiomLogHandler(dataset, token string, level slog.Level) (*adapter.Handler, error) { +// NewAxiomLogger returns a slog.Logger that logs to Axiom. +// It also returns a shutdown function that should be called when the logger is no longer needed, to ensure +// all logs are flushed. +func NewAxiomLogger(dataset, token string, level slog.Level) (*slog.Logger, func(), error) { + noShutdown := func() {} client, err := axiom.NewClient(axiom.SetToken(token)) if err != nil { - return nil, err + return nil, noShutdown, err } - return adapter.New( + axiomHandler, err := adapter.New( adapter.SetDataset(dataset), adapter.SetClient(client), adapter.SetLevel(level), ) -} + if err != nil { + return nil, noShutdown, err + } -// AxiomTraceExporter returns an Axiom OpenTelemetry trace exporter. -func AxiomTraceExporter(ctx context.Context, dataset, token string) (trace.SpanExporter, error) { - return axiotel.TraceExporter(ctx, dataset, axiotel.SetToken(token)) + return slog.New(slogotel.OtelHandler{Next: axiomHandler}), axiomHandler.Close, nil } -func SetupOtelTracing(serviceName, serviceVersion string, exporters ...trace.SpanExporter) func(ctx context.Context) { - tp := tracerProvider(serviceName, serviceVersion, exporters) - otel.SetTracerProvider(tp) - - shutDownFunc := func(ctx context.Context) { - _ = tp.Shutdown(ctx) +func NewAxiomTracerProvider(ctx context.Context, dataset, token, serviceName, serviceVersion string) (oteltrace.TracerProvider, func(), error) { + noShutdown := func() {} + exporter, err := axiotel.TraceExporter(ctx, dataset, axiotel.SetToken(token)) + if err != nil { + return nil, noShutdown, err } - return shutDownFunc -} - -// tracerProvider configures and returns a new OpenTelemetry tracer provider. -func tracerProvider(serviceName, serviceVersion string, exporters []trace.SpanExporter) *trace.TracerProvider { - rs := resource.NewWithAttributes( + traceResource := resource.NewWithAttributes( semconv.SchemaURL, semconv.ServiceNameKey.String(serviceName), semconv.ServiceVersionKey.String(serviceVersion), ) - opts := []trace.TracerProviderOption{ - trace.WithResource(rs), - } - - for _, exporter := range exporters { - opts = append(opts, trace.WithBatcher(exporter, trace.WithMaxQueueSize(10*1024))) + provider := trace.NewTracerProvider( + trace.WithResource(traceResource), + trace.WithBatcher(exporter, trace.WithMaxQueueSize(10*1024)), + ) + shutdown := func() { + _ = provider.Shutdown(ctx) } - - return trace.NewTracerProvider(opts...) + return provider, shutdown, nil } // generateTraceParent generates a random traceparent. diff --git a/workflows/v1/client.go b/workflows/v1/client.go new file mode 100644 index 0000000..e31301f --- /dev/null +++ b/workflows/v1/client.go @@ -0,0 +1,103 @@ +package workflows + +import ( + "context" + "net" + "net/http" + "strings" + + "connectrpc.com/connect" + "github.com/tilebox/tilebox-go/grpc" + "github.com/tilebox/tilebox-go/protogen/go/workflows/v1/workflowsv1connect" +) + +// clientConfig contains the configuration for a gRPC client to a workflows service. +type clientConfig struct { + httpClient connect.HTTPClient + url string + apiKey string + connectOptions []connect.ClientOption +} + +// ClientOption is an interface for configuring a client. Using such options helpers is a +// quite common pattern in Go, as it allows for optional parameters in constructors. +// This concrete implementation here is inspired by how libraries such as axiom-go and connect do their +// configuration. +type ClientOption func(*clientConfig) + +func WithHTTPClient(httpClient connect.HTTPClient) ClientOption { + return func(cfg *clientConfig) { + cfg.httpClient = httpClient + } +} + +func WithURL(url string) ClientOption { + return func(cfg *clientConfig) { + cfg.url = url + } +} + +func WithAPIKey(apiKey string) ClientOption { + return func(cfg *clientConfig) { + cfg.apiKey = apiKey + } +} + +func WithConnectClientOptions(options ...connect.ClientOption) ClientOption { + return func(cfg *clientConfig) { + cfg.connectOptions = append(cfg.connectOptions, options...) + } +} + +func newClientConfig(options []ClientOption) *clientConfig { + cfg := &clientConfig{ + url: "https://api.tilebox.com", + } + for _, option := range options { + option(cfg) + } + + // if no http client is set by the user, we use a default one + if cfg.httpClient == nil { + // if the URL looks like an HTTP URL, we use a retrying HTTP client + if strings.HasPrefix(cfg.url, "https://") || strings.HasPrefix(cfg.url, "http://") { + cfg.httpClient = grpc.RetryHTTPClient() + } else { // we connect to a unix socket + dial := func(context.Context, string, string) (net.Conn, error) { + return net.Dial("unix", cfg.url) + } + transport := &http.Transport{DialContext: dial} + cfg.httpClient = &http.Client{Transport: transport} + } + } + + return cfg +} + +func NewTaskClient(options ...ClientOption) workflowsv1connect.TaskServiceClient { + cfg := newClientConfig(options) + + return workflowsv1connect.NewTaskServiceClient( + cfg.httpClient, + cfg.url, + connect.WithClientOptions(cfg.connectOptions...), + connect.WithInterceptors( + grpc.NewAddAuthTokenInterceptor(func() string { + return cfg.apiKey + })), + ) +} + +func NewJobClient(options ...ClientOption) workflowsv1connect.JobServiceClient { + cfg := newClientConfig(options) + + return workflowsv1connect.NewJobServiceClient( + cfg.httpClient, + cfg.url, + connect.WithClientOptions(cfg.connectOptions...), + connect.WithInterceptors( + grpc.NewAddAuthTokenInterceptor(func() string { + return cfg.apiKey + })), + ) +} diff --git a/workflows/v1/jobs.go b/workflows/v1/jobs.go index 607b46d..e377812 100644 --- a/workflows/v1/jobs.go +++ b/workflows/v1/jobs.go @@ -15,15 +15,47 @@ import ( "google.golang.org/protobuf/proto" ) +type jobServiceConfig struct { + tracerProvider trace.TracerProvider + tracerName string +} + +type JobServiceOption func(*jobServiceConfig) + +func WithJobServiceTracerProvider(tracerProvider trace.TracerProvider) JobServiceOption { + return func(cfg *jobServiceConfig) { + cfg.tracerProvider = tracerProvider + } +} + +func WithJobServiceTracerName(tracerName string) JobServiceOption { + return func(cfg *jobServiceConfig) { + cfg.tracerName = tracerName + } +} + type JobService struct { client workflowsv1connect.JobServiceClient tracer trace.Tracer } -func NewJobService(client workflowsv1connect.JobServiceClient) *JobService { +func newJobServiceConfig(options []JobServiceOption) *jobServiceConfig { + cfg := &jobServiceConfig{ + tracerProvider: otel.GetTracerProvider(), // use the global tracer provider by default + tracerName: "tilebox.com/observability", // the default tracer name we use + } + for _, option := range options { + option(cfg) + } + + return cfg +} + +func NewJobService(client workflowsv1connect.JobServiceClient, options ...JobServiceOption) *JobService { + cfg := newJobServiceConfig(options) return &JobService{ client: client, - tracer: otel.Tracer("tilebox.com/observability"), + tracer: cfg.tracerProvider.Tracer(cfg.tracerName), } } diff --git a/workflows/v1/runner.go b/workflows/v1/runner.go index f657327..93b0a51 100644 --- a/workflows/v1/runner.go +++ b/workflows/v1/runner.go @@ -12,13 +12,14 @@ import ( "syscall" "time" + "go.opentelemetry.io/otel" + "connectrpc.com/connect" "github.com/avast/retry-go/v4" "github.com/google/uuid" "github.com/tilebox/tilebox-go/observability" workflowsv1 "github.com/tilebox/tilebox-go/protogen/go/workflows/v1" "github.com/tilebox/tilebox-go/protogen/go/workflows/v1/workflowsv1connect" - "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/trace" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/durationpb" @@ -28,25 +29,81 @@ type ContextKeyTaskExecutionType string const ContextKeyTaskExecution ContextKeyTaskExecutionType = "x-tilebox-task-execution-object" -const DefaultClusterSlug = "testing-4qgCk4qHH85qR7" - -// const DefaultClusterSlug = "workflow-dev-EifhUozDpwAJDL" - const pollingInterval = 5 * time.Second const jitterInterval = 5 * time.Second +type taskRunnerConfig struct { + clusterSlug string + tracerProvider trace.TracerProvider + tracerName string + logger *slog.Logger +} + +type TaskRunnerOption func(*taskRunnerConfig) + +func WithCluster(clusterSlug string) TaskRunnerOption { + return func(cfg *taskRunnerConfig) { + cfg.clusterSlug = clusterSlug + } +} + +func WithRunnerTracerProvider(tracerProvider trace.TracerProvider) TaskRunnerOption { + return func(cfg *taskRunnerConfig) { + cfg.tracerProvider = tracerProvider + } +} + +func WithRunnerTracerName(tracerName string) TaskRunnerOption { + return func(cfg *taskRunnerConfig) { + cfg.tracerName = tracerName + } +} + +func WithRunnerLogger(logger *slog.Logger) TaskRunnerOption { + return func(cfg *taskRunnerConfig) { + cfg.logger = logger + } +} + +func newTaskRunnerConfig(options []TaskRunnerOption) (*taskRunnerConfig, error) { + cfg := &taskRunnerConfig{ + tracerProvider: otel.GetTracerProvider(), // use the global tracer provider by default + tracerName: "tilebox.com/observability", // the default tracer name we use + logger: slog.Default(), + } + for _, option := range options { + option(cfg) + } + + if cfg.clusterSlug == "" { + return nil, errors.New("cluster slug is required") + } + + return cfg, nil +} + type TaskRunner struct { - Client workflowsv1connect.TaskServiceClient + client workflowsv1connect.TaskServiceClient taskDefinitions map[taskIdentifier]ExecutableTask - tracer trace.Tracer + + cluster string + tracer trace.Tracer + logger *slog.Logger } -func NewTaskRunner(client workflowsv1connect.TaskServiceClient) *TaskRunner { +func NewTaskRunner(client workflowsv1connect.TaskServiceClient, options ...TaskRunnerOption) (*TaskRunner, error) { + cfg, err := newTaskRunnerConfig(options) + if err != nil { + return nil, err + } return &TaskRunner{ - Client: client, + client: client, taskDefinitions: make(map[taskIdentifier]ExecutableTask), - tracer: otel.Tracer("tilebox.com/observability"), - } + + cluster: cfg.clusterSlug, + tracer: cfg.tracerProvider.Tracer(cfg.tracerName), + logger: cfg.logger, + }, nil } func (t *TaskRunner) RegisterTask(task ExecutableTask) error { @@ -115,11 +172,11 @@ func (t *TaskRunner) Run(ctx context.Context) { for { if task == nil { // if we don't have a task, let's try work-stealing one - taskResponse, err := t.Client.NextTask(ctx, connect.NewRequest(&workflowsv1.NextTaskRequest{ - NextTaskToRun: &workflowsv1.NextTaskToRun{ClusterSlug: DefaultClusterSlug, Identifiers: identifiers}, + taskResponse, err := t.client.NextTask(ctx, connect.NewRequest(&workflowsv1.NextTaskRequest{ + NextTaskToRun: &workflowsv1.NextTaskToRun{ClusterSlug: t.cluster, Identifiers: identifiers}, })) if err != nil { - slog.ErrorContext(ctx, "failed to work-steal a task", "error", err) + t.logger.ErrorContext(ctx, "failed to work-steal a task", "error", err) // return // should we even try again, or just stop here? } else { task = taskResponse.Msg.GetNextTask() @@ -128,7 +185,7 @@ func (t *TaskRunner) Run(ctx context.Context) { if task != nil { // we have a task to execute if isEmpty(task.GetId()) { - slog.ErrorContext(ctx, "got a task without an ID - skipping to the next task") + t.logger.ErrorContext(ctx, "got a task without an ID - skipping to the next task") task = nil continue } @@ -142,7 +199,7 @@ func (t *TaskRunner) Run(ctx context.Context) { if executionContext != nil && len(executionContext.Subtasks) > 0 { computedTask.SubTasks = executionContext.Subtasks } - nextTaskToRun := &workflowsv1.NextTaskToRun{ClusterSlug: DefaultClusterSlug, Identifiers: identifiers} + nextTaskToRun := &workflowsv1.NextTaskToRun{ClusterSlug: t.cluster, Identifiers: identifiers} select { case <-ctxSignal.Done(): // if we got a context cancellation, don't request a new task @@ -153,37 +210,37 @@ func (t *TaskRunner) Run(ctx context.Context) { task, err = retry.DoWithData( func() (*workflowsv1.Task, error) { - taskResponse, err := t.Client.NextTask(ctx, connect.NewRequest(&workflowsv1.NextTaskRequest{ + taskResponse, err := t.client.NextTask(ctx, connect.NewRequest(&workflowsv1.NextTaskRequest{ ComputedTask: computedTask, NextTaskToRun: nextTaskToRun, })) if err != nil { - slog.ErrorContext(ctx, "failed to mark task as computed, retrying", "error", err) + t.logger.ErrorContext(ctx, "failed to mark task as computed, retrying", "error", err) return nil, err } return taskResponse.Msg.GetNextTask(), nil }, retry.Context(ctxSignal), retry.DelayType(retry.CombineDelay(retry.BackOffDelay, retry.RandomDelay)), ) if err != nil { - slog.ErrorContext(ctx, "failed to retry NextTask", "error", err) + t.logger.ErrorContext(ctx, "failed to retry NextTask", "error", err) return // we got a cancellation signal, so let's just stop here } } else { // err != nil - slog.ErrorContext(ctx, "task execution failed", "error", err) + t.logger.ErrorContext(ctx, "task execution failed", "error", err) err = retry.Do( func() error { - _, err := t.Client.TaskFailed(ctx, connect.NewRequest(&workflowsv1.TaskFailedRequest{ + _, err := t.client.TaskFailed(ctx, connect.NewRequest(&workflowsv1.TaskFailedRequest{ TaskId: task.GetId(), CancelJob: true, })) if err != nil { - slog.ErrorContext(ctx, "failed to report task failure", "error", err) + t.logger.ErrorContext(ctx, "failed to report task failure", "error", err) return err } return nil }, retry.Context(ctxSignal), retry.DelayType(retry.CombineDelay(retry.BackOffDelay, retry.RandomDelay)), ) if err != nil { - slog.ErrorContext(ctx, "failed to retry TaskFailed", "error", err) + t.logger.ErrorContext(ctx, "failed to retry TaskFailed", "error", err) return // we got a cancellation signal, so let's just stop here } task = nil // reported a task failure, let's work-steal again @@ -193,7 +250,7 @@ func (t *TaskRunner) Run(ctx context.Context) { } } else { // if we didn't get a task, let's wait for a bit and try work-stealing again - slog.DebugContext(ctx, "no task to run") + t.logger.DebugContext(ctx, "no task to run") // instead of time.Sleep we set a timer and select on it, so we still can catch signals like SIGINT timer := time.NewTimer(pollingInterval + rand.N(jitterInterval)) @@ -210,7 +267,7 @@ func (t *TaskRunner) Run(ctx context.Context) { func (t *TaskRunner) executeTask(ctx context.Context, task *workflowsv1.Task) (*taskExecutionContext, error) { // start a goroutine to extend the lease of the task continuously until the task execution is finished leaseCtx, stopLeaseExtensions := context.WithCancel(ctx) - go extendTaskLease(leaseCtx, t.Client, task.GetId(), task.GetLease().GetLease().AsDuration(), task.GetLease().GetRecommendedWaitUntilNextExtension().AsDuration()) + go t.extendTaskLease(leaseCtx, t.client, task.GetId(), task.GetLease().GetLease().AsDuration(), task.GetLease().GetRecommendedWaitUntilNextExtension().AsDuration()) defer stopLeaseExtensions() // actually execute the task @@ -224,7 +281,7 @@ func (t *TaskRunner) executeTask(ctx context.Context, task *workflowsv1.Task) (* } return observability.StartJobSpan(ctx, t.tracer, fmt.Sprintf("task/%s", identifier.Name()), task.GetJob(), func(ctx context.Context) (*taskExecutionContext, error) { - slog.DebugContext(ctx, "executing task", "task", identifier.Name, "version", identifier.Version) + t.logger.DebugContext(ctx, "executing task", "task", identifier.Name, "version", identifier.Version) taskStruct := reflect.New(reflect.ValueOf(taskPrototype).Elem().Type()).Interface().(ExecutableTask) _, isProtobuf := taskStruct.(proto.Message) @@ -240,7 +297,7 @@ func (t *TaskRunner) executeTask(ctx context.Context, task *workflowsv1.Task) (* } } - executionContext := withTaskExecutionContext(ctx, t.Client, task) + executionContext := t.withTaskExecutionContext(ctx, task) err := taskStruct.Execute(executionContext) if r := recover(); r != nil { // recover from panics during task executions, so we can still report the error to the server and continue @@ -257,7 +314,7 @@ func (t *TaskRunner) executeTask(ctx context.Context, task *workflowsv1.Task) (* // extendTaskLease is a function designed to be run as a goroutine, extending the lease of a task continuously until the // context is cancelled, which indicates that the execution of the task is finished. -func extendTaskLease(ctx context.Context, client workflowsv1connect.TaskServiceClient, taskID *workflowsv1.UUID, initialLease, initialWait time.Duration) { +func (t *TaskRunner) extendTaskLease(ctx context.Context, client workflowsv1connect.TaskServiceClient, taskID *workflowsv1.UUID, initialLease, initialWait time.Duration) { wait := initialWait lease := initialLease for { @@ -268,21 +325,21 @@ func extendTaskLease(ctx context.Context, client workflowsv1connect.TaskServiceC return case <-timer.C: // the timer expired, let's try to extend the lease } - slog.DebugContext(ctx, "extending task lease", "task_id", uuid.Must(uuid.FromBytes(taskID.GetUuid())), "lease", lease, "wait", wait) + t.logger.DebugContext(ctx, "extending task lease", "task_id", uuid.Must(uuid.FromBytes(taskID.GetUuid())), "lease", lease, "wait", wait) req := &workflowsv1.TaskLeaseRequest{ TaskId: taskID, RequestedLease: durationpb.New(2 * lease), // double the current lease duration for the next extension } extension, err := client.ExtendTaskLease(ctx, connect.NewRequest(req)) if err != nil { - slog.ErrorContext(ctx, "failed to extend task lease", "error", err, "task_id", uuid.Must(uuid.FromBytes(taskID.GetUuid()))) + t.logger.ErrorContext(ctx, "failed to extend task lease", "error", err, "task_id", uuid.Must(uuid.FromBytes(taskID.GetUuid()))) // The server probably has an internal error, but there is no point in trying to extend the lease again // because it will be expired then, so let's just return return } if extension.Msg.GetLease() == nil { // the server did not return a lease extension, it means that there is no need in trying to extend the lease - slog.DebugContext(ctx, "task lease extension not granted", "task_id", uuid.Must(uuid.FromBytes(taskID.GetUuid()))) + t.logger.DebugContext(ctx, "task lease extension not granted", "task_id", uuid.Must(uuid.FromBytes(taskID.GetUuid()))) return } // will probably be double the previous lease (since we requested that) or capped by the server at maxLeaseDuration @@ -293,14 +350,14 @@ func extendTaskLease(ctx context.Context, client workflowsv1connect.TaskServiceC type taskExecutionContext struct { CurrentTask *workflowsv1.Task - Client workflowsv1connect.TaskServiceClient + runner *TaskRunner Subtasks []*workflowsv1.TaskSubmission } -func withTaskExecutionContext(ctx context.Context, client workflowsv1connect.TaskServiceClient, task *workflowsv1.Task) context.Context { +func (t *TaskRunner) withTaskExecutionContext(ctx context.Context, task *workflowsv1.Task) context.Context { return context.WithValue(ctx, ContextKeyTaskExecution, &taskExecutionContext{ CurrentTask: task, - Client: client, + runner: t, Subtasks: make([]*workflowsv1.TaskSubmission, 0), }) } @@ -339,7 +396,7 @@ func SubmitSubtasks(ctx context.Context, tasks ...Task) error { } executionContext.Subtasks = append(executionContext.Subtasks, &workflowsv1.TaskSubmission{ - ClusterSlug: DefaultClusterSlug, + ClusterSlug: executionContext.runner.cluster, Identifier: &workflowsv1.TaskIdentifier{ Name: identifier.Name(), Version: identifier.Version(), diff --git a/workflows/v1/runner_test.go b/workflows/v1/runner_test.go index 172d87d..e2f6452 100644 --- a/workflows/v1/runner_test.go +++ b/workflows/v1/runner_test.go @@ -5,6 +5,8 @@ import ( "reflect" "testing" + "github.com/google/uuid" + workflowsv1 "github.com/tilebox/tilebox-go/protogen/go/workflows/v1" "github.com/tilebox/tilebox-go/protogen/go/workflows/v1/workflowsv1connect" ) @@ -54,9 +56,13 @@ func TestTaskRunner_RegisterTask(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t1 *testing.T) { - runner := NewTaskRunner(mockClient{}) + runner, err := NewTaskRunner(mockClient{}, WithCluster("testing-4qgCk4qHH85qR7")) + if err != nil { + t1.Fatalf("Failed to create TaskRunner: %v", err) + return + } - err := runner.RegisterTask(tt.args.task) + err = runner.RegisterTask(tt.args.task) if (err != nil) != tt.wantErr { t1.Errorf("RegisterTask() error = %v, wantErr %v", err, tt.wantErr) } @@ -119,9 +125,13 @@ func TestTaskRunner_RegisterTasks(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t1 *testing.T) { - runner := NewTaskRunner(mockClient{}) + runner, err := NewTaskRunner(mockClient{}, WithCluster("testing-4qgCk4qHH85qR7")) + if err != nil { + t1.Fatalf("Failed to create TaskRunner: %v", err) + return + } - err := runner.RegisterTasks(tt.args.tasks...) + err = runner.RegisterTasks(tt.args.tasks...) if (err != nil) != tt.wantErr { t1.Errorf("RegisterTasks() error = %v, wantErr %v", err, tt.wantErr) } @@ -181,42 +191,57 @@ func Test_isEmpty(t *testing.T) { } func Test_withTaskExecutionContextRoundtrip(t *testing.T) { + cluster := "testing-4qgCk4qHH85qR7" + + runner, err := NewTaskRunner(mockClient{}, WithCluster(cluster)) + if err != nil { + t.Fatalf("Failed to create TaskRunner: %v", err) + } + + taskID := uuid.New() + type args struct { - ctx context.Context - client workflowsv1connect.TaskServiceClient - task *workflowsv1.Task + ctx context.Context + task *workflowsv1.Task } tests := []struct { name string args args - want *taskExecutionContext }{ { name: "withTaskExecutionContext", args: args{ - ctx: context.Background(), - client: mockClient{}, - task: nil, - }, - want: &taskExecutionContext{ - CurrentTask: nil, - Client: mockClient{}, - Subtasks: []*workflowsv1.TaskSubmission{}, + ctx: context.Background(), + task: &workflowsv1.Task{ + Id: &workflowsv1.UUID{Uuid: taskID[:]}, + }, }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - updatedCtx := withTaskExecutionContext(tt.args.ctx, tt.args.client, tt.args.task) + updatedCtx := runner.withTaskExecutionContext(tt.args.ctx, tt.args.task) got := getTaskExecutionContext(updatedCtx) - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("withTaskExecutionContext() = %v, want %v", got, tt.want) + if len(got.Subtasks) != 0 { + t.Errorf("withTaskExecutionContext() Subtasks length = %v, want 0", len(got.Subtasks)) + } + if got.CurrentTask.GetId() != tt.args.task.GetId() { + t.Errorf("withTaskExecutionContext() CurrentTask = %v, want %v", got.CurrentTask, tt.args.task) } }) } } func TestSubmitSubtasks(t *testing.T) { + cluster := "testing-4qgCk4qHH85qR7" + + runner, err := NewTaskRunner(mockClient{}, WithCluster(cluster)) + if err != nil { + t.Fatalf("Failed to create TaskRunner: %v", err) + } + + currentTaskID := uuid.New() + type args struct { tasks []Task } @@ -242,7 +267,7 @@ func TestSubmitSubtasks(t *testing.T) { }, wantSubtasks: []*workflowsv1.TaskSubmission{ { - ClusterSlug: DefaultClusterSlug, + ClusterSlug: cluster, Identifier: &workflowsv1.TaskIdentifier{Name: "testTask1", Version: "v0.0"}, Input: []byte("{\"ExecutableTask\":null}"), Display: "testTask1", @@ -263,7 +288,7 @@ func TestSubmitSubtasks(t *testing.T) { wantErr: true, wantSubtasks: []*workflowsv1.TaskSubmission{ { - ClusterSlug: DefaultClusterSlug, + ClusterSlug: cluster, Identifier: &workflowsv1.TaskIdentifier{Name: "testTask1", Version: "v0.0"}, Input: []byte("{\"ExecutableTask\":null}"), Display: "testTask1", @@ -275,7 +300,9 @@ func TestSubmitSubtasks(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - ctx := withTaskExecutionContext(context.Background(), nil, nil) + ctx := runner.withTaskExecutionContext(context.Background(), &workflowsv1.Task{ + Id: &workflowsv1.UUID{Uuid: currentTaskID[:]}, + }) err := SubmitSubtasks(ctx, tt.args.tasks...) if (err != nil) != tt.wantErr {