This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 87366b56ed Oneflow fronted support more model and fix bug (#11321)
87366b56ed is described below
commit 87366b56ed25456c2d1984183e9fa28e6958f93e
Author: Xiaoyu Zhang <[email protected]>
AuthorDate: Sun May 15 03:50:25 2022 +0800
Oneflow fronted support more model and fix bug (#11321)
* add relay.f.frontend.fm_oneflow support cnns
* support cuda
* fix mobilenetv2 and reviews
* fix: model without meta info
* support eager and yolo, add test
* fix: license
* add: tutorials
* fix: support new graph
* fix some comments
* refine
* fix concat op convert bug
* refine
* refine
* change cuda to cpu
* fix bug
* fix ci error in tvm
* fix pylint check
* delete useless file
* add skimage package in docker
* fix ci error
* fix bug
* add oneflow fronted test in ci
* merge conflict
* fix tutorial
* try to find error in ci
* revert
* merge conflict
* black oneflow
* Delete from_oneflow.py
* restruct oneflow fronted
* support vision-transformer
* black format
* update black version and reformat
* fix ci error
* fix doc error
* fix gpu fronted test failed
Co-authored-by: hhhfccz <[email protected]>
---
python/tvm/relay/frontend/oneflow.py | 418 ++++++++++++++-------
tests/python/frontend/oneflow/test_forward.py | 199 ++++++++++
.../python/frontend/oneflow/test_vision_models.py | 150 ++++++++
3 files changed, 630 insertions(+), 137 deletions(-)
diff --git a/python/tvm/relay/frontend/oneflow.py
b/python/tvm/relay/frontend/oneflow.py
index a1a7d513f8..ff4b5a5bcc 100644
--- a/python/tvm/relay/frontend/oneflow.py
+++ b/python/tvm/relay/frontend/oneflow.py
@@ -21,7 +21,7 @@
import os
import re
import copy
-import warnings
+from collections import OrderedDict
import numpy as np
import tvm
@@ -38,7 +38,6 @@ from .common import (
Renamer,
fold_constant,
get_relay_op,
- infer_channels,
infer_shape,
infer_type,
new_var,
@@ -97,7 +96,6 @@ def _dtype_shape_promotion(inputs):
"""Promote data type and shape for list of tensors."""
dtype_order = ["bool", "int8", "int16", "int32", "int64", "float32",
"float64"]
-
ranks = [len(infer_shape(x)) for x in inputs]
if set(ranks) == set([1, 0]):
for i, r in enumerate(ranks):
@@ -497,19 +495,26 @@ class Flatten(OneFlowOpConverter):
@classmethod
def _impl_v1(cls, inputs, attrs, params):
- axis = attrs.get("axis", 1)
- ishape = _op.shape_of(inputs[0])
- ndim = infer_shape(ishape)[0]
- if axis < 0:
- axis = axis + ndim
-
- if axis == 1:
- out = _op.nn.batch_flatten(inputs[0])
- else:
- pre_shape = _op.prod(_op.strided_slice(ishape, [0], [axis], [1]),
keepdims=True)
- post_shape = _op.prod(_op.strided_slice(ishape, [axis], [ndim],
[1]), keepdims=True)
- newshape = _op.concatenate([pre_shape, post_shape], axis=0)
- out = _op.reshape(inputs[0], newshape)
+ x = inputs[0]
+ input_shape = list(infer_shape(x))
+
+ start = attrs["start_dim"]
+ end = attrs["end_dim"]
+ ndim = len(input_shape)
+ if end < 0:
+ end += ndim
+ new_shape = [0] * start
+
+ new_shape.append(-1)
+ squeeze_axes = []
+ for i in range(start + 1, end + 1):
+ new_shape.append(1)
+ squeeze_axes.append(i)
+ for _ in range(end + 1, ndim):
+ new_shape.append(0)
+ out = _op.reshape(x, new_shape)
+ if squeeze_axes:
+ out = _op.squeeze(out, axis=squeeze_axes)
return out
@@ -518,36 +523,119 @@ class MatMul(OneFlowOpConverter):
@classmethod
def _impl_v1(cls, inputs, attrs, params):
- assert len(inputs) == 2, "Gemm op take 2 inputs, {}
given".format(len(inputs))
- # Similar to 'class Conv'
- true_names = ["weight"]
- false_names = ["_input."]
- for i in inputs:
- T_NAMES = any(x in str(i) for x in true_names)
- F_NAMES = any(x in str(i) for x in false_names)
- if T_NAMES and not F_NAMES:
- matmul_b = i
- else:
- matmul_a = i
-
- dtype = infer_type(matmul_a).checked_type.dtype
+ assert len(inputs) == 2, "MatMul op take 2 inputs, {}
given".format(len(inputs))
+ dtype = infer_type(inputs[0]).checked_type.dtype
# Y = alpha * A * B
alpha = float(attrs.get("alpha", 1.0))
transA = bool(attrs.get("transpose_a", False))
transB = bool(attrs.get("transpose_b", False))
- # get number of channels
- channels = infer_channels(matmul_b, not transB)
- if transA:
- matmul_a = _op.transpose(matmul_a, axes=(1, 0))
- if not transB:
- matmul_b = _op.transpose(matmul_b, axes=(1, 0))
- matmul_a = _op.nn.batch_flatten(matmul_a)
- if alpha != 1.0:
- matmul_a *= _expr.const(alpha, dtype=dtype)
+ a_shape = infer_shape(inputs[0])
+ b_shape = infer_shape(inputs[1])
+ if (
+ (transA and transB and a_shape[-2] != b_shape[-1])
+ or (transA and not transB and a_shape[-2] != b_shape[-2])
+ or (transB and not transA and a_shape[-1] != b_shape[-1])
+ or (not transB and not transA and a_shape[-1] != b_shape[-2])
+ ):
+ matmul_a = inputs[1]
+ matmul_b = inputs[0]
+ else:
+ matmul_a = inputs[0]
+ matmul_b = inputs[1]
- return _op.nn.dense(matmul_a, matmul_b, units=channels)
+ if transA:
+ perm = list(range(len(a_shape)))
+ perm[-2] = len(a_shape) - 1
+ perm[-1] = len(a_shape) - 2
+ matmul_a = _op.transpose(matmul_a, axes=perm)
+ if transB:
+ perm = list(range(len(b_shape)))
+ perm[-2] = len(b_shape) - 1
+ perm[-1] = len(b_shape) - 2
+ matmul_b = _op.transpose(matmul_b, axes=perm)
+
+ # This implemention almost keeps same with ONNX
+ # Need to check input shape as batch matmul must be supported.
+ a_shape = shape_of(matmul_a, dtype="int32")
+ a_rank = infer_shape(a_shape)[0]
+ b_shape = shape_of(matmul_b, dtype="int32")
+ b_rank = infer_shape(b_shape)[0]
+ # When performing a batch matmul, we need to properly handle N-dim
shapes.
+ if a_rank > 2 or b_rank > 2:
+
+ def flatten_to_nd(x, x_shape, nd=3):
+ ndims = infer_shape(x_shape)[0]
+ if ndims == nd:
+ return x
+ newshape = _op.concatenate(
+ [
+ _expr.const([-1],
dtype=infer_type(x_shape).checked_type.dtype),
+ _op.strided_slice(x_shape, [ndims - nd + 1], [ndims]),
+ ],
+ 0,
+ )
+ out = _op.reshape(x, fold_constant(newshape))
+ return out
+
+ b_type = infer_type(matmul_b)
+ # Convert to dense if the second matrix is 2d and non-dynamic
+ if b_rank == 2 and not _ty.is_dynamic(b_type.checked_type):
+ a = flatten_to_nd(matmul_a, a_shape, 2)
+ b = _op.transpose(matmul_b)
+ output = _op.nn.dense(a, b)
+ else:
+ # Convert a and b into 3 dimensional tensors.
+ a = flatten_to_nd(matmul_a, a_shape, 3)
+ b = flatten_to_nd(matmul_b, b_shape, 3)
+ # Transpose matrix dimensions of b.
+ b = _op.transpose(b, [0, 2, 1])
+ # Perform a batch matmul.
+ output = _op.nn.batch_matmul(a, b)
+ # Determine the output batch dimension.
+ if a_rank > b_rank:
+ out_batch = _op.strided_slice(a_shape, [0], [a_rank - 2])
+ elif a_rank < b_rank:
+ out_batch = _op.strided_slice(b_shape, [0], [b_rank - 2])
+ # If its unclear how broadcasting should be applied, the output
+ # shape is determined by choosing the maximum value from each
input.
+ else:
+ out_batch = _op.concatenate(
+ [
+ _op.maximum(
+ _op.strided_slice(a_shape, [i], [i + 1]),
+ _op.strided_slice(b_shape, [i], [i + 1]),
+ )
+ for i in range(a_rank - 2)
+ ],
+ 0,
+ )
+ # Reshape output to original dimensions.
+ final_shape = _op.concatenate(
+ [
+ out_batch,
+ _op.strided_slice(
+ a_shape, [infer_shape(a_shape)[0] - 2],
[infer_shape(a_shape)[0] - 1]
+ ),
+ _op.strided_slice(
+ b_shape, [infer_shape(b_shape)[0] - 1],
[infer_shape(b_shape)[0]]
+ ),
+ ],
+ 0,
+ )
+ out = _op.reshape(output, fold_constant(final_shape))
+ else:
+ if b_rank == 1:
+ matmul_b = _op.expand_dims(matmul_b, 1, 1)
+ # Otherwise a simple dense op will get the job done.
+ input_1_t = _op.transpose(matmul_b, axes=(1, 0))
+ out = _op.nn.dense(matmul_a, input_1_t)
+ if b_rank == 1:
+ out = _op.squeeze(out, axis=[-1])
+ if not np.isclose(alpha, 1.0):
+ out = out * _expr.const(alpha, dtype=dtype)
+ return out
class Reduce(OneFlowOpConverter):
@@ -635,15 +723,34 @@ class Expand(OneFlowOpConverter):
@classmethod
def _impl_v1(cls, inputs, attrs, params):
- input_shape = infer_shape(inputs[0])
- assert input_shape == attrs["in_shape"], "shape wrong"
-
- new_shape = attrs["out_shape"]
- out = _op.broadcast_to(inputs[0], shape=new_shape)
+ data_in = inputs[0]
+ shape = list(infer_shape(data_in))
+
+ ndims = len(shape)
+ sizes = attrs["logical_expand_shape"]
+ out = data_in
+ out_dims = len(sizes)
+ if ndims < out_dims:
+ num_newaxis = out_dims - ndims
+ out = _op.expand_dims(out, axis=0, num_newaxis=num_newaxis)
+ shape = [1] * num_newaxis + shape
+
+ for i in range(out_dims):
+ if sizes[i] != -1 and shape[i] == 1:
+ out = _op.repeat(out, sizes[i], axis=i)
return out
+class Transpose(OneFlowOpConverter):
+ """Operator converter for transpose."""
+
+ @classmethod
+ def _impl_v1(cls, inputs, attrs, params):
+ perm = attrs["perm"]
+ return _op.transpose(inputs[0], axes=perm)
+
+
class ExpandDim(OneFlowOpConverter):
"""Operator converter for ExpandDim"""
@@ -718,12 +825,25 @@ class BroadcastDiv(BroadcastMath):
name = "divide"
-class Greater(OneFlowOpConverter):
+class LogicalGreater(OneFlowOpConverter):
"""Operator converter for greater"""
@classmethod
def _impl_v1(cls, inputs, attrs, params):
- return _op.greater(inputs[0], inputs[1])
+ res = None
+ if attrs.get("has_int_operand", True):
+ value = attrs.get("int_operand", 0.0)
+ res = _op.greater(inputs[0], _op.full_like(inputs[0],
fill_value=_expr.const(value)))
+ elif attrs.get("has_float_operand", True):
+ value = float(attrs.get("float_operand", 0.0))
+ res = _op.greater(
+ inputs[0], _op.full_like(inputs[0],
fill_value=_expr.const(value)).astype("float32")
+ )
+ else:
+ raise AttributeError(
+ "please check if has_int_operand or has_float_operand in your
attrs"
+ )
+ return res
class Log1p(OneFlowOpConverter):
@@ -734,6 +854,15 @@ class Log1p(OneFlowOpConverter):
return _op.log(inputs[0] + _expr.const(1.0))
+class Pow(OneFlowOpConverter):
+ """Operator converter for Power"""
+
+ @classmethod
+ def _impl_v1(cls, inputs, attrs, params):
+ inputs = _dtype_shape_promotion(inputs)
+ return get_relay_op(cls.name)(inputs[0], inputs[1])
+
+
class Expm1(OneFlowOpConverter):
"""Operator converter for Expm1"""
@@ -812,14 +941,35 @@ class ScalarMul(OneFlowOpConverter):
return res
+class ScalarDiv(OneFlowOpConverter):
+ """Operator convert for Div_scalar"""
+
+ @classmethod
+ def _impl_v1(cls, inputs, attrs, params):
+ assert len(inputs) == 1, "div_scalar take == 1 inputs, but {}
given.".format(len(inputs))
+
+ if attrs.get("has_int_operand", True):
+ res = inputs[0] / _expr.const(attrs["int_operand"],
dtype="float32")
+ elif attrs.get("has_float_operand", True):
+ res = inputs[0] / _expr.const(attrs["float_operand"])
+ else:
+ raise AttributeError(
+ "please check if has_int_operand or has_float_operand in your
attrs"
+ )
+
+ return res
+
+
class ScalarPow(OneFlowOpConverter):
"""Operator convert for Pow_scalar"""
@classmethod
def _impl_v1(cls, inputs, attrs, params):
- exponent = attrs.get("exponent", 1.0)
- exponent = _expr.const(exponent, dtype="float32")
- return _op.power(inputs[0], exponent)
+ if attrs.get("has_int_operand", True):
+ coeff = _expr.const(attrs["int_operand"])
+ elif attrs.get("has_float_operand", True):
+ coeff = _expr.const(attrs["float_operand"])
+ return _op.power(inputs[0], coeff)
class MaxPool2d(Pool):
@@ -857,15 +1007,12 @@ class Softmax(OneFlowOpConverter):
@classmethod
def _impl_v1(cls, inputs, attrs, params):
- axis = attrs.get("axis", 1)
- ndim = len(infer_shape(inputs[0]))
- if axis < 0:
- axis += ndim
- axes = list(range(axis, ndim))
- x = inputs[0]
- m = _op.max(x, axes, keepdims=True)
- e = _op.exp(x - m)
- return e / _op.sum(e, axes, keepdims=True)
+ axis = attrs.get("axis", -1)
+ data = inputs[0]
+ if isinstance(axis, str):
+ axis = int(axis)
+
+ return _op.nn.softmax(data, axis=axis)
class LogSoftmax(OneFlowOpConverter):
@@ -1000,6 +1147,17 @@ class Softsign(OneFlowOpConverter):
return inputs[0] / (_expr.const(1.0) +
Absolute.get_converter()(inputs, attrs, params))
+class Variance(OneFlowOpConverter):
+ """Operator converter for Variance"""
+
+ @classmethod
+ def _impl_v1(cls, inputs, attrs, params):
+ axis = attrs["dim"]
+ keepdims = attrs["keepdim"]
+ unbiased = bool(attrs["unbiased"])
+ return _op.reduce.variance(inputs[0], axis=axis, keepdims=keepdims,
unbiased=unbiased)
+
+
class Concat(OneFlowOpConverter):
"""Operator converter for Concat"""
@@ -1234,6 +1392,7 @@ def get_convert_map():
"bias_add": Add.get_converter(),
"scalar_add": ScalarAdd.get_converter(),
"scalar_mul": ScalarMul.get_converter(),
+ "scalar_div": ScalarDiv.get_converter(),
"scalar_pow": ScalarPow.get_converter(),
"reduce_sum": ReduceSum.get_converter(),
"reduce_max": ReduceMax.get_converter(),
@@ -1243,7 +1402,7 @@ def get_convert_map():
"broadcast_mul": BroadcastMul.get_converter(),
"broadcast_sub": BroadcastSub.get_converter(),
"broadcast_div": BroadcastDiv.get_converter(),
- "broadcast_greater": Greater.get_converter(),
+ "scalar_logical_greater": LogicalGreater.get_converter(),
"log": Renamer("log"),
"log1p": Log1p.get_converter(),
"acos": Renamer("acos"),
@@ -1258,7 +1417,7 @@ def get_convert_map():
"sinh": Renamer("sinh"),
"tan": Renamer("tan"),
"tanh": Renamer("tanh"),
- "pow": Renamer("power"),
+ "pow": Pow.get_converter(),
"exp": Renamer("exp"),
"expm1": Expm1.get_converter(),
"floor": Renamer("floor"),
@@ -1271,7 +1430,7 @@ def get_convert_map():
"sign": Sign.get_converter(),
"erf": Erf.get_converter(),
"erfc": Erfc.get_converter(),
- "reciprocal_no_nan": Reciprocal.get_converter(),
+ "reciprocal": Reciprocal.get_converter(),
# defs/activation
"softmax": Softmax.get_converter(),
"softsign": Softsign.get_converter(),
@@ -1295,24 +1454,29 @@ def get_convert_map():
"upsample_bilinear_2d": UpsampleBiLinear.get_converter(),
# defs/tensor
"matmul": MatMul.get_converter(),
+ "batch_matmul": MatMul.get_converter(),
+ "broadcast_matmul": MatMul.get_converter(),
"concat": Concat.get_converter(),
"clip_by_scalar": Clip.get_converter(),
"slice": Slice.get_converter(),
"expand": Expand.get_converter(),
- "transpose": AttrCvt("transpose", {"perm": "axes"}),
+ "transpose": Transpose.get_converter(),
"expand_dims": ExpandDim.get_converter(),
"range": Range.get_converter(),
"cast": Cast.get_converter(),
# defs/others
"reshape": Reshape.get_converter(),
"constant": Constant.get_converter(),
- # "where": Where.get_converter(),
+ "where": Where.get_converter(),
"flatten": Flatten.get_converter(),
"sigmoid": Renamer("sigmoid"),
"sigmoid_v2": Renamer("sigmoid"),
"hardsigmoid": HardSigmoid.get_converter(),
+ "softplus": Softplus.get_converter(),
"squeeze": AttrCvt("squeeze", {"axes": "axis"}),
"unsqueeze": Unsqueeze.get_converter(),
+ "identity": Renamer("copy"),
+ "var": Variance.get_converter(),
}
@@ -1402,7 +1566,7 @@ def deal_parameter_convert(
):
"""deal with parameter(weight) convert in oneflow."""
for node_input_path in node_input_paths:
- node_path = os.path.join(model_dir_path, node_input_path.replace("m.",
""))
+ node_path = os.path.join(model_dir_path, node_input_path.replace("m.",
"", 1))
node_input_name = node_input_path.split("/")[0]
_input_path_2_name[node_path] = node_input_name
for param_name in _model_array:
@@ -1503,7 +1667,11 @@ class OneflowGraph(object):
print("{} should be defined by
user".format(self._init_variable_node))
def _parse_input(self, node, model_dir_path):
+ input_user_conf_list = []
for input_name in node.user_conf.input:
+ input_user_conf_list.append(input_name)
+ input_user_conf_list.sort()
+ for input_name in input_user_conf_list:
node_input_paths = getattr(node.user_conf.input[input_name], "s")
for i in node_input_paths:
node_input = i.split("/")[0]
@@ -1548,58 +1716,11 @@ class OneflowGraph(object):
return outputs
- def from_oneflow(self, nodes, model_dir_path, freeze_params=True,
user_input=None):
+ def from_oneflow(self, nodes, model_dir_path):
"""
- Parameters
- ----------
- nodes : dict, keys: node.name, value: node
- contain the graph
- model_dir_path: str
- The path of parameter
- freeze_params: bool
- If freeze_params is True,
- the computational graph input is the input of the first layer of
the network,
- which cannot be specified by the user, e.g.
- Default input is: %v_ResNetGraph_0_input.0: Tensor[(1, 3, 224,
224), float32]
- User-defined input is: %_0_input.0: Tensor[(1, 3, 640, 480),
float32]
- If freeze_params is on, then conv1-in will be the graph input, not
Input_0
- user_input: dict
- User-defined input information for the graph
- {
- node1_name:
- {
- 'name': node1_name, # str, like
"%v_ResNetGraph_0_input.0"
- 'shape': node1_shape, # tuple
- 'dtype': node1_dtype # str, like "float32"
- }
- ...
- }
- We recommend that users specify the input by specifying the job
function,
- rather than by this function
-
- Returns
- -------
- mod : tvm.IRModule
- The returned relay module
- params : dict
- A dict of name: tvm.nd.array pairs, used as pretrained weights
+ Implementation of convert the OneFlow model into an equivalent Relay
Function.
"""
- # step 1: get the graph input
- if not freeze_params:
- for node_init_name in user_input:
- if "_input." not in node_init_name:
- raise KeyError(
- "user_input['name'] should contain '_input.' "
- + "to let program know that this is input node"
- )
- self._nodes[node_init_name] = new_var(
- node_init_name,
- shape=user_input[node_init_name]["shape"],
- dtype=user_input[node_init_name]["dtype"],
- )
- self._inputs[node_init_name] = self._nodes[node_init_name]
-
- # step 2: find out if unsupported ops are used
+ # step 1: find out if unsupported ops are used
convert_map = get_convert_map()
unsupported_ops = set()
for node_name in nodes:
@@ -1619,7 +1740,7 @@ class OneflowGraph(object):
msg += ", ".join(unsupported_ops)
raise tvm.error.OpNotImplemented(msg)
- # step 3: convert op
+ # step 2: convert op
for node_name in nodes:
node = nodes[node_name]
if is_user_op(node):
@@ -1633,7 +1754,11 @@ class OneflowGraph(object):
self._parse_input(node, model_dir_path=model_dir_path)
node_inputs = oneflow_input()
+ input_user_conf_list = []
for input_name in node.user_conf.input:
+ input_user_conf_list.append(input_name)
+ input_user_conf_list.sort()
+ for input_name in input_user_conf_list:
node_input_paths =
getattr(node.user_conf.input[input_name], "s")
for i in node_input_paths:
node_input = i.split("/")[0]
@@ -1663,7 +1788,6 @@ class OneflowGraph(object):
), "Number of output mismatch {} vs {} in {}.".format(
len(node_outputs), outputs_num, op_name
)
-
if outputs_num == 1:
op = fold_constant(op)
else:
@@ -1678,10 +1802,9 @@ class OneflowGraph(object):
else:
self._nodes[node_outputs[i]] = op_temp[i]
- # step 4: get the outputs
+ # step 3: get the outputs
outputs = []
- for node_name in nodes:
- node = nodes[node_name]
+ for node_name, node in nodes.items():
if is_output_op(node):
node_name_v2 = getattr(node.output_conf, "in").split("/")[0]
if node_name in self._nodes:
@@ -1690,13 +1813,21 @@ class OneflowGraph(object):
outputs.append(self._nodes[node_name_v2])
outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs)
- # step 5: get the relay IR
+ # step 4: get the relay IR
free_vars = analysis.free_vars(outputs)
nodes = {v: k for k, v in self._nodes.items()}
free_vars = [nodes[var] for var in free_vars]
+ free_vars_inputs = []
+ free_vars_parameters = []
+ for x in free_vars:
+ if "_input.0" in x:
+ free_vars_inputs.append(x)
+ else:
+ free_vars_parameters.append(x)
+ free_vars = free_vars_inputs + free_vars_parameters
- # step 6: make sure the '_input.0' is the first in self._inputs
+ # step 5: make sure the '_input.0' is the first in self._inputs
for free_var in free_vars:
if free_var not in self._inputs:
self._inputs[free_var] = self._nodes[free_var]
@@ -1708,7 +1839,7 @@ class OneflowGraph(object):
else:
raise IndexError("{} is not in
self._inputs".format(input_name))
- # step 7: create a function from our output expression and all input
variables.
+ # step 6: create a function from our output expression and all input
variables.
func = _function.Function([v for _, v in self._sort_inputs.items()],
outputs)
return IRModule.from_expr(func), self._params
@@ -1740,20 +1871,38 @@ class OneflowGraph(object):
return sym
-def from_oneflow(graph, model_dir_path, freeze_params=True, user_input=None):
- """
- see OneflowGraph.from_oneflow
+def from_oneflow(graph, model_dir_path):
+ """Convert a OneFlow model into an equivalent Relay Function.
+
+ At present, there are two ways to run models in deep learning framework
+ Dynamic Graph and Static Graph, which are also called Eager Mode and Graph
+ Mode in OneFlow.
+
+ In general, dynamic graphs are easier to use and static graphs have better
performance.
+ OneFlow offers nn.Graph, so that users can use the eager-like programming
style to build
+ static graphs and train the models.
+
+ We utilize the intermediate representation of nn.Graph to convert the
OneFlow model to Reley.
+
+ Parameters
+ ----------
+ nodes : dict, keys: node.name, value: node
+ contain the graph
+ model_dir_path: str
+ The path of weight
+
+ Returns
+ -------
+ mod : tvm.IRModule
+ The returned relay module
+ params : dict
+ A dict of name: tvm.nd.array pairs, used as pretrained weights
"""
try:
import oneflow as flow
except ImportError:
raise ImportError("please check that OneFlow is installed")
- if not freeze_params and user_input is None:
- raise ValueError("if you want to specify graph input, please give the
'user_input'")
- if freeze_params and user_input is not None:
- warnings.warn("'user_input' will not work, please check the
'freeze_params'")
-
# get info of nodes
shape = {}
dtype = {}
@@ -1800,18 +1949,13 @@ def from_oneflow(graph, model_dir_path,
freeze_params=True, user_input=None):
graph_proto = graph._graph_proto
# get all nodes
- nodes = {}
+ nodes = OrderedDict()
for op in graph_proto.net.op:
nodes[op.name] = op
g = OneflowGraph(shape, dtype, nodes, model_dir_path)
# Use the graph proto as a scope so that ops can access other nodes if
needed.
- mod, params = g.from_oneflow(
- nodes=nodes,
- model_dir_path=model_dir_path,
- freeze_params=freeze_params,
- user_input=user_input,
- )
+ mod, params = g.from_oneflow(nodes=nodes, model_dir_path=model_dir_path)
return mod, params
diff --git a/tests/python/frontend/oneflow/test_forward.py
b/tests/python/frontend/oneflow/test_forward.py
index d144cdad2b..0d18a2fb5c 100644
--- a/tests/python/frontend/oneflow/test_forward.py
+++ b/tests/python/frontend/oneflow/test_forward.py
@@ -79,6 +79,16 @@ class OneFlowGraph_v2(flow.nn.Graph):
return out
+class OneFlowGraph_v3(flow.nn.Graph):
+ def __init__(self, module):
+ super().__init__()
+ self.m = module
+
+ def build(self, x1, x2):
+ out = self.m(x1, x2)
+ return out
+
+
def get_oneflow_output(model, inputs):
flow_output = model(inputs)
return flow_output.numpy()
@@ -89,6 +99,10 @@ def get_oneflow_concat_output(model, input1, input2, input3):
return flow_output
+def get_oneflow_elementwise_output(model, input1, input2):
+ return model(input1, input2).numpy()
+
+
def get_tvm_output(graph, model_path, inputs: flow.tensor, target="llvm",
dtype="float32"):
inputs_numpy = inputs.numpy()
if target == "llvm":
@@ -132,6 +146,32 @@ def get_tvm_concat_output(
return tvm_output
+def get_tvm_elementwise_output(
+ graph,
+ model_path,
+ input1: flow.tensor,
+ input2: flow.tensor,
+ target="llvm",
+ dtype="float32",
+):
+ input1_numpy = input1.numpy()
+ input2_numpy = input2.numpy()
+ if target == "llvm":
+ device = tvm.cpu(0)
+ elif target == "cuda":
+ device = tvm.cuda(0)
+
+ mod, params = relay.frontend.from_oneflow(graph, model_path)
+ with tvm.transform.PassContext(opt_level=10):
+ intrp = relay.build_module.create_executor("graph", mod, device,
target)
+ tvm_output = intrp.evaluate()(
+ tvm.nd.array(input1_numpy.astype(dtype)),
+ tvm.nd.array(input2_numpy.astype(dtype)),
+ **params,
+ ).numpy()
+ return tvm_output
+
+
def verify_conv(
model,
name="",
@@ -336,6 +376,33 @@ def verify_math(
tvm.testing.assert_allclose(out_flow, out_tvm, rtol=rtol, atol=atol)
+def verify_matmul(
+ model,
+ name="",
+ rtol=1e-5,
+ atol=1e-5,
+ inputs1=flow.tensor(np.random.randn(2, 5), dtype=flow.float32),
+ inputs2=flow.tensor(np.random.randn(5, 2), dtype=flow.float32),
+ device="llvm",
+):
+ if device == "cuda":
+ model.to(device)
+ inputs1 = inputs1.to(device)
+ inputs2 = inputs2.to(device)
+
+ graph = OneFlowGraph_v3(model)
+ graph._compile(inputs1, inputs2)
+ mkdir(MODEL_HOME)
+ flow.save(model.state_dict(), MODEL_HOME)
+
+ out_flow = get_oneflow_elementwise_output(graph, inputs1, inputs2)
+ out_tvm = get_tvm_elementwise_output(graph, MODEL_HOME, inputs1, inputs2,
target=device)
+ rmdir(MODEL_HOME)
+
+ assert_shape(out_flow, out_tvm)
+ tvm.testing.assert_allclose(out_flow, out_tvm, rtol=rtol, atol=atol)
+
+
def verify_concat(
model,
name="",
@@ -602,6 +669,23 @@ def test_activation():
x = self.active(x)
return x
+ class HardTanh(flow.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.active = flow.nn.Hardtanh()
+
+ def forward(self, x):
+ x = self.active(x)
+ return x
+
+ class TensorSoftmax(flow.nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x):
+ x = x.softmax(dim=-1)
+ return x
+
if os.path.exists(MODEL_HOME):
rmdir(MODEL_HOME)
@@ -616,6 +700,8 @@ def test_activation():
model9 = SiLU().eval()
model10 = LeakyReLU().eval()
model11 = GELU().eval()
+ model12 = HardTanh().eval()
+ model13 = TensorSoftmax().eval()
for device in ["llvm"]:
verify_activation(model1, device=device)
@@ -629,6 +715,12 @@ def test_activation():
verify_activation(model9, device=device)
verify_activation(model10, device=device)
verify_activation(model11, device=device)
+ verify_activation(model12, device=device)
+ verify_activation(
+ model13,
+ device=device,
+ inputs=flow.tensor(np.random.rand(1, 12, 197,
197).astype(np.float32)),
+ )
@tvm.testing.uses_gpu
@@ -665,12 +757,19 @@ def test_math():
def forward(self, x):
return flow.expm1(x)
+ class Variance(flow.nn.Module):
+ def forward(self, x):
+ return flow.var(x, 1, unbiased=False, keepdim=True)
+
model1 = Sigmoid().eval()
model2 = Sign().eval()
model3 = Log().eval()
model4 = Log2().eval()
model5 = Exp().eval()
model6 = Exp2().eval()
+ model7 = Reciprocal().eval()
+ model8 = Pow().eval()
+ model9 = Variance().eval()
for device in ["llvm"]:
verify_math(model1, device=device)
@@ -679,6 +778,9 @@ def test_math():
verify_math(model4, device=device)
verify_math(model5, device=device)
verify_math(model6, device=device)
+ verify_math(model7, device=device)
+ verify_math(model8, device=device)
+ verify_math(model9, device=device)
@tvm.testing.uses_gpu
@@ -710,6 +812,99 @@ def test_concat():
verify_concat(model, device=device)
[email protected]_gpu
+def test_add_constant():
+ class ConstantAdd(flow.nn.Module):
+ def forward(self, x):
+ out = flow.add(1.0, x)
+ return out
+
+ model = ConstantAdd().eval()
+
+ for device in ["llvm"]:
+ verify_math(
+ model, device=device, inputs=flow.tensor(np.random.randn(3, 6,
9).astype(np.float32))
+ )
+
+
[email protected]_gpu
+def test_logical():
+ class LogicalGreater(flow.nn.Module):
+ def forward(self, x):
+ return x > 1.0
+
+ model1 = LogicalGreater().eval()
+
+ for device in ["llvm"]:
+ verify_math(
+ model1, device=device, inputs=flow.tensor(np.random.randn(3, 6,
9).astype(np.float32))
+ )
+
+
[email protected]_gpu
+def test_expand():
+ class Expand(flow.nn.Module):
+ def forward(self, x):
+ return x.expand(2, -1, -1)
+
+ model1 = Expand().eval()
+
+ for device in ["llvm"]:
+ verify_math(
+ model1, device=device, inputs=flow.tensor(np.random.randn(1, 6,
9).astype(np.float32))
+ )
+
+
[email protected]_gpu
+def test_matmul():
+ class MatMul(flow.nn.Module):
+ def forward(self, x, y):
+ return flow._C.matmul(x, y)
+
+ class MatMulTranspose(flow.nn.Module):
+ def forward(self, x, y):
+ return flow._C.matmul(x, y, transpose_b=True)
+
+ class BatchMatMul(flow.nn.Module):
+ def forward(self, x, y):
+ return flow._C.batch_matmul(x, y)
+
+ class BroadCastMatMul(flow.nn.Module):
+ def forward(self, x, y):
+ return flow._C.matmul(x, y)
+
+ model1 = MatMul().eval()
+ model2 = MatMulTranspose().eval()
+ model3 = BatchMatMul().eval()
+ model4 = BroadCastMatMul().eval()
+
+ for device in ["llvm"]:
+ verify_matmul(
+ model1,
+ device=device,
+ inputs1=flow.tensor(np.random.randn(2, 3).astype(np.float32)),
+ inputs2=flow.tensor(np.random.randn(3, 3).astype(np.float32)),
+ )
+ verify_matmul(
+ model2,
+ device=device,
+ inputs1=flow.tensor(np.random.randn(1, 2).astype(np.float32)),
+ inputs2=flow.tensor(np.random.randn(3, 2).astype(np.float32)),
+ )
+ verify_matmul(
+ model3,
+ device=device,
+ inputs1=flow.tensor(np.random.randn(2, 1, 2).astype(np.float32)),
+ inputs2=flow.tensor(np.random.randn(2, 2, 3).astype(np.float32)),
+ )
+ verify_matmul(
+ model4,
+ device=device,
+ inputs1=flow.tensor(np.random.randn(3, 8, 8,
16).astype(np.float32)),
+ inputs2=flow.tensor(np.random.randn(16, 8).astype(np.float32)),
+ )
+
+
if __name__ == "__main__":
test_conv2d()
test_pool2d()
@@ -720,4 +915,8 @@ if __name__ == "__main__":
test_math()
test_slice()
test_concat()
+ test_add_constant()
+ test_logical()
+ test_expand()
+ test_matmul()
rmdir("log")
diff --git a/tests/python/frontend/oneflow/test_vision_models.py
b/tests/python/frontend/oneflow/test_vision_models.py
new file mode 100644
index 0000000000..e8d0627001
--- /dev/null
+++ b/tests/python/frontend/oneflow/test_vision_models.py
@@ -0,0 +1,150 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=import-self, invalid-name
+# pylint: disable=arguments-differ, unused-argument, unused-import
+"""Unit tests for various models and operators"""
+import os
+import sys
+
+import numpy as np
+import pytest
+import tvm
+import tvm.testing
+import tvm.topi.testing
+from tvm import relay
+from tvm.contrib import graph_executor
+
+import oneflow as flow
+from flowvision.models.alexnet import alexnet
+from flowvision.models.squeezenet import squeezenet1_0
+from flowvision.models.shufflenet_v2 import shufflenet_v2_x0_5
+from flowvision.models.mobilenet import mobilenet_v2
+from flowvision.models.ghostnet import ghostnet
+from flowvision.models.vision_transformer import vit_base_patch16_224
+
+MODEL_HOME = "test_model"
+
+
+def mkdir(path):
+ # init
+ path = path.strip()
+ path = path.rstrip("\\")
+
+ if not os.path.exists(path):
+ os.makedirs(path)
+ else:
+ print("{} is already here".format(path))
+
+
+def rmdir(path):
+ for root, dirs, files in os.walk(path, topdown=False):
+ for name in files:
+ os.remove(os.path.join(root, name))
+ for name in dirs:
+ os.rmdir(os.path.join(root, name))
+ os.removedirs(path)
+
+
+def assert_shape(out1, out2):
+ if out1.shape != out2.shape:
+ msg = "Output shapes {} and {} don't match"
+ raise AssertionError(msg.format(out1.shape, out2.shape))
+
+
+class OneFlowGraph(flow.nn.Graph):
+ def __init__(self, module):
+ super().__init__()
+ self.m = module
+
+ def build(self, x):
+ out = self.m(x)
+ return out
+
+
+def get_oneflow_output(model, inputs):
+ flow_output = model(inputs)
+ return flow_output.numpy()
+
+
+def get_tvm_output(graph, model_path, inputs: flow.tensor, target="llvm",
dtype="float32"):
+ inputs_numpy = inputs.numpy()
+ if target == "llvm":
+ device = tvm.cpu(0)
+ elif target == "cuda":
+ device = tvm.cuda(0)
+
+ mod, params = relay.frontend.from_oneflow(graph, model_path)
+ with tvm.transform.PassContext(opt_level=10):
+ intrp = relay.build_module.create_executor("graph", mod, device,
target)
+ tvm_output = intrp.evaluate()(tvm.nd.array(inputs_numpy.astype(dtype)),
**params).numpy()
+ return tvm_output
+
+
+def verify_model(
+ model,
+ name="",
+ rtol=1e-5,
+ atol=1e-5,
+ inputs=flow.tensor(
+ np.random.rand(1, 3, 224, 224),
+ dtype=flow.float32,
+ ),
+ device="llvm",
+):
+ if device == "cuda":
+ model.to(device)
+ inputs = inputs.to(device)
+
+ graph = OneFlowGraph(model)
+ graph._compile(inputs)
+
+ mkdir(MODEL_HOME)
+ flow.save(model.state_dict(), MODEL_HOME)
+
+ out_flow = get_oneflow_output(graph, inputs)
+ out_tvm = get_tvm_output(graph, MODEL_HOME, inputs, target=device)
+ rmdir(MODEL_HOME)
+
+ assert_shape(out_flow, out_tvm)
+ tvm.testing.assert_allclose(out_flow, out_tvm, rtol=rtol, atol=atol)
+
+
[email protected]_gpu
+def test_vision_models():
+
+ if os.path.exists(MODEL_HOME):
+ rmdir(MODEL_HOME)
+
+ vision_alexnet = alexnet().eval()
+ vision_squeezenet = squeezenet1_0().eval()
+ vision_shufflenet = shufflenet_v2_x0_5().eval()
+ vision_mobilenetv2 = mobilenet_v2().eval()
+ vision_ghostnet = ghostnet().eval()
+ vision_vit = vit_base_patch16_224().eval()
+
+ for device in ["llvm"]:
+ verify_model(vision_alexnet, device=device)
+ verify_model(vision_squeezenet, device=device)
+ verify_model(vision_shufflenet, device=device)
+ verify_model(vision_mobilenetv2, device=device)
+ verify_model(vision_ghostnet, device=device)
+ verify_model(vision_vit, device=device)
+
+
+if __name__ == "__main__":
+ test_vision_models()
+ rmdir("log")