echuraev commented on code in PR #13867:
URL: https://github.com/apache/tvm/pull/13867#discussion_r1095394301


##########
gallery/how_to/deploy_models/deploy_model_on_adreno.py:
##########
@@ -146,85 +207,24 @@
 img = np.expand_dims(img, 0)
 
 #################################################################
-# Load pretrained Pytorch model
-# -----------------------------
-# Create a Relay graph from a Pytorch ResNet-18 model
-import os
-import torch
-import torchvision
-import tvm
-from tvm import te
-from tvm import relay, rpc
-from tvm.contrib import utils, ndk
-from tvm.contrib import graph_executor
-
-model_name = "resnet18"
-model = getattr(torchvision.models, model_name)(pretrained=True)
-model = model.eval()
-
-# We grab the TorchScripted model via tracing
-input_shape = [1, 3, 224, 224]
-input_data = torch.randn(input_shape)
-scripted_model = torch.jit.trace(model, input_data).eval()
-
+# Convert PyTorch model to Relay module
+# -------------------------------------
+# TVM has frontend api for various frameworks under relay.frontend and now
+# for pytorch model import we have relay.frontend.from_pytorch api.
 # Input name can be arbitrary
 input_name = "input0"
 shape_list = [(input_name, img.shape)]
+
 mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)
 
 #################################################################
 # Precisions
 # ----------
-# Since TVM support Mixed Precision, we need to register 
mixed_precision_conversion:
-from tvm.relay.op import register_mixed_precision_conversion
-
-conv2d_acc = "float32"
-
-
-@register_mixed_precision_conversion("nn.conv2d", level=11)
-def conv2d_mixed_precision_rule(call_node: "relay.Call", mixed_precision_type: 
str):
-    global conv2d_acc
-    return [
-        relay.transform.mixed_precision.MIXED_PRECISION_ALWAYS,
-        conv2d_acc,
-        mixed_precision_type,
-    ]
-
-
-@register_mixed_precision_conversion("nn.dense", level=11)
-def conv2d_mixed_precision_rule(call_node: "relay.Call", mixed_precision_type: 
str):
-    global conv2d_acc
-    return [
-        relay.transform.mixed_precision.MIXED_PRECISION_ALWAYS,
-        conv2d_acc,
-        mixed_precision_type,
-    ]
+from tvm.relay.op.contrib import adreno
 
+adreno.convert_to_dtype(mod["main"], dtype)

Review Comment:
   Could you please add an explanation comment about this function before this 
call.



##########
gallery/how_to/deploy_models/deploy_model_on_adreno_tvmc.py:
##########
@@ -0,0 +1,184 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+.. _tutorial-deploy-model-on-adreno-tvmc:
+
+Deploy the Pretrained Model on Adreno™ with tvmc Interface
+==========================================================
+**Author**: Siva Rama Krishna
+
+This article is a step-by-step tutorial to deploy pretrained Keras resnet50 
model on Adreno™.
+
+Besides that, you should have TVM built for Android.
+See the following instructions on how to build it and setup RPC environment.
+
+`Deploy to Adreno GPU <https://tvm.apache.org/docs/how_to/deploy/adreno.html>`_
+
+"""
+
+import os
+import tvm
+import numpy as np
+from tvm import relay
+from tvm.driver import tvmc
+from tvm.driver.tvmc.model import TVMCPackage
+from tvm.contrib import utils
+
+#################################################################
+# Configuration
+# -------------
+# Specify Adreno target before compiling to generate texture
+# leveraging kernels and get all the benefits of textures
+# Note: This generated example running on our x86 server for demonstration.
+# If running it on the Android device, we need to
+# specify its instruction set. Set :code:`local_demo` to False if you want
+# to run this tutorial with a real device over rpc.
+local_demo = True
+
+# by default on CPU target will execute.
+# select 'llvm', 'opencl' and 'opencl -device=adreno'
+target = "llvm"
+
+# Change target configuration.
+# Run `adb shell cat /proc/cpuinfo` to find the arch.
+arch = "arm64"
+target_host = "llvm -mtriple=%s-linux-android" % arch
+
+# Auto tuning is compute and time taking task, hence disabling for default 
run. Please enable it if required.
+is_tuning = False
+tune_log = "adreno-resnet50.log"
+
+# To enable OpenCLML accelerated operator library.
+enable_clml = False
+cross_compiler = 
"/opt/android-sdk-linux/ndk/21.3.6528147/toolchains/llvm/prebuilt/linux-x86_64/bin/aarch64-linux-android28-clang"

Review Comment:
   I would suggest to use environment variable instead of absolute path.
   ```suggestion
   cross_compiler = os.environ["ANDROID_NDK_HOME"] + 
"/toolchains/llvm/prebuilt/linux-x86_64/bin/aarch64-linux-android28-clang"
   ```



##########
docs/how_to/deploy/adreno.rst:
##########
@@ -65,134 +78,483 @@ Reasons of using textures:
 Overall, with textures, it is possible to achieve a significant performance 
boost
 compared to OpenCL buffer based solutions.
 
-.. _building_tvm_for_adreno:
+In general we specify target as ``target="opencl"`` for a regular OpenCL based 
target which generates the kernels as shown below.
 
-Building TVM for Adreno
------------------------
+.. code:: c
+
+   __kernel void tvmgen_default_fused_nn_conv2d_kernel0(__global float* 
restrict p0, __global double* restrict p1, __global float* restrict 
conv2d_nhwc) {
+   // body..
+
+Above OpenCL kernel definition has ``__global float*`` poniters which are 
essestially OpenCL ``buffer``  objects.
+
+When enabled texture based enhancements by modifying target definition as 
``target="opencl -device=adreno"`` we can see the generated
+kernels using texture backed OpenCL image objects as shown below.
+
+.. code:: c
+
+   __kernel void tvmgen_default_fused_nn_conv2d_kernel0(__write_only image2d_t 
pad_temp_global_texture, __read_only image2d_t p0) {
+   // body..
+
+*image2d_t* is a built-in OpenCL types that represents two-dimensional image 
object and provides several additional functions.
+When we use *image2d_t* we read *4 elements at one time*, and it helps to 
utilize hardware in a more efficient way.
+
+Please refer to :ref:`Advanced Usage<advanced_usage>` for more details about 
generation and inspection of kernel sources.
+
+
+.. _about_openclml:
 
-This section gives instructions on how to build the Android part of TVM
-with OpenCL and TVM RPC Server in order to deploy models on Adreno.
+About OpenCLML
+--------------
 
-Since the process of building TVM for Adreno is exactly the same as the
-process of building TVM for Android, please refer to these instructions:
-`TVM RPC
-Server <https://github.com/apache/tvm/tree/main/apps/cpp_rpc>`_.
+OpenCLML is a SDK released by Qualcomm that provides accelerated deep learning 
operators.
+These operators are exposed as an extension "cl_qcom_ml_ops" to standard 
OpenCL specification.
+Please refer `Accelerate your models with our OpenCL ML SDK 
<https://developer.qualcomm.com/blog/accelerate-your-models-our-opencl-ml-sdk>`_
 for more details.
 
-Since there are many required packages for Android, you can use the official 
Docker Image to build TVM.
-For more information refer to this guide: `Deploy the Pretrained Model on 
Android 
<https://tvm.apache.org/docs/how_to/deploy_models/deploy_model_on_android.html>`_.
+OpenCLML is integrated into TVM as a `BYOC 
<https://tvm.apache.org/docs/dev/how_to/relay_bring_your_own_codegen.html?highlight=bring%20your%20own>`_
 solution.
+OpenCLML operators can use same context and can be enqueued on same command 
queue as used in native OpenCL.
+We took advantage of this to avoid any context switching over heads while 
fallback to native OpenCL.
+
+
+.. _build_deploy:
+
+TVM for Adreno™
+---------------
+
+This section gives instructions about various ways of building and deploying 
model
+to Adreno™ target. Adreno™ is a remote target which is connected to the host 
via ADB connection.
+Deploying the compiled model here require use some tools on host as well as on 
target.
+
+TVM has simplified user friendly command line based tools as well as
+developer centric python API interface for various steps like auto tuning, 
building and deploying.
+
+TVM compilation process for remote devices has multiple stages listed below.
+
+**Model import:**
+At this stage we import a model from well known frameworks like Tensorflow, 
PyTorch, ONNX ...etc.
+This stage converts the given model into TVM's relay module format. 
Alternatively one can build a relay module manually
+by using TVM's operator inventory too. TVM module generated here is a target 
independent representation of the graph.
+
+**Auto Tuning:**
+At this stage we tune the TVM generated kernels specific to a target. Auto 
tuning process requires
+target device availability and in case of a remote target like Adreno™ on 
Android device we use RPC Setup for communication.
+Later sections in this guide will detail about RPC Setup for Android device. 
Auto tuning is not a necessary step for
+compilation of a model. It is necessary for acheiving best performance out of 
TVM generated kernels.
+
+**Compilation:**
+At this stage we compile the model for specific target. Given we auto tuned 
the module in previous stage,
+TVM compilation make use of the tuning log for genetrating best performing 
kernels. TVM compilation process produces artifacts
+containing kernel shared lib, graph definition in json format and parameters 
binary file in TVM specific format.
+
+**Deploy (or test run) on Target:**
+At this stage we run the TVM compilation output on the target. Deployment is 
possible from python
+environment using RPC Setup and also using TVM's native tool which is native 
binary cross compiled for Android.
+At this stage we can run the compiled model on Android target and unit test 
output correctness and performance aspects.
+
+**Aplication Integration:**
+This stage is all about integrating TVM compiled model in applications. Here 
we discuss about
+interfacing tvm runtime from Android (cpp native environment or from JNI) for 
setting input and getting output.
+
+**Advanced Usage:**
+This section advanced user interests like viewing generated source code, 
altering precision of the module ...etc.
+
+
+This tutorial covers all the above aspects as part of below sections.
+
+- :ref:`Development environment<development_environment>`
+- :ref:`RPC Setup<rpc_setup>`
+- :ref:`Commandline tools<commandline_interface>`
+- :ref:`Python interface<python_interface>`
+- :ref:`Application Integration<application_integration>`
+- :ref:`Advanced Usage<advanced_usage>`
+
+.. _development_environment:
+
+
+Development Environment Setup : Automatic
+-----------------------------------------
+TVM ships a predefined docker container environment with all prerequisites to 
get started quickly.
+You may also refer to :ref:`Manual Environment Setup<manual_setup>` for more 
control on the dependencies.
+
+For docker setup the pre requisite is just docker tool availabilty on host.
+
+Below commands can build a docker image for adreno.
+
+::
 
-**Prerequisites**: Android NDK and Android Debug Bridge must
-be installed, the desired device must have OpenCL support and Android part of 
TVM must be built:
+   ./docker/build.sh ci_adreno
+   docker tag tvm.ci_adreno ci_adreno
+
+
+Now we can build both host and target utils with below command.
+
+::
+
+   ./tests/scripts/ci.py adreno -i
+
+To build TVM with OpenCLML SDK we need export the OpenCLML SDK as shown below 
while building
+
+::
+
+   export ADRENO_OPENCL=<Path to OpenCLML SDK>
+   ./tests/scripts/ci.py adreno -i
+
+On successful compilation this leaves us into a docker shell. The build leaves 
two folders
+
+* build-adreno:  The host side TVM compiler build.
+* build-adreno-target : Contains the android target components
+
+    * libtvm_runtime.so : TVM runtime library
+    * tvm_rpc : The rpc runtime environment tool
+    * rtvm : A native stand alone tool
+
+While using docker environment the android device is shared with host. Hence, 
it is required
+to have adb version "1.0.41" on the host as the docker used the same version.
+
+We can check adb devices availability inside docker environment too.
+
+::
+
+   user@ci-adreno-fpeqs:~$ adb devices
+   List of devices attached
+   aaaabbbb    device
+   ccccdddd    device
+
+.. _manual_setup:
+
+Development Environment Setup : Manual
+--------------------------------------
+
+Manual build process require building of host and target components.
+
+Below command will configure the build the host compiler
+
+::
+
+   mkdir -p build
+   cd build
+   cp ../cmake/config.cmake .
+
+   echo set\(USE_OPENCL ON\) >> config.cmake

Review Comment:
   In fact, it is not required part for host compilation. Although usually I 
use this flag also for host compilation, I believe that host part can be built 
w/o `USE_OPENCL`.



##########
docs/how_to/deploy/adreno.rst:
##########
@@ -65,134 +78,483 @@ Reasons of using textures:
 Overall, with textures, it is possible to achieve a significant performance 
boost
 compared to OpenCL buffer based solutions.
 
-.. _building_tvm_for_adreno:
+In general we specify target as ``target="opencl"`` for a regular OpenCL based 
target which generates the kernels as shown below.
 
-Building TVM for Adreno
------------------------
+.. code:: c
+
+   __kernel void tvmgen_default_fused_nn_conv2d_kernel0(__global float* 
restrict p0, __global double* restrict p1, __global float* restrict 
conv2d_nhwc) {
+   // body..
+
+Above OpenCL kernel definition has ``__global float*`` poniters which are 
essestially OpenCL ``buffer``  objects.
+
+When enabled texture based enhancements by modifying target definition as 
``target="opencl -device=adreno"`` we can see the generated
+kernels using texture backed OpenCL image objects as shown below.
+
+.. code:: c
+
+   __kernel void tvmgen_default_fused_nn_conv2d_kernel0(__write_only image2d_t 
pad_temp_global_texture, __read_only image2d_t p0) {
+   // body..
+
+*image2d_t* is a built-in OpenCL types that represents two-dimensional image 
object and provides several additional functions.
+When we use *image2d_t* we read *4 elements at one time*, and it helps to 
utilize hardware in a more efficient way.
+
+Please refer to :ref:`Advanced Usage<advanced_usage>` for more details about 
generation and inspection of kernel sources.
+
+
+.. _about_openclml:
 
-This section gives instructions on how to build the Android part of TVM
-with OpenCL and TVM RPC Server in order to deploy models on Adreno.
+About OpenCLML
+--------------
 
-Since the process of building TVM for Adreno is exactly the same as the
-process of building TVM for Android, please refer to these instructions:
-`TVM RPC
-Server <https://github.com/apache/tvm/tree/main/apps/cpp_rpc>`_.
+OpenCLML is a SDK released by Qualcomm that provides accelerated deep learning 
operators.
+These operators are exposed as an extension "cl_qcom_ml_ops" to standard 
OpenCL specification.
+Please refer `Accelerate your models with our OpenCL ML SDK 
<https://developer.qualcomm.com/blog/accelerate-your-models-our-opencl-ml-sdk>`_
 for more details.
 
-Since there are many required packages for Android, you can use the official 
Docker Image to build TVM.
-For more information refer to this guide: `Deploy the Pretrained Model on 
Android 
<https://tvm.apache.org/docs/how_to/deploy_models/deploy_model_on_android.html>`_.
+OpenCLML is integrated into TVM as a `BYOC 
<https://tvm.apache.org/docs/dev/how_to/relay_bring_your_own_codegen.html?highlight=bring%20your%20own>`_
 solution.
+OpenCLML operators can use same context and can be enqueued on same command 
queue as used in native OpenCL.
+We took advantage of this to avoid any context switching over heads while 
fallback to native OpenCL.
+
+
+.. _build_deploy:
+
+TVM for Adreno™
+---------------
+
+This section gives instructions about various ways of building and deploying 
model
+to Adreno™ target. Adreno™ is a remote target which is connected to the host 
via ADB connection.
+Deploying the compiled model here require use some tools on host as well as on 
target.
+
+TVM has simplified user friendly command line based tools as well as
+developer centric python API interface for various steps like auto tuning, 
building and deploying.
+
+TVM compilation process for remote devices has multiple stages listed below.
+
+**Model import:**
+At this stage we import a model from well known frameworks like Tensorflow, 
PyTorch, ONNX ...etc.
+This stage converts the given model into TVM's relay module format. 
Alternatively one can build a relay module manually
+by using TVM's operator inventory too. TVM module generated here is a target 
independent representation of the graph.
+
+**Auto Tuning:**
+At this stage we tune the TVM generated kernels specific to a target. Auto 
tuning process requires
+target device availability and in case of a remote target like Adreno™ on 
Android device we use RPC Setup for communication.
+Later sections in this guide will detail about RPC Setup for Android device. 
Auto tuning is not a necessary step for
+compilation of a model. It is necessary for acheiving best performance out of 
TVM generated kernels.
+
+**Compilation:**
+At this stage we compile the model for specific target. Given we auto tuned 
the module in previous stage,
+TVM compilation make use of the tuning log for genetrating best performing 
kernels. TVM compilation process produces artifacts
+containing kernel shared lib, graph definition in json format and parameters 
binary file in TVM specific format.
+
+**Deploy (or test run) on Target:**
+At this stage we run the TVM compilation output on the target. Deployment is 
possible from python
+environment using RPC Setup and also using TVM's native tool which is native 
binary cross compiled for Android.
+At this stage we can run the compiled model on Android target and unit test 
output correctness and performance aspects.
+
+**Aplication Integration:**
+This stage is all about integrating TVM compiled model in applications. Here 
we discuss about
+interfacing tvm runtime from Android (cpp native environment or from JNI) for 
setting input and getting output.
+
+**Advanced Usage:**
+This section advanced user interests like viewing generated source code, 
altering precision of the module ...etc.
+
+
+This tutorial covers all the above aspects as part of below sections.
+
+- :ref:`Development environment<development_environment>`
+- :ref:`RPC Setup<rpc_setup>`
+- :ref:`Commandline tools<commandline_interface>`
+- :ref:`Python interface<python_interface>`
+- :ref:`Application Integration<application_integration>`
+- :ref:`Advanced Usage<advanced_usage>`
+
+.. _development_environment:
+
+
+Development Environment Setup : Automatic
+-----------------------------------------
+TVM ships a predefined docker container environment with all prerequisites to 
get started quickly.
+You may also refer to :ref:`Manual Environment Setup<manual_setup>` for more 
control on the dependencies.
+
+For docker setup the pre requisite is just docker tool availabilty on host.
+
+Below commands can build a docker image for adreno.
+
+::
 
-**Prerequisites**: Android NDK and Android Debug Bridge must
-be installed, the desired device must have OpenCL support and Android part of 
TVM must be built:
+   ./docker/build.sh ci_adreno
+   docker tag tvm.ci_adreno ci_adreno
+
+
+Now we can build both host and target utils with below command.
+
+::
+
+   ./tests/scripts/ci.py adreno -i
+
+To build TVM with OpenCLML SDK we need export the OpenCLML SDK as shown below 
while building
+
+::
+
+   export ADRENO_OPENCL=<Path to OpenCLML SDK>
+   ./tests/scripts/ci.py adreno -i
+
+On successful compilation this leaves us into a docker shell. The build leaves 
two folders
+
+* build-adreno:  The host side TVM compiler build.
+* build-adreno-target : Contains the android target components
+
+    * libtvm_runtime.so : TVM runtime library
+    * tvm_rpc : The rpc runtime environment tool
+    * rtvm : A native stand alone tool
+
+While using docker environment the android device is shared with host. Hence, 
it is required
+to have adb version "1.0.41" on the host as the docker used the same version.
+
+We can check adb devices availability inside docker environment too.
+
+::
+
+   user@ci-adreno-fpeqs:~$ adb devices
+   List of devices attached
+   aaaabbbb    device
+   ccccdddd    device
+
+.. _manual_setup:
+
+Development Environment Setup : Manual
+--------------------------------------
+
+Manual build process require building of host and target components.
+
+Below command will configure the build the host compiler
+
+::
+
+   mkdir -p build
+   cd build
+   cp ../cmake/config.cmake .
+
+   echo set\(USE_OPENCL ON\) >> config.cmake
+   echo set\(USE_RPC ON\) >> config.cmake
+   echo set\(USE_GRAPH_EXECUTOR ON\) >> config.cmake
+   echo set\(USE_LIBBACKTRACE AUTO\) >> config.cmake
+   echo set\(USE_LLVM ON\) >> config.cmake
+
+Additionally we can push below config entry to compile with OpenCLML support.
+
+::
+
+   export ADRENO_OPENCL=<Path to OpenCLML SDK>
+   echo set\(USE_CLML ${ADRENO_OPENCL}\) >> config.cmake
+
+now we can build as shown below
+
+::
+
+   cmake ..
+   make
+
+Finally we can export python path as
+
+::
+
+   export PYTHONPATH=$PWD:/python
+   python3 -c "import tvm" # Verify tvm python package
+
+
+Now, we can configure and build the target components with below configuration
+Target build require Android NDK to be installed.
 
 - Read documentation about *Android NDK installation* here: 
https://developer.android.com/ndk
 - To get access to adb tools you can see *Android Debug Bridge installation* 
here: https://developer.android.com/studio/command-line/adb
 
-You can also build the android part of TVM locally. From the root
-folder of TVM:
 
 ::
 
-   mkdir build_android
-   cd build_android
-   cmake .. -DUSE_OPENCL=ON 
-DCMAKE_TOOLCHAIN_FILE=${ANDROID_NDK_HOME}/build/cmake/android.toolchain.cmake 
-DANDROID_ABI=arm64-v8a -DANDROID_NATIVE_API_LEVEL=android-28 
-DCMAKE_FIND_ROOT_PATH_MODE_PACKAGE=ON -DANDROID_STL=c++_static -DUSE_CPP_RPC=ON
-   make -jN tvm_runtime tvm_rpc
+   mkdir -p build-adreno
+   cd build-adreno
+   cp ../cmake/config.cmake .
+   echo set\(USE_MICRO OFF\) >> config.cmake

Review Comment:
   I believe that after #13503 this flag is redundant.



##########
docs/how_to/deploy/adreno.rst:
##########
@@ -65,134 +78,483 @@ Reasons of using textures:
 Overall, with textures, it is possible to achieve a significant performance 
boost
 compared to OpenCL buffer based solutions.
 
-.. _building_tvm_for_adreno:
+In general we specify target as ``target="opencl"`` for a regular OpenCL based 
target which generates the kernels as shown below.
 
-Building TVM for Adreno
------------------------
+.. code:: c
+
+   __kernel void tvmgen_default_fused_nn_conv2d_kernel0(__global float* 
restrict p0, __global double* restrict p1, __global float* restrict 
conv2d_nhwc) {
+   // body..
+
+Above OpenCL kernel definition has ``__global float*`` poniters which are 
essestially OpenCL ``buffer``  objects.
+
+When enabled texture based enhancements by modifying target definition as 
``target="opencl -device=adreno"`` we can see the generated
+kernels using texture backed OpenCL image objects as shown below.
+
+.. code:: c
+
+   __kernel void tvmgen_default_fused_nn_conv2d_kernel0(__write_only image2d_t 
pad_temp_global_texture, __read_only image2d_t p0) {
+   // body..
+
+*image2d_t* is a built-in OpenCL types that represents two-dimensional image 
object and provides several additional functions.
+When we use *image2d_t* we read *4 elements at one time*, and it helps to 
utilize hardware in a more efficient way.
+
+Please refer to :ref:`Advanced Usage<advanced_usage>` for more details about 
generation and inspection of kernel sources.
+
+
+.. _about_openclml:
 
-This section gives instructions on how to build the Android part of TVM
-with OpenCL and TVM RPC Server in order to deploy models on Adreno.
+About OpenCLML
+--------------
 
-Since the process of building TVM for Adreno is exactly the same as the
-process of building TVM for Android, please refer to these instructions:
-`TVM RPC
-Server <https://github.com/apache/tvm/tree/main/apps/cpp_rpc>`_.
+OpenCLML is a SDK released by Qualcomm that provides accelerated deep learning 
operators.
+These operators are exposed as an extension "cl_qcom_ml_ops" to standard 
OpenCL specification.
+Please refer `Accelerate your models with our OpenCL ML SDK 
<https://developer.qualcomm.com/blog/accelerate-your-models-our-opencl-ml-sdk>`_
 for more details.
 
-Since there are many required packages for Android, you can use the official 
Docker Image to build TVM.
-For more information refer to this guide: `Deploy the Pretrained Model on 
Android 
<https://tvm.apache.org/docs/how_to/deploy_models/deploy_model_on_android.html>`_.
+OpenCLML is integrated into TVM as a `BYOC 
<https://tvm.apache.org/docs/dev/how_to/relay_bring_your_own_codegen.html?highlight=bring%20your%20own>`_
 solution.
+OpenCLML operators can use same context and can be enqueued on same command 
queue as used in native OpenCL.
+We took advantage of this to avoid any context switching over heads while 
fallback to native OpenCL.
+
+
+.. _build_deploy:
+
+TVM for Adreno™
+---------------
+
+This section gives instructions about various ways of building and deploying 
model
+to Adreno™ target. Adreno™ is a remote target which is connected to the host 
via ADB connection.
+Deploying the compiled model here require use some tools on host as well as on 
target.
+
+TVM has simplified user friendly command line based tools as well as
+developer centric python API interface for various steps like auto tuning, 
building and deploying.
+
+TVM compilation process for remote devices has multiple stages listed below.
+
+**Model import:**
+At this stage we import a model from well known frameworks like Tensorflow, 
PyTorch, ONNX ...etc.
+This stage converts the given model into TVM's relay module format. 
Alternatively one can build a relay module manually
+by using TVM's operator inventory too. TVM module generated here is a target 
independent representation of the graph.
+
+**Auto Tuning:**
+At this stage we tune the TVM generated kernels specific to a target. Auto 
tuning process requires
+target device availability and in case of a remote target like Adreno™ on 
Android device we use RPC Setup for communication.
+Later sections in this guide will detail about RPC Setup for Android device. 
Auto tuning is not a necessary step for
+compilation of a model. It is necessary for acheiving best performance out of 
TVM generated kernels.
+
+**Compilation:**
+At this stage we compile the model for specific target. Given we auto tuned 
the module in previous stage,
+TVM compilation make use of the tuning log for genetrating best performing 
kernels. TVM compilation process produces artifacts
+containing kernel shared lib, graph definition in json format and parameters 
binary file in TVM specific format.
+
+**Deploy (or test run) on Target:**
+At this stage we run the TVM compilation output on the target. Deployment is 
possible from python
+environment using RPC Setup and also using TVM's native tool which is native 
binary cross compiled for Android.
+At this stage we can run the compiled model on Android target and unit test 
output correctness and performance aspects.
+
+**Aplication Integration:**
+This stage is all about integrating TVM compiled model in applications. Here 
we discuss about
+interfacing tvm runtime from Android (cpp native environment or from JNI) for 
setting input and getting output.
+
+**Advanced Usage:**
+This section advanced user interests like viewing generated source code, 
altering precision of the module ...etc.
+
+
+This tutorial covers all the above aspects as part of below sections.
+
+- :ref:`Development environment<development_environment>`
+- :ref:`RPC Setup<rpc_setup>`
+- :ref:`Commandline tools<commandline_interface>`
+- :ref:`Python interface<python_interface>`
+- :ref:`Application Integration<application_integration>`
+- :ref:`Advanced Usage<advanced_usage>`
+
+.. _development_environment:
+
+
+Development Environment Setup : Automatic
+-----------------------------------------
+TVM ships a predefined docker container environment with all prerequisites to 
get started quickly.
+You may also refer to :ref:`Manual Environment Setup<manual_setup>` for more 
control on the dependencies.
+
+For docker setup the pre requisite is just docker tool availabilty on host.
+
+Below commands can build a docker image for adreno.
+
+::
 
-**Prerequisites**: Android NDK and Android Debug Bridge must
-be installed, the desired device must have OpenCL support and Android part of 
TVM must be built:
+   ./docker/build.sh ci_adreno
+   docker tag tvm.ci_adreno ci_adreno
+
+
+Now we can build both host and target utils with below command.
+
+::
+
+   ./tests/scripts/ci.py adreno -i
+
+To build TVM with OpenCLML SDK we need export the OpenCLML SDK as shown below 
while building
+
+::
+
+   export ADRENO_OPENCL=<Path to OpenCLML SDK>
+   ./tests/scripts/ci.py adreno -i
+
+On successful compilation this leaves us into a docker shell. The build leaves 
two folders
+
+* build-adreno:  The host side TVM compiler build.
+* build-adreno-target : Contains the android target components
+
+    * libtvm_runtime.so : TVM runtime library
+    * tvm_rpc : The rpc runtime environment tool
+    * rtvm : A native stand alone tool
+
+While using docker environment the android device is shared with host. Hence, 
it is required
+to have adb version "1.0.41" on the host as the docker used the same version.
+
+We can check adb devices availability inside docker environment too.
+
+::
+
+   user@ci-adreno-fpeqs:~$ adb devices
+   List of devices attached
+   aaaabbbb    device
+   ccccdddd    device
+
+.. _manual_setup:
+
+Development Environment Setup : Manual
+--------------------------------------
+
+Manual build process require building of host and target components.
+
+Below command will configure the build the host compiler
+
+::
+
+   mkdir -p build
+   cd build
+   cp ../cmake/config.cmake .
+
+   echo set\(USE_OPENCL ON\) >> config.cmake
+   echo set\(USE_RPC ON\) >> config.cmake
+   echo set\(USE_GRAPH_EXECUTOR ON\) >> config.cmake
+   echo set\(USE_LIBBACKTRACE AUTO\) >> config.cmake
+   echo set\(USE_LLVM ON\) >> config.cmake
+
+Additionally we can push below config entry to compile with OpenCLML support.
+
+::
+
+   export ADRENO_OPENCL=<Path to OpenCLML SDK>
+   echo set\(USE_CLML ${ADRENO_OPENCL}\) >> config.cmake
+
+now we can build as shown below
+
+::
+
+   cmake ..
+   make
+
+Finally we can export python path as
+
+::
+
+   export PYTHONPATH=$PWD:/python
+   python3 -c "import tvm" # Verify tvm python package
+
+
+Now, we can configure and build the target components with below configuration
+Target build require Android NDK to be installed.
 
 - Read documentation about *Android NDK installation* here: 
https://developer.android.com/ndk
 - To get access to adb tools you can see *Android Debug Bridge installation* 
here: https://developer.android.com/studio/command-line/adb
 
-You can also build the android part of TVM locally. From the root
-folder of TVM:
 
 ::
 
-   mkdir build_android
-   cd build_android
-   cmake .. -DUSE_OPENCL=ON 
-DCMAKE_TOOLCHAIN_FILE=${ANDROID_NDK_HOME}/build/cmake/android.toolchain.cmake 
-DANDROID_ABI=arm64-v8a -DANDROID_NATIVE_API_LEVEL=android-28 
-DCMAKE_FIND_ROOT_PATH_MODE_PACKAGE=ON -DANDROID_STL=c++_static -DUSE_CPP_RPC=ON
-   make -jN tvm_runtime tvm_rpc
+   mkdir -p build-adreno
+   cd build-adreno
+   cp ../cmake/config.cmake .
+   echo set\(USE_MICRO OFF\) >> config.cmake
+   echo set\(USE_OPENCL ON\) >> config.cmake
+   echo set\(USE_RPC ON\) >> config.cmake
+   echo set\(USE_CPP_RPC ON\) >> config.cmake
+   echo set\(USE_CPP_RTVM ON\) >> config.cmake
+   echo set\(USE_GRAPH_EXECUTOR ON\) >> config.cmake
+   echo set\(USE_LIBBACKTRACE AUTO\) >> config.cmake
+   echo set\(USE_KALLOC_ALIGNMENT 32\) >> config.cmake
 
-where **N** is the number of cores available on your *CPU*.
+   echo set\(ANDROID_ABI arm64-v8a\) >> config.cmake
+   echo set\(ANDROID_PLATFORM android-28\) >> config.cmake
+   echo set\(MACHINE_NAME aarch64-linux-gnu\) >> config.cmake
 
-At this stage you have built TVM for Adreno.
+Additionally we can push below config to compile with OpenCLML support.
 
-.. _build_and_deploy_model_for_adreno:
+::
 
-Build and deploy model for Adreno
----------------------------------
+   export ADRENO_OPENCL=<Path to OpenCLML SDK>
+   echo set\(USE_CLML "${ADRENO_OPENCL}"\) >> config.cmake
+   echo set\(USE_CLML_GRAPH_EXECUTOR "${ADRENO_OPENCL}"\) >> config.cmake
 
-In this section we will focus on target, needed to compile and deploy models 
for Adreno, demonstrate
-the differences in generated kernels with and without textures and, in 
addition, the
-possibility of choosing a different precision for model compilation will
-be considered.
+For Android target build ANDROID_NDK_HOME is a dependency and we should have 
the same in the enviromnet variable.
+Below commands will build Adreno™ target components
 
-For the complete step-py-step process of compiling and deploying models on
-Adreno, including selection of precision, running the inference of the
-model, getting the predictions, and measuring the performance please refer to 
this tutorial: `How To Deploy model on Adreno 
<https://tvm.apache.org/docs/how_to/deploy_models/deploy_model_on_adreno.html>`_
+::
 
-|Android deployment pipeline|
+   cmake 
-DCMAKE_TOOLCHAIN_FILE="${ANDROID_NDK_HOME}/build/cmake/android.toolchain.cmake"
 \
+      -DANDROID_ABI=arm64-v8a \
+      -DANDROID_PLATFORM=android-28 \
+      -DCMAKE_SYSTEM_VERSION=1 \
+      -DCMAKE_FIND_ROOT_PATH="${ADRENO_OPENCL}" \
+      -DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \
+      -DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \
+      
-DCMAKE_CXX_COMPILER="${ANDROID_NDK_HOME}/toolchains/llvm/prebuilt/linux-x86_64/bin/aarch64-linux-android28-clang++"
 \
+      
-DCMAKE_C_COMPILER="${ANDROID_NDK_HOME}/toolchains/llvm/prebuilt/linux-x86_64/bin/aarch64-linux-android28-clang"
 \
+      -DMACHINE_NAME="aarch64-linux-gnu" ..
 
-*Fig.2 Deployment pipeline on Adreno devices*
+   make tvm_runtime tvm_rpc rtvm
 
-The figure above demonstrates a generalized pipeline for deploying and running 
neural network models on android devices.
-As can be seen from the figure, the compiled model has a set_input() and a 
run() methods,
-which *prepare the inputs* for inference and *execute the inference* on the 
remote device using the Graph Executor runtime module.
 
-Adreno target
-~~~~~~~~~~~~~
+.. _rpc_setup:
 
-Normally, when compiling models for Android using OpenCL, the
-corresponding target is used
+RPC Setup
+---------
 
-.. code:: python
+RPC Setup allows remote target access over TCP/IP networking interface. RPC 
Setup is essential for auto tuning stage as tuning
+involves running of auto generated kernels on real device and optimize the 
same by using machine learning approach. Please refer
+`Auto-Tune with Templates and AutoTVM 
<https://tvm.apache.org/docs/how_to/tune_with_autotvm/index.html>`_ got more 
details about AutoTVM.
 
-   target="opencl"
+RPC Setup is also useful to deply the compiled model to a remote device from 
python interface or ```tvmc``` tool from host device.
 
-Using Adreno, we want to get all the benefits of textures, so we have to
-use the following target to generate texture leveraging kernels
+RPC Setup has multiple components as listed below.
 
-.. code:: python
+**TVM Tracker:**
+TVM tracker is a host side daemon that manages remote devices and serve them 
to host side applications. Applications
+can connect to this tracker and acquire a remote device handle to communicate.
 
-   target="opencl -device=adreno"
+**TVM RPC:**
+TVM RPC is a native application that runs on the remote device (Android in our 
case) and registers itself to the TVM Tracker
+running on the host.
 
-Let's write a simple model with one convolutional (conv2d) layer and take a 
look at generated kernels for these
-two targets
 
-.. code:: python
+Hence, for RPC based setup we will have above components running on host and 
target device. Below sections explain how to setup the same
+manually and also inside docker using automated tools.
 
-   import tvm
-   from tvm import relay
-   import numpy as np
+**Automated RPC Setup:**
+Here we will explain how to setup RPC in docker environment.
 
-   input_shape=(1, 56, 56, 32)
-   filter_shape=(3, 3, 32, 64)
-   filter = np.random.rand(*filter_shape)
+Below command launches tracker in docker environment, where docker listens on 
port 9120.
 
-   dtype="float32"
-   input = tvm.relay.var("input", shape=input_shape, dtype=dtype)
-   weight = tvm.relay.var("weight", shape=filter_shape, dtype=dtype)
-   D = relay.nn.conv2d(input, weight, padding=(1, 1), data_layout="NHWC", 
kernel_layout="HWIO", out_dtype=dtype)
+::
 
-   mod = relay.Function([input, weight], D)
-   params = {
-      "weight": tvm.nd.array(filter)
-   }
+   ./tests/scripts/ci.py adreno -i # Launch a new shell on the anreno docker
+   source  tests/scripts/setup-adreno-env.sh -e tracker -p 9120
 
-Now compile our model with the classic OpenCL target and print its modules:
+Now, the below comand can run TVM RPC on remote android device with id 
"abcdefgh".
 
-.. code:: python
 
-   target="opencl"
+::
 
-   with tvm.transform.PassContext(opt_level=3):
-      graph, lib, params = relay.build_module.build(mod, target, params=params)
-   print(lib.imported_modules[0].get_source())
+   ./tests/scripts/ci.py adreno -i # Launch a new shell on adreno docker.
+   source  tests/scripts/setup-adreno-env.sh -e device -p 9120 -d abcdefgh
 
-Notice that the generated convolution kernel has pointers in
-the initialization of the function. The kernels generated with the above 
target are buffer-based.
 
-.. code:: c
+**Manual RPC Setup:**
 
-   __kernel void tvmgen_default_fused_nn_conv2d_kernel0(__global float* 
restrict p0, __global double* restrict p1, __global float* restrict 
conv2d_nhwc) {
-   // body..
+Below command in manual setup starts the tracker on port 9120
+
+::
+
+   python3 -m tvm.exec.rpc_tracker --host "0.0.0.0" --port "9120"
+
+TVM RPC launch on Android device require some environment setup due to Android 
device is connected via ADB interface and we need to re-route
+TCP/IP communication over ADB interface. Below commands will do necessary 
setup and run tvm_rpc on remote device.
+
+::
+
+    # Set android device to use
+    export ANDROID_SERIAL=abcdefgh
+    # Create a temporary folder on remote device.
+    adb shell "mkdir -p /data/local/tmp/tvm_ci"
+    # Copy tvm_rpc and it's dependency to remote device
+    adb push build-adreno-target/tvm_rpc /data/local/tmp/tvm_test/tvm_rpc
+    adb push build-adreno-target/libtvm_runtime.so /data/local/tmp/tvm_test
+    # Forward port 9120 from target to host
+    adb reverse tcp:9210 tcp:9120
+    # tvm_rpc by default listens on ports starting from 5000 for incoming 
connections.
+    # Hence, reroute connections to these ports on host to remore device.
+    adb forward tcp:5000 tcp:5000
+    adb forward tcp:5001 tcp:5001
+    adb forward tcp:5002 tcp:5002
+    # Finally launch rpc_daemon on remote device with identity key as "android"
+    adb shell "cd /data/local/tmp/tvm_test; killall -9 tvm_rpc; sleep 2; 
LD_LIBRARY_PATH=/data/local/tmp/tvm_test/ ./tvm_rpc server --host=0.0.0.0 
--port=5000 --port-end=5010 --tracker=127.0.0.1:9120 --key=android"
+
+Upon successfull running this remote device will be available on tracker which 
can be queried as below.
+
+::
+
+   python3 -m tvm.exec.query_rpc_tracker --port 9120
+   Tracker address 127.0.0.1:9120
+   Server List
+   ------------------------------
+   server-address           key
+   ------------------------------
+       127.0.0.1:5000    server:android
+   ------------------------------
+
+   Queue Status
+   -------------------------------
+   key       total  free  pending
+   -------------------------------
+   android   1      1     0
+   -------------------------------
+
+This concludes RPC Setup and we have rpc-tracker available on host 127.0.0.1 
(rpc-tracker) and port 9120 (rpc-port).
+
+
+.. _commandline_interface:
+
+Commandline Tools

Review Comment:
   Probably we can move this part to the `deploy_model_on_adreno_tvmc.py` and 
just keep here a brief description and link to the 
`deploy_model_on_adreno_tvmc.py`?



##########
gallery/how_to/deploy_models/deploy_model_on_adreno.py:
##########
@@ -233,46 +233,96 @@ def convert_to_dtype(mod, dtype):
 # You can also use "float16" or "float32" precisions as other dtype options.
 
 #################################################################
-# Compile the model with relay
-# ----------------------------
-# Specify Adreno target before compiling to generate texture
-# leveraging kernels and get all the benefits of textures
-# Note: This generated example running on our x86 server for demonstration.
-# If running it on the Android device, we need to
-# specify its instruction set. Set :code:`local_demo` to False if you want
-# to run this tutorial with a real device.
+# Prepare TVM Target
+# ------------------
 
-local_demo = True
+if local_demo:
+    target = tvm.target.Target("llvm")
+elif test_target.find("opencl"):
+    target = tvm.target.Target(test_target, host=target)
 
-# by default on CPU target will execute.
-# select 'cpu', 'opencl' and 'vulkan'
-test_target = "cpu"
+##################################################################
+# AutoTuning
+# ----------
+# The below few instructions can auto tune the relay module with xgboost being 
the tuner algorithm.
 
-# Change target configuration.
-# Run `adb shell cat /proc/cpuinfo` to find the arch.
-arch = "arm64"
-target = tvm.target.Target("llvm -mtriple=%s-linux-android" % arch)
+# Auto Tuning process involces stages of extracting the tasks, defining tuning 
congiguration and
+# tuning each task for best performing kernel configuration.
 
-if local_demo:
-    target = tvm.target.Target("llvm")
-elif test_target == "opencl":
-    target = tvm.target.Target("opencl", host=target)
-elif test_target == "vulkan":
-    target = tvm.target.Target("vulkan", host=target)
+# Get RPC related settings.
+rpc_tracker_host = os.environ.get("TVM_TRACKER_HOST", "127.0.0.1")
+rpc_tracker_port = int(os.environ.get("TVM_TRACKER_PORT", 9190))
+key = "android"
+
+if is_tuning:
+    # Auto Tuning Stage 1: Extract tunable tasks
+    tasks = autotvm.task.extract_from_program(
+        mod, target=test_target, target_host=target, params=params
+    )
+
+    # Auto Tuning Stage 2: Define tuning configuration
+    tmp_log_file = tune_log + ".tmp"
+    measure_option = autotvm.measure_option(
+        builder=autotvm.LocalBuilder(
+            build_func=ndk.create_shared, timeout=15
+        ),  # Build the test kernel locally
+        runner=autotvm.RPCRunner(  # The runner would be on a remote device.
+            key,  # RPC Key
+            host=rpc_tracker_host,  # Tracker host
+            port=int(rpc_tracker_port),  # Tracker port
+            number=3,  # Number of runs before averaging
+            timeout=600,  # RPC Timeout
+        ),
+    )
+    n_trial = 1024  # Number of iteration of training before choosing the best 
kernel config
+    early_stopping = False  # Do we apply early stopping when the loss is not 
minimizing
+
+    # Iterate through each task and call the tuner
+    from tvm.autotvm.tuner import XGBTuner
+
+    for i, tsk in enumerate(reversed(tasks[:3])):
+        print("Task:", tsk)
+        prefix = "[Task %2d/%2d] " % (i + 1, len(tasks))
+        tuner_obj = XGBTuner(tsk, loss_type="rank")
+
+        tsk_trial = min(n_trial, len(tsk.config_space))
+        tuner_obj.tune(
+            n_trial=tsk_trial,
+            early_stopping=early_stopping,
+            measure_option=measure_option,
+            callbacks=[
+                autotvm.callback.progress_bar(tsk_trial, prefix=prefix),
+                autotvm.callback.log_to_file(tmp_log_file),
+            ],
+        )
+    # Pick the best performing kerl configurations from the overall log.

Review Comment:
   Here I marked stage `N` because I think there is at least one more stage 
before stage 2 and this step. Could you also please add this prefix to other 
comments and mark AuthTVM steps by this prefix?
   ```suggestion
       # Auto Tuning Stage N: Pick the best performing configurations from the 
overall log.
   ```



##########
docs/how_to/deploy/adreno.rst:
##########
@@ -65,134 +78,483 @@ Reasons of using textures:
 Overall, with textures, it is possible to achieve a significant performance 
boost
 compared to OpenCL buffer based solutions.
 
-.. _building_tvm_for_adreno:
+In general we specify target as ``target="opencl"`` for a regular OpenCL based 
target which generates the kernels as shown below.
 
-Building TVM for Adreno
------------------------
+.. code:: c
+
+   __kernel void tvmgen_default_fused_nn_conv2d_kernel0(__global float* 
restrict p0, __global double* restrict p1, __global float* restrict 
conv2d_nhwc) {
+   // body..
+
+Above OpenCL kernel definition has ``__global float*`` poniters which are 
essestially OpenCL ``buffer``  objects.
+
+When enabled texture based enhancements by modifying target definition as 
``target="opencl -device=adreno"`` we can see the generated
+kernels using texture backed OpenCL image objects as shown below.
+
+.. code:: c
+
+   __kernel void tvmgen_default_fused_nn_conv2d_kernel0(__write_only image2d_t 
pad_temp_global_texture, __read_only image2d_t p0) {
+   // body..
+
+*image2d_t* is a built-in OpenCL types that represents two-dimensional image 
object and provides several additional functions.
+When we use *image2d_t* we read *4 elements at one time*, and it helps to 
utilize hardware in a more efficient way.
+
+Please refer to :ref:`Advanced Usage<advanced_usage>` for more details about 
generation and inspection of kernel sources.
+
+
+.. _about_openclml:
 
-This section gives instructions on how to build the Android part of TVM
-with OpenCL and TVM RPC Server in order to deploy models on Adreno.
+About OpenCLML
+--------------
 
-Since the process of building TVM for Adreno is exactly the same as the
-process of building TVM for Android, please refer to these instructions:
-`TVM RPC
-Server <https://github.com/apache/tvm/tree/main/apps/cpp_rpc>`_.
+OpenCLML is a SDK released by Qualcomm that provides accelerated deep learning 
operators.
+These operators are exposed as an extension "cl_qcom_ml_ops" to standard 
OpenCL specification.
+Please refer `Accelerate your models with our OpenCL ML SDK 
<https://developer.qualcomm.com/blog/accelerate-your-models-our-opencl-ml-sdk>`_
 for more details.
 
-Since there are many required packages for Android, you can use the official 
Docker Image to build TVM.
-For more information refer to this guide: `Deploy the Pretrained Model on 
Android 
<https://tvm.apache.org/docs/how_to/deploy_models/deploy_model_on_android.html>`_.
+OpenCLML is integrated into TVM as a `BYOC 
<https://tvm.apache.org/docs/dev/how_to/relay_bring_your_own_codegen.html?highlight=bring%20your%20own>`_
 solution.
+OpenCLML operators can use same context and can be enqueued on same command 
queue as used in native OpenCL.
+We took advantage of this to avoid any context switching over heads while 
fallback to native OpenCL.
+
+
+.. _build_deploy:
+
+TVM for Adreno™
+---------------
+
+This section gives instructions about various ways of building and deploying 
model
+to Adreno™ target. Adreno™ is a remote target which is connected to the host 
via ADB connection.
+Deploying the compiled model here require use some tools on host as well as on 
target.
+
+TVM has simplified user friendly command line based tools as well as
+developer centric python API interface for various steps like auto tuning, 
building and deploying.
+
+TVM compilation process for remote devices has multiple stages listed below.
+
+**Model import:**
+At this stage we import a model from well known frameworks like Tensorflow, 
PyTorch, ONNX ...etc.
+This stage converts the given model into TVM's relay module format. 
Alternatively one can build a relay module manually
+by using TVM's operator inventory too. TVM module generated here is a target 
independent representation of the graph.
+
+**Auto Tuning:**
+At this stage we tune the TVM generated kernels specific to a target. Auto 
tuning process requires
+target device availability and in case of a remote target like Adreno™ on 
Android device we use RPC Setup for communication.
+Later sections in this guide will detail about RPC Setup for Android device. 
Auto tuning is not a necessary step for
+compilation of a model. It is necessary for acheiving best performance out of 
TVM generated kernels.
+
+**Compilation:**
+At this stage we compile the model for specific target. Given we auto tuned 
the module in previous stage,
+TVM compilation make use of the tuning log for genetrating best performing 
kernels. TVM compilation process produces artifacts
+containing kernel shared lib, graph definition in json format and parameters 
binary file in TVM specific format.
+
+**Deploy (or test run) on Target:**
+At this stage we run the TVM compilation output on the target. Deployment is 
possible from python
+environment using RPC Setup and also using TVM's native tool which is native 
binary cross compiled for Android.
+At this stage we can run the compiled model on Android target and unit test 
output correctness and performance aspects.
+
+**Aplication Integration:**

Review Comment:
   ```suggestion
   **Application Integration:**
   ```



##########
docs/how_to/deploy/adreno.rst:
##########
@@ -65,134 +78,483 @@ Reasons of using textures:
 Overall, with textures, it is possible to achieve a significant performance 
boost
 compared to OpenCL buffer based solutions.
 
-.. _building_tvm_for_adreno:
+In general we specify target as ``target="opencl"`` for a regular OpenCL based 
target which generates the kernels as shown below.
 
-Building TVM for Adreno
------------------------
+.. code:: c
+
+   __kernel void tvmgen_default_fused_nn_conv2d_kernel0(__global float* 
restrict p0, __global double* restrict p1, __global float* restrict 
conv2d_nhwc) {
+   // body..
+
+Above OpenCL kernel definition has ``__global float*`` poniters which are 
essestially OpenCL ``buffer``  objects.
+
+When enabled texture based enhancements by modifying target definition as 
``target="opencl -device=adreno"`` we can see the generated
+kernels using texture backed OpenCL image objects as shown below.
+
+.. code:: c
+
+   __kernel void tvmgen_default_fused_nn_conv2d_kernel0(__write_only image2d_t 
pad_temp_global_texture, __read_only image2d_t p0) {
+   // body..
+
+*image2d_t* is a built-in OpenCL types that represents two-dimensional image 
object and provides several additional functions.
+When we use *image2d_t* we read *4 elements at one time*, and it helps to 
utilize hardware in a more efficient way.
+
+Please refer to :ref:`Advanced Usage<advanced_usage>` for more details about 
generation and inspection of kernel sources.
+
+
+.. _about_openclml:
 
-This section gives instructions on how to build the Android part of TVM
-with OpenCL and TVM RPC Server in order to deploy models on Adreno.
+About OpenCLML
+--------------
 
-Since the process of building TVM for Adreno is exactly the same as the
-process of building TVM for Android, please refer to these instructions:
-`TVM RPC
-Server <https://github.com/apache/tvm/tree/main/apps/cpp_rpc>`_.
+OpenCLML is a SDK released by Qualcomm that provides accelerated deep learning 
operators.
+These operators are exposed as an extension "cl_qcom_ml_ops" to standard 
OpenCL specification.
+Please refer `Accelerate your models with our OpenCL ML SDK 
<https://developer.qualcomm.com/blog/accelerate-your-models-our-opencl-ml-sdk>`_
 for more details.
 
-Since there are many required packages for Android, you can use the official 
Docker Image to build TVM.
-For more information refer to this guide: `Deploy the Pretrained Model on 
Android 
<https://tvm.apache.org/docs/how_to/deploy_models/deploy_model_on_android.html>`_.
+OpenCLML is integrated into TVM as a `BYOC 
<https://tvm.apache.org/docs/dev/how_to/relay_bring_your_own_codegen.html?highlight=bring%20your%20own>`_
 solution.
+OpenCLML operators can use same context and can be enqueued on same command 
queue as used in native OpenCL.
+We took advantage of this to avoid any context switching over heads while 
fallback to native OpenCL.
+
+
+.. _build_deploy:
+
+TVM for Adreno™
+---------------
+
+This section gives instructions about various ways of building and deploying 
model
+to Adreno™ target. Adreno™ is a remote target which is connected to the host 
via ADB connection.
+Deploying the compiled model here require use some tools on host as well as on 
target.
+
+TVM has simplified user friendly command line based tools as well as
+developer centric python API interface for various steps like auto tuning, 
building and deploying.
+
+TVM compilation process for remote devices has multiple stages listed below.
+
+**Model import:**
+At this stage we import a model from well known frameworks like Tensorflow, 
PyTorch, ONNX ...etc.
+This stage converts the given model into TVM's relay module format. 
Alternatively one can build a relay module manually
+by using TVM's operator inventory too. TVM module generated here is a target 
independent representation of the graph.
+
+**Auto Tuning:**
+At this stage we tune the TVM generated kernels specific to a target. Auto 
tuning process requires
+target device availability and in case of a remote target like Adreno™ on 
Android device we use RPC Setup for communication.
+Later sections in this guide will detail about RPC Setup for Android device. 
Auto tuning is not a necessary step for
+compilation of a model. It is necessary for acheiving best performance out of 
TVM generated kernels.
+
+**Compilation:**
+At this stage we compile the model for specific target. Given we auto tuned 
the module in previous stage,
+TVM compilation make use of the tuning log for genetrating best performing 
kernels. TVM compilation process produces artifacts
+containing kernel shared lib, graph definition in json format and parameters 
binary file in TVM specific format.
+
+**Deploy (or test run) on Target:**
+At this stage we run the TVM compilation output on the target. Deployment is 
possible from python
+environment using RPC Setup and also using TVM's native tool which is native 
binary cross compiled for Android.
+At this stage we can run the compiled model on Android target and unit test 
output correctness and performance aspects.
+
+**Aplication Integration:**
+This stage is all about integrating TVM compiled model in applications. Here 
we discuss about
+interfacing tvm runtime from Android (cpp native environment or from JNI) for 
setting input and getting output.
+
+**Advanced Usage:**
+This section advanced user interests like viewing generated source code, 
altering precision of the module ...etc.
+
+
+This tutorial covers all the above aspects as part of below sections.
+
+- :ref:`Development environment<development_environment>`
+- :ref:`RPC Setup<rpc_setup>`
+- :ref:`Commandline tools<commandline_interface>`
+- :ref:`Python interface<python_interface>`
+- :ref:`Application Integration<application_integration>`
+- :ref:`Advanced Usage<advanced_usage>`
+
+.. _development_environment:
+
+
+Development Environment Setup : Automatic
+-----------------------------------------
+TVM ships a predefined docker container environment with all prerequisites to 
get started quickly.
+You may also refer to :ref:`Manual Environment Setup<manual_setup>` for more 
control on the dependencies.
+
+For docker setup the pre requisite is just docker tool availabilty on host.
+
+Below commands can build a docker image for adreno.
+
+::
 
-**Prerequisites**: Android NDK and Android Debug Bridge must
-be installed, the desired device must have OpenCL support and Android part of 
TVM must be built:
+   ./docker/build.sh ci_adreno
+   docker tag tvm.ci_adreno ci_adreno
+
+
+Now we can build both host and target utils with below command.
+
+::
+
+   ./tests/scripts/ci.py adreno -i
+
+To build TVM with OpenCLML SDK we need export the OpenCLML SDK as shown below 
while building
+
+::
+
+   export ADRENO_OPENCL=<Path to OpenCLML SDK>
+   ./tests/scripts/ci.py adreno -i
+
+On successful compilation this leaves us into a docker shell. The build leaves 
two folders
+
+* build-adreno:  The host side TVM compiler build.
+* build-adreno-target : Contains the android target components
+
+    * libtvm_runtime.so : TVM runtime library
+    * tvm_rpc : The rpc runtime environment tool
+    * rtvm : A native stand alone tool
+
+While using docker environment the android device is shared with host. Hence, 
it is required
+to have adb version "1.0.41" on the host as the docker used the same version.
+
+We can check adb devices availability inside docker environment too.
+
+::
+
+   user@ci-adreno-fpeqs:~$ adb devices
+   List of devices attached
+   aaaabbbb    device
+   ccccdddd    device
+
+.. _manual_setup:
+
+Development Environment Setup : Manual
+--------------------------------------
+
+Manual build process require building of host and target components.
+
+Below command will configure the build the host compiler
+
+::
+
+   mkdir -p build
+   cd build
+   cp ../cmake/config.cmake .
+
+   echo set\(USE_OPENCL ON\) >> config.cmake
+   echo set\(USE_RPC ON\) >> config.cmake
+   echo set\(USE_GRAPH_EXECUTOR ON\) >> config.cmake
+   echo set\(USE_LIBBACKTRACE AUTO\) >> config.cmake
+   echo set\(USE_LLVM ON\) >> config.cmake
+
+Additionally we can push below config entry to compile with OpenCLML support.
+
+::
+
+   export ADRENO_OPENCL=<Path to OpenCLML SDK>
+   echo set\(USE_CLML ${ADRENO_OPENCL}\) >> config.cmake

Review Comment:
   Is it required to compile host part with `OpenCLML` support?



##########
gallery/how_to/deploy_models/deploy_model_on_adreno.py:
##########
@@ -115,6 +115,67 @@
 #    android      1      1     0
 #    ----------------------------------
 
+#################################################################
+# Configuration
+# -------------
+
+import os
+import torch
+import torchvision
+import tvm
+from tvm import te
+from tvm import relay, rpc
+from tvm.contrib import utils, ndk
+from tvm.contrib import graph_executor
+from tvm.relay.op.contrib import clml
+from tvm import autotvm
+
+# Adreno devices are efficient with float16 compared to float32

Review Comment:
   Probably it is better to define all these variables and their description 
just before their usage. In this case, probably it is worse from the point of 
code structure and code style, but it is better for reading. It doesn't 
necessary to return to the top of the file and revise information about these 
variables. IMHO, previous structure was closer to other such types of 
documentation in TVM. What do you think?



##########
docs/how_to/deploy/adreno.rst:
##########
@@ -65,134 +78,483 @@ Reasons of using textures:
 Overall, with textures, it is possible to achieve a significant performance 
boost
 compared to OpenCL buffer based solutions.
 
-.. _building_tvm_for_adreno:
+In general we specify target as ``target="opencl"`` for a regular OpenCL based 
target which generates the kernels as shown below.
 
-Building TVM for Adreno
------------------------
+.. code:: c
+
+   __kernel void tvmgen_default_fused_nn_conv2d_kernel0(__global float* 
restrict p0, __global double* restrict p1, __global float* restrict 
conv2d_nhwc) {
+   // body..
+
+Above OpenCL kernel definition has ``__global float*`` poniters which are 
essestially OpenCL ``buffer``  objects.
+
+When enabled texture based enhancements by modifying target definition as 
``target="opencl -device=adreno"`` we can see the generated
+kernels using texture backed OpenCL image objects as shown below.
+
+.. code:: c
+
+   __kernel void tvmgen_default_fused_nn_conv2d_kernel0(__write_only image2d_t 
pad_temp_global_texture, __read_only image2d_t p0) {
+   // body..
+
+*image2d_t* is a built-in OpenCL types that represents two-dimensional image 
object and provides several additional functions.
+When we use *image2d_t* we read *4 elements at one time*, and it helps to 
utilize hardware in a more efficient way.
+
+Please refer to :ref:`Advanced Usage<advanced_usage>` for more details about 
generation and inspection of kernel sources.
+
+
+.. _about_openclml:
 
-This section gives instructions on how to build the Android part of TVM
-with OpenCL and TVM RPC Server in order to deploy models on Adreno.
+About OpenCLML
+--------------
 
-Since the process of building TVM for Adreno is exactly the same as the
-process of building TVM for Android, please refer to these instructions:
-`TVM RPC
-Server <https://github.com/apache/tvm/tree/main/apps/cpp_rpc>`_.
+OpenCLML is a SDK released by Qualcomm that provides accelerated deep learning 
operators.
+These operators are exposed as an extension "cl_qcom_ml_ops" to standard 
OpenCL specification.
+Please refer `Accelerate your models with our OpenCL ML SDK 
<https://developer.qualcomm.com/blog/accelerate-your-models-our-opencl-ml-sdk>`_
 for more details.
 
-Since there are many required packages for Android, you can use the official 
Docker Image to build TVM.
-For more information refer to this guide: `Deploy the Pretrained Model on 
Android 
<https://tvm.apache.org/docs/how_to/deploy_models/deploy_model_on_android.html>`_.
+OpenCLML is integrated into TVM as a `BYOC 
<https://tvm.apache.org/docs/dev/how_to/relay_bring_your_own_codegen.html?highlight=bring%20your%20own>`_
 solution.
+OpenCLML operators can use same context and can be enqueued on same command 
queue as used in native OpenCL.
+We took advantage of this to avoid any context switching over heads while 
fallback to native OpenCL.
+
+
+.. _build_deploy:
+
+TVM for Adreno™
+---------------
+
+This section gives instructions about various ways of building and deploying 
model
+to Adreno™ target. Adreno™ is a remote target which is connected to the host 
via ADB connection.
+Deploying the compiled model here require use some tools on host as well as on 
target.
+
+TVM has simplified user friendly command line based tools as well as
+developer centric python API interface for various steps like auto tuning, 
building and deploying.
+
+TVM compilation process for remote devices has multiple stages listed below.
+
+**Model import:**
+At this stage we import a model from well known frameworks like Tensorflow, 
PyTorch, ONNX ...etc.
+This stage converts the given model into TVM's relay module format. 
Alternatively one can build a relay module manually
+by using TVM's operator inventory too. TVM module generated here is a target 
independent representation of the graph.
+
+**Auto Tuning:**
+At this stage we tune the TVM generated kernels specific to a target. Auto 
tuning process requires
+target device availability and in case of a remote target like Adreno™ on 
Android device we use RPC Setup for communication.
+Later sections in this guide will detail about RPC Setup for Android device. 
Auto tuning is not a necessary step for
+compilation of a model. It is necessary for acheiving best performance out of 
TVM generated kernels.
+
+**Compilation:**
+At this stage we compile the model for specific target. Given we auto tuned 
the module in previous stage,
+TVM compilation make use of the tuning log for genetrating best performing 
kernels. TVM compilation process produces artifacts
+containing kernel shared lib, graph definition in json format and parameters 
binary file in TVM specific format.
+
+**Deploy (or test run) on Target:**
+At this stage we run the TVM compilation output on the target. Deployment is 
possible from python
+environment using RPC Setup and also using TVM's native tool which is native 
binary cross compiled for Android.
+At this stage we can run the compiled model on Android target and unit test 
output correctness and performance aspects.
+
+**Aplication Integration:**
+This stage is all about integrating TVM compiled model in applications. Here 
we discuss about
+interfacing tvm runtime from Android (cpp native environment or from JNI) for 
setting input and getting output.
+
+**Advanced Usage:**
+This section advanced user interests like viewing generated source code, 
altering precision of the module ...etc.
+
+
+This tutorial covers all the above aspects as part of below sections.
+
+- :ref:`Development environment<development_environment>`
+- :ref:`RPC Setup<rpc_setup>`
+- :ref:`Commandline tools<commandline_interface>`
+- :ref:`Python interface<python_interface>`
+- :ref:`Application Integration<application_integration>`
+- :ref:`Advanced Usage<advanced_usage>`
+
+.. _development_environment:
+
+
+Development Environment Setup : Automatic
+-----------------------------------------
+TVM ships a predefined docker container environment with all prerequisites to 
get started quickly.
+You may also refer to :ref:`Manual Environment Setup<manual_setup>` for more 
control on the dependencies.
+
+For docker setup the pre requisite is just docker tool availabilty on host.
+
+Below commands can build a docker image for adreno.
+
+::
 
-**Prerequisites**: Android NDK and Android Debug Bridge must
-be installed, the desired device must have OpenCL support and Android part of 
TVM must be built:
+   ./docker/build.sh ci_adreno
+   docker tag tvm.ci_adreno ci_adreno
+
+
+Now we can build both host and target utils with below command.
+
+::
+
+   ./tests/scripts/ci.py adreno -i
+
+To build TVM with OpenCLML SDK we need export the OpenCLML SDK as shown below 
while building
+
+::
+
+   export ADRENO_OPENCL=<Path to OpenCLML SDK>
+   ./tests/scripts/ci.py adreno -i
+
+On successful compilation this leaves us into a docker shell. The build leaves 
two folders
+
+* build-adreno:  The host side TVM compiler build.
+* build-adreno-target : Contains the android target components
+
+    * libtvm_runtime.so : TVM runtime library
+    * tvm_rpc : The rpc runtime environment tool
+    * rtvm : A native stand alone tool
+
+While using docker environment the android device is shared with host. Hence, 
it is required
+to have adb version "1.0.41" on the host as the docker used the same version.
+
+We can check adb devices availability inside docker environment too.
+
+::
+
+   user@ci-adreno-fpeqs:~$ adb devices
+   List of devices attached
+   aaaabbbb    device
+   ccccdddd    device
+
+.. _manual_setup:
+
+Development Environment Setup : Manual
+--------------------------------------
+
+Manual build process require building of host and target components.
+
+Below command will configure the build the host compiler
+
+::
+
+   mkdir -p build
+   cd build
+   cp ../cmake/config.cmake .
+
+   echo set\(USE_OPENCL ON\) >> config.cmake
+   echo set\(USE_RPC ON\) >> config.cmake
+   echo set\(USE_GRAPH_EXECUTOR ON\) >> config.cmake
+   echo set\(USE_LIBBACKTRACE AUTO\) >> config.cmake
+   echo set\(USE_LLVM ON\) >> config.cmake
+
+Additionally we can push below config entry to compile with OpenCLML support.
+
+::
+
+   export ADRENO_OPENCL=<Path to OpenCLML SDK>
+   echo set\(USE_CLML ${ADRENO_OPENCL}\) >> config.cmake
+
+now we can build as shown below
+
+::
+
+   cmake ..
+   make
+
+Finally we can export python path as
+
+::
+
+   export PYTHONPATH=$PWD:/python
+   python3 -c "import tvm" # Verify tvm python package
+
+
+Now, we can configure and build the target components with below configuration
+Target build require Android NDK to be installed.
 
 - Read documentation about *Android NDK installation* here: 
https://developer.android.com/ndk
 - To get access to adb tools you can see *Android Debug Bridge installation* 
here: https://developer.android.com/studio/command-line/adb
 
-You can also build the android part of TVM locally. From the root
-folder of TVM:
 
 ::
 
-   mkdir build_android
-   cd build_android
-   cmake .. -DUSE_OPENCL=ON 
-DCMAKE_TOOLCHAIN_FILE=${ANDROID_NDK_HOME}/build/cmake/android.toolchain.cmake 
-DANDROID_ABI=arm64-v8a -DANDROID_NATIVE_API_LEVEL=android-28 
-DCMAKE_FIND_ROOT_PATH_MODE_PACKAGE=ON -DANDROID_STL=c++_static -DUSE_CPP_RPC=ON
-   make -jN tvm_runtime tvm_rpc
+   mkdir -p build-adreno
+   cd build-adreno
+   cp ../cmake/config.cmake .
+   echo set\(USE_MICRO OFF\) >> config.cmake
+   echo set\(USE_OPENCL ON\) >> config.cmake
+   echo set\(USE_RPC ON\) >> config.cmake
+   echo set\(USE_CPP_RPC ON\) >> config.cmake
+   echo set\(USE_CPP_RTVM ON\) >> config.cmake
+   echo set\(USE_GRAPH_EXECUTOR ON\) >> config.cmake
+   echo set\(USE_LIBBACKTRACE AUTO\) >> config.cmake
+   echo set\(USE_KALLOC_ALIGNMENT 32\) >> config.cmake
 
-where **N** is the number of cores available on your *CPU*.
+   echo set\(ANDROID_ABI arm64-v8a\) >> config.cmake
+   echo set\(ANDROID_PLATFORM android-28\) >> config.cmake
+   echo set\(MACHINE_NAME aarch64-linux-gnu\) >> config.cmake
 
-At this stage you have built TVM for Adreno.
+Additionally we can push below config to compile with OpenCLML support.
 
-.. _build_and_deploy_model_for_adreno:
+::
 
-Build and deploy model for Adreno
----------------------------------
+   export ADRENO_OPENCL=<Path to OpenCLML SDK>
+   echo set\(USE_CLML "${ADRENO_OPENCL}"\) >> config.cmake
+   echo set\(USE_CLML_GRAPH_EXECUTOR "${ADRENO_OPENCL}"\) >> config.cmake
 
-In this section we will focus on target, needed to compile and deploy models 
for Adreno, demonstrate
-the differences in generated kernels with and without textures and, in 
addition, the
-possibility of choosing a different precision for model compilation will
-be considered.
+For Android target build ANDROID_NDK_HOME is a dependency and we should have 
the same in the enviromnet variable.
+Below commands will build Adreno™ target components
 
-For the complete step-py-step process of compiling and deploying models on
-Adreno, including selection of precision, running the inference of the
-model, getting the predictions, and measuring the performance please refer to 
this tutorial: `How To Deploy model on Adreno 
<https://tvm.apache.org/docs/how_to/deploy_models/deploy_model_on_adreno.html>`_
+::
 
-|Android deployment pipeline|
+   cmake 
-DCMAKE_TOOLCHAIN_FILE="${ANDROID_NDK_HOME}/build/cmake/android.toolchain.cmake"
 \
+      -DANDROID_ABI=arm64-v8a \
+      -DANDROID_PLATFORM=android-28 \
+      -DCMAKE_SYSTEM_VERSION=1 \
+      -DCMAKE_FIND_ROOT_PATH="${ADRENO_OPENCL}" \
+      -DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \
+      -DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \
+      
-DCMAKE_CXX_COMPILER="${ANDROID_NDK_HOME}/toolchains/llvm/prebuilt/linux-x86_64/bin/aarch64-linux-android28-clang++"
 \
+      
-DCMAKE_C_COMPILER="${ANDROID_NDK_HOME}/toolchains/llvm/prebuilt/linux-x86_64/bin/aarch64-linux-android28-clang"
 \
+      -DMACHINE_NAME="aarch64-linux-gnu" ..
 
-*Fig.2 Deployment pipeline on Adreno devices*
+   make tvm_runtime tvm_rpc rtvm
 
-The figure above demonstrates a generalized pipeline for deploying and running 
neural network models on android devices.
-As can be seen from the figure, the compiled model has a set_input() and a 
run() methods,

Review Comment:
   Let's keep this pipeline in the documentation. I mean the image with 
deployment pipeline.
   If you want to modify something in the picture, I can ask Daniil, probably 
he has original file which can be easily edited and he can share it.



##########
docs/how_to/deploy/adreno.rst:
##########
@@ -65,134 +78,483 @@ Reasons of using textures:
 Overall, with textures, it is possible to achieve a significant performance 
boost
 compared to OpenCL buffer based solutions.
 
-.. _building_tvm_for_adreno:
+In general we specify target as ``target="opencl"`` for a regular OpenCL based 
target which generates the kernels as shown below.
 
-Building TVM for Adreno
------------------------
+.. code:: c
+
+   __kernel void tvmgen_default_fused_nn_conv2d_kernel0(__global float* 
restrict p0, __global double* restrict p1, __global float* restrict 
conv2d_nhwc) {
+   // body..
+
+Above OpenCL kernel definition has ``__global float*`` poniters which are 
essestially OpenCL ``buffer``  objects.
+
+When enabled texture based enhancements by modifying target definition as 
``target="opencl -device=adreno"`` we can see the generated
+kernels using texture backed OpenCL image objects as shown below.
+
+.. code:: c
+
+   __kernel void tvmgen_default_fused_nn_conv2d_kernel0(__write_only image2d_t 
pad_temp_global_texture, __read_only image2d_t p0) {
+   // body..
+
+*image2d_t* is a built-in OpenCL types that represents two-dimensional image 
object and provides several additional functions.
+When we use *image2d_t* we read *4 elements at one time*, and it helps to 
utilize hardware in a more efficient way.
+
+Please refer to :ref:`Advanced Usage<advanced_usage>` for more details about 
generation and inspection of kernel sources.
+
+
+.. _about_openclml:
 
-This section gives instructions on how to build the Android part of TVM
-with OpenCL and TVM RPC Server in order to deploy models on Adreno.
+About OpenCLML
+--------------
 
-Since the process of building TVM for Adreno is exactly the same as the
-process of building TVM for Android, please refer to these instructions:
-`TVM RPC
-Server <https://github.com/apache/tvm/tree/main/apps/cpp_rpc>`_.
+OpenCLML is a SDK released by Qualcomm that provides accelerated deep learning 
operators.
+These operators are exposed as an extension "cl_qcom_ml_ops" to standard 
OpenCL specification.
+Please refer `Accelerate your models with our OpenCL ML SDK 
<https://developer.qualcomm.com/blog/accelerate-your-models-our-opencl-ml-sdk>`_
 for more details.
 
-Since there are many required packages for Android, you can use the official 
Docker Image to build TVM.
-For more information refer to this guide: `Deploy the Pretrained Model on 
Android 
<https://tvm.apache.org/docs/how_to/deploy_models/deploy_model_on_android.html>`_.
+OpenCLML is integrated into TVM as a `BYOC 
<https://tvm.apache.org/docs/dev/how_to/relay_bring_your_own_codegen.html?highlight=bring%20your%20own>`_
 solution.
+OpenCLML operators can use same context and can be enqueued on same command 
queue as used in native OpenCL.
+We took advantage of this to avoid any context switching over heads while 
fallback to native OpenCL.
+
+
+.. _build_deploy:
+
+TVM for Adreno™
+---------------
+
+This section gives instructions about various ways of building and deploying 
model
+to Adreno™ target. Adreno™ is a remote target which is connected to the host 
via ADB connection.
+Deploying the compiled model here require use some tools on host as well as on 
target.
+
+TVM has simplified user friendly command line based tools as well as
+developer centric python API interface for various steps like auto tuning, 
building and deploying.
+
+TVM compilation process for remote devices has multiple stages listed below.
+
+**Model import:**
+At this stage we import a model from well known frameworks like Tensorflow, 
PyTorch, ONNX ...etc.
+This stage converts the given model into TVM's relay module format. 
Alternatively one can build a relay module manually
+by using TVM's operator inventory too. TVM module generated here is a target 
independent representation of the graph.
+
+**Auto Tuning:**
+At this stage we tune the TVM generated kernels specific to a target. Auto 
tuning process requires
+target device availability and in case of a remote target like Adreno™ on 
Android device we use RPC Setup for communication.
+Later sections in this guide will detail about RPC Setup for Android device. 
Auto tuning is not a necessary step for
+compilation of a model. It is necessary for acheiving best performance out of 
TVM generated kernels.
+
+**Compilation:**
+At this stage we compile the model for specific target. Given we auto tuned 
the module in previous stage,
+TVM compilation make use of the tuning log for genetrating best performing 
kernels. TVM compilation process produces artifacts
+containing kernel shared lib, graph definition in json format and parameters 
binary file in TVM specific format.
+
+**Deploy (or test run) on Target:**
+At this stage we run the TVM compilation output on the target. Deployment is 
possible from python
+environment using RPC Setup and also using TVM's native tool which is native 
binary cross compiled for Android.
+At this stage we can run the compiled model on Android target and unit test 
output correctness and performance aspects.
+
+**Aplication Integration:**
+This stage is all about integrating TVM compiled model in applications. Here 
we discuss about
+interfacing tvm runtime from Android (cpp native environment or from JNI) for 
setting input and getting output.
+
+**Advanced Usage:**
+This section advanced user interests like viewing generated source code, 
altering precision of the module ...etc.
+
+
+This tutorial covers all the above aspects as part of below sections.
+
+- :ref:`Development environment<development_environment>`
+- :ref:`RPC Setup<rpc_setup>`
+- :ref:`Commandline tools<commandline_interface>`
+- :ref:`Python interface<python_interface>`
+- :ref:`Application Integration<application_integration>`
+- :ref:`Advanced Usage<advanced_usage>`
+
+.. _development_environment:
+
+
+Development Environment Setup : Automatic
+-----------------------------------------
+TVM ships a predefined docker container environment with all prerequisites to 
get started quickly.
+You may also refer to :ref:`Manual Environment Setup<manual_setup>` for more 
control on the dependencies.
+
+For docker setup the pre requisite is just docker tool availabilty on host.
+
+Below commands can build a docker image for adreno.
+
+::
 
-**Prerequisites**: Android NDK and Android Debug Bridge must
-be installed, the desired device must have OpenCL support and Android part of 
TVM must be built:
+   ./docker/build.sh ci_adreno
+   docker tag tvm.ci_adreno ci_adreno
+
+
+Now we can build both host and target utils with below command.
+
+::
+
+   ./tests/scripts/ci.py adreno -i
+
+To build TVM with OpenCLML SDK we need export the OpenCLML SDK as shown below 
while building
+
+::
+
+   export ADRENO_OPENCL=<Path to OpenCLML SDK>
+   ./tests/scripts/ci.py adreno -i
+
+On successful compilation this leaves us into a docker shell. The build leaves 
two folders
+
+* build-adreno:  The host side TVM compiler build.
+* build-adreno-target : Contains the android target components
+
+    * libtvm_runtime.so : TVM runtime library
+    * tvm_rpc : The rpc runtime environment tool
+    * rtvm : A native stand alone tool
+
+While using docker environment the android device is shared with host. Hence, 
it is required
+to have adb version "1.0.41" on the host as the docker used the same version.
+
+We can check adb devices availability inside docker environment too.
+
+::
+
+   user@ci-adreno-fpeqs:~$ adb devices
+   List of devices attached
+   aaaabbbb    device
+   ccccdddd    device
+
+.. _manual_setup:
+
+Development Environment Setup : Manual
+--------------------------------------
+
+Manual build process require building of host and target components.
+
+Below command will configure the build the host compiler
+
+::
+
+   mkdir -p build
+   cd build
+   cp ../cmake/config.cmake .
+
+   echo set\(USE_OPENCL ON\) >> config.cmake
+   echo set\(USE_RPC ON\) >> config.cmake
+   echo set\(USE_GRAPH_EXECUTOR ON\) >> config.cmake
+   echo set\(USE_LIBBACKTRACE AUTO\) >> config.cmake
+   echo set\(USE_LLVM ON\) >> config.cmake
+
+Additionally we can push below config entry to compile with OpenCLML support.
+
+::
+
+   export ADRENO_OPENCL=<Path to OpenCLML SDK>
+   echo set\(USE_CLML ${ADRENO_OPENCL}\) >> config.cmake
+
+now we can build as shown below
+
+::
+
+   cmake ..
+   make
+
+Finally we can export python path as
+
+::
+
+   export PYTHONPATH=$PWD:/python

Review Comment:
   Usually, I export `PYTHONPATH` in this way: `export 
PYTHONPATH=$TVM_HOME/python:${PYTHONPATH}`
   I suppose that there is a typo here, because not sure that some modules 
exists in `/python`.



##########
docs/how_to/deploy/adreno.rst:
##########
@@ -65,134 +78,483 @@ Reasons of using textures:
 Overall, with textures, it is possible to achieve a significant performance 
boost
 compared to OpenCL buffer based solutions.
 
-.. _building_tvm_for_adreno:
+In general we specify target as ``target="opencl"`` for a regular OpenCL based 
target which generates the kernels as shown below.
 
-Building TVM for Adreno
------------------------
+.. code:: c
+
+   __kernel void tvmgen_default_fused_nn_conv2d_kernel0(__global float* 
restrict p0, __global double* restrict p1, __global float* restrict 
conv2d_nhwc) {
+   // body..
+
+Above OpenCL kernel definition has ``__global float*`` poniters which are 
essestially OpenCL ``buffer``  objects.
+
+When enabled texture based enhancements by modifying target definition as 
``target="opencl -device=adreno"`` we can see the generated
+kernels using texture backed OpenCL image objects as shown below.
+
+.. code:: c
+
+   __kernel void tvmgen_default_fused_nn_conv2d_kernel0(__write_only image2d_t 
pad_temp_global_texture, __read_only image2d_t p0) {
+   // body..
+
+*image2d_t* is a built-in OpenCL types that represents two-dimensional image 
object and provides several additional functions.
+When we use *image2d_t* we read *4 elements at one time*, and it helps to 
utilize hardware in a more efficient way.
+
+Please refer to :ref:`Advanced Usage<advanced_usage>` for more details about 
generation and inspection of kernel sources.
+
+
+.. _about_openclml:
 
-This section gives instructions on how to build the Android part of TVM
-with OpenCL and TVM RPC Server in order to deploy models on Adreno.
+About OpenCLML
+--------------
 
-Since the process of building TVM for Adreno is exactly the same as the
-process of building TVM for Android, please refer to these instructions:
-`TVM RPC
-Server <https://github.com/apache/tvm/tree/main/apps/cpp_rpc>`_.
+OpenCLML is a SDK released by Qualcomm that provides accelerated deep learning 
operators.
+These operators are exposed as an extension "cl_qcom_ml_ops" to standard 
OpenCL specification.
+Please refer `Accelerate your models with our OpenCL ML SDK 
<https://developer.qualcomm.com/blog/accelerate-your-models-our-opencl-ml-sdk>`_
 for more details.
 
-Since there are many required packages for Android, you can use the official 
Docker Image to build TVM.
-For more information refer to this guide: `Deploy the Pretrained Model on 
Android 
<https://tvm.apache.org/docs/how_to/deploy_models/deploy_model_on_android.html>`_.
+OpenCLML is integrated into TVM as a `BYOC 
<https://tvm.apache.org/docs/dev/how_to/relay_bring_your_own_codegen.html?highlight=bring%20your%20own>`_
 solution.
+OpenCLML operators can use same context and can be enqueued on same command 
queue as used in native OpenCL.
+We took advantage of this to avoid any context switching over heads while 
fallback to native OpenCL.
+
+
+.. _build_deploy:
+
+TVM for Adreno™
+---------------
+
+This section gives instructions about various ways of building and deploying 
model
+to Adreno™ target. Adreno™ is a remote target which is connected to the host 
via ADB connection.
+Deploying the compiled model here require use some tools on host as well as on 
target.
+
+TVM has simplified user friendly command line based tools as well as
+developer centric python API interface for various steps like auto tuning, 
building and deploying.
+
+TVM compilation process for remote devices has multiple stages listed below.
+
+**Model import:**
+At this stage we import a model from well known frameworks like Tensorflow, 
PyTorch, ONNX ...etc.
+This stage converts the given model into TVM's relay module format. 
Alternatively one can build a relay module manually
+by using TVM's operator inventory too. TVM module generated here is a target 
independent representation of the graph.
+
+**Auto Tuning:**
+At this stage we tune the TVM generated kernels specific to a target. Auto 
tuning process requires
+target device availability and in case of a remote target like Adreno™ on 
Android device we use RPC Setup for communication.
+Later sections in this guide will detail about RPC Setup for Android device. 
Auto tuning is not a necessary step for
+compilation of a model. It is necessary for acheiving best performance out of 
TVM generated kernels.
+
+**Compilation:**
+At this stage we compile the model for specific target. Given we auto tuned 
the module in previous stage,
+TVM compilation make use of the tuning log for genetrating best performing 
kernels. TVM compilation process produces artifacts
+containing kernel shared lib, graph definition in json format and parameters 
binary file in TVM specific format.
+
+**Deploy (or test run) on Target:**
+At this stage we run the TVM compilation output on the target. Deployment is 
possible from python
+environment using RPC Setup and also using TVM's native tool which is native 
binary cross compiled for Android.
+At this stage we can run the compiled model on Android target and unit test 
output correctness and performance aspects.
+
+**Aplication Integration:**
+This stage is all about integrating TVM compiled model in applications. Here 
we discuss about
+interfacing tvm runtime from Android (cpp native environment or from JNI) for 
setting input and getting output.
+
+**Advanced Usage:**
+This section advanced user interests like viewing generated source code, 
altering precision of the module ...etc.
+
+
+This tutorial covers all the above aspects as part of below sections.
+
+- :ref:`Development environment<development_environment>`
+- :ref:`RPC Setup<rpc_setup>`
+- :ref:`Commandline tools<commandline_interface>`
+- :ref:`Python interface<python_interface>`
+- :ref:`Application Integration<application_integration>`
+- :ref:`Advanced Usage<advanced_usage>`
+
+.. _development_environment:
+
+
+Development Environment Setup : Automatic
+-----------------------------------------
+TVM ships a predefined docker container environment with all prerequisites to 
get started quickly.
+You may also refer to :ref:`Manual Environment Setup<manual_setup>` for more 
control on the dependencies.
+
+For docker setup the pre requisite is just docker tool availabilty on host.
+
+Below commands can build a docker image for adreno.
+
+::
 
-**Prerequisites**: Android NDK and Android Debug Bridge must
-be installed, the desired device must have OpenCL support and Android part of 
TVM must be built:
+   ./docker/build.sh ci_adreno
+   docker tag tvm.ci_adreno ci_adreno
+
+
+Now we can build both host and target utils with below command.
+
+::
+
+   ./tests/scripts/ci.py adreno -i
+
+To build TVM with OpenCLML SDK we need export the OpenCLML SDK as shown below 
while building
+
+::
+
+   export ADRENO_OPENCL=<Path to OpenCLML SDK>
+   ./tests/scripts/ci.py adreno -i
+
+On successful compilation this leaves us into a docker shell. The build leaves 
two folders
+
+* build-adreno:  The host side TVM compiler build.
+* build-adreno-target : Contains the android target components
+
+    * libtvm_runtime.so : TVM runtime library
+    * tvm_rpc : The rpc runtime environment tool
+    * rtvm : A native stand alone tool
+
+While using docker environment the android device is shared with host. Hence, 
it is required
+to have adb version "1.0.41" on the host as the docker used the same version.
+
+We can check adb devices availability inside docker environment too.
+
+::
+
+   user@ci-adreno-fpeqs:~$ adb devices
+   List of devices attached
+   aaaabbbb    device
+   ccccdddd    device
+
+.. _manual_setup:
+
+Development Environment Setup : Manual
+--------------------------------------
+
+Manual build process require building of host and target components.
+
+Below command will configure the build the host compiler
+
+::
+
+   mkdir -p build
+   cd build
+   cp ../cmake/config.cmake .
+
+   echo set\(USE_OPENCL ON\) >> config.cmake
+   echo set\(USE_RPC ON\) >> config.cmake
+   echo set\(USE_GRAPH_EXECUTOR ON\) >> config.cmake
+   echo set\(USE_LIBBACKTRACE AUTO\) >> config.cmake
+   echo set\(USE_LLVM ON\) >> config.cmake
+
+Additionally we can push below config entry to compile with OpenCLML support.
+
+::
+
+   export ADRENO_OPENCL=<Path to OpenCLML SDK>
+   echo set\(USE_CLML ${ADRENO_OPENCL}\) >> config.cmake
+
+now we can build as shown below
+
+::
+
+   cmake ..
+   make
+
+Finally we can export python path as
+
+::
+
+   export PYTHONPATH=$PWD:/python
+   python3 -c "import tvm" # Verify tvm python package
+
+
+Now, we can configure and build the target components with below configuration
+Target build require Android NDK to be installed.
 
 - Read documentation about *Android NDK installation* here: 
https://developer.android.com/ndk
 - To get access to adb tools you can see *Android Debug Bridge installation* 
here: https://developer.android.com/studio/command-line/adb
 
-You can also build the android part of TVM locally. From the root
-folder of TVM:
 
 ::
 
-   mkdir build_android
-   cd build_android
-   cmake .. -DUSE_OPENCL=ON 
-DCMAKE_TOOLCHAIN_FILE=${ANDROID_NDK_HOME}/build/cmake/android.toolchain.cmake 
-DANDROID_ABI=arm64-v8a -DANDROID_NATIVE_API_LEVEL=android-28 
-DCMAKE_FIND_ROOT_PATH_MODE_PACKAGE=ON -DANDROID_STL=c++_static -DUSE_CPP_RPC=ON
-   make -jN tvm_runtime tvm_rpc
+   mkdir -p build-adreno
+   cd build-adreno
+   cp ../cmake/config.cmake .
+   echo set\(USE_MICRO OFF\) >> config.cmake
+   echo set\(USE_OPENCL ON\) >> config.cmake
+   echo set\(USE_RPC ON\) >> config.cmake
+   echo set\(USE_CPP_RPC ON\) >> config.cmake
+   echo set\(USE_CPP_RTVM ON\) >> config.cmake
+   echo set\(USE_GRAPH_EXECUTOR ON\) >> config.cmake
+   echo set\(USE_LIBBACKTRACE AUTO\) >> config.cmake
+   echo set\(USE_KALLOC_ALIGNMENT 32\) >> config.cmake

Review Comment:
   Never used this flag before. What is the reason of using this flag here?



##########
docs/how_to/deploy/adreno.rst:
##########
@@ -65,134 +78,483 @@ Reasons of using textures:
 Overall, with textures, it is possible to achieve a significant performance 
boost
 compared to OpenCL buffer based solutions.
 
-.. _building_tvm_for_adreno:
+In general we specify target as ``target="opencl"`` for a regular OpenCL based 
target which generates the kernels as shown below.
 
-Building TVM for Adreno
------------------------
+.. code:: c
+
+   __kernel void tvmgen_default_fused_nn_conv2d_kernel0(__global float* 
restrict p0, __global double* restrict p1, __global float* restrict 
conv2d_nhwc) {
+   // body..
+
+Above OpenCL kernel definition has ``__global float*`` poniters which are 
essestially OpenCL ``buffer``  objects.
+
+When enabled texture based enhancements by modifying target definition as 
``target="opencl -device=adreno"`` we can see the generated
+kernels using texture backed OpenCL image objects as shown below.
+
+.. code:: c
+
+   __kernel void tvmgen_default_fused_nn_conv2d_kernel0(__write_only image2d_t 
pad_temp_global_texture, __read_only image2d_t p0) {
+   // body..
+
+*image2d_t* is a built-in OpenCL types that represents two-dimensional image 
object and provides several additional functions.
+When we use *image2d_t* we read *4 elements at one time*, and it helps to 
utilize hardware in a more efficient way.
+
+Please refer to :ref:`Advanced Usage<advanced_usage>` for more details about 
generation and inspection of kernel sources.
+
+
+.. _about_openclml:
 
-This section gives instructions on how to build the Android part of TVM
-with OpenCL and TVM RPC Server in order to deploy models on Adreno.
+About OpenCLML
+--------------
 
-Since the process of building TVM for Adreno is exactly the same as the
-process of building TVM for Android, please refer to these instructions:
-`TVM RPC
-Server <https://github.com/apache/tvm/tree/main/apps/cpp_rpc>`_.
+OpenCLML is a SDK released by Qualcomm that provides accelerated deep learning 
operators.
+These operators are exposed as an extension "cl_qcom_ml_ops" to standard 
OpenCL specification.
+Please refer `Accelerate your models with our OpenCL ML SDK 
<https://developer.qualcomm.com/blog/accelerate-your-models-our-opencl-ml-sdk>`_
 for more details.
 
-Since there are many required packages for Android, you can use the official 
Docker Image to build TVM.
-For more information refer to this guide: `Deploy the Pretrained Model on 
Android 
<https://tvm.apache.org/docs/how_to/deploy_models/deploy_model_on_android.html>`_.
+OpenCLML is integrated into TVM as a `BYOC 
<https://tvm.apache.org/docs/dev/how_to/relay_bring_your_own_codegen.html?highlight=bring%20your%20own>`_
 solution.
+OpenCLML operators can use same context and can be enqueued on same command 
queue as used in native OpenCL.
+We took advantage of this to avoid any context switching over heads while 
fallback to native OpenCL.
+
+
+.. _build_deploy:
+
+TVM for Adreno™
+---------------
+
+This section gives instructions about various ways of building and deploying 
model
+to Adreno™ target. Adreno™ is a remote target which is connected to the host 
via ADB connection.
+Deploying the compiled model here require use some tools on host as well as on 
target.
+
+TVM has simplified user friendly command line based tools as well as
+developer centric python API interface for various steps like auto tuning, 
building and deploying.
+
+TVM compilation process for remote devices has multiple stages listed below.
+
+**Model import:**
+At this stage we import a model from well known frameworks like Tensorflow, 
PyTorch, ONNX ...etc.
+This stage converts the given model into TVM's relay module format. 
Alternatively one can build a relay module manually
+by using TVM's operator inventory too. TVM module generated here is a target 
independent representation of the graph.
+
+**Auto Tuning:**
+At this stage we tune the TVM generated kernels specific to a target. Auto 
tuning process requires
+target device availability and in case of a remote target like Adreno™ on 
Android device we use RPC Setup for communication.
+Later sections in this guide will detail about RPC Setup for Android device. 
Auto tuning is not a necessary step for
+compilation of a model. It is necessary for acheiving best performance out of 
TVM generated kernels.
+
+**Compilation:**
+At this stage we compile the model for specific target. Given we auto tuned 
the module in previous stage,
+TVM compilation make use of the tuning log for genetrating best performing 
kernels. TVM compilation process produces artifacts
+containing kernel shared lib, graph definition in json format and parameters 
binary file in TVM specific format.
+
+**Deploy (or test run) on Target:**
+At this stage we run the TVM compilation output on the target. Deployment is 
possible from python
+environment using RPC Setup and also using TVM's native tool which is native 
binary cross compiled for Android.
+At this stage we can run the compiled model on Android target and unit test 
output correctness and performance aspects.
+
+**Aplication Integration:**
+This stage is all about integrating TVM compiled model in applications. Here 
we discuss about
+interfacing tvm runtime from Android (cpp native environment or from JNI) for 
setting input and getting output.
+
+**Advanced Usage:**
+This section advanced user interests like viewing generated source code, 
altering precision of the module ...etc.
+
+
+This tutorial covers all the above aspects as part of below sections.
+
+- :ref:`Development environment<development_environment>`
+- :ref:`RPC Setup<rpc_setup>`
+- :ref:`Commandline tools<commandline_interface>`
+- :ref:`Python interface<python_interface>`
+- :ref:`Application Integration<application_integration>`
+- :ref:`Advanced Usage<advanced_usage>`
+
+.. _development_environment:
+
+
+Development Environment Setup : Automatic
+-----------------------------------------
+TVM ships a predefined docker container environment with all prerequisites to 
get started quickly.
+You may also refer to :ref:`Manual Environment Setup<manual_setup>` for more 
control on the dependencies.
+
+For docker setup the pre requisite is just docker tool availabilty on host.
+
+Below commands can build a docker image for adreno.
+
+::
 
-**Prerequisites**: Android NDK and Android Debug Bridge must
-be installed, the desired device must have OpenCL support and Android part of 
TVM must be built:
+   ./docker/build.sh ci_adreno
+   docker tag tvm.ci_adreno ci_adreno
+
+
+Now we can build both host and target utils with below command.
+
+::
+
+   ./tests/scripts/ci.py adreno -i
+
+To build TVM with OpenCLML SDK we need export the OpenCLML SDK as shown below 
while building
+
+::
+
+   export ADRENO_OPENCL=<Path to OpenCLML SDK>
+   ./tests/scripts/ci.py adreno -i
+
+On successful compilation this leaves us into a docker shell. The build leaves 
two folders
+
+* build-adreno:  The host side TVM compiler build.
+* build-adreno-target : Contains the android target components
+
+    * libtvm_runtime.so : TVM runtime library
+    * tvm_rpc : The rpc runtime environment tool
+    * rtvm : A native stand alone tool
+
+While using docker environment the android device is shared with host. Hence, 
it is required
+to have adb version "1.0.41" on the host as the docker used the same version.
+
+We can check adb devices availability inside docker environment too.
+
+::
+
+   user@ci-adreno-fpeqs:~$ adb devices
+   List of devices attached
+   aaaabbbb    device
+   ccccdddd    device
+
+.. _manual_setup:
+
+Development Environment Setup : Manual
+--------------------------------------
+
+Manual build process require building of host and target components.
+
+Below command will configure the build the host compiler
+
+::
+
+   mkdir -p build
+   cd build
+   cp ../cmake/config.cmake .
+
+   echo set\(USE_OPENCL ON\) >> config.cmake
+   echo set\(USE_RPC ON\) >> config.cmake
+   echo set\(USE_GRAPH_EXECUTOR ON\) >> config.cmake
+   echo set\(USE_LIBBACKTRACE AUTO\) >> config.cmake
+   echo set\(USE_LLVM ON\) >> config.cmake
+
+Additionally we can push below config entry to compile with OpenCLML support.
+
+::
+
+   export ADRENO_OPENCL=<Path to OpenCLML SDK>
+   echo set\(USE_CLML ${ADRENO_OPENCL}\) >> config.cmake
+
+now we can build as shown below
+
+::
+
+   cmake ..
+   make
+
+Finally we can export python path as
+
+::
+
+   export PYTHONPATH=$PWD:/python
+   python3 -c "import tvm" # Verify tvm python package
+
+
+Now, we can configure and build the target components with below configuration
+Target build require Android NDK to be installed.
 
 - Read documentation about *Android NDK installation* here: 
https://developer.android.com/ndk
 - To get access to adb tools you can see *Android Debug Bridge installation* 
here: https://developer.android.com/studio/command-line/adb
 
-You can also build the android part of TVM locally. From the root
-folder of TVM:
 
 ::
 
-   mkdir build_android
-   cd build_android
-   cmake .. -DUSE_OPENCL=ON 
-DCMAKE_TOOLCHAIN_FILE=${ANDROID_NDK_HOME}/build/cmake/android.toolchain.cmake 
-DANDROID_ABI=arm64-v8a -DANDROID_NATIVE_API_LEVEL=android-28 
-DCMAKE_FIND_ROOT_PATH_MODE_PACKAGE=ON -DANDROID_STL=c++_static -DUSE_CPP_RPC=ON
-   make -jN tvm_runtime tvm_rpc
+   mkdir -p build-adreno
+   cd build-adreno
+   cp ../cmake/config.cmake .
+   echo set\(USE_MICRO OFF\) >> config.cmake
+   echo set\(USE_OPENCL ON\) >> config.cmake
+   echo set\(USE_RPC ON\) >> config.cmake
+   echo set\(USE_CPP_RPC ON\) >> config.cmake
+   echo set\(USE_CPP_RTVM ON\) >> config.cmake
+   echo set\(USE_GRAPH_EXECUTOR ON\) >> config.cmake
+   echo set\(USE_LIBBACKTRACE AUTO\) >> config.cmake
+   echo set\(USE_KALLOC_ALIGNMENT 32\) >> config.cmake
 
-where **N** is the number of cores available on your *CPU*.
+   echo set\(ANDROID_ABI arm64-v8a\) >> config.cmake
+   echo set\(ANDROID_PLATFORM android-28\) >> config.cmake
+   echo set\(MACHINE_NAME aarch64-linux-gnu\) >> config.cmake
 
-At this stage you have built TVM for Adreno.
+Additionally we can push below config to compile with OpenCLML support.
 
-.. _build_and_deploy_model_for_adreno:
+::
 
-Build and deploy model for Adreno
----------------------------------
+   export ADRENO_OPENCL=<Path to OpenCLML SDK>
+   echo set\(USE_CLML "${ADRENO_OPENCL}"\) >> config.cmake
+   echo set\(USE_CLML_GRAPH_EXECUTOR "${ADRENO_OPENCL}"\) >> config.cmake
 
-In this section we will focus on target, needed to compile and deploy models 
for Adreno, demonstrate
-the differences in generated kernels with and without textures and, in 
addition, the
-possibility of choosing a different precision for model compilation will
-be considered.
+For Android target build ANDROID_NDK_HOME is a dependency and we should have 
the same in the enviromnet variable.
+Below commands will build Adreno™ target components
 
-For the complete step-py-step process of compiling and deploying models on
-Adreno, including selection of precision, running the inference of the
-model, getting the predictions, and measuring the performance please refer to 
this tutorial: `How To Deploy model on Adreno 
<https://tvm.apache.org/docs/how_to/deploy_models/deploy_model_on_adreno.html>`_
+::
 
-|Android deployment pipeline|
+   cmake 
-DCMAKE_TOOLCHAIN_FILE="${ANDROID_NDK_HOME}/build/cmake/android.toolchain.cmake"
 \
+      -DANDROID_ABI=arm64-v8a \
+      -DANDROID_PLATFORM=android-28 \
+      -DCMAKE_SYSTEM_VERSION=1 \
+      -DCMAKE_FIND_ROOT_PATH="${ADRENO_OPENCL}" \
+      -DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \
+      -DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \
+      
-DCMAKE_CXX_COMPILER="${ANDROID_NDK_HOME}/toolchains/llvm/prebuilt/linux-x86_64/bin/aarch64-linux-android28-clang++"
 \
+      
-DCMAKE_C_COMPILER="${ANDROID_NDK_HOME}/toolchains/llvm/prebuilt/linux-x86_64/bin/aarch64-linux-android28-clang"
 \
+      -DMACHINE_NAME="aarch64-linux-gnu" ..
 
-*Fig.2 Deployment pipeline on Adreno devices*
+   make tvm_runtime tvm_rpc rtvm
 
-The figure above demonstrates a generalized pipeline for deploying and running 
neural network models on android devices.
-As can be seen from the figure, the compiled model has a set_input() and a 
run() methods,
-which *prepare the inputs* for inference and *execute the inference* on the 
remote device using the Graph Executor runtime module.
 
-Adreno target
-~~~~~~~~~~~~~
+.. _rpc_setup:
 
-Normally, when compiling models for Android using OpenCL, the
-corresponding target is used
+RPC Setup

Review Comment:
   I think that `RPC Setup` is described in `deploy_model_on_adreno.py`. 
Probably we can remove this section from here and move new parts to the 
`deploy_model_on_adreno.py`? It will help to keep the structure of this 
document as a brief overview w/o many details.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to