AndrewZhaoLuo commented on code in PR #13003:
URL: https://github.com/apache/tvm/pull/13003#discussion_r995179709


##########
python/tvm/utils/roofline/cuda.py:
##########
@@ -161,6 +165,51 @@ def peak_flops_tensorcore_tir(
     return n * 16 * 16 * 16 * 2 * sms * 8 / times.min
 
 
[email protected]_peak_flops.register("cuda")
+def estimate_peak_flops(
+    func: PrimFunc,  # pylint: disable=unused-argument
+    features: Dict[str, np.ndarray],
+    target: Target,
+    dev: Device,
+    remote: Optional[RPCSession],
+) -> Tuple[float, float, str]:
+    """Estimate the peak FLOP/s of a cuda device.
+
+    Parameters
+    ----------
+    func : PrimFunc
+        Function to estimate peak flops for. Used to check if a specific kind
+        intrinsic or dtype could be used with this function.
+    features : Dict[str, np.ndarry]
+        Features extracted from `func`. Used to check if a specific kind
+        intrinsic or dtype could be used with this function.
+    target : Target
+        Target to run on. This should be as specific to the actual hardware as
+        possible.
+    dev : Device
+        Device to run on.
+    remote : Optional[RPCSession]
+      Remote session used to upload artifacts for runtime evaluation. Must be
+      the same session used to create `dev`.
+
+    Returns
+    -------
+    flops : float
+        Estimated number of flops used by `func`.
+    peak_flops : float
+        Approximate sustained FLOP/s of this target/device combo. Addition and
+        multiplications are each counted as separate FLOPs.
+    name : str
+        Dtype/intrinsic used by `func` to achieve peak flops.
+    """
+    assert nvcc.have_tensorcore(
+        dev.compute_version
+    ), "CUDA roofline only works with devices that have tensorcores"
+    flops = np.sum(features["float_addsub"] + features["float_mul"] + 
features["float_mad"])

Review Comment:
   do madd count as 2 FLOPS?



##########
python/tvm/utils/roofline/x86.py:
##########
@@ -155,12 +134,72 @@ def estimate_peak_fma_flops(
         random_fill = get_global_func("tvm.contrib.random.random_fill")
     assert random_fill, "Please make sure USE_RANDOM is ON in config.cmake"
 
-    a = nd.empty((nthreads, num_vector_registers, vec_width), dtype="float32", 
device=dev)
+    a = nd.empty((nthreads, num_vector_registers, vec_width), dtype=dtype, 
device=dev)
     random_fill(a)
     times = f.time_evaluator(f.entry_name, dev, repeat=100, number=1)(a)
     flops = 2 * vec_width * num_vector_registers * nthreads * iters  # fma is 
two flops
-    flop_s = flops / times.min
-    return flop_s
+    return flops / times.min
+
+
[email protected]_peak_flops.register("cpu")
+def estimate_peak_fma_flops(
+    func: PrimFunc,
+    features: Dict[str, np.ndarray],
+    target: Target,
+    dev: Device,
+    remote: Optional[RPCSession],
+    vec_width: Optional[int] = None,
+    num_vector_registers: Optional[int] = None,
+) -> Tuple[float, float, str]:
+    """
+    Estimate the maximum number of FLOP/s this target/device combo is capable
+    of reaching by running a test program. This assumes vectorized FMA
+    (fused-multiply-add) instructions.
+
+
+    Parameters
+    ----------
+    func : PrimFunc
+        Function to estimate peak flops for. Used to check if a specific kind
+        intrinsic or dtype could be used with this function.
+    features : Dict[str, np.ndarry]
+        Features extracted from `func`. Used to check if a specific kind
+        intrinsic or dtype could be used with this function.
+    target : Target
+        Target to run on. This should be as specific to the actual hardware as
+        possible to make sure that LLVM generates the best vector code.
+    dev : Device
+        Device to run on.
+    remote : Optional[RPCSession]
+      Remote session used to upload artifacts for runtime evaluation. Must be
+      the same session used to create `dev`.
+    vec_width : Optional[int]
+        Vector width of SIMD units on the underlying hardware. Will try to
+        infer if no value is provided.
+    num_vector_registers : Optional[int]
+        Number of vector registers on the underlying hardware. Will try to
+        infer if no value is provided.
+
+    Returns
+    -------
+    flops : float
+        Estimated number of flops used by `func`.
+    peak_flops : float
+        Approximate sustained FLOP/s of this target/device combo assuming
+        vectorized FMA instructions. Each FMA operation counts as two FLOPs.
+    name : str
+        Dtype/intrinsic used by `func` to achieve peak flops.
+    """
+    # assume that the first argument's dtype is the one we want
+    dtype = list(func.buffer_map.values())[0].dtype
+    if "int" in dtype:
+        flops = np.sum(features["int_addsub"] + features["int_mul"] + 
features["int_mad"])

Review Comment:
   should we count mad as 2 flops here, we seem to count fma above as 2.



##########
python/tvm/utils/roofline/cuda.py:
##########
@@ -161,6 +165,51 @@ def peak_flops_tensorcore_tir(
     return n * 16 * 16 * 16 * 2 * sms * 8 / times.min
 
 
[email protected]_peak_flops.register("cuda")
+def estimate_peak_flops(
+    func: PrimFunc,  # pylint: disable=unused-argument
+    features: Dict[str, np.ndarray],
+    target: Target,
+    dev: Device,
+    remote: Optional[RPCSession],
+) -> Tuple[float, float, str]:
+    """Estimate the peak FLOP/s of a cuda device.
+
+    Parameters
+    ----------
+    func : PrimFunc
+        Function to estimate peak flops for. Used to check if a specific kind
+        intrinsic or dtype could be used with this function.
+    features : Dict[str, np.ndarry]
+        Features extracted from `func`. Used to check if a specific kind
+        intrinsic or dtype could be used with this function.
+    target : Target
+        Target to run on. This should be as specific to the actual hardware as
+        possible.
+    dev : Device
+        Device to run on.
+    remote : Optional[RPCSession]
+      Remote session used to upload artifacts for runtime evaluation. Must be
+      the same session used to create `dev`.
+
+    Returns
+    -------
+    flops : float
+        Estimated number of flops used by `func`.
+    peak_flops : float
+        Approximate sustained FLOP/s of this target/device combo. Addition and
+        multiplications are each counted as separate FLOPs.
+    name : str
+        Dtype/intrinsic used by `func` to achieve peak flops.
+    """
+    assert nvcc.have_tensorcore(
+        dev.compute_version
+    ), "CUDA roofline only works with devices that have tensorcores"
+    flops = np.sum(features["float_addsub"] + features["float_mul"] + 
features["float_mad"])

Review Comment:
   is division not usually included in FLOP calculations?



##########
python/tvm/utils/roofline/cuda.py:
##########
@@ -234,3 +259,63 @@ def estimate_peak_bandwidth(
     b = nd.empty((blocks, 4, warp_size), dtype="float32", device=dev)
     times = f.time_evaluator(f.entry_name, dev, repeat=10, number=1)(a, b)
     return a.numpy().size * 4 / times.min  # 4 bytes per float32
+
+
[email protected]_peak_bandwidth.register("cuda")
+def estimate_peak_bandwidth(

Review Comment:
   Naming with the new scheme is a little confusing, hard to delineate 
`estimate_peak_bandwidth_global` and `estimate_peak_bandwidth` and with the 
other new functions, can you make a better name?



##########
python/tvm/utils/roofline/cuda.py:
##########
@@ -161,6 +165,51 @@ def peak_flops_tensorcore_tir(
     return n * 16 * 16 * 16 * 2 * sms * 8 / times.min
 
 
[email protected]_peak_flops.register("cuda")
+def estimate_peak_flops(
+    func: PrimFunc,  # pylint: disable=unused-argument
+    features: Dict[str, np.ndarray],
+    target: Target,
+    dev: Device,
+    remote: Optional[RPCSession],
+) -> Tuple[float, float, str]:
+    """Estimate the peak FLOP/s of a cuda device.
+
+    Parameters
+    ----------
+    func : PrimFunc
+        Function to estimate peak flops for. Used to check if a specific kind
+        intrinsic or dtype could be used with this function.
+    features : Dict[str, np.ndarry]
+        Features extracted from `func`. Used to check if a specific kind
+        intrinsic or dtype could be used with this function.
+    target : Target
+        Target to run on. This should be as specific to the actual hardware as
+        possible.
+    dev : Device
+        Device to run on.
+    remote : Optional[RPCSession]
+      Remote session used to upload artifacts for runtime evaluation. Must be
+      the same session used to create `dev`.
+
+    Returns
+    -------
+    flops : float
+        Estimated number of flops used by `func`.
+    peak_flops : float
+        Approximate sustained FLOP/s of this target/device combo. Addition and
+        multiplications are each counted as separate FLOPs.
+    name : str
+        Dtype/intrinsic used by `func` to achieve peak flops.
+    """
+    assert nvcc.have_tensorcore(
+        dev.compute_version
+    ), "CUDA roofline only works with devices that have tensorcores"
+    flops = np.sum(features["float_addsub"] + features["float_mul"] + 
features["float_mad"])
+    peak_flops = estimate_peak_flops_tensorcore(target, dev, remote)

Review Comment:
   Is this ready to hook up INT8 peak flop calculations too? (might be useful 
for Quantized model on tensorcore)



##########
python/tvm/utils/roofline/cuda.py:
##########
@@ -234,3 +259,63 @@ def estimate_peak_bandwidth(
     b = nd.empty((blocks, 4, warp_size), dtype="float32", device=dev)
     times = f.time_evaluator(f.entry_name, dev, repeat=10, number=1)(a, b)
     return a.numpy().size * 4 / times.min  # 4 bytes per float32
+
+
[email protected]_peak_bandwidth.register("cuda")
+def estimate_peak_bandwidth(
+    func: PrimFunc,  # pylint: disable=unused-argument
+    features: Dict[str, np.ndarray],
+    target: Target,
+    dev: Device,
+    remote: Optional[RPCSession] = None,
+) -> Tuple[float, float, str]:
+    """Estimate peak memory bandwidth of a target/device combo.
+
+    Peak bandwidth is estimated by running a small experiment on the underlying
+    hardware. The peak bandwidth measurement assumes that vector instructions
+    are being used to load the data.
+
+    Parameters
+    ----------
+    func : PrimFunc
+        Function to estimate peak bandwidth for. Used to check if a specific
+        kind of memory could be used with this function.
+    features : Dict[str, np.ndarry]
+        Features extracted from `func`. Used to check if a specific kind of
+        memory could be used with this function.
+    target : Target
+        Target to use for measurement. This target should be as specific to the
+        underlying hardware as possible.
+    dev : Device
+        Device to measure peak bandwidth on.
+    remote : Optional[RPCSession]
+      Remote session used to upload artifacts for runtime evaluation. Must be
+      the same session used to create `dev`.
+
+    Returns
+    -------
+    loaded_bytes : float
+        Estimated bytes loaded by `func`.
+    peak_bandwidth : float
+        Peak memory bandwidth in bytes/seconds.
+    name : str
+        Name of the memory being used.
+    """
+    loaded_bytes = 0.0
+    # assume no more than 100 buffers
+    for i in range(100):

Review Comment:
   Can we just iterate through all features and find the ones which match the 
pattern?



##########
python/tvm/utils/roofline/registry.py:
##########
@@ -15,18 +15,24 @@
 # specific language governing permissions and limitations
 # under the License.
 """Definition of generic functions for estimating peak flops and bandwidth"""
-from typing import Optional
-from ...target import Target, generic_func
-from ...runtime import Device
+from typing import Dict, Optional, Tuple
+
+import numpy as np
+
 from ...rpc.client import RPCSession
+from ...runtime import Device
+from ...target import Target, generic_func
+from ...tir import PrimFunc
 
 
 @generic_func
 def estimate_peak_bandwidth(
+    func: PrimFunc,

Review Comment:
   hmm what necessitates this change in interface. It seems to me `func` and 
`features` (the added inputs) are always used to calculate the added float and 
str in the return type.
   
   Why not just do the old thing and split up the calls to get the peak 
statistic, and one to get the actual statistics?
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to