diff --git a/compose/dag.go b/compose/dag.go index 5d99a74..2ccb7fd 100644 --- a/compose/dag.go +++ b/compose/dag.go @@ -22,19 +22,33 @@ import ( ) func dagChannelBuilder(dependencies []string) channel { + waitList := make(map[string]bool, len(dependencies)) + for _, dep := range dependencies { + waitList[dep] = false + } return &dagChannel{ values: make(map[string]any), - waitList: dependencies, + waitList: waitList, } } +type waitPred struct { + key string + skipped bool +} + type dagChannel struct { values map[string]any - waitList []string + waitList map[string]bool value any + skipped bool } func (ch *dagChannel) update(ctx context.Context, ins map[string]any) error { + if ch.skipped { + return nil + } + for k, v := range ins { if _, ok := ch.values[k]; ok { return fmt.Errorf("dag channel update, calculate node repeatedly: %s", k) @@ -42,26 +56,13 @@ func (ch *dagChannel) update(ctx context.Context, ins map[string]any) error { ch.values[k] = v } - for i := range ch.waitList { - if _, ok := ch.values[ch.waitList[i]]; !ok { - return nil - } - } - - if len(ch.waitList) == 1 { - ch.value = ch.values[ch.waitList[0]] - return nil - } - v, err := mergeValues(mapToList(ch.values)) - if err != nil { - return fmt.Errorf("dag channel merge value fail: %w", err) - } - ch.value = v - - return nil + return ch.tryUpdateValue() } func (ch *dagChannel) get(ctx context.Context) (any, error) { + if ch.skipped { + return nil, fmt.Errorf("dag channel has been skipped") + } if ch.value == nil { return nil, fmt.Errorf("dag channel not ready, value is nil") } @@ -71,5 +72,55 @@ func (ch *dagChannel) get(ctx context.Context) (any, error) { } func (ch *dagChannel) ready(ctx context.Context) bool { + if ch.skipped { + return false + } return ch.value != nil } + +func (ch *dagChannel) reportSkip(keys []string) (bool, error) { + for _, k := range keys { + if _, ok := ch.waitList[k]; ok { + ch.waitList[k] = true + } + } + + allSkipped := true + for _, skipped := range ch.waitList { + if !skipped { + allSkipped = false + break + } + } + ch.skipped = allSkipped + + var err error + if !allSkipped { + err = ch.tryUpdateValue() + } + + return allSkipped, err +} + +func (ch *dagChannel) tryUpdateValue() error { + var validList []string + for key, skipped := range ch.waitList { + if _, ok := ch.values[key]; !ok && !skipped { + return nil + } else if !skipped { + validList = append(validList, key) + } + } + + if len(validList) == 1 { + ch.value = ch.values[validList[0]] + return nil + } + v, err := mergeValues(mapToList(ch.values)) + if err != nil { + return err + } + ch.value = v + return nil + +} diff --git a/compose/graph.go b/compose/graph.go index 3ccbadc..b8f9383 100644 --- a/compose/graph.go +++ b/compose/graph.go @@ -723,7 +723,7 @@ func (g *graph) compile(ctx context.Context, opt *graphCompileOptions) (*composa if isWorkflow(g.cmp) { eager = true } - if !eager && opt != nil && opt.getStateEnabled { + if !isWorkflow(g.cmp) && opt != nil && opt.getStateEnabled { return nil, fmt.Errorf("shouldn't set WithGetStateEnable outside of the Workflow") } forbidGetState := true @@ -745,11 +745,6 @@ func (g *graph) compile(ctx context.Context, opt *graphCompileOptions) (*composa } } - // dag doesn't support branch - if runType == runTypeDAG && len(g.branches) > 0 { - return nil, fmt.Errorf("dag doesn't support branch for now") - } - for key := range g.fieldMappingRecords { // not allowed to map multiple fields to the same field toMap := make(map[string]bool) @@ -806,6 +801,17 @@ func (g *graph) compile(ctx context.Context, opt *graphCompileOptions) (*composa } } + for start, branches := range g.branches { + for _, branch := range branches { + for end := range branch.endNodes { + if _, ok := invertedEdges[end]; !ok { + invertedEdges[end] = []string{start} + } else { + invertedEdges[end] = append(invertedEdges[end], start) + } + } + } + } inputChannels := &chanCall{ writeTo: g.edges[START], @@ -833,6 +839,12 @@ func (g *graph) compile(ctx context.Context, opt *graphCompileOptions) (*composa edgeHandlerManager: &edgeHandlerManager{h: g.handlerOnEdges}, } + successors := make(map[string][]string) + for ch := range r.chanSubscribeTo { + successors[ch] = getSuccessors(r.chanSubscribeTo[ch]) + } + r.successors = successors + if g.stateGenerator != nil { r.runCtx = func(ctx context.Context) context.Context { return context.WithValue(ctx, stateKey{}, &internalState{ @@ -868,6 +880,17 @@ func (g *graph) compile(ctx context.Context, opt *graphCompileOptions) (*composa return r.toComposableRunnable(), nil } +func getSuccessors(c *chanCall) []string { + ret := make([]string, len(c.writeTo)) + copy(ret, c.writeTo) + for _, branch := range c.writeToBranches { + for node := range branch.endNodes { + ret = append(ret, node) + } + } + return ret +} + type subGraphCompileCallback struct { closure func(ctx context.Context, info *GraphInfo) } @@ -1043,6 +1066,14 @@ func validateDAG(chanSubscribeTo map[string]*chanCall, invertedEdges map[string] } m[subNode]-- } + for _, subBranch := range chanSubscribeTo[node].writeToBranches { + for subNode := range subBranch.endNodes { + if subNode == END { + continue + } + m[subNode]-- + } + } m[node] = -1 } } diff --git a/compose/graph_manager.go b/compose/graph_manager.go index 0401504..e508b6b 100644 --- a/compose/graph_manager.go +++ b/compose/graph_manager.go @@ -30,6 +30,7 @@ type channel interface { update(context.Context, map[string]any) error get(context.Context) (any, error) ready(context.Context) bool + reportSkip([]string) (bool, error) } type edgeHandlerManager struct { @@ -108,8 +109,9 @@ func (p *preBranchHandlerManager) handle(nodeKey string, idx int, value any, isS } type channelManager struct { - isStream bool - channels map[string]channel + isStream bool + successors map[string][]string + channels map[string]channel edgeHandlerManager *edgeHandlerManager preNodeHandlerManager *preNodeHandlerManager @@ -163,6 +165,37 @@ func (c *channelManager) updateAndGet(ctx context.Context, values map[string]map return c.getFromReadyChannels(ctx, isStream) } +func (c *channelManager) reportBranch(from string, skippedNodes []string) error { + var nKeys []string + for _, node := range skippedNodes { + skipped, err := c.channels[node].reportSkip([]string{from}) + if err != nil { + return err + } + if skipped { + nKeys = append(nKeys, node) + } + } + + for i := 0; i < len(nKeys); i++ { + key := nKeys[i] + if _, ok := c.successors[key]; !ok { + return fmt.Errorf("unknown node: %s", key) + } + for _, successor := range c.successors[key] { + skipped, err := c.channels[successor].reportSkip([]string{key}) + if err != nil { + return err + } + if skipped { + nKeys = append(nKeys, successor) + } + // todo: detect if end node has been skipped? + } + } + return nil +} + type task struct { ctx context.Context nodeKey string diff --git a/compose/graph_run.go b/compose/graph_run.go index 6984fda..0bc41e7 100644 --- a/compose/graph_run.go +++ b/compose/graph_run.go @@ -64,6 +64,7 @@ type chanBuilder func(d []string) channel type runner struct { chanSubscribeTo map[string]*chanCall invertedEdges map[string][]string + successors map[string][]string inputChannels *chanCall chanBuilder chanBuilder // could be nil @@ -176,7 +177,7 @@ func (r *runner) run(ctx context.Context, isStream bool, input any, opts ...Opti } // 1. Calculate active edges and resolve their values. - writeChannelValues, err := r.resolveCompletedTasks(ctx, completedTasks, isStream) + writeChannelValues, err := r.resolveCompletedTasks(ctx, completedTasks, isStream, cm) if err != nil { return nil, err } @@ -233,13 +234,13 @@ func (r *runner) createTasks(ctx context.Context, nodeMap map[string]any, optMap return nextTasks, nil } -func (r *runner) resolveCompletedTasks(ctx context.Context, completedTasks []*task, isStream bool) (map[string]map[string]any, error) { +func (r *runner) resolveCompletedTasks(ctx context.Context, completedTasks []*task, isStream bool, cm *channelManager) (map[string]map[string]any, error) { writeChannelValues := make(map[string]map[string]any) for _, t := range completedTasks { // update channel & new_next_tasks vs := copyItem(t.output, len(t.call.writeTo)+len(t.call.writeToBranches)*2) nextNodeKeys, err := r.calculateNext(ctx, t.nodeKey, t.call, - vs[len(t.call.writeTo)+len(t.call.writeToBranches):], isStream) + vs[len(t.call.writeTo)+len(t.call.writeToBranches):], isStream, cm) if err != nil { return nil, fmt.Errorf("calculate next step fail, node: %s, error: %w", t.nodeKey, err) } @@ -253,7 +254,7 @@ func (r *runner) resolveCompletedTasks(ctx context.Context, completedTasks []*ta return writeChannelValues, nil } -func (r *runner) calculateNext(ctx context.Context, curNodeKey string, startChan *chanCall, input []any, isStream bool) ([]string, error) { +func (r *runner) calculateNext(ctx context.Context, curNodeKey string, startChan *chanCall, input []any, isStream bool, cm *channelManager) ([]string, error) { if len(input) < len(startChan.writeToBranches) { // unreachable return nil, errors.New("calculate next input length is shorter than branches") @@ -266,6 +267,7 @@ func (r *runner) calculateNext(ctx context.Context, curNodeKey string, startChan ret := make([]string, 0, len(startChan.writeTo)) ret = append(ret, startChan.writeTo...) + skippedNodes := make(map[string]struct{}) for i, branch := range startChan.writeToBranches { // check branch input type if needed var err error @@ -305,8 +307,33 @@ func (r *runner) calculateNext(ctx context.Context, curNodeKey string, startChan return nil, errors.New("invoke branch result isn't string") } } + + for node := range branch.endNodes { + if node != w { + skippedNodes[node] = struct{}{} + } + } + ret = append(ret, w) } + + // When a node has multiple branches, + // there may be a situation where a succeeding node is selected by some branches and discarded by the other branches, + // in which case the succeeding node should not be skipped. + var skippedNodeList []string + for _, selected := range ret { + if _, ok := skippedNodes[selected]; ok { + delete(skippedNodes, selected) + } + } + for skipped := range skippedNodes { + skippedNodeList = append(skippedNodeList, skipped) + } + + err := cm.reportBranch(curNodeKey, skippedNodeList) + if err != nil { + return nil, err + } return ret, nil } @@ -337,8 +364,9 @@ func (r *runner) initChannelManager(isStream bool) *channelManager { chs[END] = builder(r.invertedEdges[END]) return &channelManager{ - isStream: isStream, - channels: chs, + isStream: isStream, + channels: chs, + successors: r.successors, edgeHandlerManager: r.edgeHandlerManager, preNodeHandlerManager: r.preNodeHandlerManager, diff --git a/compose/graph_test.go b/compose/graph_test.go index 5a35d61..37ca8cb 100644 --- a/compose/graph_test.go +++ b/compose/graph_test.go @@ -1422,3 +1422,172 @@ func TestDAGStart(t *testing.T) { assert.NoError(t, err) assert.Equal(t, map[string]any{"start": "start", "1": "1"}, result) } + +func concatLambda(s string) *Lambda { + return InvokableLambda(func(ctx context.Context, input string) (output string, err error) { return input + s, nil }) +} +func mapLambda(k, v string) *Lambda { + return InvokableLambda(func(ctx context.Context, input map[string]string) (output map[string]string, err error) { + return map[string]string{ + k: v, + }, nil + }) +} + +func TestBaseDAGBranch(t *testing.T) { + g := NewGraph[string, string]() + + err := g.AddLambdaNode("1", concatLambda("1")) + assert.NoError(t, err) + err = g.AddLambdaNode("2", concatLambda("2")) + assert.NoError(t, err) + err = g.AddBranch(START, NewGraphBranch(func(ctx context.Context, in string) (endNode string, err error) { + if len(in) > 3 { + return "2", nil + } + return "1", nil + }, map[string]bool{"1": true, "2": true})) + assert.NoError(t, err) + err = g.AddEdge("1", END) + assert.NoError(t, err) + err = g.AddEdge("2", END) + assert.NoError(t, err) + + ctx := context.Background() + r, err := g.Compile(ctx, WithNodeTriggerMode(AllPredecessor)) + assert.NoError(t, err) + result, err := r.Invoke(ctx, "hi") + assert.NoError(t, err) + assert.Equal(t, "hi1", result) +} + +func TestMultiDAGBranch(t *testing.T) { + g := NewGraph[map[string]string, map[string]string]() + + err := g.AddLambdaNode("1", mapLambda("1", "1")) + assert.NoError(t, err) + err = g.AddLambdaNode("2", mapLambda("2", "2")) + assert.NoError(t, err) + err = g.AddLambdaNode("3", mapLambda("3", "3")) + assert.NoError(t, err) + err = g.AddLambdaNode("4", mapLambda("4", "4")) + assert.NoError(t, err) + err = g.AddBranch(START, NewGraphBranch(func(ctx context.Context, in map[string]string) (endNode string, err error) { + if len(in["input"]) > 3 { + return "2", nil + } + return "1", nil + }, map[string]bool{"1": true, "2": true})) + err = g.AddBranch(START, NewGraphBranch(func(ctx context.Context, in map[string]string) (endNode string, err error) { + if len(in["input"]) > 3 { + return "4", nil + } + return "3", nil + }, map[string]bool{"3": true, "4": true})) + assert.NoError(t, err) + + err = g.AddEdge("1", END) + assert.NoError(t, err) + err = g.AddEdge("2", END) + assert.NoError(t, err) + err = g.AddEdge("3", END) + assert.NoError(t, err) + err = g.AddEdge("4", END) + assert.NoError(t, err) + + ctx := context.Background() + r, err := g.Compile(ctx, WithNodeTriggerMode(AllPredecessor)) + assert.NoError(t, err) + result, err := r.Invoke(ctx, map[string]string{"input": "hi"}) + assert.NoError(t, err) + assert.Equal(t, map[string]string{ + "1": "1", + "3": "3", + }, result) +} + +func TestCrossDAGBranch(t *testing.T) { + g := NewGraph[map[string]string, map[string]string]() + + err := g.AddLambdaNode("1", mapLambda("1", "1")) + assert.NoError(t, err) + err = g.AddLambdaNode("2", mapLambda("2", "2")) + assert.NoError(t, err) + err = g.AddLambdaNode("3", mapLambda("3", "3")) + assert.NoError(t, err) + err = g.AddBranch(START, NewGraphBranch(func(ctx context.Context, in map[string]string) (endNode string, err error) { + if len(in["input"]) > 3 { + return "2", nil + } + return "1", nil + }, map[string]bool{"1": true, "2": true})) + err = g.AddBranch(START, NewGraphBranch(func(ctx context.Context, in map[string]string) (endNode string, err error) { + if len(in["input"]) > 3 { + return "3", nil + } + return "2", nil + }, map[string]bool{"2": true, "3": true})) + assert.NoError(t, err) + + err = g.AddEdge("1", END) + assert.NoError(t, err) + err = g.AddEdge("2", END) + assert.NoError(t, err) + err = g.AddEdge("3", END) + assert.NoError(t, err) + + ctx := context.Background() + r, err := g.Compile(ctx, WithNodeTriggerMode(AllPredecessor)) + assert.NoError(t, err) + result, err := r.Invoke(ctx, map[string]string{"input": "hi"}) + assert.NoError(t, err) + assert.Equal(t, map[string]string{ + "1": "1", + "2": "2", + }, result) +} + +func TestNestedDAGBranch(t *testing.T) { + g := NewGraph[string, string]() + + err := g.AddLambdaNode("1", concatLambda("1")) + assert.NoError(t, err) + err = g.AddLambdaNode("2", concatLambda("2")) + assert.NoError(t, err) + err = g.AddLambdaNode("3", concatLambda("3")) + assert.NoError(t, err) + err = g.AddLambdaNode("4", concatLambda("4")) + assert.NoError(t, err) + err = g.AddBranch(START, NewGraphBranch(func(ctx context.Context, in string) (endNode string, err error) { + if len(in) > 3 { + return "2", nil + } + return "1", nil + }, map[string]bool{"1": true, "2": true})) + err = g.AddBranch("2", NewGraphBranch(func(ctx context.Context, in string) (endNode string, err error) { + if len(in) > 10 { + return "4", nil + } + return "3", nil + }, map[string]bool{"3": true, "4": true})) + assert.NoError(t, err) + err = g.AddEdge("1", END) + assert.NoError(t, err) + err = g.AddEdge("3", END) + assert.NoError(t, err) + err = g.AddEdge("4", END) + assert.NoError(t, err) + + ctx := context.Background() + r, err := g.Compile(ctx, WithNodeTriggerMode(AllPredecessor)) + assert.NoError(t, err) + result, err := r.Invoke(ctx, "hello") + assert.NoError(t, err) + assert.Equal(t, "hello23", result) + result, err = r.Invoke(ctx, "hi") + assert.NoError(t, err) + assert.Equal(t, "hi1", result) + result, err = r.Invoke(ctx, "hellohello") + assert.NoError(t, err) + assert.Equal(t, "hellohello24", result) +} diff --git a/compose/pregel.go b/compose/pregel.go index 0e9a850..33051e1 100644 --- a/compose/pregel.go +++ b/compose/pregel.go @@ -21,7 +21,7 @@ import ( "fmt" ) -func pregelChannelBuilder(dependencies []string) channel { +func pregelChannelBuilder(_ []string) channel { return &pregelChannel{} } @@ -68,3 +68,7 @@ func (ch *pregelChannel) get(_ context.Context) (any, error) { func (ch *pregelChannel) ready(_ context.Context) bool { return ch.value != nil } + +func (ch *pregelChannel) reportSkip(_ []string) (bool, error) { + return false, nil +}