This is an automated email from the ASF dual-hosted git repository.
kparzysz pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new dc1324635a tanh float16 (#12165)
dc1324635a is described below
commit dc1324635a52e9cf0a8dbf535c58e60b92fe4736
Author: Aakanksha Verma <[email protected]>
AuthorDate: Mon Jul 25 02:58:09 2022 +0530
tanh float16 (#12165)
Co-authored-by: aakaverm <[email protected]>
---
python/tvm/topi/hexagon/slice_ops/__init__.py | 1 +
python/tvm/topi/hexagon/slice_ops/tanh.py | 56 +++++++++++
.../contrib/test_hexagon/topi/test_tanh_slice.py | 109 +++++++++++++++++++++
3 files changed, 166 insertions(+)
diff --git a/python/tvm/topi/hexagon/slice_ops/__init__.py
b/python/tvm/topi/hexagon/slice_ops/__init__.py
index f6c30c2500..c178aeeb0e 100644
--- a/python/tvm/topi/hexagon/slice_ops/__init__.py
+++ b/python/tvm/topi/hexagon/slice_ops/__init__.py
@@ -32,3 +32,4 @@ from .cast import (
from .conv2d import *
from .reshape import reshape_compute, reshape_stir_schedule
from .relu import relu_compute, relu_stir_schedule
+from .tanh import tanh_te_compute, tanhf16_schedule
diff --git a/python/tvm/topi/hexagon/slice_ops/tanh.py
b/python/tvm/topi/hexagon/slice_ops/tanh.py
new file mode 100644
index 0000000000..3e10ec599c
--- /dev/null
+++ b/python/tvm/topi/hexagon/slice_ops/tanh.py
@@ -0,0 +1,56 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name
+
+""" Hexagon tanh slice op compute and schedule """
+import tvm
+from tvm import te, tir
+from ..utils import get_layout_transform_fn
+
+
+def tanh_te_compute(in_tensor):
+ out_tensor = te.compute(
+ in_tensor.shape, lambda n, h, w, c: tvm.tir.tanh(in_tensor[n, h, w,
c]), name="tanhf16"
+ )
+ return out_tensor
+
+
+def tanhf16_stir_sched_nhwc(func, in_layout, out_layout, h_split_factor=8):
+ """Schedule for nhwc fp16 to nchw fp16 layout"""
+ sch = tir.Schedule(func, debug_mask="all")
+ block_name = "tanhf16"
+ n, h, w, c = sch.get_loops(sch.get_block(block_name))
+ h_outer, h_inner = sch.split(h, [None, h_split_factor])
+ w_outer, w_inner = sch.split(w, [None, 4])
+ c_outer, c_inner = sch.split(c, [None, 32])
+ w_inner_o, w_inner_i = sch.split(w_inner, [None, 2])
+ sch.reorder(n, h_outer, w_outer, c_outer, h_inner, w_inner_o, c_inner,
w_inner_i)
+ sch.transform_layout(block_name, "A", in_layout)
+ sch.transform_layout(block_name, block_name, out_layout)
+ fused = sch.fuse(c_inner, w_inner_i)
+ sch.vectorize(fused)
+ return sch
+
+
+def tanhf16_schedule(tanh_func, in_layout_str, out_layout_str):
+ in_layout_transform_func = get_layout_transform_fn(in_layout_str)
+ out_layout_transform_func = get_layout_transform_fn(out_layout_str)
+ return tanhf16_stir_sched_nhwc(
+ tanh_func,
+ in_layout_transform_func,
+ out_layout_transform_func,
+ )
diff --git a/tests/python/contrib/test_hexagon/topi/test_tanh_slice.py
b/tests/python/contrib/test_hexagon/topi/test_tanh_slice.py
new file mode 100644
index 0000000000..b1e85971a2
--- /dev/null
+++ b/tests/python/contrib/test_hexagon/topi/test_tanh_slice.py
@@ -0,0 +1,109 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+""" Test for Hexagon slice tanh op """
+import numpy as np
+import pytest
+
+import tvm
+import tvm.testing
+from tvm import te
+import tvm.topi.hexagon.slice_ops as sl
+import tvm.contrib.hexagon
+from ..infrastructure import allocate_hexagon_array, transform_numpy
+
+# pylint: disable=invalid-name
+
+
+class TestTanhSlice:
+ """For Testing Tanh fp16 op"""
+
+ input_shape, orig_layout, input_layout, output_layout, axis_sep =
tvm.testing.parameters(
+ ((1, 8, 4, 32), "nhwc", "nhwc-8h2w32c2w-2d", "nhwc-8h2w32c2w-2d", [4]),
+ ((1, 16, 12, 64), "nhwc", "nhwc-8h2w32c2w-2d", "nhwc-8h2w32c2w-2d",
[4]),
+ ((1, 64, 64, 32), "nhwc", "nhwc-8h2w32c2w-2d", "nhwc-8h2w32c2w-2d",
[4]),
+ )
+ dtype = tvm.testing.parameter("float16")
+ working_scope = tvm.testing.parameter("global.vtcm")
+
+ @tvm.testing.fixture
+ def input_np(self, input_shape, dtype):
+ return np.random.uniform(size=input_shape).astype(dtype)
+
+ @tvm.testing.fixture
+ def transformed_input_np(self, input_np, orig_layout, input_layout):
+ return transform_numpy(input_np, orig_layout, input_layout)
+
+ @tvm.testing.fixture
+ def expected_output_np(self, input_np):
+ ref_np = np.tanh(input_np)
+ return ref_np
+
+ @tvm.testing.fixture
+ def transformed_expected_output_np(self, expected_output_np, orig_layout,
output_layout):
+ return transform_numpy(expected_output_np, orig_layout, output_layout)
+
+ @tvm.testing.requires_hexagon
+ def test_tanh(
+ self,
+ input_shape,
+ dtype,
+ input_layout,
+ output_layout,
+ transformed_input_np,
+ transformed_expected_output_np,
+ axis_sep,
+ hexagon_session,
+ working_scope,
+ ):
+ """Top Level testing function for tanh fp16 op"""
+
+ target_hexagon = tvm.target.hexagon("v69")
+ target = tvm.target.Target(target_hexagon, host=target_hexagon)
+ A = te.placeholder(input_shape, name="A", dtype=dtype)
+ M = sl.tanh_te_compute(A)
+ tanhf16_func = te.create_prim_func([A, M])
+ tir_s = sl.tanhf16_schedule(tanhf16_func, input_layout, output_layout)
+ A_data = allocate_hexagon_array(
+ hexagon_session.device,
+ data=transformed_input_np,
+ axis_separators=axis_sep,
+ mem_scope=working_scope,
+ )
+ M_data = allocate_hexagon_array(
+ hexagon_session.device,
+ tensor_shape=transformed_expected_output_np.shape,
+ dtype=transformed_expected_output_np.dtype,
+ axis_separators=axis_sep,
+ mem_scope=working_scope,
+ )
+ with tvm.transform.PassContext(opt_level=3):
+ tir_irm = tvm.lower(tir_s.mod, [A, M], name="tanhf16")
+ runtime_module = tvm.build(tir_irm, target=target, name="tanhf16")
+ mod = hexagon_session.load_module(runtime_module)
+
+ mod(A_data, M_data)
+ output_np = M_data.numpy()
+ tvm.testing.assert_allclose(
+ output_np,
+ transformed_expected_output_np,
+ 1e-3,
+ 1e-3,
+ )
+
+
+if __name__ == "__main__":
+ sys.exit(pytest.main(sys.argv))