Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add dag branch #75

Merged
merged 1 commit into from
Feb 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 70 additions & 19 deletions compose/dag.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,46 +22,47 @@ import (
)

func dagChannelBuilder(dependencies []string) channel {
waitList := make(map[string]bool, len(dependencies))
for _, dep := range dependencies {
waitList[dep] = false
}
return &dagChannel{
values: make(map[string]any),
waitList: dependencies,
waitList: waitList,
}
}

type waitPred struct {
key string
skipped bool
}

type dagChannel struct {
values map[string]any
waitList []string
waitList map[string]bool
value any
skipped bool
}

func (ch *dagChannel) update(ctx context.Context, ins map[string]any) error {
if ch.skipped {
return nil
}

for k, v := range ins {
if _, ok := ch.values[k]; ok {
return fmt.Errorf("dag channel update, calculate node repeatedly: %s", k)
}
ch.values[k] = v
}

for i := range ch.waitList {
if _, ok := ch.values[ch.waitList[i]]; !ok {
return nil
}
}

if len(ch.waitList) == 1 {
ch.value = ch.values[ch.waitList[0]]
return nil
}
v, err := mergeValues(mapToList(ch.values))
if err != nil {
return fmt.Errorf("dag channel merge value fail: %w", err)
}
ch.value = v

return nil
return ch.tryUpdateValue()
}

func (ch *dagChannel) get(ctx context.Context) (any, error) {
if ch.skipped {
return nil, fmt.Errorf("dag channel has been skipped")
}
if ch.value == nil {
return nil, fmt.Errorf("dag channel not ready, value is nil")
}
Expand All @@ -71,5 +72,55 @@ func (ch *dagChannel) get(ctx context.Context) (any, error) {
}

func (ch *dagChannel) ready(ctx context.Context) bool {
if ch.skipped {
return false
}
return ch.value != nil
}

func (ch *dagChannel) reportSkip(keys []string) (bool, error) {
for _, k := range keys {
if _, ok := ch.waitList[k]; ok {
ch.waitList[k] = true
}
}

allSkipped := true
for _, skipped := range ch.waitList {
if !skipped {
allSkipped = false
break
}
}
ch.skipped = allSkipped

var err error
if !allSkipped {
err = ch.tryUpdateValue()
}

return allSkipped, err
}

func (ch *dagChannel) tryUpdateValue() error {
var validList []string
for key, skipped := range ch.waitList {
if _, ok := ch.values[key]; !ok && !skipped {
return nil
} else if !skipped {
validList = append(validList, key)
}
}

if len(validList) == 1 {
ch.value = ch.values[validList[0]]
return nil
}
v, err := mergeValues(mapToList(ch.values))
if err != nil {
return err
}
ch.value = v
return nil

}
43 changes: 37 additions & 6 deletions compose/graph.go
Original file line number Diff line number Diff line change
Expand Up @@ -723,7 +723,7 @@ func (g *graph) compile(ctx context.Context, opt *graphCompileOptions) (*composa
if isWorkflow(g.cmp) {
eager = true
}
if !eager && opt != nil && opt.getStateEnabled {
if !isWorkflow(g.cmp) && opt != nil && opt.getStateEnabled {
return nil, fmt.Errorf("shouldn't set WithGetStateEnable outside of the Workflow")
}
forbidGetState := true
Expand All @@ -745,11 +745,6 @@ func (g *graph) compile(ctx context.Context, opt *graphCompileOptions) (*composa
}
}

// dag doesn't support branch
if runType == runTypeDAG && len(g.branches) > 0 {
return nil, fmt.Errorf("dag doesn't support branch for now")
}

for key := range g.fieldMappingRecords {
// not allowed to map multiple fields to the same field
toMap := make(map[string]bool)
Expand Down Expand Up @@ -806,6 +801,17 @@ func (g *graph) compile(ctx context.Context, opt *graphCompileOptions) (*composa

}
}
for start, branches := range g.branches {
for _, branch := range branches {
for end := range branch.endNodes {
if _, ok := invertedEdges[end]; !ok {
invertedEdges[end] = []string{start}
} else {
invertedEdges[end] = append(invertedEdges[end], start)
}
}
}
}

inputChannels := &chanCall{
writeTo: g.edges[START],
Expand Down Expand Up @@ -833,6 +839,12 @@ func (g *graph) compile(ctx context.Context, opt *graphCompileOptions) (*composa
edgeHandlerManager: &edgeHandlerManager{h: g.handlerOnEdges},
}

successors := make(map[string][]string)
for ch := range r.chanSubscribeTo {
successors[ch] = getSuccessors(r.chanSubscribeTo[ch])
}
r.successors = successors

if g.stateGenerator != nil {
r.runCtx = func(ctx context.Context) context.Context {
return context.WithValue(ctx, stateKey{}, &internalState{
Expand Down Expand Up @@ -868,6 +880,17 @@ func (g *graph) compile(ctx context.Context, opt *graphCompileOptions) (*composa
return r.toComposableRunnable(), nil
}

func getSuccessors(c *chanCall) []string {
ret := make([]string, len(c.writeTo))
copy(ret, c.writeTo)
for _, branch := range c.writeToBranches {
for node := range branch.endNodes {
ret = append(ret, node)
}
}
return ret
}

type subGraphCompileCallback struct {
closure func(ctx context.Context, info *GraphInfo)
}
Expand Down Expand Up @@ -1043,6 +1066,14 @@ func validateDAG(chanSubscribeTo map[string]*chanCall, invertedEdges map[string]
}
m[subNode]--
}
for _, subBranch := range chanSubscribeTo[node].writeToBranches {
for subNode := range subBranch.endNodes {
if subNode == END {
continue
}
m[subNode]--
}
}
m[node] = -1
}
}
Expand Down
37 changes: 35 additions & 2 deletions compose/graph_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ type channel interface {
update(context.Context, map[string]any) error
get(context.Context) (any, error)
ready(context.Context) bool
reportSkip([]string) (bool, error)
}

type edgeHandlerManager struct {
Expand Down Expand Up @@ -108,8 +109,9 @@ func (p *preBranchHandlerManager) handle(nodeKey string, idx int, value any, isS
}

type channelManager struct {
isStream bool
channels map[string]channel
isStream bool
successors map[string][]string
channels map[string]channel

edgeHandlerManager *edgeHandlerManager
preNodeHandlerManager *preNodeHandlerManager
Expand Down Expand Up @@ -163,6 +165,37 @@ func (c *channelManager) updateAndGet(ctx context.Context, values map[string]map
return c.getFromReadyChannels(ctx, isStream)
}

func (c *channelManager) reportBranch(from string, skippedNodes []string) error {
var nKeys []string
for _, node := range skippedNodes {
skipped, err := c.channels[node].reportSkip([]string{from})
if err != nil {
return err
}
if skipped {
nKeys = append(nKeys, node)
}
}

for i := 0; i < len(nKeys); i++ {
key := nKeys[i]
if _, ok := c.successors[key]; !ok {
return fmt.Errorf("unknown node: %s", key)
}
for _, successor := range c.successors[key] {
skipped, err := c.channels[successor].reportSkip([]string{key})
if err != nil {
return err
}
if skipped {
nKeys = append(nKeys, successor)
}
// todo: detect if end node has been skipped?
}
}
return nil
}

type task struct {
ctx context.Context
nodeKey string
Expand Down
40 changes: 34 additions & 6 deletions compose/graph_run.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ type chanBuilder func(d []string) channel
type runner struct {
chanSubscribeTo map[string]*chanCall
invertedEdges map[string][]string
successors map[string][]string
inputChannels *chanCall

chanBuilder chanBuilder // could be nil
Expand Down Expand Up @@ -176,7 +177,7 @@ func (r *runner) run(ctx context.Context, isStream bool, input any, opts ...Opti
}

// 1. Calculate active edges and resolve their values.
writeChannelValues, err := r.resolveCompletedTasks(ctx, completedTasks, isStream)
writeChannelValues, err := r.resolveCompletedTasks(ctx, completedTasks, isStream, cm)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -233,13 +234,13 @@ func (r *runner) createTasks(ctx context.Context, nodeMap map[string]any, optMap
return nextTasks, nil
}

func (r *runner) resolveCompletedTasks(ctx context.Context, completedTasks []*task, isStream bool) (map[string]map[string]any, error) {
func (r *runner) resolveCompletedTasks(ctx context.Context, completedTasks []*task, isStream bool, cm *channelManager) (map[string]map[string]any, error) {
writeChannelValues := make(map[string]map[string]any)
for _, t := range completedTasks {
// update channel & new_next_tasks
vs := copyItem(t.output, len(t.call.writeTo)+len(t.call.writeToBranches)*2)
nextNodeKeys, err := r.calculateNext(ctx, t.nodeKey, t.call,
vs[len(t.call.writeTo)+len(t.call.writeToBranches):], isStream)
vs[len(t.call.writeTo)+len(t.call.writeToBranches):], isStream, cm)
if err != nil {
return nil, fmt.Errorf("calculate next step fail, node: %s, error: %w", t.nodeKey, err)
}
Expand All @@ -253,7 +254,7 @@ func (r *runner) resolveCompletedTasks(ctx context.Context, completedTasks []*ta
return writeChannelValues, nil
}

func (r *runner) calculateNext(ctx context.Context, curNodeKey string, startChan *chanCall, input []any, isStream bool) ([]string, error) {
func (r *runner) calculateNext(ctx context.Context, curNodeKey string, startChan *chanCall, input []any, isStream bool, cm *channelManager) ([]string, error) {
if len(input) < len(startChan.writeToBranches) {
// unreachable
return nil, errors.New("calculate next input length is shorter than branches")
Expand All @@ -266,6 +267,7 @@ func (r *runner) calculateNext(ctx context.Context, curNodeKey string, startChan
ret := make([]string, 0, len(startChan.writeTo))
ret = append(ret, startChan.writeTo...)

skippedNodes := make(map[string]struct{})
for i, branch := range startChan.writeToBranches {
// check branch input type if needed
var err error
Expand Down Expand Up @@ -305,8 +307,33 @@ func (r *runner) calculateNext(ctx context.Context, curNodeKey string, startChan
return nil, errors.New("invoke branch result isn't string")
}
}

for node := range branch.endNodes {
if node != w {
skippedNodes[node] = struct{}{}
}
}

ret = append(ret, w)
}

// When a node has multiple branches,
// there may be a situation where a succeeding node is selected by some branches and discarded by the other branches,
// in which case the succeeding node should not be skipped.
var skippedNodeList []string
for _, selected := range ret {
if _, ok := skippedNodes[selected]; ok {
delete(skippedNodes, selected)
}
}
for skipped := range skippedNodes {
skippedNodeList = append(skippedNodeList, skipped)
}

err := cm.reportBranch(curNodeKey, skippedNodeList)
if err != nil {
return nil, err
}
return ret, nil
}

Expand Down Expand Up @@ -337,8 +364,9 @@ func (r *runner) initChannelManager(isStream bool) *channelManager {
chs[END] = builder(r.invertedEdges[END])

return &channelManager{
isStream: isStream,
channels: chs,
isStream: isStream,
channels: chs,
successors: r.successors,

edgeHandlerManager: r.edgeHandlerManager,
preNodeHandlerManager: r.preNodeHandlerManager,
Expand Down
Loading
Loading