vvchernov commented on code in PR #13999:
URL: https://github.com/apache/tvm/pull/13999#discussion_r1114091599
##########
python/tvm/relay/frontend/onnx.py:
##########
@@ -4877,6 +4877,122 @@ def _impl_v1(cls, inputs, attr, params):
return mm_out
+class DFT(OnnxOpConverter):
+ """Operator converter for discrete Fourier transform (DFT)."""
+
+ @classmethod
+ def _impl_v17(cls, inputs, attr, params):
+ # ************************* Read attrs *************************
+ axis = attr.get("axis")
+ inverse = attr.get("inverse")
+ onesided = attr.get("onesided")
+
+ # ************************* Read inputs ************************
+ input_tensor = inputs[0]
+ dft_length = inputs[1]
+
+ # ************************* Parse inputs ***********************
+ t1 = ["float16", "float32", "float64"]
+ t2 = ["int32", "int64"]
+
+ # input
+ assert infer_type(input_tensor).checked_type.dtype in t1
+ input_shape = infer_shape(input_tensor)
+ assert len(input_shape) >= 3
+ if axis < 0:
+ axis = len(input_shape) - axis
Review Comment:
`axis = len(input_shape) + axis`. Need **+** due to axis is already negative.
Possibly need test for check negative axis
##########
python/tvm/relay/frontend/onnx.py:
##########
@@ -4877,6 +4877,122 @@ def _impl_v1(cls, inputs, attr, params):
return mm_out
+class DFT(OnnxOpConverter):
+ """Operator converter for discrete Fourier transform (DFT)."""
+
+ @classmethod
+ def _impl_v17(cls, inputs, attr, params):
+ # ************************* Read attrs *************************
+ axis = attr.get("axis")
+ inverse = attr.get("inverse")
+ onesided = attr.get("onesided")
+
+ # ************************* Read inputs ************************
+ input_tensor = inputs[0]
+ dft_length = inputs[1]
+
+ # ************************* Parse inputs ***********************
+ t1 = ["float16", "float32", "float64"]
+ t2 = ["int32", "int64"]
+
+ # input
+ assert infer_type(input_tensor).checked_type.dtype in t1
+ input_shape = infer_shape(input_tensor)
+ assert len(input_shape) >= 3
+ if axis < 0:
+ axis = len(input_shape) - axis
+ assert 1 <= axis <= len(input_shape) - 1
Review Comment:
Add description like "Axis is out of bounds"
##########
python/tvm/topi/cuda/stft.py:
##########
@@ -133,3 +133,98 @@ def gen_ir(
name="stft_cuda",
tag="stft_cuda",
)
+
+
+def dft(
+ re_data: te.Tensor,
+ im_data: te.Tensor,
+ inverse: tir.IntImm,
+):
+ """
+ Computes the discrete Fourier transform of input (calculation along the
last axis).
+ This gives frequency components of the signal as they change over time.
+
+ Parameters
+ ----------
+ re_data : relay.Expr
+ N-D tensor, real part of the input signal.
+
+ im_data : relay.Expr
+ N-D tensor, imaginary part of the input signal.
+ If the signal is real, then the values of this tensor are zeros.
+
+ inverse : bool
+ Whether to perform the inverse discrete fourier transform.
+
+ Returns
+ -------
+ re_output : relay.Expr
+ The Fourier Transform of the input (Real part).
+ im_output : relay.Expr
+ The Fourier Transform of the input (Imaginary part).
+ """
+
+ def gen_ir(
+ re_data_buf,
+ im_data_buf,
+ re_output_buf,
+ im_output_buf,
+ ):
+ ib = tir.ir_builder.create()
+ re_data_ptr = ib.buffer_ptr(re_data_buf)
+ im_data_ptr = ib.buffer_ptr(im_data_buf)
+ re_output_ptr = ib.buffer_ptr(re_output_buf)
+ im_output_ptr = ib.buffer_ptr(im_output_buf)
+
+ shape = re_data.shape
+ n_fft = shape[len(shape) - 1]
+ base_range = 1
+ for i in range(len(shape) - 1):
+ base_range *= shape[i]
+
+ sign = -1 if inverse else 1
+ factor = 1.0 / n_fft if inverse else 1.0
+
+ max_threads = _get_max_threads(base_range)
+ with ib.new_scope():
+ nthread_tx = max_threads
+ nthread_bx = ceil_div(base_range, max_threads)
+ tx = te.thread_axis("threadIdx.x")
+ bx = te.thread_axis("blockIdx.x")
+ ib.scope_attr(tx, "thread_extent", nthread_tx)
+ ib.scope_attr(bx, "thread_extent", nthread_bx)
+
+ tid = bx * max_threads + tx
+ with ib.if_scope(tid < base_range):
+ base_idx = tid * n_fft
+ with ib.for_range(0, n_fft) as n:
+ n_idx = base_idx + n
+ re_output_ptr[n_idx] = tir.Cast(re_output_ptr.dtype, 0)
+ im_output_ptr[n_idx] = tir.Cast(im_output_ptr.dtype, 0)
+ with ib.for_range(0, n_fft) as k:
+ k_idx = base_idx + k
+ w = sign * -2 * pi * k * n / n_fft
Review Comment:
The same for generic implementation
##########
python/tvm/relay/frontend/onnx.py:
##########
@@ -4877,6 +4877,122 @@ def _impl_v1(cls, inputs, attr, params):
return mm_out
+class DFT(OnnxOpConverter):
+ """Operator converter for discrete Fourier transform (DFT)."""
+
+ @classmethod
+ def _impl_v17(cls, inputs, attr, params):
+ # ************************* Read attrs *************************
+ axis = attr.get("axis")
+ inverse = attr.get("inverse")
+ onesided = attr.get("onesided")
+
+ # ************************* Read inputs ************************
+ input_tensor = inputs[0]
+ dft_length = inputs[1]
+
+ # ************************* Parse inputs ***********************
+ t1 = ["float16", "float32", "float64"]
+ t2 = ["int32", "int64"]
+
+ # input
+ assert infer_type(input_tensor).checked_type.dtype in t1
+ input_shape = infer_shape(input_tensor)
+ assert len(input_shape) >= 3
+ if axis < 0:
+ axis = len(input_shape) - axis
+ assert 1 <= axis <= len(input_shape) - 1
+
+ # dft_length
+ if dft_length is None:
+ dft_length = input_shape[axis]
+ else:
+ dft_length_dtype = infer_type(dft_length).checked_type.dtype
+ assert dft_length_dtype in t2
+ dft_length = int(infer_value(dft_length, params).numpy())
+
+ # ************************
+ input_tensor = cls._maybe_crop_or_pad(input_tensor, axis, dft_length)
+
+ swap_axis = -1
+ re_input_tensor, im_input_tensor =
cls._split_real_and_imag_parts(input_tensor)
+
+ re_input_tensor = cls._swap_axes(re_input_tensor, axis, swap_axis)
+ im_input_tensor = cls._swap_axes(im_input_tensor, axis, swap_axis)
+
+ re_input_tensor, im_input_tensor = _op.dft(re_input_tensor,
im_input_tensor, inverse)
+
+ re_input_tensor = cls._swap_axes(re_input_tensor, axis, swap_axis)
+ im_input_tensor = cls._swap_axes(im_input_tensor, axis, swap_axis)
+
+ if onesided:
+ re_input_tensor = cls._crop_onesided(re_input_tensor, axis)
+ im_input_tensor = cls._crop_onesided(im_input_tensor, axis)
+
+ output = cls._merge_real_and_imag_parts(re_input_tensor,
im_input_tensor)
+
+ return output
Review Comment:
output variable is not necessary here
##########
python/tvm/topi/cuda/stft.py:
##########
@@ -133,3 +133,98 @@ def gen_ir(
name="stft_cuda",
tag="stft_cuda",
)
+
+
+def dft(
+ re_data: te.Tensor,
+ im_data: te.Tensor,
+ inverse: tir.IntImm,
+):
+ """
+ Computes the discrete Fourier transform of input (calculation along the
last axis).
+ This gives frequency components of the signal as they change over time.
+
+ Parameters
+ ----------
+ re_data : relay.Expr
+ N-D tensor, real part of the input signal.
+
+ im_data : relay.Expr
+ N-D tensor, imaginary part of the input signal.
+ If the signal is real, then the values of this tensor are zeros.
+
+ inverse : bool
+ Whether to perform the inverse discrete fourier transform.
+
+ Returns
+ -------
+ re_output : relay.Expr
+ The Fourier Transform of the input (Real part).
+ im_output : relay.Expr
+ The Fourier Transform of the input (Imaginary part).
+ """
+
+ def gen_ir(
+ re_data_buf,
+ im_data_buf,
+ re_output_buf,
+ im_output_buf,
+ ):
+ ib = tir.ir_builder.create()
+ re_data_ptr = ib.buffer_ptr(re_data_buf)
+ im_data_ptr = ib.buffer_ptr(im_data_buf)
+ re_output_ptr = ib.buffer_ptr(re_output_buf)
+ im_output_ptr = ib.buffer_ptr(im_output_buf)
+
+ shape = re_data.shape
+ n_fft = shape[len(shape) - 1]
+ base_range = 1
+ for i in range(len(shape) - 1):
+ base_range *= shape[i]
+
+ sign = -1 if inverse else 1
+ factor = 1.0 / n_fft if inverse else 1.0
+
+ max_threads = _get_max_threads(base_range)
+ with ib.new_scope():
+ nthread_tx = max_threads
+ nthread_bx = ceil_div(base_range, max_threads)
+ tx = te.thread_axis("threadIdx.x")
+ bx = te.thread_axis("blockIdx.x")
+ ib.scope_attr(tx, "thread_extent", nthread_tx)
+ ib.scope_attr(bx, "thread_extent", nthread_bx)
+
+ tid = bx * max_threads + tx
+ with ib.if_scope(tid < base_range):
+ base_idx = tid * n_fft
+ with ib.for_range(0, n_fft) as n:
+ n_idx = base_idx + n
+ re_output_ptr[n_idx] = tir.Cast(re_output_ptr.dtype, 0)
+ im_output_ptr[n_idx] = tir.Cast(im_output_ptr.dtype, 0)
+ with ib.for_range(0, n_fft) as k:
+ k_idx = base_idx + k
+ w = sign * -2 * pi * k * n / n_fft
Review Comment:
looks like ratio `sign * -2 * pi / n_fft` could be calculated earlier and
once
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]