tkonolige commented on code in PR #12205:
URL: https://github.com/apache/tvm/pull/12205#discussion_r933410603


##########
tests/python/unittest/test_roofline.py:
##########
@@ -0,0 +1,121 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import numpy as np
+import pytest
+from io import StringIO
+import csv
+import os
+import json
+import platform
+
+import tvm.testing
+import tvm.utils
+from tvm.runtime import profiler_vm
+from tvm import relay
+from tvm.relay.testing import mlp
+from tvm.contrib.debugger import debug_executor
+from tvm import rpc
+from tvm.contrib import utils
+from tvm.runtime.profiling import Report
+from tvm.script import tir as T
+
+
[email protected]_targets("llvm", "cuda")
+def test_estimate_peak_flops(target, dev):
+    server = rpc.Server(key="roofline_flops")
+    remote = rpc.connect("127.0.0.1", server.port, key="roofline_flops")
+    dev = remote.device(target)
+    # This test uses vectorized instructions so we need a target that supports 
them
+    if target == "llvm":
+        target = "llvm -mattr=+fma,+avx2"
+    target = tvm.target.Target(target)
+    with target:
+        flops = tvm.utils.roofline.registry.estimate_peak_flops(target, dev, 
remote)
+    if str(target.kind) == "llvm":
+        # Assume we can achieve 1 GFLOP/s per thread, which is 1 FLOP per 
cycle on a 1GHz cpu.
+        assert (
+            flops > 10**9 and flops < 10**14
+        ), f"FLOP/s should be between 10^9 and 10^14, but it is {flops}"
+    elif str(target.kind) == "cuda":
+        # should be able to hit a TFLOP/s with tensor cores
+        assert (
+            flops > 10**12 and flops < 10**14
+        ), f"FLOP/s should be between 10^12 and 10^14, but it is {flops}"
+    else:
+        raise RuntimeError("Unsupported target " + str(target))
+
+
[email protected]_if_32bit(reason="Cannot allocate enough memory on i386")
[email protected]_targets("llvm", "cuda")
+def test_estimate_peak_bandwidth(target, dev):
+    server = rpc.Server(key="roofline_bandwidth")
+    remote = rpc.connect("127.0.0.1", server.port, key="roofline_bandwidth")
+    dev = remote.device(target)
+    # This test uses vectorized instructions so we need a target that supports 
them
+    if target == "llvm":
+        target = "llvm -mattr=+fma,+avx2"
+    target = tvm.target.Target(target)
+    with target:
+        bandwidth = 
tvm.utils.roofline.registry.estimate_peak_bandwidth(target, dev, remote)
+    if str(target.kind) == "llvm":
+        # Assume we can achieve 1 GB/s. DDR2 should transfer somewhere around 6
+        # GB/s, so this should leave enough wiggle room.
+        assert (
+            bandwidth > 10**9 and bandwidth < 10**12
+        ), f"Bandwidth should be between 10^9 and 10^12, but it is {bandwidth}"
+    elif str(target.kind) == "cuda":
+        # should be able to hit a 100 GB/s on a GPU. GTX 280 hits 140 GB/s and
+        # it is really old.
+        assert (
+            bandwidth > 10**11 and bandwidth < 10**13
+        ), f"Bandwidth should be between 10^9 and 10^12, but it is {bandwidth}"
+    else:
+        raise RuntimeError("Unsupported target " + str(target))
+
+
[email protected]_if_32bit(reason="Cannot allocate enough memory on i386")
[email protected]_targets("llvm -mattr=+fma+avx2", "cuda")
+def test_roofline_analysis(target, dev):
+    a = relay.var("a", relay.TensorType((512, 512), "float32"))
+    b = relay.var("b", relay.TensorType((512, 512), "float32"))
+    c = relay.nn.dense(a, b)
+    mod = tvm.IRModule.from_expr(relay.Function([a, b], c))
+    params = {}
+
+    server = rpc.Server(key="roofline")
+    remote = rpc.connect("127.0.0.1", server.port, key="roofline")
+    dev = remote.device(target)
+
+    report = tvm.utils.roofline_analysis(mod, params, target, dev, 
remote=remote)
+    print(report)

Review Comment:
   This print has been pretty helpful in debugging testing issues. I'd like to 
leave it in.



##########
python/tvm/utils/roofline/cuda.py:
##########
@@ -0,0 +1,232 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Estimation of peak flops and memory bandwidth for cuda devices"""
+from typing import Optional
+from ...script import tir as T
+from ... import nd, build, transform
+from ...runtime import Device
+from ...target import Target
+from ...rpc.base import RPC_SESS_MASK
+from ...rpc.client import RPCSession
+from . import registry
+from ...contrib import utils
+
+
[email protected]_peak_flops.register("cuda")
+def estimate_peak_flops_tensorcore(
+    target: Target,
+    dev: Device,
+    remote: Optional[RPCSession],
+    mat_dtype: str = "float16",
+    acc_dtype: str = "float32",
+) -> float:
+    """Estimate the peak FLOP/s of a cuda device with tensorcores.
+
+    This estimate should only be used to compare with operators that can use
+    dense tensorcore mma instructions.
+
+    References
+    ----------
+    Wei Sun, Ang Li, Tong Geng, Sander Stuijk, Henk Corporaal: "Dissecting
+    Tensor Cores via Microbenchmarks: Latency, Throughput and Numerical
+    Behaviors", 2022; http://arxiv.org/abs/2206.02874
+    
https://www.nvidia.com/content/PDF/nvidia-ampere-ga-102-gpu-architecture-whitepaper-v2.1.pdf
+
+    Parameters
+    ----------
+    target : Target
+        Target to run on. This should be as specific to the actual hardware as
+        possible to make sure that LLVM generates the best vector code.
+    dev : Device
+        Device to run on.
+    remote : Optional[RPCSession]
+      Remote session used to upload artifacts for runtime evaluation. Must be
+      the same session used to create `dev`.
+    mat_dtype : str
+        Dtype of matrices passed to mma instructions.
+    acc_dtype : str
+        Dtype of accumulator to use with mma instructions. Should be compatible
+        with `mat_dtype`.
+
+    Returns
+    -------
+    float
+        Approximate sustained FLOP/s of this target/device combo assuming
+        mma instructions. Addition and multiplications are each counted as
+        separate FLOPs.
+    """
+    assert str(target.kind) == "cuda", "Only CUDA devices have tensorcores"
+
+    @T.prim_func
+    def peak_flops_tensorcore_tir(
+        inp: T.Buffer((16, 16), mat_dtype),
+        out: T.Buffer((16, 16), acc_dtype),
+        n: T.int32,
+        sms: T.int32,
+    ):
+        # pylint: disable=invalid-name, missing-function-docstring
+        A = T.alloc_buffer((16, 16), dtype=mat_dtype, scope="wmma.matrix_a")
+        B = T.alloc_buffer((16, 16), dtype=mat_dtype, scope="wmma.matrix_b")
+        C = T.alloc_buffer((16, 16), dtype=acc_dtype, scope="wmma.accumulator")
+        for _ in T.thread_binding(sms, thread="blockIdx.x"):
+            for _ in T.thread_binding(
+                8, thread="threadIdx.y"
+            ):  # need 8 warps to get enough in-SM parallelism
+                for _ in T.thread_binding(32, thread="threadIdx.x"):
+                    T.evaluate(
+                        T.tvm_load_matrix_sync(
+                            A.data,
+                            16,
+                            16,
+                            16,
+                            0,
+                            T.tvm_access_ptr(
+                                T.type_annotation(dtype=mat_dtype),
+                                inp.data,
+                                0,
+                                16,
+                                1,
+                                dtype="handle",
+                            ),
+                            16,
+                            "row_major",
+                            dtype="handle",
+                        )
+                    )
+                    T.evaluate(T.tvm_fill_fragment(B.data, 16, 16, 16, 0, 0, 
dtype="handle"))
+                    T.evaluate(T.tvm_fill_fragment(C.data, 16, 16, 16, 0, 0, 
dtype="handle"))
+                    for _ in range(n):
+                        T.evaluate(
+                            T.tvm_mma_sync(
+                                C.data, 0, A.data, 0, B.data, 0, C.data, 0, 
dtype="handle"
+                            )
+                        )
+                    T.evaluate(
+                        T.tvm_store_matrix_sync(
+                            C.data,
+                            16,
+                            16,
+                            16,
+                            0,
+                            T.tvm_access_ptr(
+                                T.type_annotation(dtype=acc_dtype),
+                                out.data,
+                                0,
+                                16,
+                                2,
+                                dtype="handle",
+                            ),
+                            16,
+                            "row_major",
+                            dtype="handle",
+                        )
+                    )
+
+    n = 100000
+    sms = dev.multi_processor_count
+    specialized = peak_flops_tensorcore_tir.specialize(
+        {peak_flops_tensorcore_tir.params[2]: n, 
peak_flops_tensorcore_tir.params[3]: sms}
+    )
+    with transform.PassContext(opt_level=3):
+        f = build(specialized, target=target)
+
+    # upload to remote if running over rpc
+    if dev.device_type >= RPC_SESS_MASK:
+        if remote is None:
+            raise RuntimeError("A RPCSession must be provided when using a 
remote device.")
+        temp = utils.tempdir()
+        path = temp.relpath("peak_fma_flops.tar")
+        f.export_library(path)
+        remote.upload(path)
+        f = remote.load_module("peak_fma_flops.tar")
+
+    x = nd.empty((16, 16), dtype=mat_dtype, device=dev)
+    y = nd.empty((16, 16), dtype=acc_dtype, device=dev)
+    times = f.time_evaluator(f.entry_name, dev, repeat=10, number=1)(x, y)
+    # each mma operation computes 16 x 16 x 16 FLOPs
+    return n * 16 * 16 * 16 * 2 * sms * 8 / times.min
+
+
[email protected]_func
+def peak_bandwidth_tir(a: T.handle, b: T.handle, blocks: T.int32, warp_size: 
T.int32) -> None:
+    # pylint: disable=invalid-name, missing-function-docstring
+    N = T.var("int32")
+    A = T.match_buffer(a, [blocks, N, 4, warp_size], "float32")
+    B = T.match_buffer(b, [blocks, warp_size, 4], "float32")
+    for i in T.thread_binding(blocks, "blockIdx.x"):
+        for k in T.serial(N):
+            for l in T.unroll(4):
+                # vectorized load is necessary to hit peak bandwidth
+                for j in T.thread_binding(warp_size, "threadIdx.x"):
+                    # += is necessary to introduce a data dependency for all
+                    # elements of A, preventing the backend from removing the
+                    # `k` loop and setting `k` to the loop extent.
+                    B[i, j, l] += A[i, k, l, j]
+
+
[email protected]_peak_bandwidth.register("cuda")
+def estimate_peak_bandwidth(
+    target: Target,
+    dev: Device,
+    remote: Optional[RPCSession] = None,
+) -> float:
+    """Estimate peak memory bandwidth of a target/device combo.
+
+    Peak bandwidth is estimated by running a small experiment on the underlying
+    hardware. The peak bandwidth measurement assumes that vector instructions
+    are being used to load the data.
+
+    Parameters
+    ----------
+    target : Target
+        Target to use for measurement. This target should be as specific to the
+        underlying hardware as possible.
+    dev : Device
+        Device to measure peak bandwidth on.
+    remote : Optional[RPCSession]
+      Remote session used to upload artifacts for runtime evaluation. Must be
+      the same session used to create `dev`.
+
+    Returns
+    -------
+    float
+        Peak memory bandwidth in bytes/seconds.
+    """
+    blocks = 1024

Review Comment:
   Pretty much. For blocks we just need enough to make sure that enough threads 
are fetching memory at the same time.



##########
python/tvm/utils/roofline/cuda.py:
##########
@@ -0,0 +1,232 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Estimation of peak flops and memory bandwidth for cuda devices"""
+from typing import Optional
+from ...script import tir as T
+from ... import nd, build, transform
+from ...runtime import Device
+from ...target import Target
+from ...rpc.base import RPC_SESS_MASK
+from ...rpc.client import RPCSession
+from . import registry
+from ...contrib import utils
+
+
[email protected]_peak_flops.register("cuda")
+def estimate_peak_flops_tensorcore(
+    target: Target,
+    dev: Device,
+    remote: Optional[RPCSession],
+    mat_dtype: str = "float16",
+    acc_dtype: str = "float32",
+) -> float:
+    """Estimate the peak FLOP/s of a cuda device with tensorcores.
+
+    This estimate should only be used to compare with operators that can use
+    dense tensorcore mma instructions.
+
+    References
+    ----------
+    Wei Sun, Ang Li, Tong Geng, Sander Stuijk, Henk Corporaal: "Dissecting
+    Tensor Cores via Microbenchmarks: Latency, Throughput and Numerical
+    Behaviors", 2022; http://arxiv.org/abs/2206.02874
+    
https://www.nvidia.com/content/PDF/nvidia-ampere-ga-102-gpu-architecture-whitepaper-v2.1.pdf
+
+    Parameters
+    ----------
+    target : Target
+        Target to run on. This should be as specific to the actual hardware as
+        possible to make sure that LLVM generates the best vector code.
+    dev : Device
+        Device to run on.
+    remote : Optional[RPCSession]
+      Remote session used to upload artifacts for runtime evaluation. Must be
+      the same session used to create `dev`.
+    mat_dtype : str
+        Dtype of matrices passed to mma instructions.
+    acc_dtype : str
+        Dtype of accumulator to use with mma instructions. Should be compatible
+        with `mat_dtype`.
+
+    Returns
+    -------
+    float
+        Approximate sustained FLOP/s of this target/device combo assuming
+        mma instructions. Addition and multiplications are each counted as
+        separate FLOPs.
+    """
+    assert str(target.kind) == "cuda", "Only CUDA devices have tensorcores"
+
+    @T.prim_func
+    def peak_flops_tensorcore_tir(
+        inp: T.Buffer((16, 16), mat_dtype),
+        out: T.Buffer((16, 16), acc_dtype),
+        n: T.int32,
+        sms: T.int32,
+    ):
+        # pylint: disable=invalid-name, missing-function-docstring
+        A = T.alloc_buffer((16, 16), dtype=mat_dtype, scope="wmma.matrix_a")
+        B = T.alloc_buffer((16, 16), dtype=mat_dtype, scope="wmma.matrix_b")
+        C = T.alloc_buffer((16, 16), dtype=acc_dtype, scope="wmma.accumulator")
+        for _ in T.thread_binding(sms, thread="blockIdx.x"):
+            for _ in T.thread_binding(
+                8, thread="threadIdx.y"
+            ):  # need 8 warps to get enough in-SM parallelism
+                for _ in T.thread_binding(32, thread="threadIdx.x"):
+                    T.evaluate(
+                        T.tvm_load_matrix_sync(
+                            A.data,
+                            16,
+                            16,
+                            16,
+                            0,
+                            T.tvm_access_ptr(
+                                T.type_annotation(dtype=mat_dtype),
+                                inp.data,
+                                0,
+                                16,
+                                1,
+                                dtype="handle",
+                            ),
+                            16,
+                            "row_major",
+                            dtype="handle",
+                        )
+                    )
+                    T.evaluate(T.tvm_fill_fragment(B.data, 16, 16, 16, 0, 0, 
dtype="handle"))
+                    T.evaluate(T.tvm_fill_fragment(C.data, 16, 16, 16, 0, 0, 
dtype="handle"))
+                    for _ in range(n):
+                        T.evaluate(
+                            T.tvm_mma_sync(
+                                C.data, 0, A.data, 0, B.data, 0, C.data, 0, 
dtype="handle"
+                            )
+                        )
+                    T.evaluate(
+                        T.tvm_store_matrix_sync(
+                            C.data,
+                            16,
+                            16,
+                            16,
+                            0,
+                            T.tvm_access_ptr(
+                                T.type_annotation(dtype=acc_dtype),
+                                out.data,
+                                0,
+                                16,
+                                2,
+                                dtype="handle",
+                            ),
+                            16,
+                            "row_major",
+                            dtype="handle",
+                        )
+                    )
+
+    n = 100000
+    sms = dev.multi_processor_count
+    specialized = peak_flops_tensorcore_tir.specialize(
+        {peak_flops_tensorcore_tir.params[2]: n, 
peak_flops_tensorcore_tir.params[3]: sms}
+    )
+    with transform.PassContext(opt_level=3):
+        f = build(specialized, target=target)
+
+    # upload to remote if running over rpc
+    if dev.device_type >= RPC_SESS_MASK:
+        if remote is None:
+            raise RuntimeError("A RPCSession must be provided when using a 
remote device.")
+        temp = utils.tempdir()
+        path = temp.relpath("peak_fma_flops.tar")
+        f.export_library(path)
+        remote.upload(path)
+        f = remote.load_module("peak_fma_flops.tar")
+
+    x = nd.empty((16, 16), dtype=mat_dtype, device=dev)
+    y = nd.empty((16, 16), dtype=acc_dtype, device=dev)
+    times = f.time_evaluator(f.entry_name, dev, repeat=10, number=1)(x, y)
+    # each mma operation computes 16 x 16 x 16 FLOPs
+    return n * 16 * 16 * 16 * 2 * sms * 8 / times.min
+
+
[email protected]_func
+def peak_bandwidth_tir(a: T.handle, b: T.handle, blocks: T.int32, warp_size: 
T.int32) -> None:
+    # pylint: disable=invalid-name, missing-function-docstring
+    N = T.var("int32")
+    A = T.match_buffer(a, [blocks, N, 4, warp_size], "float32")
+    B = T.match_buffer(b, [blocks, warp_size, 4], "float32")
+    for i in T.thread_binding(blocks, "blockIdx.x"):
+        for k in T.serial(N):
+            for l in T.unroll(4):
+                # vectorized load is necessary to hit peak bandwidth
+                for j in T.thread_binding(warp_size, "threadIdx.x"):
+                    # += is necessary to introduce a data dependency for all
+                    # elements of A, preventing the backend from removing the
+                    # `k` loop and setting `k` to the loop extent.
+                    B[i, j, l] += A[i, k, l, j]
+
+
[email protected]_peak_bandwidth.register("cuda")
+def estimate_peak_bandwidth(
+    target: Target,
+    dev: Device,
+    remote: Optional[RPCSession] = None,
+) -> float:
+    """Estimate peak memory bandwidth of a target/device combo.
+
+    Peak bandwidth is estimated by running a small experiment on the underlying
+    hardware. The peak bandwidth measurement assumes that vector instructions
+    are being used to load the data.
+
+    Parameters
+    ----------
+    target : Target
+        Target to use for measurement. This target should be as specific to the
+        underlying hardware as possible.
+    dev : Device
+        Device to measure peak bandwidth on.
+    remote : Optional[RPCSession]
+      Remote session used to upload artifacts for runtime evaluation. Must be
+      the same session used to create `dev`.
+
+    Returns
+    -------
+    float
+        Peak memory bandwidth in bytes/seconds.
+    """
+    blocks = 1024

Review Comment:
   yep!



##########
python/tvm/utils/roofline/cuda.py:
##########
@@ -0,0 +1,232 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Estimation of peak flops and memory bandwidth for cuda devices"""
+from typing import Optional
+from ...script import tir as T
+from ... import nd, build, transform
+from ...runtime import Device
+from ...target import Target
+from ...rpc.base import RPC_SESS_MASK
+from ...rpc.client import RPCSession
+from . import registry
+from ...contrib import utils
+
+
[email protected]_peak_flops.register("cuda")
+def estimate_peak_flops_tensorcore(
+    target: Target,
+    dev: Device,
+    remote: Optional[RPCSession],
+    mat_dtype: str = "float16",
+    acc_dtype: str = "float32",
+) -> float:
+    """Estimate the peak FLOP/s of a cuda device with tensorcores.
+
+    This estimate should only be used to compare with operators that can use
+    dense tensorcore mma instructions.
+
+    References
+    ----------
+    Wei Sun, Ang Li, Tong Geng, Sander Stuijk, Henk Corporaal: "Dissecting
+    Tensor Cores via Microbenchmarks: Latency, Throughput and Numerical
+    Behaviors", 2022; http://arxiv.org/abs/2206.02874
+    
https://www.nvidia.com/content/PDF/nvidia-ampere-ga-102-gpu-architecture-whitepaper-v2.1.pdf
+
+    Parameters
+    ----------
+    target : Target
+        Target to run on. This should be as specific to the actual hardware as
+        possible to make sure that LLVM generates the best vector code.
+    dev : Device
+        Device to run on.
+    remote : Optional[RPCSession]
+      Remote session used to upload artifacts for runtime evaluation. Must be
+      the same session used to create `dev`.
+    mat_dtype : str
+        Dtype of matrices passed to mma instructions.
+    acc_dtype : str
+        Dtype of accumulator to use with mma instructions. Should be compatible
+        with `mat_dtype`.
+
+    Returns
+    -------
+    float
+        Approximate sustained FLOP/s of this target/device combo assuming
+        mma instructions. Addition and multiplications are each counted as
+        separate FLOPs.
+    """
+    assert str(target.kind) == "cuda", "Only CUDA devices have tensorcores"
+
+    @T.prim_func
+    def peak_flops_tensorcore_tir(
+        inp: T.Buffer((16, 16), mat_dtype),
+        out: T.Buffer((16, 16), acc_dtype),
+        n: T.int32,
+        sms: T.int32,
+    ):
+        # pylint: disable=invalid-name, missing-function-docstring
+        A = T.alloc_buffer((16, 16), dtype=mat_dtype, scope="wmma.matrix_a")
+        B = T.alloc_buffer((16, 16), dtype=mat_dtype, scope="wmma.matrix_b")
+        C = T.alloc_buffer((16, 16), dtype=acc_dtype, scope="wmma.accumulator")
+        for _ in T.thread_binding(sms, thread="blockIdx.x"):
+            for _ in T.thread_binding(
+                8, thread="threadIdx.y"
+            ):  # need 8 warps to get enough in-SM parallelism
+                for _ in T.thread_binding(32, thread="threadIdx.x"):
+                    T.evaluate(
+                        T.tvm_load_matrix_sync(
+                            A.data,
+                            16,
+                            16,
+                            16,
+                            0,
+                            T.tvm_access_ptr(
+                                T.type_annotation(dtype=mat_dtype),
+                                inp.data,
+                                0,
+                                16,
+                                1,
+                                dtype="handle",
+                            ),
+                            16,
+                            "row_major",
+                            dtype="handle",
+                        )
+                    )
+                    T.evaluate(T.tvm_fill_fragment(B.data, 16, 16, 16, 0, 0, 
dtype="handle"))
+                    T.evaluate(T.tvm_fill_fragment(C.data, 16, 16, 16, 0, 0, 
dtype="handle"))
+                    for _ in range(n):
+                        T.evaluate(
+                            T.tvm_mma_sync(
+                                C.data, 0, A.data, 0, B.data, 0, C.data, 0, 
dtype="handle"
+                            )
+                        )
+                    T.evaluate(
+                        T.tvm_store_matrix_sync(
+                            C.data,
+                            16,
+                            16,
+                            16,
+                            0,
+                            T.tvm_access_ptr(
+                                T.type_annotation(dtype=acc_dtype),
+                                out.data,
+                                0,
+                                16,
+                                2,
+                                dtype="handle",
+                            ),
+                            16,
+                            "row_major",
+                            dtype="handle",
+                        )
+                    )
+
+    n = 100000
+    sms = dev.multi_processor_count
+    specialized = peak_flops_tensorcore_tir.specialize(
+        {peak_flops_tensorcore_tir.params[2]: n, 
peak_flops_tensorcore_tir.params[3]: sms}
+    )
+    with transform.PassContext(opt_level=3):
+        f = build(specialized, target=target)
+
+    # upload to remote if running over rpc
+    if dev.device_type >= RPC_SESS_MASK:
+        if remote is None:
+            raise RuntimeError("A RPCSession must be provided when using a 
remote device.")
+        temp = utils.tempdir()
+        path = temp.relpath("peak_fma_flops.tar")
+        f.export_library(path)
+        remote.upload(path)
+        f = remote.load_module("peak_fma_flops.tar")
+
+    x = nd.empty((16, 16), dtype=mat_dtype, device=dev)
+    y = nd.empty((16, 16), dtype=acc_dtype, device=dev)
+    times = f.time_evaluator(f.entry_name, dev, repeat=10, number=1)(x, y)
+    # each mma operation computes 16 x 16 x 16 FLOPs
+    return n * 16 * 16 * 16 * 2 * sms * 8 / times.min
+
+
[email protected]_func
+def peak_bandwidth_tir(a: T.handle, b: T.handle, blocks: T.int32, warp_size: 
T.int32) -> None:
+    # pylint: disable=invalid-name, missing-function-docstring
+    N = T.var("int32")
+    A = T.match_buffer(a, [blocks, N, 4, warp_size], "float32")
+    B = T.match_buffer(b, [blocks, warp_size, 4], "float32")
+    for i in T.thread_binding(blocks, "blockIdx.x"):
+        for k in T.serial(N):
+            for l in T.unroll(4):
+                # vectorized load is necessary to hit peak bandwidth
+                for j in T.thread_binding(warp_size, "threadIdx.x"):
+                    # += is necessary to introduce a data dependency for all
+                    # elements of A, preventing the backend from removing the
+                    # `k` loop and setting `k` to the loop extent.
+                    B[i, j, l] += A[i, k, l, j]

Review Comment:
   `B[i, l, j] += A[i, k, l, j]` is definitely better. I've switched to that. 
The funny thing is the compiler is smart enough to see through `B[i, j, l] += 
A[i, k, l, j]` and changes it to be just as good.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to