tkonolige commented on code in PR #11066:
URL: https://github.com/apache/tvm/pull/11066#discussion_r860314700


##########
python/tvm/utils/__init__.py:
##########
@@ -0,0 +1,303 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Utilities operating at a graph/model or other "high" level"""
+import csv
+import subprocess
+from typing import Dict, Union, Optional
+import numpy as np
+
+from .. import auto_scheduler, relay, tir, device, nd, IRModule, build, topi, 
transform
+from ..target import Target
+from ..runtime import profiler_vm, profiling, Device, num_threads
+from ..script import tir as T
+
+
+def _create_args(mod, dev, func_name="main"):
+    args = []
+    for arg in mod[func_name].params:
+        args.append(
+            nd.array(
+                np.zeros([x.value for x in arg.type_annotation.shape], 
arg.type_annotation.dtype),
+                device=dev,
+            )
+        )
+    return args
+
+
+def _estimated_features(mod, params, target):
+    comp = relay.vm.VMCompiler()
+    mod, params = comp.optimize(mod, params=params, target=target)
+    return {
+        prim.attrs["hash"]: (name, 
auto_scheduler.feature.named_features_from_primfunc(prim))
+        for name, prim in mod.functions.items()
+        if isinstance(prim, tir.PrimFunc)
+    }
+
+
+def _vec_width_registers(target, vec_width, num_vector_registers):
+    if vec_width is None:
+        if target.device_name == "":  # indicates x86
+            with target:
+                vec_width = topi.x86.utils.get_simd_32bit_lanes()  # in number 
of float32s
+        else:
+            raise RuntimeError(f"Cannot determine vector width for target 
{target}")
+    if num_vector_registers is None:
+        if target.device_name == "":  # indicates x86
+            with target:
+                num_vector_registers = (
+                    16  # Assuming for all platforms, probably wrong on older 
ones
+                )
+        else:
+            raise RuntimeError(f"Cannot determine number of vector registers 
for target {target}")
+    return vec_width, num_vector_registers
+
+
[email protected]_func
+def peakflops_fma_tir(
+    a: T.handle,
+    vec_width: T.int32,
+    iters: T.int32,
+    num_vector_registers: T.int32,
+    threads: T.int32,
+) -> None:
+    # pylint: disable=invalid-name, missing-function-docstring
+    N = T.var("int32")
+    A = T.match_buffer(a, [N], "float32")
+    assert (
+        N >= threads * num_vector_registers * vec_width
+    ), "Input vectors must be >= num_vector_registers*vec_width"
+    for t in T.parallel(threads):
+        for _j in range(iters):
+            for l in T.unroll(num_vector_registers):
+                # We want to use as few registers as possible, so we perform
+                # all operations on the same element
+                for k in T.vectorized(vec_width):
+                    A[t * vec_width * num_vector_registers + vec_width * l + 
k] = (
+                        A[t * vec_width * num_vector_registers + vec_width * l 
+ k]
+                        * A[t * vec_width * num_vector_registers + vec_width * 
l + k]
+                        + A[t * vec_width * num_vector_registers + vec_width * 
l + k]
+                    )
+
+
+def estimate_peak_fma_flops(
+    target: Target,
+    dev: Device,
+    vec_width: Optional[int] = None,
+    num_vector_registers: Optional[int] = None,
+) -> float:
+    """
+    Estimate the maximum number of FLOP/s this target/device combo is capable
+    of reaching by running a test program. This assumes vectorized f32 FMA
+    (fused-multiply-add) instructions.
+
+
+    Parameters
+    ----------
+    target : Target
+        Target to run on. This should be as specific to the actual hardware as
+        possible to make sure that LLVM generates the best vector code.
+    dev : Device
+        Device to run on.
+    vec_width : Optional[int]
+        Vector width of SIMD units on the underlying hardware. Will try to
+        infer if no value is provided.
+    num_vector_registers : Optional[int]
+        Number of vector registers on the underlying hardware. Will try to
+        infer if no value is provided.
+
+    Returns
+    -------
+    float
+        Approximate sustained FLOP/s of this target/device combo assuming
+        vectorized f32 FMA instructions.
+    """
+    assert str(target.kind) == "llvm", "Only llvm targets are supported"
+    vec_width, num_vector_registers = _vec_width_registers(target, vec_width, 
num_vector_registers)
+    iters = 100000
+    nthreads = num_threads()
+    specialized = peakflops_fma_tir.specialize(
+        {
+            peakflops_fma_tir.params[1]: vec_width,
+            peakflops_fma_tir.params[2]: iters,
+            peakflops_fma_tir.params[3]: num_vector_registers,
+            peakflops_fma_tir.params[4]: nthreads,
+        }
+    )
+    with transform.PassContext(opt_level=3):
+        f = build(specialized, target=target)
+    a = nd.array(np.ones(vec_width * num_vector_registers * 
nthreads).astype("float32"), device=dev)
+    times = f.time_evaluator(f.entry_name, dev, repeat=100)(a)
+    flops = 2 * vec_width * num_vector_registers * nthreads * iters  # fma is 
two flops
+    flop_s = flops / times.min
+    return flop_s
+
+
[email protected]_func
+def peak_bandwidth_tir(a: T.handle, b: T.handle, nt: T.int32, vec_width: 
T.int32) -> None:
+    # pylint: disable=invalid-name, missing-function-docstring
+    N = T.var("int32")
+    A = T.match_buffer(a, [N], "float32")
+    B = T.match_buffer(b, [nt * vec_width * 4], "float32")
+    # assert N % (nt * 4 * vec_width) == 0, "div"
+    # Parallelism is necessary to hit all cores/nodes
+    for i in T.parallel(nt):
+        for k in T.serial(N // nt // 4 // vec_width):
+            for l in T.unroll(4):
+                # vectorized load is necessary to hit peak bandwidth
+                for j in T.vectorized(vec_width):
+                    B[i * vec_width * 4 + l * vec_width + j] += A[
+                        i * (N // nt) + k * vec_width * 4 + l * vec_width + j
+                    ]
+
+
+def estimate_peak_bandwidth(target: Target, dev: Device, vec_width: 
Optional[int] = None) -> float:
+    """Estimate peak memory bandwidth of a target/device combo.
+
+    Peak bandwidth is estimated by running a small experiment on the underlying
+    hardware. The peak bandwidth measurement assumes that vector instructions
+    are being used to load the data.
+
+    Parameters
+    ----------
+    target : Target
+        Target to use for measurement. This target should be as specific to the
+        underlying hardware as possible.
+    dev : Device
+        Device to measure peak bandwidth on.
+    vec_width : Optional[int]
+        Vector unit width, determined from target if not supplied.
+
+    Returns
+    -------
+    float
+        Peak memory bandwidth in bytes/seconds.
+    """
+    # Ideally we'd be able to use this code to measure peak bandwidth of the
+    # different cache levels. If we could just generate load commands, then we
+    # could use those in a tight loop. Instead we need some code that is
+    # limited on the cache bandwidth. With the L1 cache we need an operation
+    # that has a very low arithmetic intensity and we haven't come up with one
+    # yet.
+    vec_width, _ = _vec_width_registers(target, vec_width, 1)
+    specialized = peak_bandwidth_tir.specialize(
+        {
+            peak_bandwidth_tir.params[3]: vec_width,
+        }
+    )
+    with transform.PassContext(opt_level=3):
+        f = build(specialized, target=target)
+    # Data size needs to be larger than last level of cache. We don't have a
+    # way of getting cache sizes, so this number should give us a large enough
+    # size.
+    size = 10**8 // (4 * num_threads() * vec_width) * (4 * num_threads() * 
vec_width)
+    a = nd.array(np.ones(size, dtype="float32"), device=dev)
+    b = nd.array(np.ones(vec_width * 4 * num_threads(), dtype="float32"), 
device=dev)
+    times = f.time_evaluator(f.entry_name, dev, repeat=5, number=1)(a, b, 
num_threads())
+    return size * 4 / times.min  # 4 bytes per float32
+
+
+def roofline_analysis(
+    mod: IRModule, params: Dict[str, nd.NDArray], target: Union[str, Target], 
dev: Device
+) -> profiling.Report:
+    """
+    Create a profiling report that contains roofline and other estimated
+    statistics from running a module on the VM.
+
+    These statistics are calculated by analyzing the lowered TIR of each
+    operator, so they are estimates of the true values. The statistics are:
+      - Bound: Is the operator memory or compute bound. This is computed by
+        assuming that the operator could perfectly cache all loads -- each byte
+        of memory is only loaded once.
+      - Percent of Theoretical Optimal: What percent of theoretical optimal for
+        the bound. i.e. percent of peak memory bandwidth if memory bound,
+        percent of peak FLOP/s if compute bound.
+      - Unique Loaded Bytes: estimation of the number of byte loaded not
+        counting multiple accesses to the same byte.
+      - Estimated Flops: estimated number of floating point operations.
+      - Arithmetic Intensity: ratio of FLOPs per byte of data.
+      - FLOP/s: floating point operations per second.
+      - Bandwidth: Number of bytes loaded per second.
+
+    Parameters
+    ----------
+    mod : IRModule
+      Uncompiled input module>
+
+    params : Dict[str, nd.NDArray]
+
+    target : Union[str, Target]
+      Target to run on.
+
+    dev : Device
+      Device to run on.
+
+    Returns
+    -------
+
+    report : profiling.Report
+      Profiling report which includes the estimated statistics.
+    """
+    if isinstance(target, str):
+        target = Target(target)
+    peak_bandwidth = estimate_peak_bandwidth(target, dev)
+    peak_flops = estimate_peak_fma_flops(target, dev)
+
+    ridge_point = peak_flops / peak_bandwidth
+
+    all_features = _estimated_features(mod, params, target)
+
+    lib = relay.vm.compile(mod, params=params, target=target)
+    vmexec = profiler_vm.VirtualMachineProfiler(lib, dev)
+
+    args = _create_args(mod, dev)
+    report = vmexec.profile(*args)
+    new_calls = []
+    for call in report.calls:
+        if "Hash" in call.keys():
+            _, features = all_features[call["Hash"]]
+
+            flops = np.sum(features["float_addsub"] + features["float_mul"] + 
features["float_mad"])
+            unique_loaded_bytes = 0.0
+            # assume no more than 100 buffers
+            for i in range(100):
+                # We could uses loaded bytes, but that accounts for for L1 
cache.
+                # If we use unique_bytes, then we are looking at how close we 
come
+                # to the performance assuming all data is cached perfectly.

Review Comment:
   You've convinced me, I'll switch it over.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to