josephevans commented on a change in pull request #19653:
URL: https://github.com/apache/incubator-mxnet/pull/19653#discussion_r543044850



##########
File path: python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
##########
@@ -2280,6 +2289,200 @@ def convert_layer_norm(node, **kwargs):
     return nodes
 
 
+def make_tensor(shape_list, shape_name, initializer, dtype='int64'):

Review comment:
       Please rename to create_tensor() or something, so it's a clear 
distinction between this and onnx.helper.make_tensor().

##########
File path: python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
##########
@@ -2280,6 +2289,200 @@ def convert_layer_norm(node, **kwargs):
     return nodes
 
 
+def make_tensor(shape_list, shape_name, initializer, dtype='int64'):
+    shape_np = np.array(shape_list, dtype=dtype)
+    data_type = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[shape_np.dtype]
+    dims = np.shape(shape_np)
+    tensor_node = onnx.helper.make_tensor_value_info(shape_name, data_type, 
dims)
+    initializer.append(
+        onnx.helper.make_tensor(
+            name=shape_name,
+            data_type=data_type,
+            dims=dims,
+            vals=shape_list,
+            raw=False,
+        )
+    )
+
+
+@mx_op.register("_contrib_interleaved_matmul_selfatt_qk")
+def convert_matmul_selfatt_qk(node, **kwargs):
+    from onnx.helper import make_node
+    from onnx import TensorProto
+    name, input_nodes, attrs = get_inputs(node, kwargs)
+
+    heads = int(attrs.get('heads'))
+    
+    # a, b, c, d, e are seq_len, batch_size, num_heads, 3, head_dim 
respectively
+    make_tensor([heads], name+"_const_heads", kwargs["initializer"])
+    make_tensor([0], name+"_0", kwargs["initializer"])
+    make_tensor([1], name+"_1", kwargs["initializer"])
+    make_tensor([1], name+"_1_f", kwargs["initializer"], dtype='float32')
+    make_tensor([2], name+"_2", kwargs["initializer"])
+    make_tensor([3], name+"_3", kwargs["initializer"])
+    make_tensor([4], name+"_4", kwargs["initializer"])
+    make_tensor([5], name+"_5", kwargs["initializer"])
+    make_tensor([heads], name+"_c", kwargs["initializer"])
+    make_tensor([3], name+"_d", kwargs["initializer"])
+ 
+    nodes = [
+            make_node('Shape', [input_nodes[0]], [name+"_data_shape"]),
+            make_node('Slice', [name+'_data_shape', name+'_0', name+'_1'], 
[name+"_a"]),
+            make_node('Slice', [name+'_data_shape', name+'_1', name+'_2'], 
[name+"_b"]),
+            make_node('Slice', [name+'_data_shape', name+'_2', name+'_3'], 
[name+"_cde"]),
+            make_node('Div', [name+'_cde', name+'_c'], [name+'_de']),
+            make_node('Div', [name+'_de', name+'_d'], [name+'_e']),
+            make_node('Cast', [name+'_e'], [name+'_e_f'], 
to=int(TensorProto.FLOAT)),
+            make_node('Sqrt', [name+'_e_f'], [name+'_sqrt_e']),
+            make_node('Div', [name+'_1_f', name+'_sqrt_e'], 
[name+'_1_over_sqrt_e']),
+            make_node('Mul', [name+'_b', name+'_c'], [name+'_bc']),
+
+            make_node("Concat", [name+'_a', name+'_b', name+'_c', name+'_d', 
name+'_e'], \
+                      [name+'_shape0'], axis=0),
+            make_node("Concat", [name+'_0', name+'_0', name+'_0', name+'_0', 
name+'_0'], \
+                      [name+'_slice_start0'], axis=0),
+            make_node("Concat", [name+'_a', name+'_b', name+'_c', name+'_1', 
name+'_e'], \
+                      [name+'_slice_end0'], axis=0),
+            make_node("Concat", [name+'_a', name+'_b', name+'_c', name+'_e'], \
+                      [name+'_shape1'], axis=0),
+            make_node("Concat", [name+'_bc', name+'_a', name+'_e'], \
+                      [name+'_shape2'], axis=0),
+            make_node("Concat", [name+'_0', name+'_0', name+'_0', name+'_1', 
name+'_0'], \
+                      [name+'_slice_start1'], axis=0),
+            make_node("Concat", [name+'_a', name+'_b', name+'_c', name+'_2', 
name+'_e'], \
+                      [name+'_slice_end1'], axis=0),
+
+            make_node('Reshape', [input_nodes[0], name+'_shape0'], 
[name+'_reshape0_out']),
+            make_node('Slice', [name+'_reshape0_out', name+'_slice_start0', 
name+'_slice_end0'], \
+                      [name+'_slice0_out']),
+            make_node('Reshape', [name+'_slice0_out', name+'_shape1'], 
[name+'_reshape1_out']),
+            make_node('Transpose', [name+'_reshape1_out'], 
[name+'_transpose0_out'], \
+                      perm=(1, 2, 0, 3)),
+            make_node('Reshape', [name+'_transpose0_out', name+'_shape2'], 
[name+'_reshape2_out']),
+            make_node('Mul', [name+'_reshape2_out', name+'_1_over_sqrt_e'], 
[name+'_mul0_out']),
+            make_node('Slice', [name+'_reshape0_out', name+'_slice_start1', 
name+'_slice_end1'], \
+                      [name+'_slice1_out']),
+            make_node('Reshape', [name+'_slice1_out', name+'_shape1'], 
[name+'_reshape3_out']),
+            make_node('Transpose', [name+'_reshape3_out'], 
[name+'_transpose1_out'], \
+                      perm=(1, 2, 0, 3)),
+            make_node('Reshape', [name+'_transpose1_out', name+'_shape2'], 
[name+'_reshape4_out']),
+            make_node('Transpose', [name+'_reshape4_out'], 
[name+'_transpose2_out'], \
+                      perm=(0, 2, 1)),
+            make_node('MatMul', [name+'_mul0_out', name+'_transpose2_out'], 
[name], name=name)
+        ]
+
+    return nodes
+
+
+@mx_op.register("broadcast_axis")
+def convert_broadcast_axis(node, **kwargs):
+    from onnx.helper import make_node
+    from onnx import TensorProto
+    name, input_nodes, attrs = get_inputs(node, kwargs)
+
+    data_shape_list = list(kwargs['in_shape'][0])
+    axis = convert_string_to_list(attrs.get('axis', '()'))
+    size = convert_string_to_list(attrs.get('size', '()'))
+    assert len(axis) == len(size)
+
+    make_tensor([0], name+'_0', kwargs["initializer"])
+    make_tensor([1], name+'_1', kwargs["initializer"])
+    make_tensor([], name+'_void', kwargs["initializer"])
+    create_const_scalar_node(name+'_0_s', np.int64(0), kwargs)
+    create_const_scalar_node(name+'_1_s', np.int64(1), kwargs)
+
+    shape_name = name+'_shape_0'
+    nodes = [
+            make_node('Shape', [input_nodes[0]], [shape_name]),
+            make_node('Shape', [shape_name], [name+'_in_dim']),
+            make_node('Reshape', [name+'_in_dim', name+'_void'], 
[name+'_in_dim_s']),
+            make_node('Range', [name+'_0_s', name+'_in_dim_s', name+'_1_s'], 
[name+'_range']),
+        ]
+
+    for i, axis in enumerate(axis):
+        if axis not in (0, 1):
+            make_tensor([axis], name+'_'+str(axis), kwargs["initializer"])
+        make_tensor([size[i]-1], name+'_size_'+str(i), kwargs["initializer"])
+        _ = [
+             # this is a "one-hot" tensor
+             make_node('Equal', [name+'_range', name+'_'+str(axis)], 
[name+'_equal_'+str(i)]),
+             make_node('Cast', [name+'_equal_'+str(i)], 
[name+'_cast_'+str(i)], to=int(TensorProto.INT64)),
+             make_node('Mul', [name+'_size_'+str(i), name+'_cast_'+str(i)], 
[name+'_mul_'+str(i)]),
+             make_node('Add', [name+'_mul_'+str(i), name+'_1'], 
[name+'_add_'+str(i)]),
+             make_node('Mul', [name+'_add_'+str(i), shape_name], 
[name+'_shape_'+str(i+1)])
+            ]
+        shape_name = name+'_shape_'+str(i+1)
+        nodes += _
+
+    nodes += [make_node('Expand', [input_nodes[0], shape_name], [name], 
name=name)]
+
+    return nodes
+
+@mx_op.register("_contrib_interleaved_matmul_selfatt_valatt")
+def convert_interleaved_matmul_selfatt_valatt(node, **kwargs):
+    return []
+
+
+@mx_op.register("SequenceMask")
+def convert_sequencemask(node, **kwargs):
+    from onnx.helper import make_node
+    from onnx import TensorProto
+
+    name, input_nodes, attrs = get_inputs(node, kwargs)
+
+    use_sequence_length = attrs.get('use_sequence_length', 'False')
+    mask_val = float(attrs.get('value', '0'))
+    axis = int(attrs.get('axis', '0'))
+
+    if(use_sequence_length == 'False'):
+        return [make_node('Identity', [input_nodes[0]], [name], name=name)]
+
+    make_tensor([], name+'_void', kwargs["initializer"])
+    make_tensor([0], name+'_0', kwargs["initializer"])
+    make_tensor([1], name+'_1', kwargs["initializer"])
+    make_tensor([2], name+'_2', kwargs["initializer"])
+    create_const_scalar_node(name+'_0_s', np.int64(0), kwargs)
+    create_const_scalar_node(name+'_1_s', np.int64(1), kwargs)
+    create_const_scalar_node(name+'_2_s', np.int64(2), kwargs)
+    make_tensor([mask_val], name+'_mask_val', kwargs["initializer"], 
dtype='float32')
+    #create_const_scalar_node(name+'_mask_val', np.float32(mask_val), kwargs),
+
+    nodes = [
+        make_node('Shape', [input_nodes[0]], [name+'_in_shape']),
+        make_node('Slice', [name+'_in_shape', name+'_0', name+'_1'], 
[name+'_slice_0']),
+        make_node('Slice', [name+'_in_shape', name+'_1', name+'_2'], 
[name+'_slice_1']),
+        make_node('Concat', [name+'_slice_0', name+'_1'], [name+'_shape_0'], 
axis=0),
+        make_node('Shape', [name+'_in_shape'], [name+'_in_dim']),
+        make_node('Reshape', [name+'_in_dim', name+'_void'], 
[name+'_in_dim_s']),
+        make_node('Range', [name+'_0_s', name+'_in_dim_s', name+'_1_s'], 
[name+'_range_0']),
+        make_node('Less', [name+'_range_0', name+'_2'], [name+'_less_0']),
+        make_node('Where', [name+'_less_0', name+'_in_shape', name+'_1'], 
[name+'_shape_1'])
+        ]
+
+    if(axis == 0):
+        nodes += [
+            make_node('Reshape', [name+'_slice_0', name+'_void'], 
[name+'_max_len'], name = '111'),

Review comment:
       Do we need to specify the name parameter?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to