This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new f6f90569bc Add tensorflow Einsum op converter (#12064)
f6f90569bc is described below

commit f6f90569bc893204c39f8a32a42612972a7d138f
Author: Donglin Zhuang <[email protected]>
AuthorDate: Thu Jul 14 19:15:41 2022 +0800

    Add tensorflow Einsum op converter (#12064)
    
    * Add tensorflow Einsum op converter
    
    * fix lint
    
    * fix lint
---
 python/tvm/relay/frontend/tensorflow_ops.py      | 10 +++++++
 tests/python/frontend/tensorflow/test_forward.py | 33 ++++++++++++++++++++++++
 2 files changed, 43 insertions(+)

diff --git a/python/tvm/relay/frontend/tensorflow_ops.py 
b/python/tvm/relay/frontend/tensorflow_ops.py
index 9b36d712e9..c94a4ef2e6 100644
--- a/python/tvm/relay/frontend/tensorflow_ops.py
+++ b/python/tvm/relay/frontend/tensorflow_ops.py
@@ -2480,6 +2480,15 @@ def _range():
     return _impl
 
 
+def _einsum():
+    def _impl(inputs, attr, params, mod):
+        einsum_attr = dict(attr)
+        einsum_attr["equation"] = einsum_attr["equation"].decode("utf-8")
+        return AttrCvt(op_name="einsum", ignores=["N"])([inputs], einsum_attr)
+
+    return _impl
+
+
 def _elu():
     def _impl(inputs, attr, params, mod):
         dtype = attr["T"].name
@@ -2907,6 +2916,7 @@ _convert_map = {
     "DepthToSpace": _depth_to_space(),
     "DepthwiseConv2dNative": _conv("depthwise"),
     "Dilation2D": _dilation2d(),
+    "Einsum": _einsum(),
     "Elu": _elu(),
     "Equal": _broadcast("equal"),
     "Erf": AttrCvt("erf"),
diff --git a/tests/python/frontend/tensorflow/test_forward.py 
b/tests/python/frontend/tensorflow/test_forward.py
index 4988f57c24..70a137479f 100644
--- a/tests/python/frontend/tensorflow/test_forward.py
+++ b/tests/python/frontend/tensorflow/test_forward.py
@@ -3671,6 +3671,39 @@ def test_forward_range():
         compare_tf_with_tvm([], [], "range:0")
 
 
+#######################################################################
+# Einsum
+# -----
+
+
+def _test_einsum(equation, dtype, *shape_of_input_tensors):
+    """Test Einsum Op"""
+
+    with tf.Graph().as_default():
+        inputs_placeholders = []
+        input_data = []
+        for idx, shape in enumerate(shape_of_input_tensors):
+            input_name = f"input_{idx}"
+            inputs_placeholders.append(tf.placeholder(shape=shape, 
dtype=dtype, name=input_name))
+            input_data.append(np.random.normal(size=shape).astype(dtype))
+
+        result = tf.einsum(equation, *inputs_placeholders)
+
+        compare_tf_with_tvm(input_data, [ph.name for ph in 
inputs_placeholders], result.name)
+
+
+def test_forward_einsum():
+    for dtype in ["float32"]:
+        _test_einsum("ij,jk->ik", dtype, [2, 3], [3, 5])  # Matmul
+        _test_einsum("ij,jk", dtype, [2, 3], [3, 5])  # Matmul
+        _test_einsum("i,i->", dtype, [2], [2])  # Dot product
+        _test_einsum("i,j->ij", dtype, [3], [5])  # Outer produce
+        _test_einsum("ij->ji", dtype, [2, 3])  # Transpose
+        _test_einsum("ii->i", dtype, [3, 3])  # Diag
+        _test_einsum("ii", dtype, [3, 3])  # Trace of a square matrix
+        _test_einsum("bij,bjk->bik", dtype, [7, 5, 3], [7, 3, 2])  # Batch 
matmul
+
+
 #######################################################################
 # Pad
 # ---

Reply via email to