bgawrych commented on a change in pull request #20753:
URL: https://github.com/apache/incubator-mxnet/pull/20753#discussion_r836188222



##########
File path: src/operator/subgraph/dnnl/dnnl_transformer.cc
##########
@@ -559,7 +583,7 @@ void DNNLSelfAttValAttOp::Initialize(const OpContext& ctx,
   using namespace dnnl;
 
   const auto attn_tensor = inputs[0];

Review comment:
       ```suggestion
     const auto attn_tensor = inputs[0].Reorder2Default();
   ```

##########
File path: src/nnvm/low_precision_pass.cc
##########
@@ -29,374 +29,371 @@
 #include <mxnet/base.h>
 #include <algorithm>
 #include <functional>
+#include "operator/operator_common.h"
 
 namespace mxnet {
 using nnvm::Graph;
 using nnvm::Node;
 using nnvm::NodeEntry;
 using nnvm::ObjectPtr;
-using nnvm::Symbol;
-
-// create a node for operator : op_name with name : node_name
-static ObjectPtr CreateNode(std::string op_name, std::string node_name) {
-  ObjectPtr node   = Node::Create();
-  node->attrs.name = node_name;
-  if (op_name == "nullptr") {
-    node->attrs.op = nullptr;
-    // ugly workaround because VariableParam is not exposed
-    node->attrs.parsed =
-        
nnvm::Symbol::CreateVariable(node->attrs.name).outputs[0].node->attrs.parsed;
-  } else {
-    node->attrs.op = Op::Get(op_name);
-  }
-  return node;
-}
 
-static ObjectPtr InsertNode(std::string op_name,
-                            std::string node_name,
-                            ObjectPtr current,
-                            NodeEntry previous) {
-  ObjectPtr node = CreateNode(op_name, node_name);
-  node->inputs.emplace_back(previous);
-  if (current)
-    current->inputs.emplace_back(NodeEntry{node, 0, 0});
-  return node;
+bool IsCastOp(const nnvm::Op* const op) {
+  return op && (op == Op::Get("amp_cast") || op == Op::Get("Cast"));
 }
+/*!
+ * \brief Before the model conversion, node entries of the original graph are 
mapped to the
+ * equivalent node entries in the new graph that will be then converted to a 
mixed precision graph.
+ * This class wraps a mapped NodeEntry from the new graph, providing a 
transparent interface for
+ * acquiring versions of the wrapped entry with a specific dtype, adding a 
casting nodes to the
+ * graph when needed (one for each unique dtype that was requested).
+ */
+class MappedNodeEntry {
+ public:
+  MappedNodeEntry(NodeEntry node_entry, const int original_dtype)
+      : entry(std::move(node_entry)), original_dtype(original_dtype) {
+    dtype = original_dtype;
+  }
 
-// get suffix for a node entry so that it can be used for 
amp_cast/amp_multicast node name
-static std::string GetSuffix(const nnvm::NodeEntry& node_entry,
-                             const std::unordered_map<Node*, ObjectPtr>& 
mirror_map) {
-  static const auto& flist_outputs = 
nnvm::Op::GetAttr<nnvm::FListOutputNames>("FListOutputNames");
-  std::string suffix               = "";
-  ObjectPtr mirror_node            = mirror_map.at(node_entry.node.get());
-  if (mirror_node->op() != nullptr) {
-    auto list_output_names_func = flist_outputs.get(node_entry.node->op(), 
nullptr);
-    if (list_output_names_func != nullptr) {
-      std::vector<std::string> names = 
list_output_names_func(node_entry.node->attrs);
-      suffix                         = "_" + names[node_entry.index];
-    } else {
-      suffix = "_" + std::to_string(node_entry.index);
+  /*!
+   * \brief Converts the dtype of this NodeEntry. This should be called after 
a node has been
+   * converted and dtypes of its outputs may have changed
+   */
+  void UpdateDTypeAfterConversion(const int new_dtype) {
+    CHECK_EQ(dtype, original_dtype);  // dtype should be changed only once
+    CHECK(entry.node->op());
+    dtype = new_dtype;
+  }
+
+  /*!
+   * \brief If dtype of this NodeEntry was not changed, returns the mapped 
entry. Otherwise returns
+   * a NodeEntry to the node which casts to the original dtype of this 
NodeEntry.
+   */
+  const NodeEntry& AsOriginal() {
+    return AsType(original_dtype);
+  }
+
+  /*!
+   * \brief If dtype of this NodeEntry matches the specified dtype, returns 
the mapped entry.
+   * Otherwise returns a NodeEntry to the node which casts to that type.
+   */
+  const NodeEntry& AsType(const int target_dtype, const bool can_add_cast = 
true) {
+    if (dtype == target_dtype || target_dtype == -1) {
+      return entry;
+    }
+    NodeEntry& cast_entry = casts[target_dtype];
+    if (cast_entry.node == nullptr) {
+      CHECK(can_add_cast);
+      cast_entry = Cast(target_dtype);
+      CHECK(cast_entry.node);
     }
+    return cast_entry;
+  }
+
+  /*! \brief Returns whether this entry has the specified dtype or an existing 
cast to that dtype */
+  bool HasDTypeEntry(const int target_dtype) const {
+    return dtype == target_dtype || casts.count(target_dtype) > 0;
+  }
+
+  /*!
+   * \brief Returns whether this entry can be cast to a specific dtype. This 
should be called on
+   * input entires of a node before its conversion.
+   */
+  bool CanBeCastTo(const int target_dtype) {
+    static const auto& amp_cast_op = Op::Get("amp_cast");
+    static const auto& infertype   = 
nnvm::Op::GetAttr<nnvm::FInferType>("FInferType")[amp_cast_op];
+    nnvm::NodeAttrs dummy_atts;
+    dummy_atts.dict["dtype"] = mxnet::op::type_string(target_dtype);
+    amp_cast_op->attr_parser(&dummy_atts);
+
+    std::vector<int> in_types  = {dtype};
+    std::vector<int> out_types = {-1};
+    return infertype(dummy_atts, &in_types, &out_types);
+  }
+
+  /*! \brief Returns whether this NodeEntry (of a parameter) can be cast 
offline */
+  bool CanBeCastOfflineTo(const int target_dtype) const {
+    CHECK(entry.node->is_variable());
+    return casts.count(target_dtype) > 0;
+  }
+
+ private:
+  ObjectPtr CreateCastNode(const std::string& op_name, const std::string& 
node_name) {
+    CHECK_GT(op_name.size(), 0);
+
+    ObjectPtr node   = Node::Create();
+    node->attrs.name = node_name;
+    node->attrs.op   = Op::Get(op_name);
+    node->inputs.emplace_back(entry);
+    return node;
+  }
+
+  NodeEntry Cast(const int target_dtype) {
+    CHECK(CanBeCastTo(target_dtype));
+
+    const std::string dt_name        = mxnet::op::type_string(target_dtype);
+    const std::string suffix         = "_" + std::to_string(entry.index);
+    const std::string cast_node_name = entry.node->attrs.name + suffix + 
"_amp_cast_" + dt_name;
+    ObjectPtr cast_node              = CreateCastNode("amp_cast", 
cast_node_name);
+    cast_node->attrs.dict["dtype"]   = dt_name;
+    cast_node->op()->attr_parser(&(cast_node->attrs));
+    return NodeEntry{std::move(cast_node), 0, 0};
+  }
+
+ public:
+  const NodeEntry entry;
+  const int original_dtype;  // original dtype of the entry
+
+ private:
+  int dtype;  // current dtype of the entry
+  std::unordered_map<int, NodeEntry> casts;
+};
+
+using EntryMap_t     = nnvm::NodeEntryMap<MappedNodeEntry>;
+using NodeMap_t      = std::unordered_map<Node*, ObjectPtr>;
+using NodeEntrySet_t = std::unordered_set<NodeEntry, nnvm::NodeEntryHash, 
nnvm::NodeEntryEqual>;
+using NodesEntries_t = std::unordered_map<Node*, NodeEntrySet_t>;
+using DstNodes_t     = std::unordered_map<Node*, std::unordered_map<Node*, 
NodeEntry>>;
+
+/*! \brief Makes sure the node in the new graph will work with the same 
precision as in the original
+ * graph */
+static void KeepOriginalNode(const ObjectPtr& old_node,
+                             const NodeMap_t& node_map,
+                             EntryMap_t* const entry_map) {
+  const ObjectPtr& new_node = node_map.at(old_node.get());
+  for (const auto& old_ne : old_node->inputs) {
+    new_node->inputs.push_back(entry_map->at(old_ne).AsOriginal());
   }
-  return suffix;
 }
 
-// add amp_cast node between curr_node and input
-static void AddCastNode(const nnvm::NodeEntry& e,
-                        const std::string& suffix,
-                        const nnvm::NodeEntry& input,
-                        const std::string dtype,
-                        nnvm::NodeEntryMap<NodeEntry>* mirror_entry_map,
-                        ObjectPtr curr_node) {
-  ObjectPtr cast_node =
-      InsertNode("amp_cast", e.node->attrs.name + suffix + "_amp_cast_" + 
dtype, curr_node, input);
-  cast_node->attrs.dict["dtype"] = dtype;
-  cast_node->op()->attr_parser(&(cast_node->attrs));
-  (*mirror_entry_map)[e] = NodeEntry{std::move(cast_node), 0, e.version};
-  return;
+/*! \brief Tries to convert the node to low precision. Returns whether the 
node has been
+ * successfully converted
+ */
+static bool TryLowPrecision(const int target_dtype,
+                            const ObjectPtr& old_node,
+                            const NodeMap_t& node_map,
+                            const NodesEntries_t& nodes_entries,
+                            EntryMap_t* const entry_map) {
+  static const auto& infertype      = 
nnvm::Op::GetAttr<nnvm::FInferType>("FInferType");
+  static const auto& fmutate_inputs = 
Op::GetAttr<nnvm::FMutateInputs>("FMutateInputs");
+
+  std::vector<int> in_types(old_node->inputs.size(), -1);
+  std::vector<int> out_types(old_node->num_outputs(), -1);
+  in_types[0] = target_dtype;
+  if (infertype.count(old_node->op()) == 0 ||
+      infertype[old_node->op()](old_node->attrs, &in_types, &out_types) == 
false) {
+    return false;
+  }
+
+  if (fmutate_inputs.count(old_node->op()) != 0) {
+    std::vector<uint32_t> mutable_inputs = 
fmutate_inputs[old_node->op()](old_node->attrs);
+    for (size_t i = 0; i < old_node->inputs.size(); ++i) {
+      if (in_types[i] == target_dtype) {
+        if (std::find(mutable_inputs.begin(), mutable_inputs.end(), i) != 
mutable_inputs.end()) {
+          return false;
+        }
+      }
+    }
+  }
+
+  for (size_t i = 0; i < old_node->inputs.size(); ++i) {
+    MappedNodeEntry& mapped_ne = entry_map->at(old_node->inputs[i]);
+    // if this tensor needs a cast, check whether MappedNodeEntry can actually 
cast it
+    if (in_types[i] != -1 && !mapped_ne.HasDTypeEntry(in_types[i]) &&
+        !mapped_ne.CanBeCastTo(in_types[i])) {
+      return false;
+    }
+  }
+
+  const ObjectPtr& new_node = node_map.at(old_node.get());
+  for (size_t i = 0; i < old_node->inputs.size(); ++i) {
+    
new_node->inputs.push_back(entry_map->at(old_node->inputs[i]).AsType(in_types[i]));
+  }
+
+  for (const NodeEntry& old_ne : nodes_entries.at(old_node.get())) {
+    entry_map->at(old_ne).UpdateDTypeAfterConversion(out_types[old_ne.index]);
+  }
+
+  return true;
 }
 
-// add amp_multicast node between curr_node and inputs
-static void AddMultiCastNode(const std::vector<NodeEntry>& inputs,
-                             const std::string& node_name,
-                             const std::unordered_map<Node*, ObjectPtr>& 
mirror_map,
-                             ObjectPtr curr_node) {
-  ObjectPtr node =
-      CreateNode("amp_multicast", inputs[0].node->attrs.name + node_name + 
"_amp_multicast");
-  for (const auto& node_entry : inputs) {
-    ObjectPtr mirror_node = mirror_map.at(node_entry.node.get());
-    NodeEntry mirror_entry =
-        NodeEntry{std::move(mirror_node), node_entry.index, 
node_entry.version};
-    node->inputs.emplace_back(mirror_entry);
+/*! \brief Tries to convert the node to low precision if all of its inputs 
already have the correct
+ * dtype. Otherwise keeps the node unchanged.
+ */
+static void HandleWidestDtypeNode(const int target_dtype,
+                                  const ObjectPtr& old_node,
+                                  const NodeMap_t& node_map,
+                                  const NodesEntries_t& nodes_entries,
+                                  EntryMap_t* const entry_map) {
+  static const auto& infertype = 
nnvm::Op::GetAttr<nnvm::FInferType>("FInferType");
+
+  std::vector<int> in_types(old_node->inputs.size(), target_dtype);
+  std::vector<int> out_types(old_node->num_outputs(), -1);
+  const bool inferred = (infertype.count(old_node->op()) > 0 &&
+                         infertype[old_node->op()](old_node->attrs, &in_types, 
&out_types));
+
+  bool has_lp_inputs = inferred;
+  for (int i = 0; has_lp_inputs && i < old_node->inputs.size(); ++i) {
+    const NodeEntry& input = old_node->inputs[i];
+    has_lp_inputs &= entry_map->at(input).HasDTypeEntry(in_types[i]);
+  }
+
+  if (!has_lp_inputs ||
+      !TryLowPrecision(target_dtype, old_node, node_map, nodes_entries, 
entry_map)) {
+    KeepOriginalNode(old_node, node_map, entry_map);
   }
-  node->attrs.dict["num_outputs"] = std::to_string(inputs.size());
-  node->op()->attr_parser(&(node->attrs));
-  for (uint32_t i = 0; i < inputs.size(); ++i) {
-    const auto& e = inputs[i];
-    curr_node->inputs.emplace_back(NodeEntry{node, static_cast<uint32_t>(i), 
e.version});
+}
+/*!
+ * \brief Tries to convert the node to low precision if some of its inputs 
already are converted.
+ * Otherwise keeps the node unchanged.
+ */
+void HandleDTypeNeutralNode(const int target_dtype,
+                            const ObjectPtr& old_node,
+                            const NodeMap_t& node_map,
+                            const NodesEntries_t& nodes_entries,
+                            EntryMap_t* const entry_map) {
+  const auto& is_lp = [&](const auto& old_ne) {
+    return entry_map->at(old_ne).HasDTypeEntry(target_dtype);
+  };
+  if (!std::any_of(old_node->inputs.begin(), old_node->inputs.end(), is_lp) ||
+      !TryLowPrecision(target_dtype, old_node, node_map, nodes_entries, 
entry_map)) {
+    KeepOriginalNode(old_node, node_map, entry_map);
   }
-  return;
 }
 
-static bool CheckConditionalFP32(
-    const std::unordered_map<std::string,
-                             std::unordered_map<std::string, 
std::vector<std::string>>>&
-        conditional_fp32_ops,
-    const std::unordered_set<std::string>& excluded_syms,
-    ObjectPtr node) {
-  if (node->is_variable() || (excluded_syms.count(node->attrs.name) > 0) ||
-      conditional_fp32_ops.count(node->op()->name) == 0) {
-    return false;
-  } else {
-    // Iterate through all conditional ops
-    auto it = conditional_fp32_ops.find(node->op()->name);
-    if (it != conditional_fp32_ops.end()) {
-      auto it_params = it->second;
-      // For each param name, iterate through param values to check
-      // if the provided param name is equal to any of the values
-      for (auto& it_param : it_params) {
-        auto param_key = node->attrs.dict.find(it_param.first);
-        if (param_key != node->attrs.dict.end()) {
-          auto it_param_vals = it_param.second;
-          if (std::find(it_param_vals.begin(), it_param_vals.end(), 
param_key->second) !=
-              it_param_vals.end()) {
-            return true;
+/* \brief Decides which prameters can be cast offline and removes redundant 
cast nodes from the
+ * graph */
+static void RemoveParamCasts(const int target_dtype,
+                             const std::string& offline_param_cast_attr,
+                             const NodeMap_t& node_map,
+                             const DstNodes_t& old_param_dst_nodes,
+                             EntryMap_t* entry_map) {
+  for (const auto& [old_param, old_param_dsts] : old_param_dst_nodes) {
+    const ObjectPtr& new_param      = node_map.at(old_param);
+    const auto& can_be_cast_offline = [&](const auto& old_node_x_ne_pair) {

Review comment:
       I don't understand 'old_node_x_ne_pair' variable name - also i would add 
more indent in lambda body

##########
File path: tests/python/dnnl/test_amp.py
##########
@@ -56,111 +53,95 @@ def test_amp_coverage():
     assert ret == [], "Elements " + str(ret) + " exist in more than 1 AMP 
list."
 
     # Check the coverage
-    py_str = lambda x: x.decode('utf-8')
-
-    plist = ctypes.POINTER(ctypes.c_char_p)()
-    size = ctypes.c_uint()
-
-    mx.base._LIB.MXListAllOpNames(ctypes.byref(size),
-                                     ctypes.byref(plist))
-    op_names = []
-    for i in range(size.value):
-        s = py_str(plist[i])
-        if not s.startswith("_backward") \
-           and not s.startswith("_contrib_backward_"):
-            op_names.append(s)
-
-    ret1 = set(op_names) - set(t)
-
-    if ret1 != set():
-        warnings.warn("Operators " + str(ret1) + " do not exist in AMP lists 
(in "
-                       "python/mxnet/amp/lists/symbol_bf16.py) - please add 
them. "
-                       """Please follow these guidelines for choosing a proper 
list:
-                       - if your operator is not to be used in a computational 
graph
-                         (e.g. image manipulation operators, optimizers) or 
does not have
-                         inputs, put it in BF16_FP32_FUNCS list,
-                       - if your operator requires FP32 inputs or is not safe 
to use with lower
-                         precision, put it in FP32_FUNCS list,
-                       - if your operator supports both FP32 and lower 
precision, has
-                         multiple inputs and expects all inputs to be of the 
same
-                         type, put it in WIDEST_TYPE_CASTS list,
-                       - if your operator supports both FP32 and lower 
precision and has
-                         either a single input or supports inputs of different 
type,
-                         put it in BF16_FP32_FUNCS list,
-                       - if your operator is both safe to use in lower 
precision and
-                         it is highly beneficial to use it in lower precision, 
then
-                         put it in BF16_FUNCS (this is unlikely for new 
operators)
-                       - If you are not sure which list to choose, FP32_FUNCS 
is the
-                         safest option""")
-
-def test_bf16_casting():
-    data = mx.sym.var("data")
-    out1 = mx.sym.amp_cast(data, dtype=bfloat16)
-    out2 = mx.sym.amp_cast(data, dtype="float32")
-    out3 = mx.sym.amp_cast(data, dtype=bfloat16)
-    # When two ops from data, with different dtypes,
-    # data should be float32
-    res = mx.sym.Group([out1, out2])
-    final_res = amp.convert_symbol(res, data_names=[], 
target_dtype="bfloat16", cast_optional_params=True)
-    exe = final_res._simple_bind(ctx=mx.cpu(), data=(1, 2))
-    assert exe.arg_arrays[0].dtype == np.float32
-
-    # When two ops from data, both casted to bfloat16,
-    # data should be bfloat16
-    res = mx.sym.Group([out1, out3])
-    final_res = amp.convert_symbol(res, data_names=[], 
target_dtype="bfloat16", cast_optional_params=True)
-    exe = final_res._simple_bind(ctx=mx.cpu(), data=(1, 2))
-    assert exe.arg_arrays[0].dtype == bfloat16
-
-    # AMP Multicast test where one node is float32, another is bfloat16
-    data = mx.sym.var("data", dtype="float32")
-    data2 = mx.sym.var("data2", dtype=bfloat16)
-    out4 = mx.sym.amp_multicast(data, data2, num_outputs=2)
-    final_res = amp.convert_symbol(out4, target_dtype="bfloat16", 
cast_optional_params=True)
-    exe = final_res._simple_bind(ctx=mx.cpu(), data2=(1, 2), data=(1, 2))
-    assert exe.arg_arrays[0].dtype == bfloat16
-
-    # AMP Multicast test where two non input nodes are bfloat16,
-    # and one input node is float32
-    data = mx.sym.var("data", dtype="float32")
-    data2 = mx.sym.var("data2", dtype=bfloat16)
-    data3 = mx.sym.var("data3", dtype=bfloat16)
-    out5 = mx.sym.amp_multicast(data,
-                                mx.sym.elemwise_add(data2, data3),
-                                num_outputs=2)
-    final_res = amp.convert_symbol(out5, target_dtype_ops=[], 
target_dtype="bfloat16",
-                                   fp32_ops=[], cast_optional_params=True)
-    exe = final_res._simple_bind(ctx=mx.cpu(), data=(1, 2), data2=(1, 2), 
data3=(1, 2))
-    assert exe.arg_arrays[0].dtype == np.float32
-
-    # AMP Multicast test where three input nodes one bf16, one fp32
-    # one unknown
-    data = mx.sym.var("data", dtype=bfloat16)
-    data2 = mx.sym.var("data2", dtype="float32")
-    data3 = mx.sym.var("data3")
-    out6 = mx.sym.amp_multicast(data, data2, data3, num_outputs=3)
-    final_res = amp.convert_symbol(out6, target_dtype_ops=[], 
target_dtype="bfloat16",
-                                   fp32_ops=[], cast_optional_params=True)
-    exe = final_res._simple_bind(ctx=mx.cpu(), data=(1, 2), data2=(1, 2),
-                                data3=(1, 2))
-    assert exe.arg_arrays[2].dtype == np.float32
-
-    # Input node to amp_multicast and amp_cast, if dtypes conflict
-    # and input node is already bf16, it should still be bf16
-    data = mx.sym.var("data", dtype=bfloat16)
-    data2 = mx.sym.var("data2", dtype="float32")
-    out7 = mx.sym.Group([mx.sym.amp_multicast(data, data2, num_outputs=2), 
mx.sym.amp_cast(data, dtype=bfloat16)])
-    final_res = amp.convert_symbol(out7, target_dtype_ops=[], 
target_dtype="bfloat16",
-                                   fp32_ops=[], cast_optional_params=True)
-    exe = final_res._simple_bind(ctx=mx.cpu(), data=(1, 2), data2=(1, 2))
-    assert exe.arg_arrays[0].dtype == bfloat16
-
-    # Input node to amp_multicast and amp_cast, if dtypes conflict
-    # and input node is already fp32, it should be changed to bf16
-    data = mx.sym.var("data", dtype="float32")
-    data2 = mx.sym.var("data2", dtype=bfloat16)
-    out8 = mx.sym.Group([mx.sym.amp_multicast(data, data2, num_outputs=2), 
mx.sym.amp_cast(data, dtype=bfloat16)])
-    final_res = amp.convert_symbol(out8, target_dtype_ops=[], 
target_dtype="bfloat16",
-                                   fp32_ops=[], cast_optional_params=True)
-    exe = final_res._simple_bind(ctx=mx.cpu(), data=(1, 2), data2=(1, 2))
-    assert exe.arg_arrays[0].dtype == bfloat16
+    covered = set(t)
+    ops = get_all_registered_operators_grouped()
+    required = set(k for k in ops
+                   if not k.startswith(("_backward", "_contrib_backward", 
"_npi_backward")) and
+                   not k.endswith("_backward"))
+
+    extra = covered - required
+    assert not extra, f"{len(extra)} operators are not needed in the AMP 
lists: {sorted(extra)}"
+
+    guidelines = """Please follow these guidelines for choosing a proper list:
+    - if your operator is not to be used in a computational graph
+      (e.g. image manipulation operators, optimizers) or does not have
+      inputs, put it in BF16_FP32_FUNCS list,
+    - if your operator requires FP32 inputs or is not safe to use with lower
+      precision, put it in FP32_FUNCS list,
+    - if your operator supports both FP32 and lower precision, has
+      multiple inputs and expects all inputs to be of the same
+      type, put it in WIDEST_TYPE_CASTS list,
+    - if your operator supports both FP32 and lower precision and has
+      either a single input or supports inputs of different type,
+      put it in BF16_FP32_FUNCS list,
+    - if your operator is both safe to use in lower precision and
+      it is highly beneficial to use it in lower precision, then
+      put it in BF16_FUNCS (this is unlikely for new operators)
+    - If you are not sure which list to choose, FP32_FUNCS is the
+      safest option"""
+    diff = required - covered
+
+    if len(diff) > 0:
+      warnings.warn(f"{len(diff)} operators {sorted(diff)} do not exist in AMP 
lists (in "
+                    f"python/mxnet/amp/lists/symbol_bf16.py) - please add 
them. "
+                    f"\n{guidelines}")
+
+
[email protected]_np
+def test_bf16_offline_casting():
+  class TestNet(nn.HybridBlock):
+    def __init__(self):
+      super().__init__()
+      self.lp16_op1 = nn.Conv2D(4, 3)
+      self.lp16_op2 = nn.Conv2DTranspose(4, 3)
+      self.fp32_op = nn.Dense(4)
+
+    def forward(self, x):
+      x = self.lp16_op1(x)
+      x = self.lp16_op2(x)
+      x = x.reshape(x.shape[0], -1)
+      x = self.fp32_op(x)
+      return x
+
+  net = TestNet()
+  net.initialize()
+  data_example = mx.np.random.uniform(-1, 1, (4, 3, 16, 16))
+  lp_net = amp.convert_hybrid_block(net, data_example, target_dtype=bfloat16,
+                                    target_dtype_ops=['Convolution'], 
fp32_ops=['FullyConnected'],
+                                    cast_params_offline=True, 
device=mx.current_context())
+  lp_net(data_example)
+  for name, data in lp_net.collect_params().items():
+    assert data.dtype == (np.float32 if 'fp32_op' in name else bfloat16)
+
+
[email protected]_np
+def test_bf16_offline_casting_shared_params():
+  COMMON_SIZE = 4
+
+  class TestNet(nn.HybridBlock):
+    def __init__(self):
+      super().__init__()
+      self.lp16_op1 = nn.Dense(COMMON_SIZE)
+      self.lp16_op2 = nn.Dense(COMMON_SIZE)
+      self.lp16_op2.share_parameters({'weight': self.lp16_op1.weight})
+      self.fp32_op = nn.Conv1D(COMMON_SIZE, 3)
+      self.fp32_op.share_parameters({'bias': self.lp16_op2.bias})
+
+    def forward(self, x):
+      x = self.lp16_op1(x)
+      x1 = self.lp16_op2(x)
+      x2 = mx.np.expand_dims(x, 1)
+      x2 = self.fp32_op(x2)
+      x2 = nn.Flatten()(x2)

Review comment:
       ```suggestion
         x2 = npx.batch_flatten(x2)
   ```

##########
File path: tests/python/gpu/test_amp.py
##########
@@ -111,84 +111,61 @@ def test_amp_conversion_rnn(amp_tests):
         mx.test_utils.assert_almost_equal(out.asnumpy(), out2.asnumpy(), 
atol=1e-2, rtol=1e-2)
 
 
-def test_fp16_casting(amp_tests):
-    data = mx.sym.var("data")
-    out1 = mx.sym.amp_cast(data, dtype="float16")
-    out2 = mx.sym.amp_cast(data, dtype="float32")
-    out3 = mx.sym.amp_cast(data, dtype="float16")
-    # When two ops from data, with different dtypes,
-    # data should be float32
-    res = mx.sym.Group([out1, out2])
-    final_res = amp.convert_symbol(res, data_names=[], 
cast_optional_params=True)
-    exe = final_res._simple_bind(ctx=mx.gpu(), data=(1, 2))
-    assert exe.arg_arrays[0].dtype == np.float32
-
-    # When two ops from data, both casted to float16,
-    # data should be float16
-    res = mx.sym.Group([out1, out3])
-    final_res = amp.convert_symbol(res, data_names=[], 
cast_optional_params=True)
-    exe = final_res._simple_bind(ctx=mx.gpu(), data=(1, 2))
-    assert exe.arg_arrays[0].dtype == np.float16
-
-    # AMP Multicast test where one node is float32, another is float16
-    data = mx.sym.var("data", dtype=np.float32)
-    data2 = mx.sym.var("data2", dtype=np.float16)
-    out4 = mx.sym.amp_multicast(data, data2, num_outputs=2)
-    final_res = amp.convert_symbol(out4, cast_optional_params=True)
-    exe = final_res._simple_bind(ctx=mx.gpu(), data2=(1, 2), data=(1, 2))
-    assert exe.arg_arrays[0].dtype == np.float16
-
-    # AMP Multicast test where two non input nodes are float16,
-    # and one input node is float32
-    data = mx.sym.var("data", dtype=np.float32)
-    data2 = mx.sym.var("data2", dtype=np.float16)
-    data3 = mx.sym.var("data3", dtype=np.float16)
-    out5 = mx.sym.amp_multicast(data,
-                                mx.sym.elemwise_add(data2, data3),
-                                num_outputs=2)
-    final_res = amp.convert_symbol(out5, target_dtype_ops=[],
-                                   fp32_ops=[], cast_optional_params=True)
-    exe = final_res._simple_bind(ctx=mx.gpu(), data=(1, 2), data2=(1, 2), 
data3=(1, 2))
-    assert exe.arg_arrays[0].dtype == np.float32
-
-    # AMP Multicast test where three input nodes one fp16, one fp32
-    # one unknown
-    data = mx.sym.var("data", dtype=np.float16)
-    data2 = mx.sym.var("data2", dtype=np.float32)
-    data3 = mx.sym.var("data3")
-    out6 = mx.sym.amp_multicast(data, data2, data3, num_outputs=3)
-    final_res = amp.convert_symbol(out6, target_dtype_ops=[],
-                                   fp32_ops=[], cast_optional_params=True)
-    exe = final_res._simple_bind(ctx=mx.gpu(), data=(1, 2), data2=(1, 2),
-                                data3=(1, 2))
-    assert exe.arg_arrays[2].dtype == np.float32
-
-    # Input node to amp_multicast and amp_cast, if dtypes conflict
-    # and input node is already fp16, it should still be fp16
-    data = mx.sym.var("data", dtype=np.float16)
-    data2 = mx.sym.var("data2", dtype=np.float32)
-    out7 = mx.sym.Group([mx.sym.amp_multicast(data, data2, num_outputs=2), 
mx.sym.amp_cast(data, dtype="float16")])
-    final_res = amp.convert_symbol(out7, target_dtype_ops=[],
-                                   fp32_ops=[], cast_optional_params=True)
-    exe = final_res._simple_bind(ctx=mx.gpu(), data=(1, 2), data2=(1, 2))
-    assert exe.arg_arrays[0].dtype == np.float16
-
-    # Input node to amp_multicast and amp_cast, if dtypes conflict
-    # and input node is already fp32, it should be changed to fp16
-    data = mx.sym.var("data", dtype=np.float32)
-    data2 = mx.sym.var("data2", dtype=np.float16)
-    out8 = mx.sym.Group([mx.sym.amp_multicast(data, data2, num_outputs=2), 
mx.sym.amp_cast(data, dtype="float16")])
-    final_res = amp.convert_symbol(out8, target_dtype_ops=[],
-                                   fp32_ops=[], cast_optional_params=True)
-    exe = final_res._simple_bind(ctx=mx.gpu(), data=(1, 2), data2=(1, 2))
-    assert exe.arg_arrays[0].dtype == np.float16
-
-    # Check for symbol which has slice channel
-    data = mx.sym.var("data")
-    data2 = mx.sym.var("data2")
-    data._set_attr(__dtype__="-1")
-    data2._set_attr(__dtype__="-1")
-    concat_res = mx.sym.concat(data, data2)
-    out = mx.sym.split(concat_res, axis=1, num_outputs=2)
-    final_res = amp.convert_symbol(out)
-
[email protected]_np
+def test_bf16_offline_casting():
+  class TestNet(nn.HybridBlock):
+    def __init__(self):
+      super().__init__()
+      self.lp16_op1 = nn.Conv2D(4, 3)
+      self.lp16_op2 = nn.Conv2DTranspose(4, 3)
+      self.fp32_op = nn.Dense(4)
+
+    def forward(self, x):
+      x = self.lp16_op1(x)
+      x = self.lp16_op2(x)
+      x = x.reshape(x.shape[0], -1)
+      x = self.fp32_op(x)
+      return x
+
+  net = TestNet()
+  net.initialize()
+  data_example = mx.np.random.uniform(-1, 1, (4, 3, 16, 16))
+  lp_net = amp.convert_hybrid_block(net, data_example, target_dtype='float16',
+                                    target_dtype_ops=['Convolution'], 
fp32_ops=['FullyConnected'],
+                                    cast_params_offline=True, 
device=mx.current_context())
+  lp_net(data_example)
+  for name, data in lp_net.collect_params().items():
+    assert data.dtype == (np.float32 if 'fp32_op' in name else 'float16')
+
+
[email protected]_np
+def test_bf16_offline_casting_shared_params():

Review comment:
       ```suggestion
   def test_fp16_offline_casting_shared_params():
   ```

##########
File path: tests/python/gpu/test_amp.py
##########
@@ -111,84 +111,61 @@ def test_amp_conversion_rnn(amp_tests):
         mx.test_utils.assert_almost_equal(out.asnumpy(), out2.asnumpy(), 
atol=1e-2, rtol=1e-2)
 
 
-def test_fp16_casting(amp_tests):
-    data = mx.sym.var("data")
-    out1 = mx.sym.amp_cast(data, dtype="float16")
-    out2 = mx.sym.amp_cast(data, dtype="float32")
-    out3 = mx.sym.amp_cast(data, dtype="float16")
-    # When two ops from data, with different dtypes,
-    # data should be float32
-    res = mx.sym.Group([out1, out2])
-    final_res = amp.convert_symbol(res, data_names=[], 
cast_optional_params=True)
-    exe = final_res._simple_bind(ctx=mx.gpu(), data=(1, 2))
-    assert exe.arg_arrays[0].dtype == np.float32
-
-    # When two ops from data, both casted to float16,
-    # data should be float16
-    res = mx.sym.Group([out1, out3])
-    final_res = amp.convert_symbol(res, data_names=[], 
cast_optional_params=True)
-    exe = final_res._simple_bind(ctx=mx.gpu(), data=(1, 2))
-    assert exe.arg_arrays[0].dtype == np.float16
-
-    # AMP Multicast test where one node is float32, another is float16
-    data = mx.sym.var("data", dtype=np.float32)
-    data2 = mx.sym.var("data2", dtype=np.float16)
-    out4 = mx.sym.amp_multicast(data, data2, num_outputs=2)
-    final_res = amp.convert_symbol(out4, cast_optional_params=True)
-    exe = final_res._simple_bind(ctx=mx.gpu(), data2=(1, 2), data=(1, 2))
-    assert exe.arg_arrays[0].dtype == np.float16
-
-    # AMP Multicast test where two non input nodes are float16,
-    # and one input node is float32
-    data = mx.sym.var("data", dtype=np.float32)
-    data2 = mx.sym.var("data2", dtype=np.float16)
-    data3 = mx.sym.var("data3", dtype=np.float16)
-    out5 = mx.sym.amp_multicast(data,
-                                mx.sym.elemwise_add(data2, data3),
-                                num_outputs=2)
-    final_res = amp.convert_symbol(out5, target_dtype_ops=[],
-                                   fp32_ops=[], cast_optional_params=True)
-    exe = final_res._simple_bind(ctx=mx.gpu(), data=(1, 2), data2=(1, 2), 
data3=(1, 2))
-    assert exe.arg_arrays[0].dtype == np.float32
-
-    # AMP Multicast test where three input nodes one fp16, one fp32
-    # one unknown
-    data = mx.sym.var("data", dtype=np.float16)
-    data2 = mx.sym.var("data2", dtype=np.float32)
-    data3 = mx.sym.var("data3")
-    out6 = mx.sym.amp_multicast(data, data2, data3, num_outputs=3)
-    final_res = amp.convert_symbol(out6, target_dtype_ops=[],
-                                   fp32_ops=[], cast_optional_params=True)
-    exe = final_res._simple_bind(ctx=mx.gpu(), data=(1, 2), data2=(1, 2),
-                                data3=(1, 2))
-    assert exe.arg_arrays[2].dtype == np.float32
-
-    # Input node to amp_multicast and amp_cast, if dtypes conflict
-    # and input node is already fp16, it should still be fp16
-    data = mx.sym.var("data", dtype=np.float16)
-    data2 = mx.sym.var("data2", dtype=np.float32)
-    out7 = mx.sym.Group([mx.sym.amp_multicast(data, data2, num_outputs=2), 
mx.sym.amp_cast(data, dtype="float16")])
-    final_res = amp.convert_symbol(out7, target_dtype_ops=[],
-                                   fp32_ops=[], cast_optional_params=True)
-    exe = final_res._simple_bind(ctx=mx.gpu(), data=(1, 2), data2=(1, 2))
-    assert exe.arg_arrays[0].dtype == np.float16
-
-    # Input node to amp_multicast and amp_cast, if dtypes conflict
-    # and input node is already fp32, it should be changed to fp16
-    data = mx.sym.var("data", dtype=np.float32)
-    data2 = mx.sym.var("data2", dtype=np.float16)
-    out8 = mx.sym.Group([mx.sym.amp_multicast(data, data2, num_outputs=2), 
mx.sym.amp_cast(data, dtype="float16")])
-    final_res = amp.convert_symbol(out8, target_dtype_ops=[],
-                                   fp32_ops=[], cast_optional_params=True)
-    exe = final_res._simple_bind(ctx=mx.gpu(), data=(1, 2), data2=(1, 2))
-    assert exe.arg_arrays[0].dtype == np.float16
-
-    # Check for symbol which has slice channel
-    data = mx.sym.var("data")
-    data2 = mx.sym.var("data2")
-    data._set_attr(__dtype__="-1")
-    data2._set_attr(__dtype__="-1")
-    concat_res = mx.sym.concat(data, data2)
-    out = mx.sym.split(concat_res, axis=1, num_outputs=2)
-    final_res = amp.convert_symbol(out)
-
[email protected]_np
+def test_bf16_offline_casting():

Review comment:
       ```suggestion
   def test_fp16_offline_casting():
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to