tkonolige commented on code in PR #11066:
URL: https://github.com/apache/tvm/pull/11066#discussion_r859096015


##########
python/tvm/utils/__init__.py:
##########
@@ -0,0 +1,303 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Utilities operating at a graph/model or other "high" level"""
+import csv
+import subprocess
+from typing import Dict, Union, Optional
+import numpy as np
+
+from .. import auto_scheduler, relay, tir, device, nd, IRModule, build, topi, 
transform
+from ..target import Target
+from ..runtime import profiler_vm, profiling, Device, num_threads
+from ..script import tir as T
+
+
+def _create_args(mod, dev, func_name="main"):
+    args = []
+    for arg in mod[func_name].params:
+        args.append(
+            nd.array(
+                np.zeros([x.value for x in arg.type_annotation.shape], 
arg.type_annotation.dtype),
+                device=dev,
+            )
+        )
+    return args
+
+
+def _estimated_features(mod, params, target):
+    comp = relay.vm.VMCompiler()
+    mod, params = comp.optimize(mod, params=params, target=target)
+    return {
+        prim.attrs["hash"]: (name, 
auto_scheduler.feature.named_features_from_primfunc(prim))
+        for name, prim in mod.functions.items()
+        if isinstance(prim, tir.PrimFunc)
+    }
+
+
+def _vec_width_registers(target, vec_width, num_vector_registers):
+    if vec_width is None:
+        if target.device_name == "":  # indicates x86
+            with target:
+                vec_width = topi.x86.utils.get_simd_32bit_lanes()  # in number 
of float32s
+        else:
+            raise RuntimeError(f"Cannot determine vector width for target 
{target}")
+    if num_vector_registers is None:
+        if target.device_name == "":  # indicates x86
+            with target:
+                num_vector_registers = (
+                    16  # Assuming for all platforms, probably wrong on older 
ones
+                )
+        else:
+            raise RuntimeError(f"Cannot determine number of vector registers 
for target {target}")
+    return vec_width, num_vector_registers
+
+
[email protected]_func
+def peakflops_fma_tir(
+    a: T.handle,
+    vec_width: T.int32,
+    iters: T.int32,
+    num_vector_registers: T.int32,
+    threads: T.int32,
+) -> None:
+    # pylint: disable=invalid-name, missing-function-docstring
+    N = T.var("int32")
+    A = T.match_buffer(a, [N], "float32")
+    assert (
+        N >= threads * num_vector_registers * vec_width
+    ), "Input vectors must be >= num_vector_registers*vec_width"
+    for t in T.parallel(threads):
+        for _j in range(iters):
+            for l in T.unroll(num_vector_registers):
+                # We want to use as few registers as possible, so we perform
+                # all operations on the same element
+                for k in T.vectorized(vec_width):
+                    A[t * vec_width * num_vector_registers + vec_width * l + 
k] = (
+                        A[t * vec_width * num_vector_registers + vec_width * l 
+ k]
+                        * A[t * vec_width * num_vector_registers + vec_width * 
l + k]
+                        + A[t * vec_width * num_vector_registers + vec_width * 
l + k]
+                    )
+
+
+def estimate_peak_fma_flops(
+    target: Target,
+    dev: Device,
+    vec_width: Optional[int] = None,
+    num_vector_registers: Optional[int] = None,
+) -> float:
+    """
+    Estimate the maximum number of FLOP/s this target/device combo is capable
+    of reaching by running a test program. This assumes vectorized f32 FMA
+    (fused-multiply-add) instructions.
+
+
+    Parameters
+    ----------
+    target : Target
+        Target to run on. This should be as specific to the actual hardware as
+        possible to make sure that LLVM generates the best vector code.
+    dev : Device
+        Device to run on.
+    vec_width : Optional[int]
+        Vector width of SIMD units on the underlying hardware. Will try to
+        infer if no value is provided.
+    num_vector_registers : Optional[int]
+        Number of vector registers on the underlying hardware. Will try to
+        infer if no value is provided.
+
+    Returns
+    -------
+    float
+        Approximate sustained FLOP/s of this target/device combo assuming
+        vectorized f32 FMA instructions.
+    """
+    assert str(target.kind) == "llvm", "Only llvm targets are supported"
+    vec_width, num_vector_registers = _vec_width_registers(target, vec_width, 
num_vector_registers)
+    iters = 100000
+    nthreads = num_threads()
+    specialized = peakflops_fma_tir.specialize(
+        {
+            peakflops_fma_tir.params[1]: vec_width,
+            peakflops_fma_tir.params[2]: iters,
+            peakflops_fma_tir.params[3]: num_vector_registers,
+            peakflops_fma_tir.params[4]: nthreads,
+        }
+    )
+    with transform.PassContext(opt_level=3):
+        f = build(specialized, target=target)
+    a = nd.array(np.ones(vec_width * num_vector_registers * 
nthreads).astype("float32"), device=dev)
+    times = f.time_evaluator(f.entry_name, dev, repeat=100)(a)
+    flops = 2 * vec_width * num_vector_registers * nthreads * iters  # fma is 
two flops
+    flop_s = flops / times.min
+    return flop_s
+
+
[email protected]_func
+def peak_bandwidth_tir(a: T.handle, b: T.handle, nt: T.int32, vec_width: 
T.int32) -> None:
+    # pylint: disable=invalid-name, missing-function-docstring
+    N = T.var("int32")
+    A = T.match_buffer(a, [N], "float32")
+    B = T.match_buffer(b, [nt * vec_width * 4], "float32")
+    # assert N % (nt * 4 * vec_width) == 0, "div"
+    # Parallelism is necessary to hit all cores/nodes
+    for i in T.parallel(nt):
+        for k in T.serial(N // nt // 4 // vec_width):
+            for l in T.unroll(4):
+                # vectorized load is necessary to hit peak bandwidth
+                for j in T.vectorized(vec_width):
+                    B[i * vec_width * 4 + l * vec_width + j] += A[

Review Comment:
   `+=` is required to make a data dependence between all the loads, otherwise 
llvm could rewrite this to just loading the last element in the loop (`k=N // 
nt // 4 // vec_width-1`). This compute is much less than the maximum arithmetic 
intensity of processors, so we will be bandwidth limited (which is what we 
want). If you look at this comment 
https://github.com/apache/tvm/pull/11066/files/0e60f8d2a3fd2c234886ea4cb2df57e646fc17ef#diff-a127e78ae9dc951f53d45d15e2a905cb559a0bee347d7fc7232334f9e69a862fR189-R194,
 I explain in what cases this compute does matter.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to