This is an automated email from the ASF dual-hosted git repository.

bgawrych pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new cf15e0a478 [BUGFIX] Fix remove Cast fuse (#21086)
cf15e0a478 is described below

commit cf15e0a478ea45166b2f397ff440e30601ea2ec4
Author: bartekkuncer <[email protected]>
AuthorDate: Fri Jul 15 13:47:18 2022 +0200

    [BUGFIX] Fix remove Cast fuse (#21086)
    
    * Fix remove Cast fuse
    
    * Fix Reset()
    
    * Fix review
    
    * Fix sanity
---
 src/operator/subgraph/dnnl/dnnl_remove_casts_property.h | 12 ++++++++++--
 1 file changed, 10 insertions(+), 2 deletions(-)

diff --git a/src/operator/subgraph/dnnl/dnnl_remove_casts_property.h 
b/src/operator/subgraph/dnnl/dnnl_remove_casts_property.h
index d796718949..143a2a4e49 100644
--- a/src/operator/subgraph/dnnl/dnnl_remove_casts_property.h
+++ b/src/operator/subgraph/dnnl/dnnl_remove_casts_property.h
@@ -95,7 +95,7 @@ class SgDNNLRemoveCastsSelector : public SubgraphSelectorV2 {
   }
 
   void Reset() override {
-    status_   = kFail;
+    status_   = kExpand;
     castDtype = -1;
   }
 };
@@ -105,7 +105,7 @@ class SgDNNLRemoveCastsProperty : public SubgraphProperty {
   SgDNNLRemoveCastsProperty() {}
 
   static SubgraphPropertyPtr Create() {
-    static const std::string& name = "Remove casts optimization pass";
+    static const std::string& name = "Remove Casts optimization pass";
     auto property                  = 
std::make_shared<SgDNNLRemoveCastsProperty>();
     property->SetAttr<std::string>("property_name", name);
     property->SetAttr<bool>("inference_only", true);
@@ -137,6 +137,14 @@ class SgDNNLRemoveCastsProperty : public SubgraphProperty {
     auto selector = std::make_shared<SgDNNLRemoveCastsSelector>();
     return selector;
   }
+
+  void ConnectSubgraphOutputs(const nnvm::ObjectPtr subgraph_node,
+                              std::vector<nnvm::NodeEntry*>* output_entries) 
const override {
+    // Connect all extern output entries to output[0]
+    for (size_t i = 0; i < output_entries->size(); ++i) {
+      *output_entries->at(i) = nnvm::NodeEntry{subgraph_node, 0, 0};
+    }
+  }
 };
 
 }  // namespace op

Reply via email to