This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 99c4f9d  [TENSORFLOW]reduce ops updated (#5180)
99c4f9d is described below

commit 99c4f9d52a12a3fe83c5caa03cd7f1d827f5b5f0
Author: Samuel <siju.sam...@huawei.com>
AuthorDate: Fri Apr 10 07:26:47 2020 +0530

    [TENSORFLOW]reduce ops updated (#5180)
---
 python/tvm/relay/frontend/tensorflow.py          |  45 +++++----
 tests/python/frontend/tensorflow/test_forward.py | 123 ++++++-----------------
 2 files changed, 59 insertions(+), 109 deletions(-)

diff --git a/python/tvm/relay/frontend/tensorflow.py 
b/python/tvm/relay/frontend/tensorflow.py
index 84318c3..77dbcb5 100644
--- a/python/tvm/relay/frontend/tensorflow.py
+++ b/python/tvm/relay/frontend/tensorflow.py
@@ -1094,6 +1094,14 @@ def _reduce(op):
             ignores=['name', 'Tidx'])([inputs[0]], attr)
     return _impl
 
+def _euclidean_norm():
+    def _impl(inputs, attr, params, mod):
+        axis = tuple(_get_list_param(params, inputs[1]))
+        keep_dims = bool(attr.get('keep_dims', False))
+        return _op.sqrt(_op.cast(_op.reduce.sum(_op.multiply(inputs[0], 
inputs[0]),
+                                                axis, keep_dims), "float32"))
+    return _impl
+
 def _square():
     def _impl(inputs, attr, params, mod):
         return _op.multiply(inputs[0], inputs[0])
@@ -1686,8 +1694,8 @@ _freezed_graph_pruned_op_list = ['ReadVariableOp', 
'ResourceGather', 'Variable',
 _convert_map = {
     'Abs'                               : AttrCvt('abs'),
     'Add'                               : _elemwise('add'),
-    'AddV2'                             : _elemwise('add'),
     'AddN'                              : _add_n(),
+    'AddV2'                             : _elemwise('add'),
     'All'                               : _reduce('all'),
     'Any'                               : _reduce('any'),
     'ArgMax'                            : _argx(_op.argmax, 'argmax'),
@@ -1710,16 +1718,18 @@ _convert_map = {
     'Concat'                            : _concat(),
     'ConcatV2'                          : _concatV2(),
     'Conv2D'                            : _conv('conv'),
-    'Conv3D'                            : _conv3d('conv'),
     'Conv2DBackpropInput'               : _conv('conv_transpose'),
+    'Conv3D'                            : _conv3d('conv'),
+    'Cos'                               : AttrCvt('cos'),
     'CropAndResize'                     : _crop_and_resize(),
     'DecodeJpeg'                        : _decode_image(),
-    'DepthwiseConv2dNative'             : _conv('depthwise'),
     'DepthToSpace'                      : _depth_to_space(),
+    'DepthwiseConv2dNative'             : _conv('depthwise'),
     'Dilation2D'                        : _dilation2d(),
-    'Equal'                             : _broadcast('equal'),
     'Elu'                               : _elu(),
+    'Equal'                             : _broadcast('equal'),
     'Erf'                               : AttrCvt('erf'),
+    'EuclideanNorm'                     : _euclidean_norm(),
     'Exp'                               : AttrCvt('exp'),
     'ExpandDims'                        : _expand_dims(),
     'Fill'                              : _fill(),
@@ -1743,19 +1753,16 @@ _convert_map = {
     'LessEqual'                         : _broadcast('less_equal'),
     'Log'                               : AttrCvt('log'),
     'Log1p'                             : _log1p(),
-    'Tan'                               : AttrCvt('tan'),
-    'Cos'                               : AttrCvt('cos'),
-    'Sin'                               : AttrCvt('sin'),
     'LogicalAnd'                        : _logical('logical_and'),
-    'LogicalOr'                         : _logical('logical_or'),
     'LogicalNot'                        : _logical('logical_not'),
+    'LogicalOr'                         : _logical('logical_or'),
     'LogSoftmax'                        : AttrCvt('log_softmax'),
     'LRN'                               : _lrn(),
     'MatMul'                            : _matmul(),
     'Max'                               : _reduce('max'),
+    'Maximum'                           : _elemwise('maximum'),
     'MaxPool'                           : _pooling('max_pool'),
     'MaxPool3D'                         : _pool3d('max_pool3d'),
-    'Maximum'                           : _elemwise('maximum'),
     'Mean'                              : _mean(),
     'Min'                               : _reduce('min'),
     'Minimum'                           : _elemwise('minimum'),
@@ -1767,14 +1774,6 @@ _convert_map = {
     'NotEqual'                          : _broadcast('not_equal'),
     'OneHot'                            : _one_hot(),
     'Pack'                              : _pack(),
-    'TensorArrayV3'                     : _tensor_array(),
-    'TensorArrayScatterV3'              : _tensor_array_scatter(),
-    'TensorArrayGatherV3'               : _tensor_array_gather(),
-    'TensorArraySizeV3'                 : _tensor_array_size(),
-    'TensorArrayWriteV3'                : _tensor_array_write(),
-    'TensorArrayReadV3'                 : _tensor_array_read(),
-    'TensorArraySplitV3'                : _tensor_array_split(),
-    'TensorArrayConcatV3'               : _tensor_array_concat(),
     'Pad'                               : _pad('Pad'),
     'PadV2'                             : _pad('PadV2'),
     'Pow'                               : _elemwise('power'),
@@ -1785,8 +1784,8 @@ _convert_map = {
     'Relu'                              : AttrCvt('relu'),
     'Relu6'                             : _relu6(),
     'Reshape'                           : _reshape(),
-    'ResizeBilinear'                    : _resize('bilinear'),
     'ResizeBicubic'                     : _resize('bilinear'),
+    'ResizeBilinear'                    : _resize('bilinear'),
     'ResizeNearestNeighbor'             : _resize('nearest_neighbor'),
     'ReverseV2'                         : _reverse_v2(),
     'RightShift'                        : AttrCvt('right_shift'),
@@ -1797,6 +1796,7 @@ _convert_map = {
     'Shape'                             : _shape(),
     'Sigmoid'                           : AttrCvt('sigmoid'),
     'Sign'                              : AttrCvt('sign'),
+    'Sin'                               : AttrCvt('sin'),
     'Size'                              : _size(),
     'Slice'                             : _slice(),
     'Softmax'                           : _softmax(),
@@ -1813,7 +1813,16 @@ _convert_map = {
     'StridedSlice'                      : _stridedSlice(),
     'Sub'                               : _elemwise('subtract'),
     'Sum'                               : _sum(),
+    'Tan'                               : AttrCvt('tan'),
     'Tanh'                              : AttrCvt('tanh'),
+    'TensorArrayConcatV3'               : _tensor_array_concat(),
+    'TensorArrayGatherV3'               : _tensor_array_gather(),
+    'TensorArrayReadV3'                 : _tensor_array_read(),
+    'TensorArrayScatterV3'              : _tensor_array_scatter(),
+    'TensorArraySizeV3'                 : _tensor_array_size(),
+    'TensorArraySplitV3'                : _tensor_array_split(),
+    'TensorArrayV3'                     : _tensor_array(),
+    'TensorArrayWriteV3'                : _tensor_array_write(),
     'Tile'                              : _tile(),
     'TopKV2'                            : _topk(),
     'Transpose'                         : _transpose(),
diff --git a/tests/python/frontend/tensorflow/test_forward.py 
b/tests/python/frontend/tensorflow/test_forward.py
index 5c0391f..35a3466 100644
--- a/tests/python/frontend/tensorflow/test_forward.py
+++ b/tests/python/frontend/tensorflow/test_forward.py
@@ -1029,28 +1029,6 @@ def test_forward_argminmax():
         _test_argx(tf.argmax, data=data, axis=axis)
         _test_argx(tf.argmin, data=data, axis=axis)
 
-#######################################################################
-# Reduce
-# ------
-
-
-def _test_reduce(func, data, **kwargs):
-    """ One iteration of a reduce operation"""
-
-    with tf.Graph().as_default():
-        inp = array_ops.placeholder(
-            shape=data.shape, dtype=data.dtype, name="c0")
-        func(inp, name="reducex0", **kwargs)
-
-        compare_tf_with_tvm(data, 'c0:0', 'reducex0:0')
-
-
-def test_forward_reduce():
-    data = np.random.uniform(size=(8, 4, 9)).astype('float32')
-    _test_reduce(tf.reduce_sum, data=data)
-    _test_reduce(tf.reduce_sum, data=data, axis=0)
-    _test_reduce(tf.reduce_sum, data=data, axis=(0, 1))
-
 
 #######################################################################
 # Variable
@@ -2845,55 +2823,42 @@ def test_forward_size():
     check_size((10,))
 
 #######################################################################
-# All, Any, Max, Min
-# ------------------
-
-def test_forward_reduce_all():
-    """Test the All operator."""
-    np_data = np.random.choice([True, False], size=(5, 7, 11))
-    tf.reset_default_graph()
-    with tf.Graph().as_default():
-        in_data = tf.placeholder(tf.bool, (5, 7, 11), name="in_data")
-        tf.reduce_all(in_data, name="all")
-        compare_tf_with_tvm([np_data], ['in_data:0'], 'all:0')
-
-def test_forward_reduce_any():
-    """Test the Any operator."""
-    np_data = np.random.choice([True, False], size=(5, 7, 11))
-    tf.reset_default_graph()
-    with tf.Graph().as_default():
-        in_data = tf.placeholder(tf.bool, (5, 7, 11), name="in_data")
-        tf.reduce_any(in_data, name="any")
-        compare_tf_with_tvm([np_data], ['in_data:0'], 'any:0')
+# All, Any, Max, Min, Prod, variance, std, logsumexp, euclidean_norm
+# ------------------------------------------------------------------
 
-def test_forward_reduce_max():
-    def check_max(ishape, axis, keepdims, dtype):
-        tf.reset_default_graph()
-        np_data = np.random.uniform(size=ishape).astype(dtype)
-        with tf.Graph().as_default():
-            in_data = tf.placeholder(dtype, name="in_data")
-            tf.math.reduce_max(in_data, axis=axis,
-                               keepdims=keepdims, name="reduce_max")
-            compare_tf_with_tvm([np_data], ['in_data:0'], 'reduce_max:0')
-
-    check_max((10, 8, 16, 32), axis=(-1), keepdims=True, dtype="int32")
-    check_max((10, 8, 16, 32), axis=(2, 3), keepdims=True, dtype="float32")
-    check_max((10, 8, 16, 32), axis=(1, 2), keepdims=True, dtype='float32')
-
-
-def test_forward_reduce_min():
-    def check_min(ishape, axis, keepdims, dtype):
+def test_forward_reduce():
+    def _check_op(tf_op, ishape, axis, keepdims, dtype="float32"):
         tf.reset_default_graph()
-        np_data = np.random.uniform(size=ishape).astype(dtype)
+        if dtype == "bool":
+            np_data = np.random.choice([True, False], size=ishape)
+        else:
+            np_data = np.random.uniform(size=ishape).astype(dtype)
+        if tf_op == tf.math.reduce_prod:
+            axis = 1
+            np_data = np_data.reshape(1, -1)
         with tf.Graph().as_default():
             in_data = tf.placeholder(dtype, name="in_data")
-            tf.math.reduce_min(in_data, axis=axis,
-                               keepdims=keepdims, name="reduce_max")
-            compare_tf_with_tvm([np_data], ['in_data:0'], 'reduce_max:0')
-
-    check_min((10, 8, 16, 32), axis=(-1), keepdims=True, dtype="int32")
-    check_min((10, 8, 16, 32), axis=(2, 3), keepdims=True, dtype="float32")
-    check_min((10, 8, 16, 32), axis=(1, 2), keepdims=True, dtype='float32')
+            reduce_op = tf_op(in_data, axis=axis,
+                               keepdims=keepdims, name="reduce_std")
+            compare_tf_with_tvm([np_data], ['in_data:0'], reduce_op.name)
+
+    def _test_math_op(op, dtypes=["int32", "float32"]):
+        for dtype in dtypes:
+            _check_op(op, (3, 10), axis=(-1), keepdims=False, dtype=dtype)
+            _check_op(op, (8, 16, 32), axis=(-1), keepdims=False, dtype=dtype)
+            _check_op(op, (1, 8, 8, 3), axis=(2, 3), keepdims=True, 
dtype=dtype)
+            _check_op(op, (2, 3, 10, 10), axis=(1, 2), keepdims=True, 
dtype=dtype)
+
+    _test_math_op(tf.math.reduce_all, dtypes=["bool"])
+    _test_math_op(tf.math.reduce_any, dtypes=["bool"])
+    _test_math_op(tf.math.reduce_max)
+    _test_math_op(tf.math.reduce_min)
+    _test_math_op(tf.math.reduce_prod)
+    _test_math_op(tf.math.reduce_variance)
+    _test_math_op(tf.math.reduce_std, dtypes=["float32"])
+    _test_math_op(tf.math.reduce_logsumexp, dtypes=["float32"])
+    if package_version.parse(tf.VERSION) >= package_version.parse('1.15.0'):
+        _test_math_op(tf.math.reduce_euclidean_norm)
 
 #######################################################################
 # Relational operators
@@ -2943,26 +2908,6 @@ def test_forward_expand_dims():
 
 
 #######################################################################
-# Prod
-# ----
-def _test_forward_reduce_prod(shape, axis, keepdims):
-    inp_array1 = np.random.uniform(-5, 5, size=shape).astype(np.float32)
-    with tf.Graph().as_default():
-        in1 = tf.placeholder(shape=inp_array1.shape, dtype=inp_array1.dtype)
-        out = tf.math.reduce_prod(in1, axis, keepdims)
-        compare_tf_with_tvm(inp_array1, in1.name, out.name)
-
-
-def test_forward_reduce_prod():
-    _test_forward_reduce_prod((5,), 0, False)
-    _test_forward_reduce_prod((5, 5), 0, False)
-    _test_forward_reduce_prod((5, 5), 1, False)
-    _test_forward_reduce_prod((5,), 0, True)
-    _test_forward_reduce_prod((5, 5), 0, True)
-    _test_forward_reduce_prod((5, 5), 1, True)
-
-
-#######################################################################
 # Maximum, Minimum
 # ----------------
 def test_forward_maximum():
@@ -3295,10 +3240,6 @@ if __name__ == '__main__':
     test_forward_argminmax()
     test_forward_reduce()
     test_forward_mean()
-    test_forward_reduce_prod()
-    test_forward_reduce_all()
-    test_forward_reduce_any()
-    test_forward_reduce_min()
 
     # General
     test_forward_multi_input()

Reply via email to