This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new dd25fad [CUDA] Improve local_response_norm schedule (#8946)
dd25fad is described below
commit dd25fad6cb9a9fb4d4b2800143ead53b947f4ba2
Author: masahi <[email protected]>
AuthorDate: Wed Sep 8 04:37:50 2021 +0900
[CUDA] Improve local_response_norm schedule (#8946)
* Improve cuda lrn schedule
* fuse reduction and the next elemwise kernel
* remove cpp schedule
* fix
* fixed unintended revert
Co-authored-by: masa <[email protected]>
---
include/tvm/topi/cuda/normalization.h | 75 -------------------------------
include/tvm/topi/nn/local_response_norm.h | 27 +++++++----
include/tvm/topi/rocm/normalization.h | 46 -------------------
python/tvm/relay/op/strategy/rocm.py | 7 ---
python/tvm/topi/cuda/nn.py | 21 ++++++++-
python/tvm/topi/rocm/__init__.py | 1 -
python/tvm/topi/rocm/nn.py | 24 ----------
src/topi/schedule.cc | 10 -----
8 files changed, 37 insertions(+), 174 deletions(-)
diff --git a/include/tvm/topi/cuda/normalization.h
b/include/tvm/topi/cuda/normalization.h
deleted file mode 100644
index 270b6af..0000000
--- a/include/tvm/topi/cuda/normalization.h
+++ /dev/null
@@ -1,75 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file cuda/normalization.h
- * \brief CUDA schedule for LRN and l2 normalization operations
- */
-#ifndef TVM_TOPI_CUDA_NORMALIZATION_H_
-#define TVM_TOPI_CUDA_NORMALIZATION_H_
-
-#include <tvm/target/generic_func.h>
-#include <tvm/te/operation.h>
-#include <tvm/te/schedule_pass.h>
-#include <tvm/topi/tags.h>
-
-namespace tvm {
-namespace topi {
-
-using namespace tvm::te;
-namespace cuda {
-/*!
- * \brief Create a CUDA schedule for LRN
- * \param outs The output tensors.
- * \return A schedule for the given ops.
- */
-inline Schedule schedule_lrn(const Array<Tensor>& outs) {
- Array<Operation> out_ops;
- for (auto t : outs) {
- out_ops.push_back(t->op);
- }
- Schedule s = create_schedule(out_ops);
- int num_thread = 64;
- IterVar block_x = tvm::te::thread_axis(Range(), "blockIdx.x");
- IterVar thread_x = tvm::te::thread_axis(Range(0, num_thread), "threadIdx.x");
- Tensor lrn = outs[0];
- Tensor sqr_sum_up = lrn->op->InputTensors()[1];
- Tensor sqr_sum = sqr_sum_up->op->InputTensors()[0];
- Tensor set_pad = sqr_sum->op->InputTensors()[0];
- s[set_pad].bind(set_pad->op.as<ComputeOpNode>()->axis[0], block_x);
- IterVar rxk = sqr_sum->op.as<ComputeOpNode>()->reduce_axis[0];
- IterVar xko, xki;
- s[sqr_sum].split(rxk, num_thread, &xko, &xki);
- Tensor srf = s.rfactor(sqr_sum, xki)[0];
- s[sqr_sum].bind(s[sqr_sum]->op.as<ComputeOpNode>()->axis[0], block_x);
- s[sqr_sum].bind(s[sqr_sum]->op.as<ComputeOpNode>()->reduce_axis[0],
thread_x);
- s[srf].compute_at(s[sqr_sum],
s[sqr_sum]->op.as<ComputeOpNode>()->reduce_axis[0]);
- s[sqr_sum_up].bind(sqr_sum_up->op.as<ComputeOpNode>()->axis[0], block_x);
- IterVar xto, xti;
- s[lrn].split_by_nparts(lrn->op.as<ComputeOpNode>()->axis[1], num_thread,
&xto, &xti);
- s[lrn].bind(lrn->op.as<ComputeOpNode>()->axis[0], block_x);
- s[lrn].bind(xto, thread_x);
-
- return s;
-}
-
-} // namespace cuda
-} // namespace topi
-} // namespace tvm
-#endif // TVM_TOPI_CUDA_NORMALIZATION_H_
diff --git a/include/tvm/topi/nn/local_response_norm.h
b/include/tvm/topi/nn/local_response_norm.h
index 717adb8..c826ec0 100644
--- a/include/tvm/topi/nn/local_response_norm.h
+++ b/include/tvm/topi/nn/local_response_norm.h
@@ -64,17 +64,26 @@ inline Tensor lrn(const Tensor& data, int size, int axis =
1, float alpha = 0.00
auto rxs = tvm::te::reduce_axis(Range(0, size), "rxs");
Tensor sqr_sum;
if (axis == 1) {
- sqr_sum = tvm::te::compute(input_shape, [&](Var i, Var l, Var j, Var k) {
- return tvm::sum(pad_data(i, l + rxs, j, k) * pad_data(i, l + rxs, j, k),
{rxs});
- });
+ sqr_sum = tvm::te::compute(
+ input_shape,
+ [&](Var i, Var l, Var j, Var k) {
+ return tvm::sum(pad_data(i, l + rxs, j, k) * pad_data(i, l + rxs, j,
k), {rxs});
+ },
+ "tensor", "sqr_sum");
} else if (axis == 3) {
- sqr_sum = tvm::te::compute(input_shape, [&](Var i, Var l, Var j, Var k) {
- return tvm::sum(pad_data(i, l, j, k + rxs) * pad_data(i, l, j, k + rxs),
{rxs});
- });
+ sqr_sum = tvm::te::compute(
+ input_shape,
+ [&](Var i, Var l, Var j, Var k) {
+ return tvm::sum(pad_data(i, l, j, k + rxs) * pad_data(i, l, j, k +
rxs), {rxs});
+ },
+ "tensor", "sqr_sum");
}
- auto sqrt_sum_up = tvm::te::compute(input_shape, [&](Var i, Var j, Var k,
Var l) {
- return tvm::pow(bias + (div(alpha * sqr_sum(i, j, k, l), size)), beta);
- });
+ auto sqrt_sum_up = tvm::te::compute(
+ input_shape,
+ [&](Var i, Var j, Var k, Var l) {
+ return tvm::pow(bias + (div(alpha * sqr_sum(i, j, k, l), size)), beta);
+ },
+ "tensor", kElementWise);
return topi::divide(data, sqrt_sum_up);
}
} // namespace nn
diff --git a/include/tvm/topi/rocm/normalization.h
b/include/tvm/topi/rocm/normalization.h
deleted file mode 100644
index 2fbb880..0000000
--- a/include/tvm/topi/rocm/normalization.h
+++ /dev/null
@@ -1,46 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file rocm/normalization.h
- * \brief rocm schedule for LRN and l2 normalization operations
- */
-#ifndef TVM_TOPI_ROCM_NORMALIZATION_H_
-#define TVM_TOPI_ROCM_NORMALIZATION_H_
-
-#include <tvm/target/generic_func.h>
-#include <tvm/te/operation.h>
-#include <tvm/topi/tags.h>
-
-namespace tvm {
-namespace topi {
-
-using namespace tvm::te;
-namespace rocm {
-/*!
- * \brief Create a rocm schedule for LRN
- * \param outs The output tensors.
- * \return A schedule for the given ops.
- */
-inline Schedule schedule_lrn(const Array<Tensor>& outs) { return
topi::cuda::schedule_lrn(outs); }
-
-} // namespace rocm
-} // namespace topi
-} // namespace tvm
-#endif // TVM_TOPI_ROCM_NORMALIZATION_H_
diff --git a/python/tvm/relay/op/strategy/rocm.py
b/python/tvm/relay/op/strategy/rocm.py
index 64373dc..8d9c28b 100644
--- a/python/tvm/relay/op/strategy/rocm.py
+++ b/python/tvm/relay/op/strategy/rocm.py
@@ -27,13 +27,6 @@ from .. import op as _op
from .cuda import judge_winograd, naive_schedule
-@schedule_lrn.register("rocm")
-def schedule_lrn_rocm(attrs, outs, target):
- """schedule LRN for rocm"""
- with target:
- return topi.rocm.schedule_lrn(outs)
-
-
@conv2d_strategy.register("rocm")
def conv2d_strategy_rocm(attrs, inputs, out_type, target):
"""conv2d rocm strategy"""
diff --git a/python/tvm/topi/cuda/nn.py b/python/tvm/topi/cuda/nn.py
index 0de3777..e29bb44 100644
--- a/python/tvm/topi/cuda/nn.py
+++ b/python/tvm/topi/cuda/nn.py
@@ -18,7 +18,9 @@
"""scheduler functions for cuda backend"""
from __future__ import absolute_import as _abs
-from .. import cpp
+import tvm
+from tvm import te
+from ..utils import traverse_inline
def schedule_lrn(outs):
@@ -35,4 +37,19 @@ def schedule_lrn(outs):
sch: Schedule
The computation schedule for the op.
"""
- return cpp.cuda.schedule_lrn(outs)
+ outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
+ s = te.create_schedule([x.op for x in outs])
+ max_threads =
int(tvm.target.Target.current(allow_none=False).max_num_threads)
+
+ def _callback(op):
+ if "sqr_sum" in op.tag:
+ pad = op.input_tensors[0]
+ s[pad].compute_inline()
+ fused_axis = s[outs[0]].fuse(*s[outs[0]].op.axis)
+ bx, tx = s[outs[0]].split(fused_axis, factor=max_threads)
+ s[outs[0]].bind(bx, te.thread_axis("blockIdx.x"))
+ s[outs[0]].bind(tx, te.thread_axis("threadIdx.x"))
+ s[op].compute_at(s[outs[0]], tx)
+
+ traverse_inline(s, outs[0].op, _callback)
+ return s
diff --git a/python/tvm/topi/rocm/__init__.py b/python/tvm/topi/rocm/__init__.py
index 1ea4c79..f61039a 100644
--- a/python/tvm/topi/rocm/__init__.py
+++ b/python/tvm/topi/rocm/__init__.py
@@ -22,4 +22,3 @@ from __future__ import absolute_import as _abs
from .batch_matmul import *
from .conv2d import *
from .dense import *
-from .nn import *
diff --git a/python/tvm/topi/rocm/nn.py b/python/tvm/topi/rocm/nn.py
deleted file mode 100644
index c963375..0000000
--- a/python/tvm/topi/rocm/nn.py
+++ /dev/null
@@ -1,24 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""scheduler for normalization functions on rocm backend"""
-from __future__ import absolute_import as _abs
-
-from .. import cpp
-
-
-def schedule_lrn(outs):
- return cpp.rocm.schedule_lrn(outs)
diff --git a/src/topi/schedule.cc b/src/topi/schedule.cc
index f9400bf..21f863b 100644
--- a/src/topi/schedule.cc
+++ b/src/topi/schedule.cc
@@ -29,7 +29,6 @@
#include <tvm/target/generic_func.h>
#include <tvm/topi/cuda/dense.h>
#include <tvm/topi/cuda/injective.h>
-#include <tvm/topi/cuda/normalization.h>
#include <tvm/topi/cuda/pooling.h>
#include <tvm/topi/cuda/reduction.h>
#include <tvm/topi/cuda/softmax.h>
@@ -39,7 +38,6 @@
#include <tvm/topi/generic/injective.h>
#include <tvm/topi/rocm/dense.h>
#include <tvm/topi/rocm/injective.h>
-#include <tvm/topi/rocm/normalization.h>
#include <tvm/topi/rocm/pooling.h>
#include <tvm/topi/rocm/reduction.h>
#include <tvm/topi/rocm/softmax.h>
@@ -139,10 +137,6 @@
TVM_REGISTER_GLOBAL("topi.rocm.schedule_softmax").set_body([](TVMArgs args, TVMR
*rv = topi::rocm::schedule_softmax(args[0], args[1]);
});
-TVM_REGISTER_GLOBAL("topi.rocm.schedule_lrn").set_body([](TVMArgs args,
TVMRetValue* rv) {
- *rv = topi::rocm::schedule_lrn(args[0]);
-});
-
/* CUDA schedules */
TVM_REGISTER_GLOBAL("topi.cuda.dense_cuda").set_body([](TVMArgs args,
TVMRetValue* rv) {
*rv = cuda::dense_cuda(args[0], args[1], args[2], args[3], args[4]);
@@ -177,10 +171,6 @@
TVM_REGISTER_GLOBAL("topi.cuda.schedule_softmax").set_body([](TVMArgs args, TVMR
*rv = topi::cuda::schedule_softmax(args[0], args[1]);
});
-TVM_REGISTER_GLOBAL("topi.cuda.schedule_lrn").set_body([](TVMArgs args,
TVMRetValue* rv) {
- *rv = topi::cuda::schedule_lrn(args[0]);
-});
-
/* Utility functions */
TVM_REGISTER_GLOBAL("topi.utils.is_empty_shape").set_body([](TVMArgs args,
TVMRetValue* rv) {
*rv = topi::detail::is_empty_shape(args[0]);