bartekkuncer commented on code in PR #21103:
URL: https://github.com/apache/incubator-mxnet/pull/21103#discussion_r927125960


##########
src/operator/nn/dnnl/dnnl_where.cc:
##########
@@ -95,6 +104,25 @@ static mxnet::TShape GetBroadcastableShape(const 
mxnet::TShape& in_shape,
   return broadcastable_in_shape;
 }
 
+/*!
+ * \brief Create shape vector basing on two input shapes
+ * \param first_shape first input shape
+ * \param second_shape second input shape
+ * \return deducted broadcasted shape basing on first_shape and second_shape

Review Comment:
   ```suggestion
    * \return deducted broadcasted shape based on first_shape and second_shape
   ```



##########
src/operator/nn/dnnl/dnnl_where.cc:
##########
@@ -34,7 +34,16 @@ namespace op {
 
 // Support for https://oneapi-src.github.io/oneDNN/v2.6/dev_guide_binary.html
 bool SupportDNNLWhere(const std::vector<NDArray>& inputs) {
-  return SupportDNNL<DNNLTypeMode::NoInt32, DNNLTensorsDtypes::Mixed>(inputs);
+  if (inputs[0].dtype() == mshadow::kBool) {
+    // oneDNN natively doesn't support bool data type, however this operator 
was written
+    // to allow using bool datatype for 'condition' tensor - data will be 
treated as uint8
+    bool bool_support = SupportDNNLShape<1, 12>(inputs[0].shape());
+    return bool_support &&
+           SupportDNNL<DNNLTypeMode::NoInt32, 
DNNLTensorsDtypes::AllSame>({inputs[1], inputs[2]});

Review Comment:
   ```suggestion
       return SupportDNNLShape<1, 12>(inputs[0].shape()) &&
              SupportDNNL<DNNLTypeMode::NoInt32, 
DNNLTensorsDtypes::AllSame>({inputs[1], inputs[2]});
   ```
   If you insist on keeping the variable please make it const.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to