huajsj commented on code in PR #11334:
URL: https://github.com/apache/tvm/pull/11334#discussion_r875538130
##########
tests/python/relay/test_pipeline_executor.py:
##########
@@ -22,12 +22,193 @@
import tvm
import tvm.testing
from tvm import relay
-from tvm.relay import transform
+from tvm.relay import transform, build_module
+from tvm.relay.testing import run_opt_pass
from tvm.contrib import graph_executor, pipeline_executor,
pipeline_executor_build
from tvm._ffi import get_global_func
from tvm.contrib import cc as _cc
+"""Splitting graph into a list of subgraph"""
+
+
+def graph_split(expr, split_conf, params=None):
+ def get_dep_var(sub_var_dep):
+ return [var for var, _ in sub_var_dep[len(sub_var_dep) -
1]["ref_nodes"].items()]
+
+ def parse_dependency(value, snode_dep, new_input_idx):
+ new_args = []
+ need_update = False
+ for var in value.args:
+ is_free_var = False
+ for i in range(0, len(snode_dep) - 1):
+ dep = snode_dep[i]
+ if var in dep["nodes"]:
+ # Mark the previous subgraph node as a dependency.
+ dep["nodes"][var] = dep["nodes"][var] + 1
+ dep["ref_nodes"][var] = dep["nodes"][var]
+ # The var of this call is a free_var
+ is_free_var = True
+ # if the var of this call is free_var, recreate it and give it a
fixed input name.
+ if is_free_var:
+ need_update = True
+ new_args.append(relay.var(f"data_n_{new_input_idx}",
var.checked_type))
+ new_input_idx = new_input_idx + 1
+ else:
+ new_args.append(var)
+ # if the call have a free_var, recreate it.
+ if need_update:
+ value = tvm.relay.expr.Call(
+ value.op, new_args, value.attrs, value.type_args, value.span
+ )
+ return value, snode_dep, new_input_idx
+
+ def merge_constant_expr(constant_expr, expr):
+ # merge constant express with a express
+ if not isinstance(constant_expr.body, tvm.relay.expr.Let):
+ return tvm.relay.expr.Let(constant_expr.var, constant_expr.value,
expr)
+
+ return tvm.relay.expr.Let(
+ constant_expr.var, constant_expr.value,
merge_constant_expr(constant_expr.body, expr)
+ )
+
+ def _recursion(anf, pipeline_mods, split_conf, constant_expr):
+ # Enumrate all operators of compute graph, then split the compute
graph into a group of
+ # subgraph.
+ nonlocal operator_index_map
+ nonlocal new_input_idx
+ nonlocal snode_dep
+ cur_node_dep = snode_dep[len(snode_dep) - 1]
+ if isinstance(anf, tvm.relay.Function):
+ return tvm.relay.Function(
+ anf.params,
+ _recursion(anf.body, pipeline_mods, split_conf, constant_expr),
+ anf.ret_type,
+ anf.type_params,
+ anf.attrs,
+ )
+ if isinstance(anf, tvm.relay.expr.Let):
+ value = anf.value
+ # record the constant expr to make sure all sugraph can find
correct constant.
+ if isinstance(value, tvm.relay.expr.Constant):
+ if not constant_expr:
+ constant_expr = tvm.relay.expr.Let(anf.var, value, anf.var)
+ else:
+ constant_expr = tvm.relay.expr.Let(anf.var, value,
constant_expr)
+ if isinstance(value, tvm.relay.expr.Call):
+ new_args = []
+ # build current var list
+ cur_node_dep["nodes"][anf.var] = 0
+ # Get the dependency information of the nodes.
+ value, snode_dep, new_input_idx = parse_dependency(value,
snode_dep, new_input_idx)
+ if isinstance(value.op, tvm.ir.Op):
+ if value.op.name in operator_index_map:
+ operator_index_map[value.op.name] =
operator_index_map[value.op.name] + 1
Review Comment:
this not work with some python version, to compatible with old version still
keep it.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]