RafLit commented on a change in pull request #20835: URL: https://github.com/apache/incubator-mxnet/pull/20835#discussion_r795424225
########## File path: src/operator/quantization/quantized_reshape-inl.h ########## @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file quantized_reshape-inl.h + * \author: Adam Grabowski, [email protected] + */ + +#ifndef MXNET_OPERATOR_QUANTIZATION_QUANTIZED_RESHAPE_INL_H_ +#define MXNET_OPERATOR_QUANTIZATION_QUANTIZED_RESHAPE_INL_H_ + +#include <string> +#include <vector> +#include "operator/tensor/matrix_op-inl.h" +#include "operator/numpy/np_matrix_op-inl.h" + +namespace mxnet { +namespace op { + +struct QuantizedReshapeParam : public dmlc::Parameter<QuantizedReshapeParam> { + mxnet::TShape newshape; + mxnet::Tuple<int> shape; + bool reverse, keep_highest, is_numpy_op; + std::string order; Review comment: I don't think merging both parameter types is a good idea. Duplicating type information in a variable is counter-intuitive. What do you think about separating the operator into two versions: numpy and ndarray one? ########## File path: tests/python/quantization/test_quantization.py ########## @@ -945,6 +945,46 @@ def check_quantized_bn(data_shape, qdtype): check_quantized_bn((32, 3, 224, 224), qdtype) +def test_quantized_reshape(): + test_cases = [((2, 3, 5, 5), (-2, -1), False, (2, 75)), + ((2, 3, 5, 5), (-2, -2, -1), False, (2, 3, 25)), + ((5, 3, 4, 5), (-2, -1, -2), False, (5, 15, 4)), + ((2, 3, 5, 4), (-1, -2, -2), False, (8, 3, 5)), + ((2, 3, 5, 5), (-2, -2, -2, -2), False, (2, 3, 5, 5)), + ((2, 1, 4, 5), (-2, -3, -2, -2), False, (2, 4, 5)), + ((1, 1, 4, 1), (-3, -3, -2, -2), False, (4, 1)), + ((1, 1, 1, 1), (-3, -3, -3, -3), False, ()), + ((2, 4, 5, 3), (-1, 2, 2, 1), False, (30, 2, 2, 1)), + ((2, 3, 5, 6), (-4,), False, (2, 3, 5, 6)), + ((2, 3, 5, 6), (6, 1, -4), False, (6, 1, 5, 6)), + ((2, 3, 5, 6), (-5, -5), False, (6, 30)), + ((2, 3, 5, 6), (-5, -1), False, (6, 30)), + ((64,), (-6, 16, 4), False, (16, 4)), + ((64,), (-6, 16, -1), False, (16, 4)), + ((64, 1, 2, 3), (-6, 16, -1, -4), False, (16, 4, 1, 2, 3)), + ((8, 5, 4, 6), (-4, -1, 3, -6), True, (8, 5, 4, 2, 3))] + + def check_quantized_reshape(shape, qdtype, newshape, reverse, expected_ret_shape): + if qdtype == 'uint8': + data_low = 0.0 + data_high = 127.0 + else: + data_low = -127.0 + data_high = 127.0 + qdata = mx.np.random.uniform(low=data_low, high=data_high, size=shape).astype(qdtype) + min_data = mx.np.array([-1023.343], dtype='float32') + max_data = mx.np.array([2343.324275], dtype='float32') + qoutput, min_output, max_output = npx.quantized_reshape(qdata, min_data, max_data, newshape=newshape, reverse=reverse) + assert qoutput.shape == expected_ret_shape + assert same(qdata.asnumpy().flatten(), qoutput.asnumpy().flatten()) Review comment: is flatten necessary here? ########## File path: tests/python/dnnl/subgraphs/test_conv_subgraph.py ########## @@ -73,6 +73,28 @@ def forward(self, x): check_fusion(net, data_shape, attr) [email protected]_np [email protected]('data_shape', DATA_SHAPE) [email protected]('use_bias', [True, False]) +def test_conv_reshape_conv(use_bias, data_shape): + + class Conv_Reshape_Conv(nn.HybridBlock): + def __init__(self, **kwargs): + super(Conv_Reshape_Conv, self).__init__(**kwargs) + self.conv0 = nn.Conv2D(channels=64, kernel_size=(3, 3), strides=1, use_bias=use_bias) + self.conv1 = nn.Conv2D(channels=32, kernel_size=(5, 5), strides=1, use_bias=use_bias) + + def forward(self, x): + out = self.conv0(x) + out = mx.npx.reshape(out, newshape=(-1, int(out.shape[1]/4), out.shape[2]*2, out.shape[3]*2)) + out = self.conv1(out) + return out + + attr = {'conv': []} + net = Conv_Reshape_Conv() + check_fusion(net, data_shape, attr) Review comment: Wouldn't it be better to use check_quantize explicitly? Is anything fused in this example? ########## File path: tests/python/quantization/test_quantization.py ########## @@ -945,6 +945,46 @@ def check_quantized_bn(data_shape, qdtype): check_quantized_bn((32, 3, 224, 224), qdtype) +def test_quantized_reshape(): + test_cases = [((2, 3, 5, 5), (-2, -1), False, (2, 75)), + ((2, 3, 5, 5), (-2, -2, -1), False, (2, 3, 25)), + ((5, 3, 4, 5), (-2, -1, -2), False, (5, 15, 4)), + ((2, 3, 5, 4), (-1, -2, -2), False, (8, 3, 5)), + ((2, 3, 5, 5), (-2, -2, -2, -2), False, (2, 3, 5, 5)), + ((2, 1, 4, 5), (-2, -3, -2, -2), False, (2, 4, 5)), + ((1, 1, 4, 1), (-3, -3, -2, -2), False, (4, 1)), + ((1, 1, 1, 1), (-3, -3, -3, -3), False, ()), + ((2, 4, 5, 3), (-1, 2, 2, 1), False, (30, 2, 2, 1)), + ((2, 3, 5, 6), (-4,), False, (2, 3, 5, 6)), + ((2, 3, 5, 6), (6, 1, -4), False, (6, 1, 5, 6)), + ((2, 3, 5, 6), (-5, -5), False, (6, 30)), + ((2, 3, 5, 6), (-5, -1), False, (6, 30)), + ((64,), (-6, 16, 4), False, (16, 4)), + ((64,), (-6, 16, -1), False, (16, 4)), + ((64, 1, 2, 3), (-6, 16, -1, -4), False, (16, 4, 1, 2, 3)), + ((8, 5, 4, 6), (-4, -1, 3, -6), True, (8, 5, 4, 2, 3))] + + def check_quantized_reshape(shape, qdtype, newshape, reverse, expected_ret_shape): + if qdtype == 'uint8': + data_low = 0.0 + data_high = 127.0 Review comment: can this be increased to 255? ########## File path: src/operator/quantization/quantized_reshape-inl.h ########## @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file quantized_reshape-inl.h + * \author: Adam Grabowski, [email protected] + */ + +#ifndef MXNET_OPERATOR_QUANTIZATION_QUANTIZED_RESHAPE_INL_H_ +#define MXNET_OPERATOR_QUANTIZATION_QUANTIZED_RESHAPE_INL_H_ + +#include <string> +#include <vector> +#include "operator/tensor/matrix_op-inl.h" +#include "operator/numpy/np_matrix_op-inl.h" + +namespace mxnet { +namespace op { + +struct QuantizedReshapeParam : public dmlc::Parameter<QuantizedReshapeParam> { + mxnet::TShape newshape; + mxnet::Tuple<int> shape; + bool reverse, keep_highest, is_numpy_op; + std::string order; + + DMLC_DECLARE_PARAMETER(QuantizedReshapeParam) { + DMLC_DECLARE_FIELD(newshape).set_default(mxnet::TShape(0, -1)); + DMLC_DECLARE_FIELD(shape).set_default(mxnet::Tuple<int>()); + DMLC_DECLARE_FIELD(reverse).set_default(false); + DMLC_DECLARE_FIELD(order).set_default("C"); + DMLC_DECLARE_FIELD(keep_highest).set_default(false); + DMLC_DECLARE_FIELD(is_numpy_op).set_default(true); Review comment: I think redundant type information should be avoided ########## File path: src/operator/quantization/quantized_reshape.cc ########## @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file quantized_reshape.cc + * \author: Adam Grabowski, [email protected] + */ + +#include <utility> +#include "quantized_reshape-inl.h" + +namespace mxnet { +namespace op { + +DMLC_REGISTER_PARAMETER(QuantizedReshapeParam); + +void QuantizedReshapeCompute(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector<TBlob>& inputs, + const std::vector<OpReqType>& req, + const std::vector<TBlob>& outputs) { + CHECK_EQ(inputs.size(), 3U); + CHECK_EQ(outputs.size(), 3U); + CHECK_EQ(req.size(), 3U); + + if (req[0] != kWriteInplace) + UnaryOp::IdentityCompute<cpu>(attrs, ctx, inputs, req, outputs); + + *outputs[1].dptr<float>() = *inputs[1].dptr<float>(); + *outputs[2].dptr<float>() = *inputs[2].dptr<float>(); +} + +NNVM_REGISTER_OP(_contrib_quantized_reshape) + .add_alias("_npx_quantized_reshape") + .set_num_inputs(3) + .set_num_outputs(3) + .set_attr_parser(ParamParser<QuantizedReshapeParam>) + .set_attr<nnvm::FListInputNames>( + "FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector<std::string>{"data", "min_data", "max_data"}; + }) + .set_attr<nnvm::FListOutputNames>( + "FListOutputNames", + [](const NodeAttrs& attrs) { + return std::vector<std::string>{"output", "min_output", "max_output"}; + }) + .set_attr<nnvm::FInplaceOption>( + "FInplaceOption", + [](const NodeAttrs& attrs) { + return std::vector<std::pair<int, int> >{{0, 0}, {1, 1}, {2, 2}}; + }) + .set_attr<FCompute>("FCompute<cpu>", QuantizedReshapeCompute) + .set_attr<FResourceRequest>("FResourceRequest", + [](const NodeAttrs& n) { + return std::vector<ResourceRequest>{ResourceRequest::kTempSpace}; + }) + .set_attr<mxnet::FInferShape>("FInferShape", QuantizedReshapeInferShape) + .set_attr<nnvm::FInferType>("FInferType", QuantizedReshapeType) + .set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes) + .set_attr<FQuantizable>("FQuantizable", + [](const NodeAttrs& attrs) { return QuantizeType::kSupport; }) + .add_argument("data", "NDArray-or-Symbol", "Array to be reshaped.") + .add_argument("min_data", + "NDArray-or-Symbol", + "The minimum scalar value " + "possibly produced for the data") + .add_argument("max_data", + "NDArray-or-Symbol", + "The maximum scalar value " + "possibly produced for the data") + .add_arguments(QuantizedReshapeParam::__FIELDS__()); + +template <bool is_numpy_op> +nnvm::ObjectPtr QuantizedReshapeNode(const NodeAttrs& attrs) { + QuantizedReshapeParam param; + if (is_numpy_op) { Review comment: this will be checked at runtime for both the numpy version and ndarray. It would be more intuitive to separate them -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
