huajsj commented on code in PR #11334:
URL: https://github.com/apache/tvm/pull/11334#discussion_r875537165


##########
tests/python/relay/test_pipeline_executor.py:
##########
@@ -22,12 +22,193 @@
 import tvm
 import tvm.testing
 from tvm import relay
-from tvm.relay import transform
+from tvm.relay import transform, build_module
+from tvm.relay.testing import run_opt_pass
 from tvm.contrib import graph_executor, pipeline_executor, 
pipeline_executor_build
 from tvm._ffi import get_global_func
 from tvm.contrib import cc as _cc
 
 
+"""Splitting graph into a list of subgraph"""
+
+
+def graph_split(expr, split_conf, params=None):
+    def get_dep_var(sub_var_dep):
+        return [var for var, _ in sub_var_dep[len(sub_var_dep) - 
1]["ref_nodes"].items()]
+
+    def parse_dependency(value, snode_dep, new_input_idx):
+        new_args = []
+        need_update = False
+        for var in value.args:
+            is_free_var = False
+            for i in range(0, len(snode_dep) - 1):
+                dep = snode_dep[i]
+                if var in dep["nodes"]:
+                    # Mark the previous subgraph node as a dependency.
+                    dep["nodes"][var] = dep["nodes"][var] + 1
+                    dep["ref_nodes"][var] = dep["nodes"][var]
+                    # The var of this call is a free_var
+                    is_free_var = True
+            # if the var of this call is free_var, recreate it and give it a 
fixed input name.
+            if is_free_var:
+                need_update = True
+                new_args.append(relay.var(f"data_n_{new_input_idx}", 
var.checked_type))
+                new_input_idx = new_input_idx + 1
+            else:
+                new_args.append(var)
+        # if the call have a free_var, recreate it.
+        if need_update:
+            value = tvm.relay.expr.Call(
+                value.op, new_args, value.attrs, value.type_args, value.span
+            )
+        return value, snode_dep, new_input_idx
+
+    def merge_constant_expr(constant_expr, expr):
+        # merge constant express with a express
+        if not isinstance(constant_expr.body, tvm.relay.expr.Let):
+            return tvm.relay.expr.Let(constant_expr.var, constant_expr.value, 
expr)
+
+        return tvm.relay.expr.Let(
+            constant_expr.var, constant_expr.value, 
merge_constant_expr(constant_expr.body, expr)
+        )
+
+    def _recursion(anf, pipeline_mods, split_conf, constant_expr):
+        # Enumrate all operators of compute graph, then split the compute 
graph into a group of
+        # subgraph.
+        nonlocal operator_index_map
+        nonlocal new_input_idx
+        nonlocal snode_dep
+        cur_node_dep = snode_dep[len(snode_dep) - 1]
+        if isinstance(anf, tvm.relay.Function):
+            return tvm.relay.Function(
+                anf.params,
+                _recursion(anf.body, pipeline_mods, split_conf, constant_expr),
+                anf.ret_type,
+                anf.type_params,
+                anf.attrs,
+            )
+        if isinstance(anf, tvm.relay.expr.Let):
+            value = anf.value
+            # record the constant expr to make sure all sugraph can find 
correct constant.
+            if isinstance(value, tvm.relay.expr.Constant):
+                if not constant_expr:
+                    constant_expr = tvm.relay.expr.Let(anf.var, value, anf.var)
+                else:
+                    constant_expr = tvm.relay.expr.Let(anf.var, value, 
constant_expr)
+            if isinstance(value, tvm.relay.expr.Call):
+                new_args = []
+                # build current var list
+                cur_node_dep["nodes"][anf.var] = 0
+                # Get the dependency information of the nodes.
+                value, snode_dep, new_input_idx = parse_dependency(value, 
snode_dep, new_input_idx)
+                if isinstance(value.op, tvm.ir.Op):
+                    if value.op.name in operator_index_map:
+                        operator_index_map[value.op.name] = 
operator_index_map[value.op.name] + 1
+                    else:
+                        operator_index_map[value.op.name] = 0
+                    split_operator_name = split_conf[0]["op_name"] if 
split_conf else ""
+                    split_operator_index = split_conf[0]["op_index"] if 
split_conf else ""
+                    if (
+                        split_conf
+                        and split_operator_name in operator_index_map
+                        and operator_index_map[split_operator_name] >= 
split_operator_index

Review Comment:
   this is a "integer" compare and not a string comparison. added comments to 
avoid confusion.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to