eric-haibin-lin commented on a change in pull request #17138: Interleaved MHA 
for CPU path
URL: https://github.com/apache/incubator-mxnet/pull/17138#discussion_r360658933
 
 

 ##########
 File path: tests/python/unittest/test_operator.py
 ##########
 @@ -9373,6 +9373,330 @@ def check_random_uniform():
             hight = 1
             assertRaises(MXNetError, mx.nd.random_uniform, alpha, beta, shape)
 
+def check_multihead_attention_selfatt(dtype):
+    def convert_weight(F, q_weight, k_weight, v_weight, num_heads):
+        q_weight = F.reshape(q_weight, shape=(num_heads, -1, 0), reverse=True)
+        k_weight = F.reshape(k_weight, shape=(num_heads, -1, 0), reverse=True)
+        v_weight = F.reshape(v_weight, shape=(num_heads, -1, 0), reverse=True)
+        all_weights = F.concat(q_weight, k_weight, v_weight, dim=-2)
+        all_weights = F.reshape(all_weights, shape=(-1, 0), reverse=True)
+        return all_weights
+
+    def convert_bias(F, q_bias, k_bias, v_bias, num_heads):
+        q_bias = F.reshape(q_bias, shape=(num_heads, -1))
+        k_bias = F.reshape(k_bias, shape=(num_heads, -1))
+        v_bias = F.reshape(v_bias, shape=(num_heads, -1))
+        all_bias = F.stack(q_bias, k_bias, v_bias, axis=1)
+        all_bias = F.reshape(all_bias, shape=(-1,))
+        return all_bias
+
+    batch_size = 2
+    qkv_length = 7  # length of a sequence
+    qkv_dim = 9     # dimension of encoding
+    num_heads = 3   # number of attention head
+    head_dim = 5    # head size
+    out_dim = 13 * num_heads
+    qkv_units = num_heads * head_dim
+
+    arg_params = {
+        'qkv': mx.nd.array(np.random.rand(*(batch_size, qkv_length, 
qkv_dim)).astype(dtype) * 0.1, dtype=dtype),
+        'q_weight': mx.nd.array(np.random.rand(*(qkv_units, 
qkv_dim)).astype(dtype) * 0.1, dtype=dtype),
+        'k_weight': mx.nd.array(np.random.rand(*(qkv_units, 
qkv_dim)).astype(dtype) * 0.1, dtype=dtype),
+        'v_weight': mx.nd.array(np.random.rand(*(qkv_units, 
qkv_dim)).astype(dtype) * 0.1, dtype=dtype),
+        'q_bias': mx.nd.array(np.random.rand(*(qkv_units,)).astype(dtype) * 
0.1, dtype=dtype),
+        'k_bias': mx.nd.array(np.random.rand(*(qkv_units,)).astype(dtype) * 
0.1, dtype=dtype),
+        'v_bias': mx.nd.array(np.random.rand(*(qkv_units,)).astype(dtype) * 
0.1, dtype=dtype),
+        'out_weight': mx.nd.array(np.random.rand(*(out_dim, 
qkv_units)).astype(dtype) * 0.1, dtype=dtype),
+        'out_bias': mx.nd.array(np.random.rand(*(out_dim,)).astype(dtype) * 
0.1, dtype=dtype),
+        }
+
+    qkv = mx.sym.Variable('qkv')
+    sonde = mx.sym.Variable('sonde')
+    q_weight = mx.sym.Variable('q_weight')
+    k_weight = mx.sym.Variable('k_weight')
+    v_weight = mx.sym.Variable('v_weight')
+    q_bias = mx.sym.Variable('q_bias')
+    k_bias = mx.sym.Variable('k_bias')
+    v_bias = mx.sym.Variable('v_bias')
+    out_weight = mx.sym.Variable('out_weight')
+    out_bias = mx.sym.Variable('out_bias')
+    qkv_weight = convert_weight(mx.sym, q_weight, k_weight, v_weight, 
num_heads)
+    qkv_bias = convert_bias(mx.sym, q_bias, k_bias, v_bias, num_heads)
+    qkv = mx.sym.transpose(qkv, axes=(1, 0, 2))
+    qkv_proj = mx.sym.FullyConnected(qkv, weight=qkv_weight, bias=qkv_bias, 
flatten=False,
+                                     num_hidden=qkv_units * 3, no_bias=False)
+    att_score = mx.sym.contrib.interleaved_matmul_selfatt_qk(
+            qkv_proj, heads=num_heads)
+    att_score = att_score + sonde
+    weighted_value = mx.sym.contrib.interleaved_matmul_selfatt_valatt(
+            qkv_proj, att_score, heads=num_heads)
+    output = mx.sym.FullyConnected(weighted_value, weight=out_weight, 
bias=out_bias, flatten=False,
+                                   num_hidden=out_dim, no_bias=False)
+    output = mx.sym.transpose(output, axes=(1, 0, 2))
+    output = mx.sym.Group([output, att_score])
+    executor = output.simple_bind(ctx=default_context(),
+                                  qkv=(batch_size, qkv_length, qkv_dim),
+                                  q_weight=(qkv_units, qkv_dim),
+                                  q_bias=(qkv_units,),
+                                  k_weight=(qkv_units, qkv_dim),
+                                  k_bias=(qkv_units,),
+                                  v_weight=(qkv_units, qkv_dim),
+                                  v_bias=(qkv_units,),
+                                  type_dict={'qkv': dtype,
+                                             'q_weight': dtype,
+                                             'k_weight': dtype,
+                                             'v_weight': dtype,
+                                             'q_bias': dtype,
+                                             'k_bias': dtype,
+                                             'v_bias': dtype,
+                                             'sonde': dtype},
+                                  grad_req='write', force_rebind=True)
+    output_shape = executor.outputs[0].shape
+    output_grads = np.random.rand(*output_shape).astype(dtype) * 0.1
+    executor.copy_params_from(arg_params, {})
+    executor.arg_dict['sonde'][:] = 0.
+    executor.arg_dict['sonde'].wait_to_read()
+    executor.forward(is_train=True)
+    output_opti = executor.outputs[0].asnumpy()
+    att_score_opti = executor.outputs[1].asnumpy()
+    executor.backward([mx.nd.array(output_grads, dtype=dtype),
+                       mx.nd.zeros(att_score_opti.shape, dtype=dtype)])
+    grads_opti = {k: v.asnumpy() for k, v in executor.grad_dict.items()}
+    qkv = mx.sym.Variable('qkv')
+    sonde = mx.sym.Variable('sonde')
+    q_weight = mx.sym.Variable('q_weight')
+    k_weight = mx.sym.Variable('k_weight')
+    v_weight = mx.sym.Variable('v_weight')
+    q_bias = mx.sym.Variable('q_bias')
+    k_bias = mx.sym.Variable('k_bias')
+    v_bias = mx.sym.Variable('v_bias')
+    out_weight = mx.sym.Variable('out_weight')
+    out_bias = mx.sym.Variable('out_bias')
+
+    q = mx.sym.FullyConnected(qkv, weight=q_weight, bias=q_bias, flatten=False,
+                              num_hidden=qkv_units, no_bias=False)
+    k = mx.sym.FullyConnected(qkv, weight=k_weight, bias=k_bias, flatten=False,
+                              num_hidden=qkv_units, no_bias=False)
+    v = mx.sym.FullyConnected(qkv, weight=v_weight, bias=v_bias, flatten=False,
+                              num_hidden=qkv_units, no_bias=False)
+    q = mx.sym.reshape(q, shape=(0, 0, num_heads, -1))
+    q = mx.sym.transpose(q, axes=(0, 2, 1, 3))
+    q = mx.sym.reshape(q, shape=(-1, 0, 0), reverse=True)
+    k = mx.sym.reshape(k, shape=(0, 0, num_heads, -1))
+    k = mx.sym.transpose(k, axes=(0, 2, 1, 3))
+    k = mx.sym.reshape(k, shape=(-1, 0, 0), reverse=True)
+    q = mx.sym.contrib.div_sqrt_dim(q)
+    att_score = mx.sym.batch_dot(q, k, transpose_b=True)
+    att_score = att_score + sonde
+    v = mx.sym.reshape(v, shape=(0, 0, num_heads, -1))
+    v = mx.sym.transpose(v, axes=(0, 2, 1, 3))
+    v = mx.sym.reshape(v, shape=(-1, 0, 0), reverse=True)
+    weighted_value = mx.sym.batch_dot(att_score, v)
+    weighted_value = mx.sym.reshape(weighted_value, shape=(-1, num_heads, 0, 
0),
+                                    reverse=True)
+    weighted_value = mx.sym.transpose(weighted_value, axes=(0, 2, 1, 3))
+    weighted_value = mx.sym.reshape(weighted_value, shape=(0, 0, -1))
+    output = mx.sym.FullyConnected(weighted_value, weight=out_weight, 
bias=out_bias, flatten=False,
+                                   num_hidden=out_dim, no_bias=False)
+    output = mx.sym.Group([output, att_score])
+    executor = output.simple_bind(ctx=default_context(),
+                                  qkv=(batch_size, qkv_length, qkv_dim),
+                                  type_dict={'qkv': dtype},
+                                  grad_req='write', force_rebind=True)
+    executor.copy_params_from(arg_params, {})
+    executor.arg_dict['sonde'][:] = 0.
+    executor.arg_dict['sonde'].wait_to_read()
+    executor.forward(is_train=True)
+    output_orig = executor.outputs[0].asnumpy()
+    att_score_orig = executor.outputs[1].asnumpy()
+    executor.backward([mx.nd.array(output_grads, dtype=dtype),
+                       mx.nd.zeros(att_score_orig.shape, dtype=dtype)])
+    grads_orig = {k : v.asnumpy() for k, v in executor.grad_dict.items()}
+    assert_allclose(att_score_orig, att_score_opti, rtol=1e-2, atol=1e-3)
+    assert_allclose(output_orig, output_opti, rtol=1e-2, atol=1e-3)
+
+    for k in grads_opti.keys():
+        assert(grads_orig[k].dtype == grads_opti[k].dtype)
+        assert(grads_orig[k].shape == grads_opti[k].shape)
+        assert_allclose(grads_orig[k], grads_opti[k], rtol=1e-2, atol=1e-3)
+
+
+@with_seed(12345)
+def test_multihead_attention_selfatt():
+    dtypes = ['float32']
+    if default_context().device_type == 'gpu':
+        dtypes += ['float16']
+
+    for dtype in dtypes:
+        check_multihead_attention_selfatt(dtype=dtype)
+
+def check_multihead_attention_encdec(dtype):
+    def convert_weight(F, k_weight, v_weight, num_heads):
+        k_weight = F.reshape(k_weight, shape=(num_heads, -1, 0), reverse=True)
+        v_weight = F.reshape(v_weight, shape=(num_heads, -1, 0), reverse=True)
+        all_weights = F.concat(k_weight, v_weight, dim=-2)
+        all_weights = F.reshape(all_weights, shape=(-1, 0), reverse=True)
+        return all_weights
+
+    def convert_bias(F, k_bias, v_bias, num_heads):
+        k_bias = F.reshape(k_bias, shape=(num_heads, -1))
+        v_bias = F.reshape(v_bias, shape=(num_heads, -1))
+        all_bias = F.stack(k_bias, v_bias, axis=1)
+        all_bias = F.reshape(all_bias, shape=(-1,))
+        return all_bias
+
+    batch_size = 2
+    qkv_length = 7  # length of a sequence
+    qkv_dim = 9     # dimension of encoding
+    num_heads = 3   # number of attention head
+    head_dim = 5    # head size
+    out_dim = 13 * num_heads
+    qkv_units = num_heads * head_dim
+
+    arg_params = {
+        'q': mx.nd.array(np.random.rand(*(batch_size, qkv_length, 
qkv_dim)).astype(dtype) * 0.1, dtype=dtype),
+        'kv': mx.nd.array(np.random.rand(*(batch_size, qkv_length, 
qkv_dim)).astype(dtype) * 0.1, dtype=dtype),
+        'q_weight': mx.nd.array(np.random.rand(*(qkv_units, 
qkv_dim)).astype(dtype) * 0.1, dtype=dtype),
+        'k_weight': mx.nd.array(np.random.rand(*(qkv_units, 
qkv_dim)).astype(dtype) * 0.1, dtype=dtype),
+        'v_weight': mx.nd.array(np.random.rand(*(qkv_units, 
qkv_dim)).astype(dtype) * 0.1, dtype=dtype),
+        'q_bias': mx.nd.array(np.random.rand(*(qkv_units,)).astype(dtype) * 
0.1, dtype=dtype),
+        'k_bias': mx.nd.array(np.random.rand(*(qkv_units,)).astype(dtype) * 
0.1, dtype=dtype),
+        'v_bias': mx.nd.array(np.random.rand(*(qkv_units,)).astype(dtype) * 
0.1, dtype=dtype),
+        'out_weight': mx.nd.array(np.random.rand(*(out_dim, 
qkv_units)).astype(dtype) * 0.1, dtype=dtype),
+        'out_bias': mx.nd.array(np.random.rand(*(out_dim,)).astype(dtype) * 
0.1, dtype=dtype),
+        }
+
+    q = mx.sym.Variable('q')
+    kv = mx.sym.Variable('kv')
+    sonde = mx.sym.Variable('sonde')
+    q_weight = mx.sym.Variable('q_weight')
+    k_weight = mx.sym.Variable('k_weight')
+    v_weight = mx.sym.Variable('v_weight')
+    q_bias = mx.sym.Variable('q_bias')
+    k_bias = mx.sym.Variable('k_bias')
+    v_bias = mx.sym.Variable('v_bias')
+    out_weight = mx.sym.Variable('out_weight')
+    out_bias = mx.sym.Variable('out_bias')
+    kv_weight = convert_weight(mx.sym, k_weight, v_weight, num_heads)
+    kv_bias = convert_bias(mx.sym, k_bias, v_bias, num_heads)
+    kv = mx.sym.transpose(kv, axes=(1, 0, 2))
+    kv_proj = mx.sym.FullyConnected(kv, weight=kv_weight, bias=kv_bias, 
flatten=False,
+                                    num_hidden=qkv_units * 2, no_bias=False)
+    q = mx.sym.transpose(q, axes=(1, 0, 2))
+    q_proj = mx.sym.FullyConnected(q, weight=q_weight, bias=q_bias, 
flatten=False,
+                                   num_hidden=qkv_units, no_bias=False)
+    att_score = mx.sym.contrib.interleaved_matmul_encdec_qk(
+            q_proj, kv_proj, heads=num_heads) 
+    att_score = att_score + sonde
+    weighted_value = mx.sym.contrib.interleaved_matmul_encdec_valatt(
+            kv_proj, att_score, heads=num_heads)
+    output = mx.sym.FullyConnected(weighted_value, weight=out_weight, 
bias=out_bias, flatten=False,
+                                   num_hidden=out_dim, no_bias=False)
+    output = mx.sym.transpose(output, axes=(1, 0, 2))
+    output = mx.sym.Group([output, att_score])
+    executor = output.simple_bind(ctx=default_context(),
+                                  q=(batch_size, qkv_length, qkv_dim),
+                                  kv=(batch_size, qkv_length, qkv_dim),
+                                  q_weight=(qkv_units, qkv_dim),
+                                  q_bias=(qkv_units,),
+                                  k_weight=(qkv_units, qkv_dim),
+                                  k_bias=(qkv_units,),
+                                  v_weight=(qkv_units, qkv_dim),
+                                  v_bias=(qkv_units,),
+                                  out_weight=(out_dim, qkv_units),
+                                  out_bias=(out_dim,),
+                                  type_dict={'q': dtype,
+                                             'kv': dtype,
+                                             'q_weight': dtype,
+                                             'q_bias': dtype,
+                                             'k_weight': dtype,
+                                             'k_bias': dtype,
+                                             'v_weight': dtype,
+                                             'v_bias': dtype,
+                                             'out_weight': dtype,
+                                             'out_bias': dtype,
+                                              },
+                                  grad_req='write', force_rebind=True)
+    output_shape = executor.outputs[0].shape
+    output_grads = np.random.rand(*output_shape).astype(dtype) * 0.1
+    executor.copy_params_from(arg_params, {})
+    executor.arg_dict['sonde'][:] = 0.
+    executor.arg_dict['sonde'].wait_to_read()
+    executor.forward(is_train=True)
+    output_opti = executor.outputs[0].asnumpy()
+    att_score_opti = executor.outputs[1].asnumpy()
+    executor.backward([mx.nd.array(output_grads, dtype=dtype), 
mx.nd.zeros(att_score_opti.shape, dtype=dtype)])
+
+    grads_opti = {k: v.asnumpy() for k, v in executor.grad_dict.items()}
+
+    q = mx.sym.Variable('q')
+    kv = mx.sym.Variable('kv')
+    sonde = mx.sym.Variable('sonde')
+    q_weight = mx.sym.Variable('q_weight')
+    k_weight = mx.sym.Variable('k_weight')
+    v_weight = mx.sym.Variable('v_weight')
+    q_bias = mx.sym.Variable('q_bias')
+    k_bias = mx.sym.Variable('k_bias')
+    v_bias = mx.sym.Variable('v_bias')
+    out_weight = mx.sym.Variable('out_weight')
+    out_bias = mx.sym.Variable('out_bias')
+
+    q = mx.sym.FullyConnected(q, weight=q_weight, bias=q_bias, flatten=False,
+                              num_hidden=qkv_units, no_bias=False)
+    k = mx.sym.FullyConnected(kv, weight=k_weight, bias=k_bias, flatten=False,
+                              num_hidden=qkv_units, no_bias=False)
+    v = mx.sym.FullyConnected(kv, weight=v_weight, bias=v_bias, flatten=False,
+                              num_hidden=qkv_units, no_bias=False)
+    q = mx.sym.reshape(q, shape=(0, 0, num_heads, -1))
+    q = mx.sym.transpose(q, axes=(0, 2, 1, 3))
+    q = mx.sym.reshape(q, shape=(-1, 0, 0), reverse=True)
+    k = mx.sym.reshape(k, shape=(0, 0, num_heads, -1))
+    k = mx.sym.transpose(k, axes=(0, 2, 1, 3))
+    k = mx.sym.reshape(k, shape=(-1, 0, 0), reverse=True)
+    q = mx.sym.contrib.div_sqrt_dim(q)
+    att_score = mx.sym.batch_dot(q, k, transpose_b=True)
+    att_score = att_score + sonde
+    v = mx.sym.reshape(v, shape=(0, 0, num_heads, -1))
+    v = mx.sym.transpose(v, axes=(0, 2, 1, 3))
+    v = mx.sym.reshape(v, shape=(-1, 0, 0), reverse=True)
+    weighted_value = mx.sym.batch_dot(att_score, v)
+    weighted_value = mx.sym.reshape(weighted_value, shape=(-1, num_heads, 0, 
0),
+                                    reverse=True)
+    weighted_value = mx.sym.transpose(weighted_value, axes=(0, 2, 1, 3))
+    weighted_value = mx.sym.reshape(weighted_value, shape=(0, 0, -1))
+    output = mx.sym.FullyConnected(weighted_value, weight=out_weight, 
bias=out_bias, flatten=False,
+                                   num_hidden=out_dim, no_bias=False)
+    output = mx.sym.Group([output, att_score])
+    executor = output.simple_bind(ctx=default_context(),
+                                  q=(batch_size, qkv_length, qkv_dim),
+                                  kv=(batch_size, qkv_length, qkv_dim),
+                                  type_dict={'q': dtype,
+                                             'kv': dtype},
+                                  grad_req='write', force_rebind=True)
+    executor.copy_params_from(arg_params, {})
+    executor.arg_dict['sonde'][:] = 0.
+    executor.arg_dict['sonde'].wait_to_read()
+    executor.forward(is_train=True)
+    output_orig = executor.outputs[0].asnumpy()
+    att_score_orig = executor.outputs[1].asnumpy()
+    executor.backward([mx.nd.array(output_grads, dtype=dtype), 
mx.nd.zeros(att_score_orig.shape, dtype=dtype)])
+    grads_orig = {k : v.asnumpy() for k, v in executor.grad_dict.items()}
+    assert_allclose(att_score_orig, att_score_opti, rtol=1e-2, atol=1e-3)
+    assert_allclose(output_orig, output_opti, rtol=1e-2, atol=1e-3)
+
+    for k in grads_opti.keys():
+        assert(grads_orig[k].dtype == grads_opti[k].dtype)
+        assert(grads_orig[k].shape == grads_opti[k].shape)
+        assert_allclose(grads_orig[k], grads_opti[k], rtol=1e-2, atol=1e-3)
+
+@with_seed(12345)
+def test_multihead_attention_encdec():
 
 Review comment:
   do we still need `check_multihead_attention_selfatt ` in 
test_operator_gpu.py? 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to