ashutosh-arm commented on a change in pull request #10563:
URL: https://github.com/apache/tvm/pull/10563#discussion_r824676537
##########
File path: src/relay/backend/contrib/cmsisnn/scalar_to_tensor_constant.cc
##########
@@ -86,63 +85,62 @@ class ScalarToTensorConstantMutator : public
MixedModeMutator {
final_call = Call(global_var, call->args);
}
- // Substitute scalar constant with a tensor constant in the call to
composite function
- // comprising partitioned binary ops. Shape of the new constant should be
same as its
- // neighbouring tensor's shape.
+ // Substitute scalar constant with tensor constant in the call to
composite function.
if (auto* func_node = call->op.as<FunctionNode>()) {
Function func = GetRef<Function>(func_node);
- auto func_name = func->GetAttr<String>(attr::kComposite);
- if (func_name.defined() &&
- (func_name == "cmsis-nn.qnn_add" || func_name ==
"cmsis-nn.qnn_mul")) {
- final_call = ReplaceScalarWithTensorConstant(GetRef<Call>(call), func);
- }
+ final_call = ReplaceScalarWithTensorConstant(GetRef<Call>(call), func);
}
return final_call;
}
// Replaces scalar variable with a tensor variable with same shape as that
of the neibouring
- // operand tensor in a binary op
+ // operand tensor in a binary op (add or multiply supported via CMSIS-NN
path). This applies only
+ // to 1st and 2nd arguments of the ops.
Call ReplaceScalarWithTensorVariable(Call call) {
const OpNode* opnode = call->op.as<OpNode>();
- if (opnode == nullptr) {
+ if (opnode == nullptr || (opnode->name != "qnn.add" && opnode->name !=
"qnn.mul")) {
Review comment:
Do you mean why do we have two separate functions
ReplaceScalarWithTensorConstant() and ReplaceScalarWithTensorVariable()?
##########
File path: src/relay/backend/contrib/cmsisnn/scalar_to_tensor_constant.cc
##########
@@ -86,63 +85,62 @@ class ScalarToTensorConstantMutator : public
MixedModeMutator {
final_call = Call(global_var, call->args);
}
- // Substitute scalar constant with a tensor constant in the call to
composite function
- // comprising partitioned binary ops. Shape of the new constant should be
same as its
- // neighbouring tensor's shape.
+ // Substitute scalar constant with tensor constant in the call to
composite function.
if (auto* func_node = call->op.as<FunctionNode>()) {
Function func = GetRef<Function>(func_node);
- auto func_name = func->GetAttr<String>(attr::kComposite);
- if (func_name.defined() &&
- (func_name == "cmsis-nn.qnn_add" || func_name ==
"cmsis-nn.qnn_mul")) {
- final_call = ReplaceScalarWithTensorConstant(GetRef<Call>(call), func);
- }
+ final_call = ReplaceScalarWithTensorConstant(GetRef<Call>(call), func);
}
return final_call;
}
// Replaces scalar variable with a tensor variable with same shape as that
of the neibouring
Review comment:
ACK
##########
File path: src/relay/backend/contrib/cmsisnn/scalar_to_tensor_constant.cc
##########
@@ -86,63 +85,62 @@ class ScalarToTensorConstantMutator : public
MixedModeMutator {
final_call = Call(global_var, call->args);
}
- // Substitute scalar constant with a tensor constant in the call to
composite function
- // comprising partitioned binary ops. Shape of the new constant should be
same as its
- // neighbouring tensor's shape.
+ // Substitute scalar constant with tensor constant in the call to
composite function.
if (auto* func_node = call->op.as<FunctionNode>()) {
Function func = GetRef<Function>(func_node);
- auto func_name = func->GetAttr<String>(attr::kComposite);
- if (func_name.defined() &&
- (func_name == "cmsis-nn.qnn_add" || func_name ==
"cmsis-nn.qnn_mul")) {
- final_call = ReplaceScalarWithTensorConstant(GetRef<Call>(call), func);
- }
+ final_call = ReplaceScalarWithTensorConstant(GetRef<Call>(call), func);
}
return final_call;
}
// Replaces scalar variable with a tensor variable with same shape as that
of the neibouring
- // operand tensor in a binary op
+ // operand tensor in a binary op (add or multiply supported via CMSIS-NN
path). This applies only
+ // to 1st and 2nd arguments of the ops.
Call ReplaceScalarWithTensorVariable(Call call) {
const OpNode* opnode = call->op.as<OpNode>();
- if (opnode == nullptr) {
+ if (opnode == nullptr || (opnode->name != "qnn.add" && opnode->name !=
"qnn.mul")) {
Review comment:
Atm no, but its quite possible in the future. I will move it inside a
function.
##########
File path: src/relay/backend/contrib/cmsisnn/scalar_to_tensor_constant.cc
##########
@@ -67,8 +67,7 @@ class ScalarToTensorConstantMutator : public MixedModeMutator
{
Expr final_call = post;
call = post.as<CallNode>();
- // Create a new variable argument that is of the same shape as the
neighbouring argument
- // in the binary op. This needs to be done only when one of the arguments
is a scalar.
+ // Substitute scalar variable with a tensor variable in qnn.add and
qnn.mul ops
Review comment:
Makes sense.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]