Skip to content

Commit

Permalink
Merge branch 'master' into expose-metadata
Browse files Browse the repository at this point in the history
  • Loading branch information
Sushisource authored Jan 17, 2025
2 parents 925c6f0 + 454a8a8 commit 7b4ccfe
Show file tree
Hide file tree
Showing 4 changed files with 199 additions and 15 deletions.
10 changes: 9 additions & 1 deletion internal/internal_workflow_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -1972,9 +1972,10 @@ func (w *workflowClientInterceptor) SignalWorkflow(ctx context.Context, in *Clie
return err
}

links, _ := ctx.Value(NexusOperationLinksKey).([]*commonpb.Link)

request := &workflowservice.SignalWorkflowExecutionRequest{
Namespace: w.client.namespace,
RequestId: uuid.New(),
WorkflowExecution: &commonpb.WorkflowExecution{
WorkflowId: in.WorkflowID,
RunId: in.RunID,
Expand All @@ -1983,6 +1984,13 @@ func (w *workflowClientInterceptor) SignalWorkflow(ctx context.Context, in *Clie
Input: input,
Identity: w.client.identity,
Header: header,
Links: links,
}

if requestID, ok := ctx.Value(NexusOperationRequestIDKey).(string); ok && requestID != "" {
request.RequestId = requestID
} else {
request.RequestId = uuid.New()
}

grpcCtx, cancel := newGRPCContext(ctx, defaultGrpcRetryParameters(ctx))
Expand Down
8 changes: 8 additions & 0 deletions internal/nexus_operations.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,14 @@ type isWorkflowRunOpContextKeyType struct{}
// panic as we don't want to expose a partial client to sync operations.
var IsWorkflowRunOpContextKey = isWorkflowRunOpContextKeyType{}

type nexusOperationRequestIDKeyType struct{}

var NexusOperationRequestIDKey = nexusOperationRequestIDKeyType{}

type nexusOperationLinksKeyType struct{}

var NexusOperationLinksKey = nexusOperationLinksKeyType{}

// NexusOperationContextFromGoContext gets the [NexusOperationContext] associated with the given [context.Context].
func NexusOperationContextFromGoContext(ctx context.Context) (nctx *NexusOperationContext, ok bool) {
nctx, ok = ctx.Value(nexusOperationContextKey).(*NexusOperationContext)
Expand Down
75 changes: 61 additions & 14 deletions temporalnexus/operation.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ import (
"github.com/nexus-rpc/sdk-go/nexus"
"go.temporal.io/api/common/v1"
"go.temporal.io/api/enums/v1"

"go.temporal.io/sdk/client"
"go.temporal.io/sdk/internal"
"go.temporal.io/sdk/internal/common/metrics"
Expand Down Expand Up @@ -94,6 +93,46 @@ func NewSyncOperation[I any, O any](
}
}

// SignalWorkflowInput encapsulates the values required to send a signal to a workflow.
//
// NOTE: Experimental
type SignalWorkflowInput struct {
// WorkflowID is the ID of the workflow which will receive the signal. Required.
WorkflowID string
// RunID is the run ID of the workflow which will receive the signal. Optional. If empty, the signal will be
// delivered to the running execution of the indicated workflow ID.
RunID string
// SignalName is the name of the signal. Required.
SignalName string
// Arg is the payload attached to the signal. Optional.
Arg any
}

// NewWorkflowSignalOperation is a helper for creating a synchronous nexus.Operation to deliver a signal, linking the
// signal to a Nexus operation. Request ID from the Nexus options is propagated to the workflow to ensure idempotency.
//
// NOTE: Experimental
func NewWorkflowSignalOperation[T any](
name string,
getSignalInput func(context.Context, T, nexus.StartOperationOptions) SignalWorkflowInput,
) nexus.Operation[T, nexus.NoValue] {
return NewSyncOperation(name, func(ctx context.Context, c client.Client, in T, options nexus.StartOperationOptions) (nexus.NoValue, error) {
signalInput := getSignalInput(ctx, in, options)

if options.RequestID != "" {
ctx = context.WithValue(ctx, internal.NexusOperationRequestIDKey, options.RequestID)
}

links, err := convertNexusLinks(options.Links, GetLogger(ctx))
if err != nil {
return nil, err
}
ctx = context.WithValue(ctx, internal.NexusOperationLinksKey, links)

return nil, c.SignalWorkflow(ctx, signalInput.WorkflowID, signalInput.RunID, signalInput.SignalName, signalInput.Arg)
})
}

func (o *syncOperation[I, O]) Name() string {
return o.name
}
Expand Down Expand Up @@ -360,8 +399,26 @@ func ExecuteUntypedWorkflow[R any](
})
}

links, err := convertNexusLinks(nexusOptions.Links, nctx.Log)
if err != nil {
return nil, err
}
internal.SetLinksOnStartWorkflowOptions(&startWorkflowOptions, links)

run, err := nctx.Client.ExecuteWorkflow(ctx, startWorkflowOptions, workflowType, args...)
if err != nil {
return nil, err
}
return workflowHandle[R]{
namespace: nctx.Namespace,
id: run.GetID(),
runID: run.GetRunID(),
}, nil
}

func convertNexusLinks(nexusLinks []nexus.Link, log log.Logger) ([]*common.Link, error) {
var links []*common.Link
for _, nexusLink := range nexusOptions.Links {
for _, nexusLink := range nexusLinks {
switch nexusLink.Type {
case string((&common.Link_WorkflowEvent{}).ProtoReflect().Descriptor().FullName()):
link, err := ConvertNexusLinkToLinkWorkflowEvent(nexusLink)
Expand All @@ -374,18 +431,8 @@ func ExecuteUntypedWorkflow[R any](
},
})
default:
nctx.Log.Warn("ignoring unsupported link data type: %q", nexusLink.Type)
log.Warn("ignoring unsupported link data type: %q", nexusLink.Type)
}
}
internal.SetLinksOnStartWorkflowOptions(&startWorkflowOptions, links)

run, err := nctx.Client.ExecuteWorkflow(ctx, startWorkflowOptions, workflowType, args...)
if err != nil {
return nil, err
}
return workflowHandle[R]{
namespace: nctx.Namespace,
id: run.GetID(),
runID: run.GetRunID(),
}, nil
return links, nil
}
121 changes: 121 additions & 0 deletions test/nexus_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,13 @@ func waitForCancelWorkflow(ctx workflow.Context, ownID string) (string, error) {
return "", workflow.Await(ctx, func() bool { return false })
}

func waitForSignalWorkflow(ctx workflow.Context, _ string) (string, error) {
ch := workflow.GetSignalChannel(ctx, "nexus-signal")
var val string
ch.Receive(ctx, &val)
return val, ctx.Err()
}

var workflowOp = temporalnexus.NewWorkflowRunOperation(
"workflow-op",
waitForCancelWorkflow,
Expand Down Expand Up @@ -550,6 +557,120 @@ func TestSyncOperationFromWorkflow(t *testing.T) {
})
}

func TestSignalOperationFromWorkflow(t *testing.T) {
receiverID := "nexus-signal-receiver-" + uuid.NewString()

op := temporalnexus.NewWorkflowSignalOperation("signal-operation", func(_ context.Context, input string, _ nexus.StartOperationOptions) temporalnexus.SignalWorkflowInput {
return temporalnexus.SignalWorkflowInput{
WorkflowID: receiverID,
SignalName: "nexus-signal",
Arg: input,
}
})

ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel()
tc := newTestContext(t, ctx)

senderWF := func(ctx workflow.Context) error {
c := workflow.NewNexusClient(tc.endpoint, "test")
fut := c.ExecuteOperation(ctx, op, "nexus", workflow.NexusOperationOptions{})

var exec workflow.NexusOperationExecution
if err := fut.GetNexusOperationExecution().Get(ctx, &exec); err != nil {
return fmt.Errorf("expected start to succeed: %w", err)
}
if exec.OperationID != "" {
return fmt.Errorf("expected empty operation ID")
}

return fut.Get(ctx, nil)
}

w := worker.New(tc.client, tc.taskQueue, worker.Options{})
service := nexus.NewService("test")
require.NoError(t, service.Register(op))
w.RegisterNexusService(service)
w.RegisterWorkflow(waitForSignalWorkflow)
w.RegisterWorkflow(senderWF)
require.NoError(t, w.Start())
t.Cleanup(w.Stop)

receiver, err := tc.client.ExecuteWorkflow(ctx, client.StartWorkflowOptions{
ID: receiverID,
TaskQueue: tc.taskQueue,
// The endpoint registry may take a bit to propagate to the history service, use a shorter workflow task
// timeout to speed up the attempts.
WorkflowTaskTimeout: time.Second,
}, waitForSignalWorkflow, "successful")
require.NoError(t, err)

sender, err := tc.client.ExecuteWorkflow(ctx, client.StartWorkflowOptions{
TaskQueue: tc.taskQueue,
// The endpoint registry may take a bit to propagate to the history service, use a shorter workflow task
// timeout to speed up the attempts.
WorkflowTaskTimeout: time.Second,
}, senderWF)
require.NoError(t, err)
require.NoError(t, sender.Get(ctx, nil))

iter := tc.client.GetWorkflowHistory(
ctx,
sender.GetID(),
sender.GetRunID(),
false,
enums.HISTORY_EVENT_FILTER_TYPE_ALL_EVENT,
)
var nexusOperationScheduleEventID int64
var targetEvent *historypb.HistoryEvent
for iter.HasNext() {
event, err := iter.Next()
require.NoError(t, err)
if event.GetEventType() == enums.EVENT_TYPE_NEXUS_OPERATION_SCHEDULED {
nexusOperationScheduleEventID = event.GetEventId()
require.NotEmpty(t, event.GetNexusOperationScheduledEventAttributes().GetRequestId())
break
}
}

var out string
require.NoError(t, receiver.Get(ctx, &out))
require.Equal(t, "nexus", out)

iter = tc.client.GetWorkflowHistory(
ctx,
receiver.GetID(),
receiver.GetRunID(),
false,
enums.HISTORY_EVENT_FILTER_TYPE_ALL_EVENT,
)
for iter.HasNext() {
event, err := iter.Next()
require.NoError(t, err)
if event.GetEventType() == enums.EVENT_TYPE_WORKFLOW_EXECUTION_SIGNALED {
targetEvent = event
break
}
}
require.NotNil(t, targetEvent)
require.NotNil(t, targetEvent.GetWorkflowExecutionSignaledEventAttributes())
require.Len(t, targetEvent.GetLinks(), 1)
require.True(t, proto.Equal(
&common.Link_WorkflowEvent{
Namespace: tc.testConfig.Namespace,
WorkflowId: sender.GetID(),
RunId: sender.GetRunID(),
Reference: &common.Link_WorkflowEvent_EventRef{
EventRef: &common.Link_WorkflowEvent_EventReference{
EventId: nexusOperationScheduleEventID,
EventType: enums.EVENT_TYPE_NEXUS_OPERATION_SCHEDULED,
},
},
},
targetEvent.GetLinks()[0].GetWorkflowEvent(),
))
}

func TestAsyncOperationFromWorkflow(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
Expand Down

0 comments on commit 7b4ccfe

Please sign in to comment.