leandron commented on a change in pull request #8277:
URL: https://github.com/apache/tvm/pull/8277#discussion_r654270621



##########
File path: tests/python/frontend/tensorflow/test_forward.py
##########
@@ -1717,58 +1717,58 @@ def test_forward_variable():
     _test_variable(np.random.uniform(size=(32, 100)).astype("float32"))
 
 
[email protected]_targets("llvm", "cuda")
-def test_read_variable_op(target, dev):
-    """Read Variable op test"""
-
-    tf.reset_default_graph()
-    data = np.random.uniform(size=(32, 100)).astype("float32")
-    input_tensor = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
-
-    size = input_tensor.shape.dims[1]
-    var_data = np.random.uniform(-5, 5, size=[size, size]).astype(np.float32)
-    input_var = tf.Variable(var_data, name="var1", use_resource=True)
-    math_ops.matmul(input_tensor, input_var)
-
-    out_name = ["MatMul:0"]
-    out_node = ["MatMul"]
-    in_name = ["Placeholder:0"]
-    in_node = ["Placeholder"]
-    in_data = [data]
-
-    with tf.Session() as sess:
-        sess.run(variables.global_variables_initializer())
-
-        final_graph_def = sess.graph.as_graph_def(add_shapes=True)
-        tf_output = run_tf_graph(sess, in_data, in_name, out_name)
-
-        shape_dict = {e: i.shape for e, i in zip(in_name, in_data)}
-        with pytest.raises(Exception) as execinfo:
-            mod, params = relay.frontend.from_tensorflow(
-                final_graph_def, layout=None, shape=shape_dict, outputs=None
-            )
-
-        assert execinfo.value.args[0].startswith("Graph is not frozen. Provide 
a frozen graph")
-
-        # Now convert the variables to constant and run inference on the 
converted graph
-        final_graph_def = tf.graph_util.convert_variables_to_constants(
-            sess,
-            sess.graph.as_graph_def(add_shapes=True),
-            out_node,
-        )
-
-        tvm_output = run_tvm_graph(
-            final_graph_def,
-            in_data,
-            in_node,
-            target=target,
-            out_names=out_name,
-            num_output=len(out_name),
-        )
-        for i in range(len(tf_output)):
-            tvm.testing.assert_allclose(tf_output[i], tvm_output[i], 
atol=1e-4, rtol=1e-5)
-
-        sess.close()
+# @tvm.testing.parametrize_targets("llvm", "cuda")

Review comment:
       Is this test commented out intentionally?

##########
File path: python/tvm/topi/transform.py
##########
@@ -941,3 +942,33 @@ def adv_index(data, indices):
         Output tensor
     """
     return cpp.adv_index(data, indices)
+
+
[email protected]
+def invert_permutation(data):
+    """Computes the inverse permutation of data.
+
+    Parameters
+    ----------
+    data : tvm.te.Tensor
+        Input data
+
+    Returns
+    -------
+    result : tvm.te.Tensor
+        Output tensor
+
+    Examples
+    --------
+    .. code-block:: python
+
+        data = [3, 4, 0, 2, 1]
+
+        topi.invert_permutation(data) = [2, 4, 3, 0, 1]

Review comment:
       ```suggestion
       .. code-block:: python
           data = [3, 4, 0, 2, 1]
           topi.invert_permutation(data) = [2, 4, 3, 0, 1]
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to