eric-haibin-lin commented on a change in pull request #17138: Interleaved MHA
for CPU path
URL: https://github.com/apache/incubator-mxnet/pull/17138#discussion_r360658933
##########
File path: tests/python/unittest/test_operator.py
##########
@@ -9373,6 +9373,330 @@ def check_random_uniform():
hight = 1
assertRaises(MXNetError, mx.nd.random_uniform, alpha, beta, shape)
+def check_multihead_attention_selfatt(dtype):
+ def convert_weight(F, q_weight, k_weight, v_weight, num_heads):
+ q_weight = F.reshape(q_weight, shape=(num_heads, -1, 0), reverse=True)
+ k_weight = F.reshape(k_weight, shape=(num_heads, -1, 0), reverse=True)
+ v_weight = F.reshape(v_weight, shape=(num_heads, -1, 0), reverse=True)
+ all_weights = F.concat(q_weight, k_weight, v_weight, dim=-2)
+ all_weights = F.reshape(all_weights, shape=(-1, 0), reverse=True)
+ return all_weights
+
+ def convert_bias(F, q_bias, k_bias, v_bias, num_heads):
+ q_bias = F.reshape(q_bias, shape=(num_heads, -1))
+ k_bias = F.reshape(k_bias, shape=(num_heads, -1))
+ v_bias = F.reshape(v_bias, shape=(num_heads, -1))
+ all_bias = F.stack(q_bias, k_bias, v_bias, axis=1)
+ all_bias = F.reshape(all_bias, shape=(-1,))
+ return all_bias
+
+ batch_size = 2
+ qkv_length = 7 # length of a sequence
+ qkv_dim = 9 # dimension of encoding
+ num_heads = 3 # number of attention head
+ head_dim = 5 # head size
+ out_dim = 13 * num_heads
+ qkv_units = num_heads * head_dim
+
+ arg_params = {
+ 'qkv': mx.nd.array(np.random.rand(*(batch_size, qkv_length,
qkv_dim)).astype(dtype) * 0.1, dtype=dtype),
+ 'q_weight': mx.nd.array(np.random.rand(*(qkv_units,
qkv_dim)).astype(dtype) * 0.1, dtype=dtype),
+ 'k_weight': mx.nd.array(np.random.rand(*(qkv_units,
qkv_dim)).astype(dtype) * 0.1, dtype=dtype),
+ 'v_weight': mx.nd.array(np.random.rand(*(qkv_units,
qkv_dim)).astype(dtype) * 0.1, dtype=dtype),
+ 'q_bias': mx.nd.array(np.random.rand(*(qkv_units,)).astype(dtype) *
0.1, dtype=dtype),
+ 'k_bias': mx.nd.array(np.random.rand(*(qkv_units,)).astype(dtype) *
0.1, dtype=dtype),
+ 'v_bias': mx.nd.array(np.random.rand(*(qkv_units,)).astype(dtype) *
0.1, dtype=dtype),
+ 'out_weight': mx.nd.array(np.random.rand(*(out_dim,
qkv_units)).astype(dtype) * 0.1, dtype=dtype),
+ 'out_bias': mx.nd.array(np.random.rand(*(out_dim,)).astype(dtype) *
0.1, dtype=dtype),
+ }
+
+ qkv = mx.sym.Variable('qkv')
+ sonde = mx.sym.Variable('sonde')
+ q_weight = mx.sym.Variable('q_weight')
+ k_weight = mx.sym.Variable('k_weight')
+ v_weight = mx.sym.Variable('v_weight')
+ q_bias = mx.sym.Variable('q_bias')
+ k_bias = mx.sym.Variable('k_bias')
+ v_bias = mx.sym.Variable('v_bias')
+ out_weight = mx.sym.Variable('out_weight')
+ out_bias = mx.sym.Variable('out_bias')
+ qkv_weight = convert_weight(mx.sym, q_weight, k_weight, v_weight,
num_heads)
+ qkv_bias = convert_bias(mx.sym, q_bias, k_bias, v_bias, num_heads)
+ qkv = mx.sym.transpose(qkv, axes=(1, 0, 2))
+ qkv_proj = mx.sym.FullyConnected(qkv, weight=qkv_weight, bias=qkv_bias,
flatten=False,
+ num_hidden=qkv_units * 3, no_bias=False)
+ att_score = mx.sym.contrib.interleaved_matmul_selfatt_qk(
+ qkv_proj, heads=num_heads)
+ att_score = att_score + sonde
+ weighted_value = mx.sym.contrib.interleaved_matmul_selfatt_valatt(
+ qkv_proj, att_score, heads=num_heads)
+ output = mx.sym.FullyConnected(weighted_value, weight=out_weight,
bias=out_bias, flatten=False,
+ num_hidden=out_dim, no_bias=False)
+ output = mx.sym.transpose(output, axes=(1, 0, 2))
+ output = mx.sym.Group([output, att_score])
+ executor = output.simple_bind(ctx=default_context(),
+ qkv=(batch_size, qkv_length, qkv_dim),
+ q_weight=(qkv_units, qkv_dim),
+ q_bias=(qkv_units,),
+ k_weight=(qkv_units, qkv_dim),
+ k_bias=(qkv_units,),
+ v_weight=(qkv_units, qkv_dim),
+ v_bias=(qkv_units,),
+ type_dict={'qkv': dtype,
+ 'q_weight': dtype,
+ 'k_weight': dtype,
+ 'v_weight': dtype,
+ 'q_bias': dtype,
+ 'k_bias': dtype,
+ 'v_bias': dtype,
+ 'sonde': dtype},
+ grad_req='write', force_rebind=True)
+ output_shape = executor.outputs[0].shape
+ output_grads = np.random.rand(*output_shape).astype(dtype) * 0.1
+ executor.copy_params_from(arg_params, {})
+ executor.arg_dict['sonde'][:] = 0.
+ executor.arg_dict['sonde'].wait_to_read()
+ executor.forward(is_train=True)
+ output_opti = executor.outputs[0].asnumpy()
+ att_score_opti = executor.outputs[1].asnumpy()
+ executor.backward([mx.nd.array(output_grads, dtype=dtype),
+ mx.nd.zeros(att_score_opti.shape, dtype=dtype)])
+ grads_opti = {k: v.asnumpy() for k, v in executor.grad_dict.items()}
+ qkv = mx.sym.Variable('qkv')
+ sonde = mx.sym.Variable('sonde')
+ q_weight = mx.sym.Variable('q_weight')
+ k_weight = mx.sym.Variable('k_weight')
+ v_weight = mx.sym.Variable('v_weight')
+ q_bias = mx.sym.Variable('q_bias')
+ k_bias = mx.sym.Variable('k_bias')
+ v_bias = mx.sym.Variable('v_bias')
+ out_weight = mx.sym.Variable('out_weight')
+ out_bias = mx.sym.Variable('out_bias')
+
+ q = mx.sym.FullyConnected(qkv, weight=q_weight, bias=q_bias, flatten=False,
+ num_hidden=qkv_units, no_bias=False)
+ k = mx.sym.FullyConnected(qkv, weight=k_weight, bias=k_bias, flatten=False,
+ num_hidden=qkv_units, no_bias=False)
+ v = mx.sym.FullyConnected(qkv, weight=v_weight, bias=v_bias, flatten=False,
+ num_hidden=qkv_units, no_bias=False)
+ q = mx.sym.reshape(q, shape=(0, 0, num_heads, -1))
+ q = mx.sym.transpose(q, axes=(0, 2, 1, 3))
+ q = mx.sym.reshape(q, shape=(-1, 0, 0), reverse=True)
+ k = mx.sym.reshape(k, shape=(0, 0, num_heads, -1))
+ k = mx.sym.transpose(k, axes=(0, 2, 1, 3))
+ k = mx.sym.reshape(k, shape=(-1, 0, 0), reverse=True)
+ q = mx.sym.contrib.div_sqrt_dim(q)
+ att_score = mx.sym.batch_dot(q, k, transpose_b=True)
+ att_score = att_score + sonde
+ v = mx.sym.reshape(v, shape=(0, 0, num_heads, -1))
+ v = mx.sym.transpose(v, axes=(0, 2, 1, 3))
+ v = mx.sym.reshape(v, shape=(-1, 0, 0), reverse=True)
+ weighted_value = mx.sym.batch_dot(att_score, v)
+ weighted_value = mx.sym.reshape(weighted_value, shape=(-1, num_heads, 0,
0),
+ reverse=True)
+ weighted_value = mx.sym.transpose(weighted_value, axes=(0, 2, 1, 3))
+ weighted_value = mx.sym.reshape(weighted_value, shape=(0, 0, -1))
+ output = mx.sym.FullyConnected(weighted_value, weight=out_weight,
bias=out_bias, flatten=False,
+ num_hidden=out_dim, no_bias=False)
+ output = mx.sym.Group([output, att_score])
+ executor = output.simple_bind(ctx=default_context(),
+ qkv=(batch_size, qkv_length, qkv_dim),
+ type_dict={'qkv': dtype},
+ grad_req='write', force_rebind=True)
+ executor.copy_params_from(arg_params, {})
+ executor.arg_dict['sonde'][:] = 0.
+ executor.arg_dict['sonde'].wait_to_read()
+ executor.forward(is_train=True)
+ output_orig = executor.outputs[0].asnumpy()
+ att_score_orig = executor.outputs[1].asnumpy()
+ executor.backward([mx.nd.array(output_grads, dtype=dtype),
+ mx.nd.zeros(att_score_orig.shape, dtype=dtype)])
+ grads_orig = {k : v.asnumpy() for k, v in executor.grad_dict.items()}
+ assert_allclose(att_score_orig, att_score_opti, rtol=1e-2, atol=1e-3)
+ assert_allclose(output_orig, output_opti, rtol=1e-2, atol=1e-3)
+
+ for k in grads_opti.keys():
+ assert(grads_orig[k].dtype == grads_opti[k].dtype)
+ assert(grads_orig[k].shape == grads_opti[k].shape)
+ assert_allclose(grads_orig[k], grads_opti[k], rtol=1e-2, atol=1e-3)
+
+
+@with_seed(12345)
+def test_multihead_attention_selfatt():
+ dtypes = ['float32']
+ if default_context().device_type == 'gpu':
+ dtypes += ['float16']
+
+ for dtype in dtypes:
+ check_multihead_attention_selfatt(dtype=dtype)
+
+def check_multihead_attention_encdec(dtype):
+ def convert_weight(F, k_weight, v_weight, num_heads):
+ k_weight = F.reshape(k_weight, shape=(num_heads, -1, 0), reverse=True)
+ v_weight = F.reshape(v_weight, shape=(num_heads, -1, 0), reverse=True)
+ all_weights = F.concat(k_weight, v_weight, dim=-2)
+ all_weights = F.reshape(all_weights, shape=(-1, 0), reverse=True)
+ return all_weights
+
+ def convert_bias(F, k_bias, v_bias, num_heads):
+ k_bias = F.reshape(k_bias, shape=(num_heads, -1))
+ v_bias = F.reshape(v_bias, shape=(num_heads, -1))
+ all_bias = F.stack(k_bias, v_bias, axis=1)
+ all_bias = F.reshape(all_bias, shape=(-1,))
+ return all_bias
+
+ batch_size = 2
+ qkv_length = 7 # length of a sequence
+ qkv_dim = 9 # dimension of encoding
+ num_heads = 3 # number of attention head
+ head_dim = 5 # head size
+ out_dim = 13 * num_heads
+ qkv_units = num_heads * head_dim
+
+ arg_params = {
+ 'q': mx.nd.array(np.random.rand(*(batch_size, qkv_length,
qkv_dim)).astype(dtype) * 0.1, dtype=dtype),
+ 'kv': mx.nd.array(np.random.rand(*(batch_size, qkv_length,
qkv_dim)).astype(dtype) * 0.1, dtype=dtype),
+ 'q_weight': mx.nd.array(np.random.rand(*(qkv_units,
qkv_dim)).astype(dtype) * 0.1, dtype=dtype),
+ 'k_weight': mx.nd.array(np.random.rand(*(qkv_units,
qkv_dim)).astype(dtype) * 0.1, dtype=dtype),
+ 'v_weight': mx.nd.array(np.random.rand(*(qkv_units,
qkv_dim)).astype(dtype) * 0.1, dtype=dtype),
+ 'q_bias': mx.nd.array(np.random.rand(*(qkv_units,)).astype(dtype) *
0.1, dtype=dtype),
+ 'k_bias': mx.nd.array(np.random.rand(*(qkv_units,)).astype(dtype) *
0.1, dtype=dtype),
+ 'v_bias': mx.nd.array(np.random.rand(*(qkv_units,)).astype(dtype) *
0.1, dtype=dtype),
+ 'out_weight': mx.nd.array(np.random.rand(*(out_dim,
qkv_units)).astype(dtype) * 0.1, dtype=dtype),
+ 'out_bias': mx.nd.array(np.random.rand(*(out_dim,)).astype(dtype) *
0.1, dtype=dtype),
+ }
+
+ q = mx.sym.Variable('q')
+ kv = mx.sym.Variable('kv')
+ sonde = mx.sym.Variable('sonde')
+ q_weight = mx.sym.Variable('q_weight')
+ k_weight = mx.sym.Variable('k_weight')
+ v_weight = mx.sym.Variable('v_weight')
+ q_bias = mx.sym.Variable('q_bias')
+ k_bias = mx.sym.Variable('k_bias')
+ v_bias = mx.sym.Variable('v_bias')
+ out_weight = mx.sym.Variable('out_weight')
+ out_bias = mx.sym.Variable('out_bias')
+ kv_weight = convert_weight(mx.sym, k_weight, v_weight, num_heads)
+ kv_bias = convert_bias(mx.sym, k_bias, v_bias, num_heads)
+ kv = mx.sym.transpose(kv, axes=(1, 0, 2))
+ kv_proj = mx.sym.FullyConnected(kv, weight=kv_weight, bias=kv_bias,
flatten=False,
+ num_hidden=qkv_units * 2, no_bias=False)
+ q = mx.sym.transpose(q, axes=(1, 0, 2))
+ q_proj = mx.sym.FullyConnected(q, weight=q_weight, bias=q_bias,
flatten=False,
+ num_hidden=qkv_units, no_bias=False)
+ att_score = mx.sym.contrib.interleaved_matmul_encdec_qk(
+ q_proj, kv_proj, heads=num_heads)
+ att_score = att_score + sonde
+ weighted_value = mx.sym.contrib.interleaved_matmul_encdec_valatt(
+ kv_proj, att_score, heads=num_heads)
+ output = mx.sym.FullyConnected(weighted_value, weight=out_weight,
bias=out_bias, flatten=False,
+ num_hidden=out_dim, no_bias=False)
+ output = mx.sym.transpose(output, axes=(1, 0, 2))
+ output = mx.sym.Group([output, att_score])
+ executor = output.simple_bind(ctx=default_context(),
+ q=(batch_size, qkv_length, qkv_dim),
+ kv=(batch_size, qkv_length, qkv_dim),
+ q_weight=(qkv_units, qkv_dim),
+ q_bias=(qkv_units,),
+ k_weight=(qkv_units, qkv_dim),
+ k_bias=(qkv_units,),
+ v_weight=(qkv_units, qkv_dim),
+ v_bias=(qkv_units,),
+ out_weight=(out_dim, qkv_units),
+ out_bias=(out_dim,),
+ type_dict={'q': dtype,
+ 'kv': dtype,
+ 'q_weight': dtype,
+ 'q_bias': dtype,
+ 'k_weight': dtype,
+ 'k_bias': dtype,
+ 'v_weight': dtype,
+ 'v_bias': dtype,
+ 'out_weight': dtype,
+ 'out_bias': dtype,
+ },
+ grad_req='write', force_rebind=True)
+ output_shape = executor.outputs[0].shape
+ output_grads = np.random.rand(*output_shape).astype(dtype) * 0.1
+ executor.copy_params_from(arg_params, {})
+ executor.arg_dict['sonde'][:] = 0.
+ executor.arg_dict['sonde'].wait_to_read()
+ executor.forward(is_train=True)
+ output_opti = executor.outputs[0].asnumpy()
+ att_score_opti = executor.outputs[1].asnumpy()
+ executor.backward([mx.nd.array(output_grads, dtype=dtype),
mx.nd.zeros(att_score_opti.shape, dtype=dtype)])
+
+ grads_opti = {k: v.asnumpy() for k, v in executor.grad_dict.items()}
+
+ q = mx.sym.Variable('q')
+ kv = mx.sym.Variable('kv')
+ sonde = mx.sym.Variable('sonde')
+ q_weight = mx.sym.Variable('q_weight')
+ k_weight = mx.sym.Variable('k_weight')
+ v_weight = mx.sym.Variable('v_weight')
+ q_bias = mx.sym.Variable('q_bias')
+ k_bias = mx.sym.Variable('k_bias')
+ v_bias = mx.sym.Variable('v_bias')
+ out_weight = mx.sym.Variable('out_weight')
+ out_bias = mx.sym.Variable('out_bias')
+
+ q = mx.sym.FullyConnected(q, weight=q_weight, bias=q_bias, flatten=False,
+ num_hidden=qkv_units, no_bias=False)
+ k = mx.sym.FullyConnected(kv, weight=k_weight, bias=k_bias, flatten=False,
+ num_hidden=qkv_units, no_bias=False)
+ v = mx.sym.FullyConnected(kv, weight=v_weight, bias=v_bias, flatten=False,
+ num_hidden=qkv_units, no_bias=False)
+ q = mx.sym.reshape(q, shape=(0, 0, num_heads, -1))
+ q = mx.sym.transpose(q, axes=(0, 2, 1, 3))
+ q = mx.sym.reshape(q, shape=(-1, 0, 0), reverse=True)
+ k = mx.sym.reshape(k, shape=(0, 0, num_heads, -1))
+ k = mx.sym.transpose(k, axes=(0, 2, 1, 3))
+ k = mx.sym.reshape(k, shape=(-1, 0, 0), reverse=True)
+ q = mx.sym.contrib.div_sqrt_dim(q)
+ att_score = mx.sym.batch_dot(q, k, transpose_b=True)
+ att_score = att_score + sonde
+ v = mx.sym.reshape(v, shape=(0, 0, num_heads, -1))
+ v = mx.sym.transpose(v, axes=(0, 2, 1, 3))
+ v = mx.sym.reshape(v, shape=(-1, 0, 0), reverse=True)
+ weighted_value = mx.sym.batch_dot(att_score, v)
+ weighted_value = mx.sym.reshape(weighted_value, shape=(-1, num_heads, 0,
0),
+ reverse=True)
+ weighted_value = mx.sym.transpose(weighted_value, axes=(0, 2, 1, 3))
+ weighted_value = mx.sym.reshape(weighted_value, shape=(0, 0, -1))
+ output = mx.sym.FullyConnected(weighted_value, weight=out_weight,
bias=out_bias, flatten=False,
+ num_hidden=out_dim, no_bias=False)
+ output = mx.sym.Group([output, att_score])
+ executor = output.simple_bind(ctx=default_context(),
+ q=(batch_size, qkv_length, qkv_dim),
+ kv=(batch_size, qkv_length, qkv_dim),
+ type_dict={'q': dtype,
+ 'kv': dtype},
+ grad_req='write', force_rebind=True)
+ executor.copy_params_from(arg_params, {})
+ executor.arg_dict['sonde'][:] = 0.
+ executor.arg_dict['sonde'].wait_to_read()
+ executor.forward(is_train=True)
+ output_orig = executor.outputs[0].asnumpy()
+ att_score_orig = executor.outputs[1].asnumpy()
+ executor.backward([mx.nd.array(output_grads, dtype=dtype),
mx.nd.zeros(att_score_orig.shape, dtype=dtype)])
+ grads_orig = {k : v.asnumpy() for k, v in executor.grad_dict.items()}
+ assert_allclose(att_score_orig, att_score_opti, rtol=1e-2, atol=1e-3)
+ assert_allclose(output_orig, output_opti, rtol=1e-2, atol=1e-3)
+
+ for k in grads_opti.keys():
+ assert(grads_orig[k].dtype == grads_opti[k].dtype)
+ assert(grads_orig[k].shape == grads_opti[k].shape)
+ assert_allclose(grads_orig[k], grads_opti[k], rtol=1e-2, atol=1e-3)
+
+@with_seed(12345)
+def test_multihead_attention_encdec():
Review comment:
do we still need `check_multihead_attention_selfatt ` in
test_operator_gpu.py?
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services