masahi commented on code in PR #12318:
URL: https://github.com/apache/tvm/pull/12318#discussion_r956933885


##########
gallery/how_to/work_with_pytorch/using_as_torch.py:
##########
@@ -0,0 +1,172 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Wrap Your TVMscript with PyTorch Module
+======================
+**Author**: 
+`Yaoda Zhou <https://github.com/juda>`_,
+`Masahiro Masuda <https://github.com/masahi>`_
+
+This article is an introductory tutorial on wrapping the TVMscript code with 
the PyTorch module.
+By the decorator `as_torch`, users can wrap a TVMscript code into a PyTorch 
nn.Module naturally.
+"""
+
+# sphinx_gallery_start_ignore
+from tvm import testing
+
+testing.utils.install_request_hook(depth=3)
+# sphinx_gallery_end_ignore
+
+# Import PyTorch, as well as necessary libraries
+import torch
+import torch.nn.functional as F
+import torch.utils.benchmark as benchmark
+
+import tvm
+from tvm.contrib.torch import as_torch
+from tvm.script import tir as T
+
+######################################################################
+# Write your own PyTorch operator by TVMscript
+# -------------------------------
+# PyTorch is a very popular machine learning framework which contains
+# optimized implementations of most commonly used operators.
+# Nevertheless, sometimes you might want to write your own operators in 
PyTorch.
+# In that case, the performance of such custom operators might not be 
satisfactory for your needs.
+#
+# One of the examples is to define a 1-d depthwise convolution operator.
+# Assume the number of in_channel and out_channel are both 70,
+# the width is 80 and the kernel size is 20,
+# then the 1-d depthwise conv could be written in PyTorch in one line:
+
+in_channel = 70
+out_channel = 70
+width = 80
+kernel_size = 20
+
+
+def torch_depthwise(inputs, filters):
+    return F.conv1d(inputs, filters.view(out_channel, 1, kernel_size), 
groups=out_channel)
+
+
+# We can run this function as:
+
+inputs = torch.randn(in_channel, width)
+filters = torch.randn(out_channel, kernel_size)
+ret_torch = torch_depthwise(inputs, filters)
+
+# The `torch_depthwise` function, in a plain Python code, could be written as:
+
+
+def vanilla_depthwise(input, weight):
+    ret = torch.zeros(out_channel, width - kernel_size + 1)
+    for j in range(out_channel):
+        for i in range(width - kernel_size + 1):
+            for k in range(kernel_size):
+                ret[j, i] += weight[j, k] * input[j, i + k]
+    return ret
+
+
+# Then, we plan to optimize the `depthwise` function by leveraging the power 
of TVM.
+# TVM community proposes an embedded Domain Specific Language on Python call 
TVMscript,
+# which serves for an abstraction of program on various hardware backends.
+
+# As a concrete example, we can write such a TVMscript for 1-d depthwise conv 
code as below.
+# The computation procedure of `tvm_depthwise` is corresponding to the code 
snippet of `vanilla_depthwise`.
+
+# In our `tvm_depthwise` function, both inputs and outputs are set to be 
function parameters
+# that held on the multi-dimension buffers. For each buffer, the shape and 
data type information are required.
+# In the function body, there is a syntactic sugar `T.grid` for writing 
multiple nested iterators.
+# In the body of the loop, each computation is wrapped in an additional 
construct named `T.block`.
+# A block is a basic unit of computation. Inside the block, we need to provide 
a few more information about the block axes.
+# Here, 2 spatial and 1 reduce block iterators are created and bound to the 
loop iterators i, j and k.
+# The computations and machine learning compilation analysis will be defined 
around them.
+# The last 3 lines are computation statements, including an initialization of 
`C[vj, vi]` and the summing up along the axis k.
+# Finally, we place 2 decorators `T.prim_func` and `as_torch` above the 
definition of function,
+# which converts the Python AST to TVMscript AST and then converts to 
PyTorch's `nn.Module`.
+
+
+@as_torch
[email protected]_func
+def tvm_depthwise(
+    A: T.Buffer((70, 80), "float32"),

Review Comment:
   I think the `Buffer` syntax sugar can be extended for dynamic shapes. But 
currently we cannot tune over dynamic shapes, so the performance will probably 
be slower than PT.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to