AndrewZhaoLuo commented on code in PR #14382:
URL: https://github.com/apache/tvm/pull/14382#discussion_r1147885531
##########
src/relay/transforms/div_to_mul.cc:
##########
@@ -26,42 +26,61 @@
namespace tvm {
namespace relay {
+template <typename T>
+inline bool const_has_values(size_t size, const ConstantNode* const_node,
+ const std::vector<T>&& values) {
+ for (size_t i = 0; i < size; i++) {
+ T data = static_cast<T*>(const_node->data->data)[i];
+ for (const T& v : values) {
+ if (data == v) return true;
+ }
+ }
+ return false;
+}
+
+inline size_t get_num_elements_const(const ConstantNode* const_node) {
+ const auto& shape = const_node->data.Shape();
+
+ size_t cnt_elements = 1;
+ for (const auto& dim : shape) {
+ cnt_elements *= dim;
+ }
+
+ return cnt_elements;
+}
+
class DivToMulRewrite : public MixedModeMutator {
Expr Rewrite_(const CallNode* pre, const Expr& post) final {
if (const CallNode* call_node = post.as<CallNode>()) {
if (call_node->op == Op::Get("divide")) {
auto rhs = call_node->args[1].as<ConstantNode>();
if (rhs != nullptr) {
- auto inv =
- runtime::NDArray::Empty(rhs->data.Shape(), rhs->data.DataType(),
rhs->data->device);
+ auto one = runtime::NDArray::Empty({}, rhs->data.DataType(),
rhs->data->device);
+ size_t num_ele = get_num_elements_const(rhs);
std::string dtype = DLDataType2String(rhs->data.DataType());
+
+ bool const_has_zero_flag = false;
if (dtype == "float32") {
- float rhs_val = static_cast<float*>(rhs->data->data)[0];
- // Check for division by zero
- if (rhs_val == 0.) {
- return post;
- }
- static_cast<float*>(inv->data)[0] = 1. / rhs_val;
+ static_cast<float*>(one->data)[0] = 1.;
+ const_has_zero_flag = const_has_values<float>(num_ele, rhs, {0.});
} else if (dtype == "float64") {
- double rhs_val = static_cast<double*>(rhs->data->data)[0];
- // Check for division by zero
- if (rhs_val == 0.) {
- return post;
- }
- static_cast<double*>(inv->data)[0] = 1. / rhs_val;
+ static_cast<double*>(one->data)[0] = 1.;
+ const_has_zero_flag = const_has_values<double>(num_ele, rhs, {0.});
} else if (dtype == "float16") {
- // Do f16 math in f32
- float rhs_val =
__gnu_h2f_ieee(static_cast<uint16_t*>(rhs->data->data)[0]);
- // Check for division by zero
- if (rhs_val == 0.) {
- return post;
- }
- static_cast<uint16_t*>(inv->data)[0] = __gnu_f2h_ieee(1. /
rhs_val);
+ static_cast<uint16_t*>(one->data)[0] = __gnu_f2h_ieee(1.);
+ // have to handle both + and - zero semantics manually here
+ const_has_zero_flag = const_has_values<uint16_t>(num_ele, rhs,
{0x0000, 0x8000});
} else {
- // Cannot do 1/int because it will truncate
+ LOG(WARNING) << "Unknown dtype not handled for div_to_mull: " <<
rhs->data.DataType();
return post;
}
- return Multiply(call_node->args[0], Constant(inv));
+
+ if (const_has_zero_flag) {
+ return post;
+ }
+
+ // rely on constant folding to fold things
Review Comment:
Hmm good point, I think there is a way to specify a required pass to be run
after this one, so maybe will try to figure that out.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]