shengxinhu commented on a change in pull request #9438:
URL: https://github.com/apache/tvm/pull/9438#discussion_r745221748
##########
File path: python/tvm/relay/frontend/onnx.py
##########
@@ -3125,6 +3125,223 @@ def _impl_v1(cls, inputs, attr, params):
return ret
+class Scan(OnnxOpConverter):
+ """Operator converter for Scan"""
+
+ @classmethod
+ def _impl_v8(cls, inputs, attr, params):
+ new_inputs = inputs[1:]
+ batch_num = infer_shape(inputs[1])[0]
+ out = []
+ for i in range(batch_num):
+ v9_inputs = [
+ _op.take(new_inputs[j], _expr.const(i), axis=0) for j in
range(len(new_inputs))
+ ]
+ results = cls._impl_v9(v9_inputs, attr, params)
+ results = [_op.expand_dims(results[j], axis=0) for j in
range(len(results))]
+ if i == 0:
+ out = results
+ else:
+ out = [_op.concatenate([out[j], results[j]], axis=0) for j in
range(len(results))]
+
+ out = _expr.TupleWrapper(_expr.Tuple(out), len(out))
+ return out
+
+ @classmethod
+ def _impl_v9(cls, inputs, attr, params):
+ body = attr.get("body")
+ num_scan_inputs = attr.get("num_scan_inputs")
+ num_all_inputs = len(inputs)
+ num_state_inputs = len(body.input) - num_scan_inputs
+ num_state_outputs = num_state_inputs
+ num_all_outputs = len(body.output)
+ num_scan_outputs = num_all_outputs - num_state_outputs
+ scan_input_axes = attr.get("scan_input_axes", [0] * num_scan_inputs)
+ scan_input_directions = attr.get("scan_input_directions", [0] *
num_scan_inputs)
+ scan_output_axes = attr.get("scan_output_axes", [0] * num_scan_outputs)
+ scan_output_directions = attr.get("scan_output_directions", [0] *
num_scan_outputs)
+ # loop count are the same for all scan inputs, so get loop count by
first input scan
+ # strided_slice not support dynamic axes, so assume input shape are
static
+ max_loop_count =
infer_shape(inputs[num_state_inputs])[scan_input_axes[0]]
+
+ # Create a copy of the body function to prevent the original
+ # from being modified.
+ body = copy.copy(attr["body"])
+
+ # Loop inputs will be packed as
+ # [iter_count, loop_deps, scan_outputs]
+ def cond_fn(*loop_inputs):
+ i = loop_inputs[0]
+ return _op.less(i, relay.const(max_loop_count, "int32"))
+
+ # Get the current graph proto and create a clone for the subgraph
+ graph_scope = GraphProto.current
+ subgraph_scope = GraphProto(
+ graph_scope._shape, graph_scope._dtype, graph_scope._freeze_params
+ )
+ # Load nodes from outer graph into inner graph.
+ subgraph_scope._nodes = graph_scope._nodes.copy()
+
+ # Create a list of variables for each value updated in the loop.
+ def get_var(name, val, scan=False):
+ checked_type = infer_type(val)
+ if hasattr(checked_type, "type_annotation"):
+ checked_type = checked_type.type_annotation
+ if hasattr(checked_type, "checked_type"):
+ checked_type = checked_type.checked_type
+ shape = get_const_tuple(checked_type.shape)
+ actual_shape = []
+ for dim in shape:
+ if isinstance(dim, int) and dim == 0:
+ actual_shape.append(_ty.Any())
+ else:
+ actual_shape.append(dim)
+ if scan:
+ return _expr.var(name, shape=[_ty.Any()] + actual_shape,
dtype=checked_type.dtype)
+
+ return _expr.var(name, shape=actual_shape,
dtype=checked_type.dtype)
+
+ # Construct variables and initial empty tensors for any scan outputs.
+ # To do this, we'll figure out the output shapes of the body subgraph
by importing
+ # it and doing type inference.
+ scan_output_vars = []
+ scan_output_init = []
+ if num_scan_outputs > 0:
+ with subgraph_scope:
+ loop_outputs = subgraph_scope.from_onnx(
+ body, graph_scope.opset, get_output_expr=True
+ )
+ loop_outputs = _expr.TupleWrapper(loop_outputs, len(body.output))
+
+ for i in range(num_scan_outputs):
+ name, _, _, _ = get_info(body.output[i + num_state_outputs])
+ output_node = infer_type(loop_outputs[i + num_state_outputs])
+ shape = list(get_const_tuple(output_node.checked_type.shape))
+ shape.insert(scan_output_axes[i], max_loop_count)
+ dtype = output_node.checked_type.dtype
+ scan_output_vars.append(_expr.var(name, shape=shape, dtype=dtype))
+ scan_output_init.append(_op.zeros(shape, dtype))
+
Review comment:
It seems has a more elegant method(similar as operator Loop existing
implementation):
scan_output_vars.append(_expr.var(name, shape=_ty.Any() * len(shape),
dtype=dtype))
scan_output_init.append(_op.rehape(_expr.const(np.array[]).astype(dtype)),
[0] + [1]*(len(shape)-1))
so the following _op.strided_slice could be removed, as the scan_output_init
is empty.
But I did not make it work, could you take a look?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]