hogepodge commented on a change in pull request #7642:
URL: https://github.com/apache/tvm/pull/7642#discussion_r600034455



##########
File path: tutorials/get_started/tensor_expr_get_started.py
##########
@@ -302,18 +385,452 @@
     fadd_cl(a, b, c)
     tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
 
-######################################################################
-# Summary
-# -------
-# This tutorial provides a walk through of TVM workflow using
-# a vector add example. The general workflow is
+################################################################################
+# .. note:: Code Specialization
+#
+#   As you may have noticed, the declarations of A, B and C all take the same
+#   shape argument, n. TVM will take advantage of this to pass only a single
+#   shape argument to the kernel, as you will find in the printed device code.
+#   This is one form of specialization.
+#
+#   On the host side, TVM will automatically generate check code that checks
+#   the constraints in the parameters. So if you pass arrays with different
+#   shapes into fadd, an error will be raised.
+#
+#   We can do more specializations. For example, we can write :code:`n =
+#   tvm.runtime.convert(1024)` instead of :code:`n = te.var("n")`, in the
+#   computation declaration. The generated function will only take vectors with
+#   length 1024.
+
+################################################################################
+# .. note:: TE Scheduling Primitives
+#
+#   TVM includes a number of different scheduling primitives:
+#
+#   - split: splits a specified axis into two axises by the defined factor.
+#   - tile: tiles will split a computation across two axes by the defined 
factors.
+#   - fuse: fuses two consecutive axises of one computation.
+#   - reorder: can reorder the axises of a computation into a defined order.
+#   - bind: can bind a computation to a specific thread, useful in GPU 
programming.
+#   - compute_at: by default, TVM will compute tensors at the outermost level
+#     of the function, or the root, by default. compute_at specifies that one
+#     tensor should be computed at the first axis of computation for another
+#     operator.
+#   - compute_inline: when marked inline, a computation will be expanded then
+#     inserted into the address where the tensor is required.
+#   - compute_root: moves a computation to the outermost layer, or root, of the
+#     function. This means that stage of the computation will be fully computed
+#     before it moves on to the next stage.
+#
+#   A complete description of these primitives can be found in the
+# [Schedule 
Primitives](https://tvm.apache.org/docs/tutorials/language/schedule_primitives.html)
 docs page.
+
+################################################################################
+# Example 2: Manually Optimizing Matrix Multiplication with TE
+# ------------------------------------------------------------
+#
+# Now we will consider a second, more advanced example, demonstrating how with
+# just 18 lines of python code TVM speeds up a common matrix multiplication 
operation by 18x.
+#
+# **Matrix multiplication is a compute intensive operation. There are two 
important optimizations for good CPU performance:**
+# 1. Increase the cache hit rate of memory access. Both complex numerical
+#    computation and hot-spot memory access can be accelerated by a high cache 
hit
+#    rate. This requires us to transform the origin memory access pattern to a 
pattern that fits the cache policy.
+# 2. SIMD (Single instruction multi-data), also known as the vector processing
+#    unit. On each cycle instead of processing a single value, SIMD can 
process a small batch of data.
+#    This requires us to transform the data access pattern in the loop
+#    body in uniform pattern so that the LLVM backend can lower it to SIMD.
+#
+# The techniques used in this tutorial are a subset of tricks mentioned in this
+# `repository <https://github.com/flame/how-to-optimize-gemm>`_. Some of them
+# have been applied by TVM abstraction automatically, but some of them cannot
+# be automatically applied due to TVM constraints.
+#
+# All the experiment results mentioned below are executed on 2015 15" MacBook
+# equipped with Intel i7-4770HQ CPU. The cache line size should be 64 bytes for
+# all the x86 CPUs.
+
+################################################################################
+# Preparation and Performance Baseline
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+#
+# We begin by collecting performance data on the `numpy` implementation of
+# matrix multiplication.
+
+import tvm
+import tvm.testing
+from tvm import te
+import numpy
+import timeit
+
+# The size of the matrix
+# (M, K) x (K, N)
+# You are free to try out different shapes, sometimes TVM optimization 
outperforms numpy with MKL.
+M = 1024
+K = 1024
+N = 1024
+
+# The default tensor data type in tvm
+dtype = "float32"
+
+# using Intel AVX2 (Advanced Vector Extensions) ISA for SIMD To get the best
+# performance, please change the following line to llvm -mcpu=core-avx2, or
+# specific type of CPU you use.  If you're using llvm, you can get this
+# information from the command ``llc --version`` to get the CPU type, and
+# you can check ``/proc/cpuinfo`` for additional extensions that your
+# processor might support.
+target = "llvm"
+ctx = tvm.context(target, 0)
+
+# Random generated tensor for testing
+a = tvm.nd.array(numpy.random.rand(M, K).astype(dtype), ctx)
+b = tvm.nd.array(numpy.random.rand(K, N).astype(dtype), ctx)
+
+# Repeatedly perform a matrix multiplication to get a performance baseline
+# for the default numpy implementation
+np_repeat = 100
+np_runing_time = timeit.timeit(
+    setup="import numpy\n"
+    "M = " + str(M) + "\n"
+    "K = " + str(K) + "\n"
+    "N = " + str(N) + "\n"
+    'dtype = "float32"\n'
+    "a = numpy.random.rand(M, K).astype(dtype)\n"
+    "b = numpy.random.rand(K, N).astype(dtype)\n",
+    stmt="answer = numpy.dot(a, b)",
+    number=np_repeat,
+)
+print("Numpy running time: %f" % (np_runing_time / np_repeat))
+
+answer = numpy.dot(a.asnumpy(), b.asnumpy())
+
+################################################################################
+# Now we write a basic matrix multiplication using TVM TE and verify that it
+# produces the same results as the numpy implementation. We also write a
+# function that will help us measure the performance of the schedule
+# optimizations.
+
+# TVM Matrix Multiplication using TE
+k = te.reduce_axis((0, K), "k")
+A = te.placeholder((M, K), name="A")
+B = te.placeholder((K, N), name="B")
+C = te.compute((M, N), lambda x, y: te.sum(A[x, k] * B[k, y], axis=k), 
name="C")
+
+# Default schedule
+s = te.create_schedule(C.op)
+func = tvm.build(s, [A, B, C], target=target, name="mmult")
+
+c = tvm.nd.array(numpy.zeros((M, N), dtype=dtype), ctx)
+func(a, b, c)
+tvm.testing.assert_allclose(c.asnumpy(), answer, rtol=1e-5)
+
+
+def evaluate_operation(s, vars, target, name, optimization, log):
+    func = tvm.build(s, [A, B, C], target=target, name="mmult")
+    assert func
+
+    c = tvm.nd.array(numpy.zeros((M, N), dtype=dtype), ctx)
+    func(a, b, c)
+    tvm.testing.assert_allclose(c.asnumpy(), answer, rtol=1e-5)
+
+    evaluator = func.time_evaluator(func.entry_name, ctx, number=10)
+    mean_time = evaluator(a, b, c).mean
+    print("%s: %f" % (optimization, mean_time))
+    log.append((optimization, mean_time))
+
+
+log = []
+
+evaluate_operation(s, [A, B, C], target=target, name="mmult", 
optimization="none", log=log)
+
+################################################################################
+# Let's take a look at the intermediate representation of the operator and
+# default schedule using the TVM lower function. Note how the implementation is
+# essentially a naive implementation of a matrix multiplication, using three
+# nested loops over the indices of the A and B matrices.
+
+print(tvm.lower(s, [A, B, C], simple_mode=True))
+
+################################################################################
+# Optimization 1: Blocking
+# ~~~~~~~~~~~~~~~~~~~~~~~~
+#
+# A important trick to enhance the cache hit rate is blocking, where you
+# structure memory access such that the inside a block is a small neighborhood
+# that has high memory locality. In this tutorial, we pick a block factor of
+# 32. This will result in a block that will fill a 32 * 32 * sizeof(float) area
+# of memory. This corresponds to a cache size of 4KB, in relation to a
+# reference cache size of 32 KB for L1 cache.
+#
+# We begin by creating a default schedule for the ``C`` operation, then apply a
+# ``tile`` scheduling primitive to it with the specified block factor, with the
+# scheduling primitive returning the resulting loop order from outermost to
+# innermost, as a vector ``[x_outer, y_outer, x_inner, y_inner]``. We then get
+# the reduction axis for output of the operation, and perform a split operation
+# on it using a factor of 4. This factor doesn't directly impact the blocking
+# optimization we're working on right now, but will be useful later when we
+# apply vectorization.
+#
+# Now that the operation has been blocked, we can reorder the computation to
+# put the reduction operation into the outermost loop of the computation,
+# helping to guarantee that the blocked data remains in cache. This completes
+# the schedule, and we can build and test the performance compared to the naive
+# schedule.
+
+bn = 32
+s = te.create_schedule(C.op)
+
+# Blocking by loop tiling
+xo, yo, xi, yi = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn)
+(k,) = s[C].op.reduce_axis
+ko, ki = s[C].split(k, factor=4)
+
+# Hoist reduction domain outside the blocking loop
+s[C].reorder(xo, yo, ko, ki, xi, yi)
+
+evaluate_operation(s, [A, B, C], target=target, name="mmult", 
optimization="blocking", log=log)
+
+################################################################################
+# By reordering the computation to take advantage of caching, you should see a
+# significant improvement in the performance of the computation. Now, print the
+# internal representation and compare it to the original:
+
+print(tvm.lower(s, [A, B, C], simple_mode=True))
+
+################################################################################
+# Optimization 2: Vectorization
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+#
+# Another important optimization trick is vectorization. When the memory access
+# pattern is uniform, the compiler can detect this pattern and pass the
+# continuous memory to the SIMD vector processor. In TVM, we can use the
+# ``vectorize`` interface to hint the compiler this pattern, taking advantage
+# of this hardware feature.
+#
+# In this tutorial, we chose to vectorize the inner loop row data since it is
+# already cache friendly from our previous optimizations.
+
+# Begin by applying the previous optimizations again
+s = te.create_schedule(C.op)

Review comment:
       I think in some instances we need to recreate the schedule because the 
operations are reordered. I can update to not create a new schedule every time, 
and add a comment about incremental improvements.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to