mbs-octoml commented on a change in pull request #10352:
URL: https://github.com/apache/tvm/pull/10352#discussion_r813204830
##########
File path: src/printer/relay_text_printer.cc
##########
@@ -220,9 +220,13 @@ Doc RelayTextPrinter::AllocVar(const Var& var) {
}
Doc val = GetUniqueName("%" + name);
memo_[var] = val;
+ if (!var->virtual_device()->IsFullyUnconstrained()) {
Review comment:
Any issue with this being used for both param- and let-bound vars even
though we don't parse annots for let-bound vars?
##########
File path: src/relay/ir/expr_functor.cc
##########
@@ -476,41 +476,31 @@ Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>&
args_map) {
if (const FunctionNode* func = expr.as<FunctionNode>()) {
Expr new_body = ExprBinder(args_map).VisitExpr(func->body);
Array<Var> new_params;
- std::vector<VirtualDevice> new_param_virtual_devices;
+ bool params_unchanged = true;
for (size_t i = 0; i < func->params.size(); ++i) {
if (!args_map.count(func->params[i])) {
new_params.push_back(func->params[i]);
-
new_param_virtual_devices.push_back(GetFunctionParamVirtualDevice(func, i));
+ } else if (const auto var = args_map[func->params[i]].as<VarNode>()) {
+ // If we're mapping a variable to a variable and not a normal expr,
then we want to
+ // put the substitution in the new parameters.
+ params_unchanged = false;
+ new_params.push_back(GetRef<Var>(var));
}
}
- if (new_body.same_as(func->body) && new_params.size() ==
func->params.size()) {
+ if (new_body.same_as(func->body) && new_params.size() ==
func->params.size() &&
+ params_unchanged) {
return expr;
}
auto ret =
Function(new_params, new_body, func->ret_type, func->type_params,
func->attrs, func->span);
- ret =
- MaybeFunctionOnDevice(ret, new_param_virtual_devices,
GetFunctionResultVirtualDevice(func));
- std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> set;
Review comment:
Bind should really be called SubstAndBind, since it assumes any free
vars introduced in the subst rhs should be implicitly bound in the result
function. I think you'll need to keep this semantics.
For your purposes I think just directly using the ExprBinder is what you
want since you're simply wanting the Subst aspect. I forgot about this hackery
when I pointed you to this func.
##########
File path: src/printer/relay_text_printer.cc
##########
@@ -329,7 +333,10 @@ Doc RelayTextPrinter::PrintExpr(const Expr& expr, bool
meta, bool try_inline, bo
// Should only be triggered when op is a free variable being visited for the
// first time.
-Doc RelayTextPrinter::VisitExpr_(const VarNode* op) { return
AllocVar(GetRef<Var>(op)); }
+Doc RelayTextPrinter::VisitExpr_(const VarNode* op) {
Review comment:
nit: guess you can revert this now.
##########
File path: src/parser/parser.cc
##########
@@ -1107,11 +1111,26 @@ class Parser {
[&]() {
auto token = Match(TokenType::kLocal);
auto string = token.ToString();
+
+ // The fake attributes where the virtual device is specified.
+ VirtualDevice virtual_device;
+ if (WhenMatch(TokenType::kLCurly)) {
+ Map<String, ObjectRef> fake_attrs = ParseAttrs();
+ VLOG(1) << "Fake attributes for function parameter: " <<
fake_attrs;
Review comment:
meganit: I've been using VLOG(9) for these super-duper verbose ones
since I often just set TVM_LOG_DEBUG="DEFAULT=1" to get an overall debug trace
that's still vaguely readable.
##########
File path: src/relay/backend/vm/compiler.cc
##########
@@ -252,21 +252,16 @@ class VMFunctionCompiler :
DeviceAwareExprFunctor<void(const Expr& n)> {
// Do that flattening on-the-fly here.
Function inner_func = Downcast<Function>(func->body);
std::vector<Var> params;
- std::vector<VirtualDevice> param_virtual_devices;
params.reserve(func->params.size() + inner_func->params.size());
- param_virtual_devices.reserve(func->params.size() +
inner_func->params.size());
param_device_indexes.reserve(func->params.size() +
inner_func->params.size());
for (size_t i = 0; i < func->params.size(); ++i) {
params.emplace_back(func->params[i]);
- VirtualDevice param_virtual_device =
GetFunctionParamVirtualDevice(func.get(), i);
- param_virtual_devices.push_back(param_virtual_device);
- param_device_indexes.push_back(GetDeviceIndex(param_virtual_device));
+
param_device_indexes.push_back(GetDeviceIndex(func->params[i]->virtual_device()));
Review comment:
we get some payoff at last!
##########
File path: src/relay/backend/vm/lambda_lift.cc
##########
@@ -111,22 +111,21 @@ class LambdaLifter : public
transform::DeviceAwareExprMutator {
auto free_type_vars = FreeTypeVars(func, module_);
Array<Var> captured_vars;
- std::vector<VirtualDevice> captured_var_virtual_devices;
bool recursive = false;
for (const auto& var : free_vars) {
if (!letrec_.empty() && var == letrec_.back()) {
recursive = true;
continue;
}
captured_vars.push_back(var);
- captured_var_virtual_devices.push_back(GetVirtualDevice(var));
Review comment:
These vars may be let-bound as well as param-bound, so I don't think
this is sound.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]