Skip to content

Commit

Permalink
Allow errors in signal interceptors (#564)
Browse files Browse the repository at this point in the history
  • Loading branch information
vitarb authored Oct 8, 2021
1 parent b52191f commit 2c0af69
Show file tree
Hide file tree
Showing 8 changed files with 107 additions and 39 deletions.
6 changes: 3 additions & 3 deletions internal/interceptors.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ type WorkflowInboundCallsInterceptor interface {

// ProcessSignal is called before the signal is passed to the workflow implementation, note that this function does NOT
// have any flow control and can not modify the signal or prevent it from being passed to the workflow.
ProcessSignal(ctx Context, signalName string, arg interface{})
ProcessSignal(ctx Context, signalName string, arg interface{}) error

// HandleQuery is invoked when query request is received, this function HAS flow control and can alter parameters
// or values returned by the query. Handler that is passed as a parameter MUST be called in order to execute the query.
Expand Down Expand Up @@ -117,8 +117,8 @@ func (w WorkflowInboundCallsInterceptorBase) ExecuteWorkflow(ctx Context, workfl
}

// ProcessSignal process inbound signal notification
func (w WorkflowInboundCallsInterceptorBase) ProcessSignal(ctx Context, signalName string, arg interface{}) {
w.Next.ProcessSignal(ctx, signalName, arg)
func (w WorkflowInboundCallsInterceptorBase) ProcessSignal(ctx Context, signalName string, arg interface{}) error {
return w.Next.ProcessSignal(ctx, signalName, arg)
}

// HandleQuery handles inbound query request
Expand Down
14 changes: 7 additions & 7 deletions internal/internal_event_handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,9 @@ type (
currentReplayTime time.Time // Indicates current replay time of the command.
currentLocalTime time.Time // Local time when currentReplayTime was updated.

completeHandler completionHandler // events completion handler
cancelHandler func() // A cancel handler to be invoked on a cancel notification
signalHandler func(name string, input *commonpb.Payloads) // A signal handler to be invoked on a signal event
completeHandler completionHandler // events completion handler
cancelHandler func() // A cancel handler to be invoked on a cancel notification
signalHandler func(name string, input *commonpb.Payloads) error // A signal handler to be invoked on a signal event
queryHandler func(queryType string, queryArgs *commonpb.Payloads) (*commonpb.Payloads, error)

logger log.Logger
Expand Down Expand Up @@ -420,7 +420,7 @@ func (wc *workflowEnvironmentImpl) ExecuteChildWorkflow(
tagWorkflowType, params.WorkflowType.Name)
}

func (wc *workflowEnvironmentImpl) RegisterSignalHandler(handler func(name string, input *commonpb.Payloads)) {
func (wc *workflowEnvironmentImpl) RegisterSignalHandler(handler func(name string, input *commonpb.Payloads) error) {
wc.signalHandler = handler
}

Expand Down Expand Up @@ -863,7 +863,7 @@ func (weh *workflowExecutionEventHandlerImpl) ProcessEvent(
// No Operation.

case enumspb.EVENT_TYPE_WORKFLOW_EXECUTION_SIGNALED:
weh.handleWorkflowExecutionSignaled(event.GetWorkflowExecutionSignaledEventAttributes())
err = weh.handleWorkflowExecutionSignaled(event.GetWorkflowExecutionSignaledEventAttributes())

case enumspb.EVENT_TYPE_SIGNAL_EXTERNAL_WORKFLOW_EXECUTION_INITIATED:
signalID := event.GetSignalExternalWorkflowExecutionInitiatedEventAttributes().Control
Expand Down Expand Up @@ -1222,8 +1222,8 @@ func (weh *workflowExecutionEventHandlerImpl) ProcessLocalActivityResult(lar *lo
}

func (weh *workflowExecutionEventHandlerImpl) handleWorkflowExecutionSignaled(
attributes *historypb.WorkflowExecutionSignaledEventAttributes) {
weh.signalHandler(attributes.GetSignalName(), attributes.Input)
attributes *historypb.WorkflowExecutionSignaledEventAttributes) error {
return weh.signalHandler(attributes.GetSignalName(), attributes.Input)
}

func (weh *workflowExecutionEventHandlerImpl) handleStartChildWorkflowExecutionFailed(event *historypb.HistoryEvent) error {
Expand Down
2 changes: 1 addition & 1 deletion internal/internal_worker_base.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ type (
GetLogger() log.Logger
GetMetricsScope() tally.Scope
// Must be called before WorkflowDefinition.Execute returns
RegisterSignalHandler(handler func(name string, input *commonpb.Payloads))
RegisterSignalHandler(handler func(name string, input *commonpb.Payloads) error)
SignalExternalWorkflow(namespace, workflowID, runID, signalName string, input *commonpb.Payloads, arg interface{}, childWorkflowOnly bool, callback ResultHandler)
RegisterQueryHandler(handler func(queryType string, queryArgs *commonpb.Payloads) (*commonpb.Payloads, error))
IsReplaying() bool
Expand Down
8 changes: 6 additions & 2 deletions internal/internal_workflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -511,17 +511,21 @@ func (d *syncWorkflowDefinition) Execute(env WorkflowEnvironment, header *common
d.cancel()
})

getWorkflowEnvironment(d.rootCtx).RegisterSignalHandler(func(name string, input *commonpb.Payloads) {
getWorkflowEnvironment(d.rootCtx).RegisterSignalHandler(func(name string, input *commonpb.Payloads) error {
eo := getWorkflowEnvOptions(d.rootCtx)
// Notify interceptor, note that because channel send operation below is non-blocking, we can not pass a full closure
// that would encapsulate signal processing. Hence this call is just a "forked" notification that signal has been received.
envInterceptor.inboundInterceptor.ProcessSignal(d.rootCtx, name, input)
err := envInterceptor.inboundInterceptor.ProcessSignal(d.rootCtx, name, input)
if err != nil {
return err
}
// We don't want this code to be blocked ever, using sendAsync().
ch := eo.getSignalChannel(d.rootCtx, name).(*channelImpl)
ok := ch.SendAsync(input)
if !ok {
panic(fmt.Sprintf("Exceeded channel buffer size for signal: %v", name))
}
return nil
})

getWorkflowEnvironment(d.rootCtx).RegisterQueryHandler(func(queryType string, queryArgs *commonpb.Payloads) (*commonpb.Payloads, error) {
Expand Down
12 changes: 6 additions & 6 deletions internal/internal_workflow_testsuite.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ type (
openSessions map[string]*SessionInfo

workflowCancelHandler func()
signalHandler func(name string, input *commonpb.Payloads)
signalHandler func(name string, input *commonpb.Payloads) error
queryHandler func(string, *commonpb.Payloads) (*commonpb.Payloads, error)
startedHandler func(r WorkflowExecution, e error)

Expand Down Expand Up @@ -1982,7 +1982,7 @@ func (env *testWorkflowEnvironmentImpl) RegisterCancelHandler(handler func()) {
env.workflowCancelHandler = handler
}

func (env *testWorkflowEnvironmentImpl) RegisterSignalHandler(handler func(name string, input *commonpb.Payloads)) {
func (env *testWorkflowEnvironmentImpl) RegisterSignalHandler(handler func(name string, input *commonpb.Payloads) error) {
env.signalHandler = handler
}

Expand Down Expand Up @@ -2061,8 +2061,8 @@ func (env *testWorkflowEnvironmentImpl) SignalExternalWorkflow(namespace, workfl
err := newUnknownExternalWorkflowExecutionError()
callback(nil, err)
} else {
childEnv.signalHandler(signalName, input)
callback(nil, nil)
err := childEnv.signalHandler(signalName, input)
callback(nil, err)
}
childEnv.postCallback(func() {}, true) // resume child workflow since a signal is sent.
return
Expand Down Expand Up @@ -2249,7 +2249,7 @@ func (env *testWorkflowEnvironmentImpl) signalWorkflow(name string, input interf
panic(err)
}
env.postCallback(func() {
env.signalHandler(name, data)
_ = env.signalHandler(name, data)
}, startWorkflowTask)
}

Expand All @@ -2264,7 +2264,7 @@ func (env *testWorkflowEnvironmentImpl) signalWorkflowByID(workflowID, signalNam
return serviceerror.NewNotFound(fmt.Sprintf("Workflow %v already completed", workflowID))
}
workflowHandle.env.postCallback(func() {
workflowHandle.env.signalHandler(signalName, data)
_ = workflowHandle.env.signalHandler(signalName, data)
}, true)
return nil
}
Expand Down
4 changes: 2 additions & 2 deletions internal/workflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -383,8 +383,8 @@ func NewFuture(ctx Context) (Future, Settable) {
return impl, impl
}

func (wc *workflowEnvironmentInterceptor) ProcessSignal(Context, string, interface{}) {
// no op
func (wc *workflowEnvironmentInterceptor) ProcessSignal(Context, string, interface{}) error {
return nil
}

func (wc *workflowEnvironmentInterceptor) HandleQuery(_ Context, _ string, args *commonpb.Payloads,
Expand Down
4 changes: 2 additions & 2 deletions test/bindings_workflows_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ type SingleActivityWorkflowDefinition struct {

func (d *SingleActivityWorkflowDefinition) Execute(env bindings.WorkflowEnvironment, header *commonpb.Header, input *commonpb.Payloads) {
var signalInput string
env.RegisterSignalHandler(func(name string, input *commonpb.Payloads) {
_ = converter.GetDefaultDataConverter().FromPayloads(input, &signalInput)
env.RegisterSignalHandler(func(name string, input *commonpb.Payloads) error {
return converter.GetDefaultDataConverter().FromPayloads(input, &signalInput)
})
d.callbacks = append(d.callbacks, func() {
env.NewTimer(time.Second, d.addCallback(func(result *commonpb.Payloads, err error) {
Expand Down
96 changes: 80 additions & 16 deletions test/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,17 +71,18 @@ const (
type IntegrationTestSuite struct {
*require.Assertions
suite.Suite
config Config
client client.Client
activities *Activities
workflows *Workflows
worker worker.Worker
seq int64
taskQueueName string
tracer *tracingInterceptor
trafficController *test.SimpleTrafficController
metricsScopeCloser io.Closer
metricsReporter *metrics.CapturingStatsReporter
config Config
client client.Client
activities *Activities
workflows *Workflows
worker worker.Worker
seq int64
taskQueueName string
tracer *tracingInterceptor
inboundSignalInterceptor *signalInterceptor
trafficController *test.SimpleTrafficController
metricsScopeCloser io.Closer
metricsReporter *metrics.CapturingStatsReporter
}

func TestIntegrationSuite(t *testing.T) {
Expand Down Expand Up @@ -148,8 +149,10 @@ func (ts *IntegrationTestSuite) SetupTest() {
ts.activities.clearInvoked()
ts.taskQueueName = fmt.Sprintf("tq-%v-%s", ts.seq, ts.T().Name())
ts.tracer = newTracingInterceptor()
ts.inboundSignalInterceptor = newSignalInterceptor()
workflowInterceptors := []interceptors.WorkflowInterceptor{ts.tracer, ts.inboundSignalInterceptor}
options := worker.Options{
WorkflowInterceptorChainFactories: []interceptors.WorkflowInterceptor{ts.tracer},
WorkflowInterceptorChainFactories: workflowInterceptors,
WorkflowPanicPolicy: worker.FailWorkflow,
}

Expand Down Expand Up @@ -562,15 +565,39 @@ func (ts *IntegrationTestSuite) TestSignalWorkflow() {
ts.tracer.GetTrace("SignalWorkflow"))
}

func (ts *IntegrationTestSuite) TestSignalWorkflowWithInterceptorError() {
ctx, cancel := context.WithTimeout(context.Background(), ctxTimeout)
defer cancel()

// Return error 3 times from the interceptor
ts.inboundSignalInterceptor.ReturnErrorTimes = 3
wfOpts := ts.startWorkflowOptions("test-signal-workflow-interceptor-error")
run, err := ts.client.ExecuteWorkflow(ctx, wfOpts, ts.workflows.SignalWorkflow)
ts.Nil(err)
err = ts.client.SignalWorkflow(ctx, "test-signal-workflow-interceptor-error", run.GetRunID(), "string-signal", "string-value")
ts.NoError(err)

wt := &commonpb.WorkflowType{Name: "workflow-type"}
err = ts.client.SignalWorkflow(ctx, "test-signal-workflow-interceptor-error", run.GetRunID(), "proto-signal", wt)
ts.NoError(err)

var protoValue *commonpb.WorkflowType
err = run.Get(ctx, &protoValue)
// Workflow should succeed after retries upon an error in the signal interceptor
ts.NoError(err)
// Expect that interceptors were called as many times as 2 signals plus the number of times error was induced into the chain.
ts.Equal(2+ts.inboundSignalInterceptor.ReturnErrorTimes, ts.inboundSignalInterceptor.TimesInvoked)
}

func (ts *IntegrationTestSuite) TestSignalWorkflowWithStubbornGrpcError() {
ctx, cancel := context.WithTimeout(context.Background(), ctxTimeout)
defer cancel()

ts.trafficController.AddError("SignalWorkflowExecution", serviceerror.NewInternal("server failure"), test.FailAllAttempts)
wfOpts := ts.startWorkflowOptions("test-signal-workflow")
wfOpts := ts.startWorkflowOptions("test-signal-workflow-grpc-error")
run, err := ts.client.ExecuteWorkflow(ctx, wfOpts, ts.workflows.SignalWorkflow)
ts.Nil(err)
err = ts.client.SignalWorkflow(ctx, "test-signal-workflow", run.GetRunID(), "string-signal", "string-value")
err = ts.client.SignalWorkflow(ctx, "test-signal-workflow-grpc-error", run.GetRunID(), "string-signal", "string-value")
ts.Error(err)
ts.Equal("context deadline exceeded", err.Error())
}
Expand Down Expand Up @@ -1332,9 +1359,9 @@ func (t *tracingInboundCallsInterceptor) ExecuteWorkflow(ctx workflow.Context, w
return result
}

func (t *tracingInboundCallsInterceptor) ProcessSignal(ctx workflow.Context, signalName string, arg interface{}) {
func (t *tracingInboundCallsInterceptor) ProcessSignal(ctx workflow.Context, signalName string, arg interface{}) error {
t.trace = append(t.trace, "ProcessSignal")
t.Next.ProcessSignal(ctx, signalName, arg)
return t.Next.ProcessSignal(ctx, signalName, arg)
}

func (t *tracingInboundCallsInterceptor) HandleQuery(ctx workflow.Context, queryType string, args *commonpb.Payloads,
Expand All @@ -1345,6 +1372,43 @@ func (t *tracingInboundCallsInterceptor) HandleQuery(ctx workflow.Context, query
return result, err
}

var _ interceptors.WorkflowInterceptor = (*signalInterceptor)(nil)
var _ interceptors.WorkflowInboundCallsInterceptor = (*signalInboundCallsInterceptor)(nil)
var _ interceptors.WorkflowOutboundCallsInterceptor = (*signalOutboundCallsInterceptor)(nil)

type signalInterceptor struct {
ReturnErrorTimes int
TimesInvoked int
}

func newSignalInterceptor() *signalInterceptor {
return &signalInterceptor{}
}

type signalInboundCallsInterceptor struct {
interceptors.WorkflowInboundCallsInterceptorBase
control *signalInterceptor
}

func (t *signalInboundCallsInterceptor) ProcessSignal(ctx workflow.Context, signalName string, arg interface{}) error {
t.control.TimesInvoked++
if t.control.TimesInvoked <= t.control.ReturnErrorTimes {
return fmt.Errorf("interceptor induced failure while processing signal %v", signalName)
}
return t.Next.ProcessSignal(ctx, signalName, arg)
}

type signalOutboundCallsInterceptor struct {
interceptors.WorkflowOutboundCallsInterceptorBase
}

func (t *signalInterceptor) InterceptWorkflow(_ *workflow.Info, next interceptors.WorkflowInboundCallsInterceptor) interceptors.WorkflowInboundCallsInterceptor {
result := &signalInboundCallsInterceptor{}
result.Next = next
result.control = t
return result
}

func (ts *IntegrationTestSuite) metricCount(name string, tagFilterKeyValue ...string) (total int64) {
for _, counter := range ts.metricsReporter.Counts() {
if counter.Name() != name {
Expand Down

0 comments on commit 2c0af69

Please sign in to comment.