sxjscience commented on a change in pull request #19387:
URL: https://github.com/apache/incubator-mxnet/pull/19387#discussion_r508816083



##########
File path: tests/python/unittest/test_operator.py
##########
@@ -9418,3 +9418,85 @@ def 
test_broadcast_ops_on_misaligned_input_oneside(dtype, lead_dim, both_ways):
     mx.nd.waitall()
     assert_almost_equal(f, expected)
 
+
+def test_sldwin_selfatten_operators():
+    def gen_sliding_window_mask_full(batch_size, num_heads, seq_length, w, 
symmetric, d):
+        """Generate sliding_window attention mask for the full attention 
matrix ( seq_len^2 ).
+        """
+        mask_np = np.zeros((batch_size, num_heads, seq_length, seq_length))
+        for i in range(seq_length):
+            end = (i + 1 + w * d) if symmetric else (i + 1)
+            for j in range(i - w * d, end, d):
+                if j >= 0 and j < seq_length:
+                    mask_np[:, :, i, j] = 1
+        return mask_np
+
+    def test_sldwin_atten_op_impl(batch_size, seq_length, num_heads,
+                                  num_head_units, w, symmetric, d):
+        # Generate the data
+        query = np.random.normal(0, 1, (batch_size, seq_length, num_heads, 
num_head_units))
+        key = np.random.normal(0, 1, (batch_size, seq_length, num_heads, 
num_head_units))
+        value = np.random.normal(0, 1, (batch_size, seq_length, num_heads, 
num_head_units))
+        valid_length = np.zeros((batch_size,))
+        valid_length[:] = seq_length
+
+        ctx = mx.gpu(0)

Review comment:
       No need to specify the ctx. Both GPU and CPU will be tested 
automatically in the CI system.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to