tmoreau89 commented on a change in pull request #6125:
URL: https://github.com/apache/incubator-tvm/pull/6125#discussion_r477013469
##########
File path: vta/python/vta/top/graphpack.py
##########
@@ -174,6 +174,103 @@ def _operator_idx_inc(expr, count_meta,
operator_current_idx):
operator_current_idx = operator_current_idx + 1
return operator_current_idx
+
+class ExprDeviceAnnot(ExprMutator):
+ """Visitor to perform graph annotation on an AST.
+
+ Parameters
+ ----------
+ start: int
+ the start location to mark run on vta (inclusive)
+ end: int
+ the end location to mark run on vta (exclusive)
+
+ Returns
+ ---------
+ None
+ """
+ def __init__(self, start=-1, end=-1):
+ self.ext_ctx = tvm.context("ext_dev")
+ self.cpu_ctx = tvm.context("cpu")
+ self.cast = op.op.get("cast")
+ self.counter = -1
+ self.start = start
+ self.end = end
+ super().__init__()
+
+ def visit_call(self, call):
+ """ Visit the children. """
+ # First visit the children.
+ args = [self.visit(arg) for arg in call.args]
+
+ self.counter += 1
+ if self.counter == self.start:
+ ret = relay.Call(call.op, args, call.attrs)
+ ret = relay.annotation.on_device(ret, self.ext_ctx)
+ return ret
+
+ if self.counter == self.end:
+ ret = relay.Call(call.op, args, call.attrs)
+ ret = relay.annotation.on_device(ret, self.cpu_ctx)
+ return ret
+
+ if self.counter > self.start and self.counter < self.end:
+ ret = relay.Call(call.op, args, call.attrs)
+
+ # skip the float op, i.e., float->int cast
+ if self.is_float_op(call):
+ return ret
+
+ return relay.annotation.on_device(ret, self.ext_ctx)
+
+ return relay.Call(self.visit(call.op), args, call.attrs)
+
+ def is_float_op(self, call):
+ """check if this op belongs to a float op
+ in general, float op's odtype is float;
+ a special case is float->int cast, which follow this op sequence:
+ multiply(float) -> round(float) -> clip(float) -> cast(int);
+ """
+ args = call.args
+ odtype = _get_tensor_type(call)
+
+ if odtype == "float32":
+ return True
+
+ if call.op == self.cast:
+ idtype = _get_tensor_type(args[0])
+ if idtype == "float32":
+ return True
+
+ return False
+
+
+class ExprLocater(ExprMutator):
Review comment:
ExprLocater -> ExprLocator
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]