jcf94 commented on a change in pull request #8234:
URL: https://github.com/apache/tvm/pull/8234#discussion_r649638468



##########
File path: python/tvm/relay/op/nn/nn.py
##########
@@ -1471,7 +1471,46 @@ def bias_add(data, bias, axis=1):
     return _make.bias_add(data, bias, axis)
 
 
-def dense(data, weight, units=None, out_dtype=""):
+def matmul(data, weight, units=None, out_dtype="", data_transposed=False, 
weight_transposed=False):
+    """Dense operator.
+    Applies a linear transformation. The X & W can be transposed.
+
+    .. math::
+
+    `Y = X * W`
+
+    Parameters
+    ----------
+    data : tvm.relay.Expr
+        The input data to the operator,
+        of shape `(d_1, d_2, ..., d_n, units_in)` or `(d_1, d_2, ..., 
units_in, d_n)`.
+
+    weight : tvm.relay.Expr
+        The weight expressions, 2-D matrix,
+        of shape `(units_in, units)` or `(units, units_in)`.
+
+    units : int, optional
+        Number of hidden units of the dense transformation.
+
+    out_dtype : str, optional
+        Specifies the output data type for mixed precision dense,
+        of shape `(d_1, d_2, ..., d_n, units)`.
+
+    data_transposed : bool, optional
+        Whether the data tensor is in transposed format.
+
+    weight_transposed : bool, optional
+        Whether the weight tensor is in transposed format.
+
+    Returns
+    -------
+    result : tvm.relay.Expr
+        The computed result.
+    """
+    return _make.matmul(data, weight, units, out_dtype, data_transposed, 
weight_transposed)
+
+
+def dense(data, weight, units=None, out_dtype="", data_transposed=False, 
weight_transposed=True):

Review comment:
       TVM has some API calling `relay_op(xxx, **attrs)`, python function calls 
cannot work with the missing of the two parameters.

##########
File path: python/tvm/relay/frontend/tensorflow.py
##########
@@ -1230,6 +1239,9 @@ def from_tensorflow(graph, layout="NHWC", shape=None, 
outputs=None):
     params : dict of str to tvm.nd.NDArray
         Dict of converted parameters stored in tvm.nd.NDArray format
     """
+    global _USE_DENSE_INSTEAD_OF_MATMUL

Review comment:
       Yeah, I've also tried several ways, but seems there is no better 
solution from my view. Python module can be seen as a const singleton, this 
should be safe if the `from_tensorflow` function is the only entry.

##########
File path: python/tvm/relay/op/nn/_nn.py
##########
@@ -52,6 +52,32 @@
 reg.register_pattern("nn.log_softmax", OpPattern.OPAQUE)
 
 
[email protected]_legalize("nn.matmul")
+def leaglize_matmul(attrs, inputs, types):
+    """Legalize matmul op.

Review comment:
       Actually nothing now, this is just a blank function like the other ops 
do.

##########
File path: python/tvm/relay/op/strategy/cuda.py
##########
@@ -698,6 +698,26 @@ def conv1d_transpose_strategy_cuda(attrs, inputs, 
out_type, target):
     return strategy
 
 
+@matmul_strategy.register(["cuda", "gpu"])
+def matmul_strategy_cuda(attrs, inputs, out_type, target):
+    """dense cuda strategy"""
+    strategy = _op.OpStrategy()
+    if target.kind.name == "cuda" and "cublas" in target.libs:
+        strategy.add_implementation(
+            wrap_compute_matmul(topi.cuda.matmul_cublas),
+            wrap_topi_schedule(topi.cuda.schedule_matmul_cublas),
+            name="matmul_cublas.cuda",
+            plevel=25,
+        )
+    if is_auto_scheduler_enabled():
+        strategy.add_implementation(
+            wrap_compute_matmul(topi.nn.matmul),
+            naive_schedule,
+            name="matmul.cuda",
+        )

Review comment:
       In the current implementation, all `matmul` op with `NT` layout is 
transformed to `dense` at the beginning.

##########
File path: tests/python/relay/test_op_level1.py
##########
@@ -426,7 +486,7 @@ def test_dense():
     for dtype in ["float16", "float32"]:
         # Dense accuracy for float16 is poor
         if dtype == "float16":
-            return
+            continue

Review comment:
       I think this is a bug. This function will always return, and no check 
will be taken in this UT.

##########
File path: python/tvm/relay/op/nn/nn.py
##########
@@ -1471,7 +1471,46 @@ def bias_add(data, bias, axis=1):
     return _make.bias_add(data, bias, axis)
 
 
-def dense(data, weight, units=None, out_dtype=""):
+def matmul(data, weight, units=None, out_dtype="", data_transposed=False, 
weight_transposed=False):
+    """Dense operator.
+    Applies a linear transformation. The X & W can be transposed.
+
+    .. math::
+
+    `Y = X * W`
+
+    Parameters
+    ----------
+    data : tvm.relay.Expr
+        The input data to the operator,
+        of shape `(d_1, d_2, ..., d_n, units_in)` or `(d_1, d_2, ..., 
units_in, d_n)`.
+
+    weight : tvm.relay.Expr
+        The weight expressions, 2-D matrix,
+        of shape `(units_in, units)` or `(units, units_in)`.
+
+    units : int, optional
+        Number of hidden units of the dense transformation.
+
+    out_dtype : str, optional
+        Specifies the output data type for mixed precision dense,
+        of shape `(d_1, d_2, ..., d_n, units)`.
+
+    data_transposed : bool, optional
+        Whether the data tensor is in transposed format.
+
+    weight_transposed : bool, optional
+        Whether the weight tensor is in transposed format.
+
+    Returns
+    -------
+    result : tvm.relay.Expr
+        The computed result.
+    """
+    return _make.matmul(data, weight, units, out_dtype, data_transposed, 
weight_transposed)
+
+
+def dense(data, weight, units=None, out_dtype="", data_transposed=False, 
weight_transposed=True):

Review comment:
       TVM has some API calling `relay_op(xxx, **attrs)`, python function calls 
cannot work with the missing of the two parameters unless we still keep a 
`DenseAttrs`.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to