masahi commented on code in PR #11190:
URL: https://github.com/apache/tvm/pull/11190#discussion_r863470284


##########
python/tvm/relay/op/transform.py:
##########
@@ -1823,3 +1823,40 @@ def invert_permutation(data):
         relay.invert_permutation(data) = [2, 4, 3, 0, 1]
     """
     return _make.invert_permutation(data)
+
+
+def stft(data, n_fft, hop_length, win_length, window, normalized, onesided):
+    """
+    The STFT computes the Fourier transform of short overlapping windows of 
the input.
+    This giving frequency components of the signal as they change over time.
+    Parameters
+    ----------
+    data : relay.Expr
+        Either a 1-D tensor or a 2-D batch tensor.
+    n_fft : int
+        The size of Fourier transform
+    hop_length : int
+        The distance between neighboring sliding window frames
+    win_length : int
+        The size of window frame and STFT filter
+    window : relay.Expr
+        A 1-D tensor window frame
+    normalized : bool
+        Whether to return the normalized STFT results
+    onesided : bool
+        Whether to return onesided result or fill with conjugate symmetry
+    Returns
+    -------
+    output : relay.Expr
+        Tensor containing the STFT result

Review Comment:
   Document the output shape. I had to read the type rel to see how the output 
shape looks like.



##########
python/tvm/relay/op/transform.py:
##########
@@ -1823,3 +1823,40 @@ def invert_permutation(data):
         relay.invert_permutation(data) = [2, 4, 3, 0, 1]
     """
     return _make.invert_permutation(data)
+
+
+def stft(data, n_fft, hop_length, win_length, window, normalized, onesided):
+    """
+    The STFT computes the Fourier transform of short overlapping windows of 
the input.
+    This giving frequency components of the signal as they change over time.
+    Parameters
+    ----------
+    data : relay.Expr
+        Either a 1-D tensor or a 2-D batch tensor.
+    n_fft : int
+        The size of Fourier transform
+    hop_length : int
+        The distance between neighboring sliding window frames
+    win_length : int
+        The size of window frame and STFT filter
+    window : relay.Expr
+        A 1-D tensor window frame

Review Comment:
   In PyTorch, window argument is optional. So shouldn't we support that too?



##########
tests/python/relay/test_op_level3.py:
##########
@@ -1784,23 +1784,6 @@ def test_segment_sum(
         )
 
 
-def verify_func(target, dev, func, data, ref_res):
-    assert isinstance(data, list)
-    for kind in ["vm"]:
-        mod = tvm.ir.IRModule.from_expr(func)
-        op_res = relay.create_executor(kind, mod=mod, device=dev, 
target=target).evaluate()(*data)
-        if isinstance(op_res, tvm.runtime.container.ADT):
-            assert len(op_res) == len(
-                ref_res
-            ), "Outputs from TVM and Python implementation must be equal "
-
-            for op_result, ref_result in zip(op_res, ref_res):
-                tvm.testing.assert_allclose(op_result.numpy(), ref_result, 
rtol=1e-5)
-        else:
-            tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5)
-        relay.backend.te_compiler.get().clear()
-
-

Review Comment:
   Remove this diff



##########
python/tvm/topi/stft.py:
##########
@@ -0,0 +1,121 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, too-many-arguments, too-many-nested-blocks
+"""STFT operator"""
+from math import pi
+from tvm import te, tir
+
+
+def stft(
+    data,
+    n_fft,
+    hop_length,
+    win_length,
+    window,
+    normalized,
+    onesided,  # pylint: disable=unused-argument
+    output_shape,
+):
+    """
+    The STFT computes the Fourier transform of short overlapping windows of 
the input.
+    This giving frequency components of the signal as they change over time.
+    Parameters
+    ----------
+    data : relay.Expr
+        Either a 1-D tensor or a 2-D batch tensor.
+    n_fft : int
+        The size of Fourier transform
+    hop_length : int
+        The distance between neighboring sliding window frames
+    win_length : int
+        The size of window frame and STFT filter
+    window : relay.Expr
+        A 1-D tensor window frame
+    normalized : bool
+        Whether to return the normalized STFT results
+    onesided : bool
+        Whether to return onesided result or fill with conjugate symmetry
+    Returns
+    -------
+    output : relay.Expr
+        Tensor containing the STFT result
+    Examples
+    --------
+    .. code-block:: python
+
+        data = [1, 2, 3, 4, 5, 6]
+        window = [4, 3, 2]
+        [n_fft, hop_length, win_length, normalized, onesided] = [3, 3, 3, 
False, True]
+        relay.stft(data, n_fft, hop_length, win_length, window, normalized, 
onesided)
+        -> [[[15.0000,  0.0000], [34.0000,  0.0000]], [[ 4.5000,  0.8660], [ 
1.0000, -1.7321]]]
+    """
+
+    def gen_ir(
+        data_ptr,
+        n_fft,
+        hop_length,
+        win_length,
+        window_ptr,
+        normalized,
+        onesided,  # pylint: disable=unused-argument
+        output_ptr,
+        loop_kind,
+    ):
+        ib = tir.ir_builder.create()
+        data = ib.buffer_ptr(data_ptr)
+        window = ib.buffer_ptr(window_ptr)
+        output = ib.buffer_ptr(output_ptr)
+
+        with ib.for_range(0, output_ptr.shape[0]) as batch:
+            # 
https://librosa.org/doc/0.7.2/_modules/librosa/core/spectrum.html#stft
+            with ib.for_range(0, output_ptr.shape[1], kind="parallel") as row:

Review Comment:
   fuse the outer loop to have one big parallel loop.



##########
tests/python/frontend/pytorch/test_forward.py:
##########
@@ -4113,6 +4125,36 @@ def test_fn(equation):
     verify_model(test_fn("ij,jk,km->im"), [x, y, z])
 
 
+def test_stft():
+    def test_fn(n_fft, hop_length, win_length, center, pad_mode, normalized, 
onesided):
+        return lambda input, window: torch.stft(
+            input=input,
+            n_fft=n_fft,
+            hop_length=hop_length,
+            win_length=win_length,
+            window=window,
+            center=center,
+            pad_mode=pad_mode,
+            normalized=normalized,
+            onesided=onesided,
+        )
+
+    input = torch.rand([1, 12]).float()
+    window = torch.tensor([2, 3, 4], dtype=torch.int32)
+    targets = ["llvm", "cuda"]
+    verify_trace_model(test_fn(3, 3, 3, False, "constant", False, True), 
[input, window], targets)
+    verify_trace_model(test_fn(3, 3, 3, True, "constant", False, True), 
[input, window], targets)
+    verify_trace_model(test_fn(3, 3, 3, False, "reflect", False, True), 
[input, window], targets)
+    verify_trace_model(test_fn(3, 3, 3, True, "reflect", False, True), [input, 
window], targets)
+    verify_trace_model(test_fn(3, 3, 3, True, "reflect", True, True), [input, 
window], targets)
+    verify_trace_model(test_fn(3, 3, 3, True, "reflect", False, False), 
[input, window], targets)
+    input = torch.rand([2, 12]).float()
+    window = torch.tensor([2, 3, 4], dtype=torch.int32)
+    verify_trace_model(test_fn(3, 3, 3, False, "reflect", False, True), 
[input, window], targets)
+    window = torch.tensor([1, 3], dtype=torch.int32)
+    verify_trace_model(test_fn(2, 1, 2, False, "reflect", False, True), 
[input, window], targets)
+

Review Comment:
   Please add a test for `window=None` case.



##########
python/tvm/topi/cuda/stft.py:
##########
@@ -0,0 +1,128 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, too-many-arguments, too-many-nested-blocks
+"""STFT operator"""
+from math import pi
+import tvm
+from tvm import te, tir
+from ..utils import ceil_div
+
+
+def stft(
+    data,
+    n_fft,
+    hop_length,
+    win_length,
+    window,
+    normalized,
+    onesided,  # pylint: disable=unused-argument
+    output_shape,
+):
+    """
+    The STFT computes the Fourier transform of short overlapping windows of 
the input.
+    This giving frequency components of the signal as they change over time.
+    Parameters
+    ----------
+    data : relay.Expr
+        Either a 1-D tensor or a 2-D batch tensor.
+    n_fft : int
+        The size of Fourier transform
+    hop_length : int
+        The distance between neighboring sliding window frames
+    win_length : int
+        The size of window frame and STFT filter
+    window : relay.Expr
+        A 1-D tensor window frame
+    normalized : bool
+        Whether to return the normalized STFT results
+    onesided : bool
+        Whether to return onesided result or fill with conjugate symmetry
+    Returns
+    -------
+    output : relay.Expr
+        Tensor containing the STFT result
+    Examples
+    --------
+    .. code-block:: python
+
+        data = [1, 2, 3, 4, 5, 6]
+        window = [4, 3, 2]
+        [n_fft, hop_length, win_length, normalized, onesided] = [3, 3, 3, 
False, True]
+        relay.stft(data, n_fft, hop_length, win_length, window, normalized, 
onesided)
+        -> [[[15.0000,  0.0000], [34.0000,  0.0000]], [[ 4.5000,  0.8660], [ 
1.0000, -1.7321]]]
+    """
+
+    def gen_ir(
+        data_ptr,
+        n_fft,
+        hop_length,
+        win_length,
+        window_ptr,
+        normalized,
+        onesided,  # pylint: disable=unused-argument
+        output_ptr,
+    ):
+        ib = tir.ir_builder.create()
+        data = ib.buffer_ptr(data_ptr)
+        window = ib.buffer_ptr(window_ptr)
+        output = ib.buffer_ptr(output_ptr)
+
+        with ib.new_scope():
+            max_threads = 
int(tvm.target.Target.current(allow_none=False).max_num_threads)
+            nthread_tx = max_threads
+            nthread_bx = ceil_div(output_ptr.shape[0], max_threads)
+            tx = te.thread_axis("threadIdx.x")
+            bx = te.thread_axis("blockIdx.x")
+            ib.scope_attr(tx, "thread_extent", nthread_tx)
+            ib.scope_attr(bx, "thread_extent", nthread_bx)
+            batch = bx * max_threads + tx
+
+            with ib.for_range(0, output_ptr.shape[0]) as batch:
+                with ib.for_range(0, output_ptr.shape[1]) as row:
+                    with ib.for_range(0, output_ptr.shape[2]) as col:

Review Comment:
   This looks weird. You try to parallelize over the batch dim but nothing is 
parallelized.



##########
python/tvm/relay/op/transform.py:
##########
@@ -1823,3 +1823,40 @@ def invert_permutation(data):
         relay.invert_permutation(data) = [2, 4, 3, 0, 1]
     """
     return _make.invert_permutation(data)
+
+
+def stft(data, n_fft, hop_length, win_length, window, normalized, onesided):
+    """
+    The STFT computes the Fourier transform of short overlapping windows of 
the input.
+    This giving frequency components of the signal as they change over time.

Review Comment:
   giving -> gives
   
   Fix the same typo in other files too.



##########
python/tvm/topi/stft.py:
##########
@@ -0,0 +1,121 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, too-many-arguments, too-many-nested-blocks
+"""STFT operator"""
+from math import pi
+from tvm import te, tir
+
+
+def stft(
+    data,
+    n_fft,
+    hop_length,
+    win_length,
+    window,
+    normalized,
+    onesided,  # pylint: disable=unused-argument
+    output_shape,
+):
+    """
+    The STFT computes the Fourier transform of short overlapping windows of 
the input.
+    This giving frequency components of the signal as they change over time.
+    Parameters
+    ----------
+    data : relay.Expr
+        Either a 1-D tensor or a 2-D batch tensor.
+    n_fft : int
+        The size of Fourier transform
+    hop_length : int
+        The distance between neighboring sliding window frames
+    win_length : int
+        The size of window frame and STFT filter
+    window : relay.Expr
+        A 1-D tensor window frame
+    normalized : bool
+        Whether to return the normalized STFT results
+    onesided : bool
+        Whether to return onesided result or fill with conjugate symmetry
+    Returns
+    -------
+    output : relay.Expr
+        Tensor containing the STFT result
+    Examples
+    --------
+    .. code-block:: python
+
+        data = [1, 2, 3, 4, 5, 6]
+        window = [4, 3, 2]
+        [n_fft, hop_length, win_length, normalized, onesided] = [3, 3, 3, 
False, True]
+        relay.stft(data, n_fft, hop_length, win_length, window, normalized, 
onesided)
+        -> [[[15.0000,  0.0000], [34.0000,  0.0000]], [[ 4.5000,  0.8660], [ 
1.0000, -1.7321]]]
+    """
+
+    def gen_ir(
+        data_ptr,
+        n_fft,
+        hop_length,
+        win_length,
+        window_ptr,
+        normalized,
+        onesided,  # pylint: disable=unused-argument
+        output_ptr,
+        loop_kind,
+    ):
+        ib = tir.ir_builder.create()
+        data = ib.buffer_ptr(data_ptr)
+        window = ib.buffer_ptr(window_ptr)
+        output = ib.buffer_ptr(output_ptr)
+
+        with ib.for_range(0, output_ptr.shape[0]) as batch:
+            # 
https://librosa.org/doc/0.7.2/_modules/librosa/core/spectrum.html#stft
+            with ib.for_range(0, output_ptr.shape[1], kind="parallel") as row:
+                with ib.for_range(0, output_ptr.shape[2], kind=loop_kind) as 
col:
+                    output[batch, row, col, 0] = tir.Cast(data_ptr.dtype, 0)
+                    output[batch, row, col, 1] = tir.Cast(data_ptr.dtype, 0)
+                    with ib.for_range(0, win_length) as wlen:
+                        output[batch, row, col, 0] += (
+                            window[wlen]
+                            * data[batch, col * hop_length + wlen]
+                            * tir.cos(2 * pi * row * wlen / win_length)
+                        )
+                        output[batch, row, col, 1] -= (
+                            window[wlen]
+                            * data[batch, col * hop_length + wlen]
+                            * tir.sin(2 * pi * row * wlen / win_length)
+                        )
+                    with ib.if_scope(normalized):
+                        output[batch, row, col, 0] /= 
tir.sqrt(tir.const(n_fft, "float32"))
+                        output[batch, row, col, 1] /= 
tir.sqrt(tir.const(n_fft, "float32"))
+
+        return ib.get()
+
+    output_buf = tir.decl_buffer(output_shape, data.dtype, "output_buf")
+    loop_kind = "vectorize"
+    if hasattr(output_shape[2], "name") and output_shape[2].name == "any_dim":

Review Comment:
   `if isinstance(output_shape[2], tir.expr.Any)`



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to