tkonolige commented on code in PR #11066: URL: https://github.com/apache/tvm/pull/11066#discussion_r860034447
########## python/tvm/utils/__init__.py: ########## @@ -0,0 +1,303 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Utilities operating at a graph/model or other "high" level""" +import csv +import subprocess +from typing import Dict, Union, Optional +import numpy as np + +from .. import auto_scheduler, relay, tir, device, nd, IRModule, build, topi, transform +from ..target import Target +from ..runtime import profiler_vm, profiling, Device, num_threads +from ..script import tir as T + + +def _create_args(mod, dev, func_name="main"): + args = [] + for arg in mod[func_name].params: + args.append( + nd.array( + np.zeros([x.value for x in arg.type_annotation.shape], arg.type_annotation.dtype), + device=dev, + ) + ) + return args + + +def _estimated_features(mod, params, target): + comp = relay.vm.VMCompiler() + mod, params = comp.optimize(mod, params=params, target=target) + return { + prim.attrs["hash"]: (name, auto_scheduler.feature.named_features_from_primfunc(prim)) + for name, prim in mod.functions.items() + if isinstance(prim, tir.PrimFunc) + } + + +def _vec_width_registers(target, vec_width, num_vector_registers): + if vec_width is None: + if target.device_name == "": # indicates x86 + with target: + vec_width = topi.x86.utils.get_simd_32bit_lanes() # in number of float32s + else: + raise RuntimeError(f"Cannot determine vector width for target {target}") + if num_vector_registers is None: + if target.device_name == "": # indicates x86 + with target: + num_vector_registers = ( + 16 # Assuming for all platforms, probably wrong on older ones + ) + else: + raise RuntimeError(f"Cannot determine number of vector registers for target {target}") + return vec_width, num_vector_registers + + [email protected]_func +def peakflops_fma_tir( + a: T.handle, + vec_width: T.int32, + iters: T.int32, + num_vector_registers: T.int32, + threads: T.int32, +) -> None: + # pylint: disable=invalid-name, missing-function-docstring + N = T.var("int32") + A = T.match_buffer(a, [N], "float32") + assert ( + N >= threads * num_vector_registers * vec_width + ), "Input vectors must be >= num_vector_registers*vec_width" + for t in T.parallel(threads): + for _j in range(iters): + for l in T.unroll(num_vector_registers): + # We want to use as few registers as possible, so we perform + # all operations on the same element + for k in T.vectorized(vec_width): + A[t * vec_width * num_vector_registers + vec_width * l + k] = ( + A[t * vec_width * num_vector_registers + vec_width * l + k] + * A[t * vec_width * num_vector_registers + vec_width * l + k] + + A[t * vec_width * num_vector_registers + vec_width * l + k] + ) + + +def estimate_peak_fma_flops( + target: Target, + dev: Device, + vec_width: Optional[int] = None, + num_vector_registers: Optional[int] = None, +) -> float: + """ + Estimate the maximum number of FLOP/s this target/device combo is capable + of reaching by running a test program. This assumes vectorized f32 FMA + (fused-multiply-add) instructions. + + + Parameters + ---------- + target : Target + Target to run on. This should be as specific to the actual hardware as + possible to make sure that LLVM generates the best vector code. + dev : Device + Device to run on. + vec_width : Optional[int] + Vector width of SIMD units on the underlying hardware. Will try to + infer if no value is provided. + num_vector_registers : Optional[int] + Number of vector registers on the underlying hardware. Will try to + infer if no value is provided. + + Returns + ------- + float + Approximate sustained FLOP/s of this target/device combo assuming + vectorized f32 FMA instructions. + """ + assert str(target.kind) == "llvm", "Only llvm targets are supported" + vec_width, num_vector_registers = _vec_width_registers(target, vec_width, num_vector_registers) + iters = 100000 + nthreads = num_threads() + specialized = peakflops_fma_tir.specialize( + { + peakflops_fma_tir.params[1]: vec_width, + peakflops_fma_tir.params[2]: iters, + peakflops_fma_tir.params[3]: num_vector_registers, + peakflops_fma_tir.params[4]: nthreads, + } + ) + with transform.PassContext(opt_level=3): + f = build(specialized, target=target) + a = nd.array(np.ones(vec_width * num_vector_registers * nthreads).astype("float32"), device=dev) + times = f.time_evaluator(f.entry_name, dev, repeat=100)(a) + flops = 2 * vec_width * num_vector_registers * nthreads * iters # fma is two flops + flop_s = flops / times.min + return flop_s + + [email protected]_func +def peak_bandwidth_tir(a: T.handle, b: T.handle, nt: T.int32, vec_width: T.int32) -> None: + # pylint: disable=invalid-name, missing-function-docstring + N = T.var("int32") + A = T.match_buffer(a, [N], "float32") + B = T.match_buffer(b, [nt * vec_width * 4], "float32") + # assert N % (nt * 4 * vec_width) == 0, "div" + # Parallelism is necessary to hit all cores/nodes + for i in T.parallel(nt): + for k in T.serial(N // nt // 4 // vec_width): + for l in T.unroll(4): + # vectorized load is necessary to hit peak bandwidth + for j in T.vectorized(vec_width): + B[i * vec_width * 4 + l * vec_width + j] += A[ + i * (N // nt) + k * vec_width * 4 + l * vec_width + j + ] + + +def estimate_peak_bandwidth(target: Target, dev: Device, vec_width: Optional[int] = None) -> float: + """Estimate peak memory bandwidth of a target/device combo. + + Peak bandwidth is estimated by running a small experiment on the underlying + hardware. The peak bandwidth measurement assumes that vector instructions + are being used to load the data. + + Parameters + ---------- + target : Target + Target to use for measurement. This target should be as specific to the + underlying hardware as possible. + dev : Device + Device to measure peak bandwidth on. + vec_width : Optional[int] + Vector unit width, determined from target if not supplied. + + Returns + ------- + float + Peak memory bandwidth in bytes/seconds. + """ + # Ideally we'd be able to use this code to measure peak bandwidth of the + # different cache levels. If we could just generate load commands, then we + # could use those in a tight loop. Instead we need some code that is + # limited on the cache bandwidth. With the L1 cache we need an operation + # that has a very low arithmetic intensity and we haven't come up with one + # yet. + vec_width, _ = _vec_width_registers(target, vec_width, 1) + specialized = peak_bandwidth_tir.specialize( + { + peak_bandwidth_tir.params[3]: vec_width, + } + ) + with transform.PassContext(opt_level=3): + f = build(specialized, target=target) + # Data size needs to be larger than last level of cache. We don't have a + # way of getting cache sizes, so this number should give us a large enough + # size. + size = 10**8 // (4 * num_threads() * vec_width) * (4 * num_threads() * vec_width) Review Comment: > I'm just ramping into the discussion here but given cache size variability I would argue for the iterative approach for increasing the array size until performance plateaus. It's easy to measure and less likely to lead to false assumptions being made. I would like to have a more robust approach, but the iterative approach also suffers from needing some sort of configuration. How do we decide that we have plateaued at memory bandwidth? What if we plateaued at LLC bandwidth? We'd have to run past a certain size to make sure it wasn't LLC, so we just end up running something around the size I've set here. The best solution is to have some way of getting LLC size and then doing a multiple of it, but we don't have a way to do that right now. Do you have a proposal for how I might determine if the bandwidth has plateaued or not? > It would be neat to do all the STREAM benches for the memory bandwidth measurement and combine them. I'm not sure the STREAM benchmarks are a good fit here because they involve writing back to memory, which may not be a reflection on what is happening in a regular ML kernel. And how would we combine them? Just average? I've looked at what other tools do to estimate bandwidth/flops and most have handwritten assembly or optimized kernels (like here) that just do load or just do flops. > Would also be interesting to apply the specific stream benchmark that applies to a given workload pattern based on the read and writes defined in a block. Probably diminishing returns given that they are usually comparable measurements of the bandwidth. I've thought about this idea, especially with regards to cache size/type of flops, but I haven't come up with a good way to determine cache sizes (for access patterns) or type of compute (FMA or not). It is definitely something I want to do, but I was trying to get a small working PR in first and then we can make the analysis more specific. > Personally I think it would be nice to expose an interface to change this (you can set it default 10 ** 8) in the worst case for what csullivan says above I guess we can expose it, but how would the user know what to set it to? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
