This is an automated email from the ASF dual-hosted git repository.
jcf94 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new eebd5a9 [FastMath] Add cuda & x86 schedules for fast_softmax (#8150)
eebd5a9 is described below
commit eebd5a94e88e77a1a6adc8a5726c3f55e63ae8e1
Author: Chenfan <[email protected]>
AuthorDate: Mon May 31 16:22:33 2021 +0800
[FastMath] Add cuda & x86 schedules for fast_softmax (#8150)
* Add cuda & x86 schedules for fast_softmax
* Bug fix
* Re-trigger CI
---
python/tvm/relay/op/strategy/cuda.py | 12 ++++++++++++
python/tvm/relay/op/strategy/x86.py | 12 ++++++++++++
python/tvm/topi/cuda/softmax.py | 21 ++++++++++++++++++---
python/tvm/topi/x86/nn.py | 11 +++++++++++
tests/python/relay/test_op_fast_math.py | 11 +++++++----
5 files changed, 60 insertions(+), 7 deletions(-)
diff --git a/python/tvm/relay/op/strategy/cuda.py
b/python/tvm/relay/op/strategy/cuda.py
index d820283..8367a68 100644
--- a/python/tvm/relay/op/strategy/cuda.py
+++ b/python/tvm/relay/op/strategy/cuda.py
@@ -89,6 +89,18 @@ def softmax_strategy_cuda(attrs, inputs, out_type, target):
return strategy
+@fast_softmax_strategy.register(["cuda", "gpu"])
+def fast_softmax_strategy_cuda(attrs, inputs, out_type, target):
+ """fast_softmax cuda strategy"""
+ strategy = _op.OpStrategy()
+ strategy.add_implementation(
+ wrap_compute_softmax(topi.nn.fast_softmax),
+ wrap_topi_schedule(topi.cuda.schedule_softmax),
+ name="fast_softmax.cuda",
+ )
+ return strategy
+
+
@schedule_log_softmax.register(["cuda", "gpu"])
def schedule_log_softmax_cuda(attrs, outs, target):
"""scheudle log_softmax for cuda"""
diff --git a/python/tvm/relay/op/strategy/x86.py
b/python/tvm/relay/op/strategy/x86.py
index 60bd92e..c21ec4d 100644
--- a/python/tvm/relay/op/strategy/x86.py
+++ b/python/tvm/relay/op/strategy/x86.py
@@ -79,6 +79,18 @@ def softmax_strategy_cpu(attrs, inputs, out_type, target):
return strategy
+@fast_softmax_strategy.register("cpu")
+def fast_softmax_strategy_cpu(attrs, inputs, out_type, target):
+ """fast_softmax x86 strategy"""
+ strategy = _op.OpStrategy()
+ strategy.add_implementation(
+ wrap_compute_softmax(topi.nn.fast_softmax),
+ wrap_topi_schedule(topi.x86.schedule_softmax),
+ name="fast_softmax.x86",
+ )
+ return strategy
+
+
@schedule_log_softmax.register("cpu")
def schedule_log_softmax_cpu(attrs, outs, target):
"""schedule log_softmax op for x86"""
diff --git a/python/tvm/topi/cuda/softmax.py b/python/tvm/topi/cuda/softmax.py
index 99fbdd0..b743aef 100644
--- a/python/tvm/topi/cuda/softmax.py
+++ b/python/tvm/topi/cuda/softmax.py
@@ -47,8 +47,15 @@ def schedule_softmax(outs):
expsum = softmax.op.input_tensors[1]
exp = softmax.op.input_tensors[0]
max_elem = s[exp].op.input_tensors[1]
+ delta = None
+ elif op_tag == "fast_softmax_output":
+ expsum = softmax.op.input_tensors[1]
+ exp = softmax.op.input_tensors[0]
+ delta = s[exp].op.input_tensors[0]
+ max_elem = s[delta].op.input_tensors[1]
elif op_tag == "log_softmax_output":
exp = None
+ delta = None
max_elem = softmax.op.input_tensors[1]
expsum = softmax.op.input_tensors[2]
else:
@@ -73,6 +80,8 @@ def schedule_softmax(outs):
if len(softmax.shape) > 2:
ops = [max_elem.op, expsum.op, softmax.op]
+ if delta is not None:
+ ops.append(delta.op)
if exp is not None:
ops.append(exp.op)
@@ -99,7 +108,10 @@ def schedule_softmax(outs):
s[expsum].compute_at(s[softmax], xo)
# (2) exp
- if exp is not None:
+ if delta is not None:
+ s[exp].compute_inline()
+ s[delta].compute_inline()
+ elif exp is not None:
xo, xi = s[exp].split(exp.op.axis[1], nparts=num_thread)
_, xii = s[exp].split(xi, factor=4)
s[exp].vectorize(xii)
@@ -112,7 +124,7 @@ def schedule_softmax(outs):
k = max_elem.op.reduce_axis[0]
ko, _ = s[max_elem].split(k, nparts=num_thread)
s[max_elem].bind(ko, thread_x)
- if exp is not None:
+ if exp is not None and delta is None:
s[max_elem].compute_at(s[exp], xo)
else:
s[max_elem].bind(ko, thread_x)
@@ -123,7 +135,10 @@ def schedule_softmax(outs):
block_x = te.thread_axis("blockIdx.x")
thread_x = te.thread_axis((0, num_thread), "threadIdx.x")
- if exp is not None:
+ if delta is not None:
+ s[exp].compute_inline()
+ s[delta].compute_inline()
+ elif exp is not None:
s[exp].bind(exp.op.axis[0], block_x)
s[max_elem].bind(max_elem.op.axis[0], block_x)
diff --git a/python/tvm/topi/x86/nn.py b/python/tvm/topi/x86/nn.py
index 0994700..4c39f2a 100644
--- a/python/tvm/topi/x86/nn.py
+++ b/python/tvm/topi/x86/nn.py
@@ -42,9 +42,17 @@ def schedule_softmax(outs):
exp = softmax.op.input_tensors[0]
expsum = softmax.op.input_tensors[1]
max_elem = s[exp].op.input_tensors[1]
+ delta = None
+ axis = int(softmax.op.attrs["axis"])
+ elif op_tag == "fast_softmax_output":
+ exp = softmax.op.input_tensors[0]
+ expsum = softmax.op.input_tensors[1]
+ delta = s[exp].op.input_tensors[0]
+ max_elem = s[delta].op.input_tensors[1]
axis = int(softmax.op.attrs["axis"])
elif op_tag == "log_softmax_output":
exp = None
+ delta = None
max_elem = softmax.op.input_tensors[1]
expsum = softmax.op.input_tensors[2]
axis = 1
@@ -65,6 +73,9 @@ def schedule_softmax(outs):
s[max_elem].compute_at(s[softmax], fused_outer_axes)
s[expsum].compute_at(s[softmax], fused_outer_axes)
+ if delta is not None:
+ s[exp].compute_inline()
+ s[delta].compute_inline()
if exp is not None:
s[exp].compute_at(s[softmax], fused_outer_axes)
diff --git a/tests/python/relay/test_op_fast_math.py
b/tests/python/relay/test_op_fast_math.py
index c9314fa..f968dbe 100644
--- a/tests/python/relay/test_op_fast_math.py
+++ b/tests/python/relay/test_op_fast_math.py
@@ -23,9 +23,11 @@ import tvm.relay as relay
from tvm import topi
from tvm import te
from tvm.contrib import graph_executor
+from tvm.topi import testing
-def test_fastmath():
[email protected]_targets("llvm", "cuda")
+def test_fastmath(target, dev):
def test_apply(relay_op, name, f_numpy, low, high, step, dtype="float32"):
a_np = np.arange(low, high, step).astype(dtype).reshape((1, -1))
b_np = f_numpy(a_np)
@@ -36,13 +38,14 @@ def test_fastmath():
mod = tvm.IRModule.from_expr(func)
with tvm.transform.PassContext(opt_level=3,
required_pass=["FastMath"]):
- graph, lib, params = relay.build(mod, target="llvm", params=None)
+ graph, lib, params = relay.build(mod, target=target, params=None)
# Check that the op related to fast math have been convered to
function in lib
func_name = "fused_" + name
- assert lib.get_function(func_name)
+ # When there're multiple targets in tvm.testing.parametrize_targets,
the function
+ # built will have a "_1" in function name
+ assert func_name in graph
- dev = tvm.cpu(0)
m = graph_executor.create(graph, lib, dev)
# Set inputs
m.set_input("x", tvm.nd.array(a_np, dev))