manupa-arm commented on a change in pull request #10563:
URL: https://github.com/apache/tvm/pull/10563#discussion_r824645289
##########
File path: src/relay/backend/contrib/cmsisnn/scalar_to_tensor_constant.cc
##########
@@ -86,63 +85,62 @@ class ScalarToTensorConstantMutator : public
MixedModeMutator {
final_call = Call(global_var, call->args);
}
- // Substitute scalar constant with a tensor constant in the call to
composite function
- // comprising partitioned binary ops. Shape of the new constant should be
same as its
- // neighbouring tensor's shape.
+ // Substitute scalar constant with tensor constant in the call to
composite function.
if (auto* func_node = call->op.as<FunctionNode>()) {
Function func = GetRef<Function>(func_node);
- auto func_name = func->GetAttr<String>(attr::kComposite);
- if (func_name.defined() &&
- (func_name == "cmsis-nn.qnn_add" || func_name ==
"cmsis-nn.qnn_mul")) {
- final_call = ReplaceScalarWithTensorConstant(GetRef<Call>(call), func);
- }
+ final_call = ReplaceScalarWithTensorConstant(GetRef<Call>(call), func);
}
return final_call;
}
// Replaces scalar variable with a tensor variable with same shape as that
of the neibouring
- // operand tensor in a binary op
+ // operand tensor in a binary op (add or multiply supported via CMSIS-NN
path). This applies only
+ // to 1st and 2nd arguments of the ops.
Call ReplaceScalarWithTensorVariable(Call call) {
const OpNode* opnode = call->op.as<OpNode>();
- if (opnode == nullptr) {
+ if (opnode == nullptr || (opnode->name != "qnn.add" && opnode->name !=
"qnn.mul")) {
Review comment:
nit : Would there be more of these ? If so keeping a static set and
checking consistency would be more maintainable.
##########
File path: src/relay/backend/contrib/cmsisnn/scalar_to_tensor_constant.cc
##########
@@ -67,8 +67,7 @@ class ScalarToTensorConstantMutator : public MixedModeMutator
{
Expr final_call = post;
call = post.as<CallNode>();
- // Create a new variable argument that is of the same shape as the
neighbouring argument
- // in the binary op. This needs to be done only when one of the arguments
is a scalar.
+ // Substitute scalar variable with a tensor variable in qnn.add and
qnn.mul ops
Review comment:
nit : I feel here we could just say Substitute scalar variable with a
tensor variable where appropiate, because it seems what is being replaced is an
implementation detail of ReplaceScalarWithTensorVariable
##########
File path: src/relay/backend/contrib/cmsisnn/scalar_to_tensor_constant.cc
##########
@@ -86,63 +85,62 @@ class ScalarToTensorConstantMutator : public
MixedModeMutator {
final_call = Call(global_var, call->args);
}
- // Substitute scalar constant with a tensor constant in the call to
composite function
- // comprising partitioned binary ops. Shape of the new constant should be
same as its
- // neighbouring tensor's shape.
+ // Substitute scalar constant with tensor constant in the call to
composite function.
if (auto* func_node = call->op.as<FunctionNode>()) {
Function func = GetRef<Function>(func_node);
- auto func_name = func->GetAttr<String>(attr::kComposite);
- if (func_name.defined() &&
- (func_name == "cmsis-nn.qnn_add" || func_name ==
"cmsis-nn.qnn_mul")) {
- final_call = ReplaceScalarWithTensorConstant(GetRef<Call>(call), func);
- }
+ final_call = ReplaceScalarWithTensorConstant(GetRef<Call>(call), func);
}
return final_call;
}
// Replaces scalar variable with a tensor variable with same shape as that
of the neibouring
Review comment:
[Unrelated to this PR, feel free ignore] : a typo : neibouring
##########
File path: src/relay/backend/contrib/cmsisnn/scalar_to_tensor_constant.cc
##########
@@ -86,63 +85,62 @@ class ScalarToTensorConstantMutator : public
MixedModeMutator {
final_call = Call(global_var, call->args);
}
- // Substitute scalar constant with a tensor constant in the call to
composite function
- // comprising partitioned binary ops. Shape of the new constant should be
same as its
- // neighbouring tensor's shape.
+ // Substitute scalar constant with tensor constant in the call to
composite function.
if (auto* func_node = call->op.as<FunctionNode>()) {
Function func = GetRef<Function>(func_node);
- auto func_name = func->GetAttr<String>(attr::kComposite);
- if (func_name.defined() &&
- (func_name == "cmsis-nn.qnn_add" || func_name ==
"cmsis-nn.qnn_mul")) {
- final_call = ReplaceScalarWithTensorConstant(GetRef<Call>(call), func);
- }
+ final_call = ReplaceScalarWithTensorConstant(GetRef<Call>(call), func);
}
return final_call;
}
// Replaces scalar variable with a tensor variable with same shape as that
of the neibouring
- // operand tensor in a binary op
+ // operand tensor in a binary op (add or multiply supported via CMSIS-NN
path). This applies only
+ // to 1st and 2nd arguments of the ops.
Call ReplaceScalarWithTensorVariable(Call call) {
const OpNode* opnode = call->op.as<OpNode>();
- if (opnode == nullptr) {
+ if (opnode == nullptr || (opnode->name != "qnn.add" && opnode->name !=
"qnn.mul")) {
Review comment:
For my knowledge, why do we need two versions of
ReplaceScalarWithTensorVariable ? -- I thought cmsis-nn only considers the
composite functions
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]