kevinthesun commented on a change in pull request #5103: [Relay][ADT]Static 
Tensor Array
URL: https://github.com/apache/incubator-tvm/pull/5103#discussion_r401099670
 
 

 ##########
 File path: python/tvm/relay/prelude.py
 ##########
 @@ -27,6 +27,538 @@
 from . import op
 
 
+class StaticTensorArrayOps(object):
+    """Contains tensor array related ops for fixed rank tensor array"""
+
+    def __init__(self, prelude, dtype, shape):
+        """Create tensor array ops registry"""
+        self.prelude = prelude
+        self.dtype = dtype
+        self.shape = shape
+
+    def get_name(self, canonical):
+        """Get name corresponding to the canonical name"""
+        shape_str = str(self.shape).replace('[', '').replace(']', '')\
+            .replace('(', '').replace(')', '').replace(', ', '_')\
+            .replace(',', '')
+        if len(shape_str) == 0:
+            shape_str = "scalar"
+        if canonical == 'tensor_t':
+            return 'static_tensor_{}_{}_t'.format(self.dtype, shape_str)
+        return "{}_{}_{}".format(canonical, self.dtype, shape_str)
+
+    def get_var(self, canonical):
+        """Get var corresponding to the canonical name"""
+        name = self.get_name(canonical)
+        return getattr(self.prelude, name)
+
+    def define_tensor_adt(self):
+        """Defines the dynamic tensor ADT, which is the container for tensors
+        with variable shapes."""
+        tensor_type_name = self.get_name('tensor_t')
+        # Skip register if tensor type is already registered.
+        global_type_names = set()
+        for g_ty_var in self.prelude.mod.get_global_type_vars():
+            global_type_names.add(g_ty_var.name_hint)
+        if tensor_type_name in global_type_names:
+            return
+
+        tensor_type_var = GlobalTypeVar(tensor_type_name)
+        setattr(self.prelude, tensor_type_name, tensor_type_var)
+        tensor_type = TensorType(self.shape, self.dtype)
+        tensor_constructor_name = self.get_name('tensor_constructor')
+
+        tensor_nil_name = self.get_name('tensor_nil')
+        tensor_nil_case = Constructor(tensor_nil_name, [], tensor_type_var)
+        tensor_case = Constructor(tensor_constructor_name, [tensor_type], 
tensor_type_var)
+
+        setattr(self.prelude, tensor_nil_name, tensor_nil_case)
+        setattr(self.prelude, tensor_constructor_name, tensor_case)
+        self.prelude.mod[tensor_type_var] = TypeData(tensor_type_var,
+                                                     [],
+                                                     [tensor_nil_case, 
tensor_case])
+
+    def define_tensor_array(self):
+        """Defines a function to create a tensor array with size n.
+        tensor_array(n) : Tensor[(), int32] -> list[tensor_t]
+        """
+        tensor_array_constructor_name = self.get_name("tensor_array")
+        tensor_array_constructor_var = 
self._create_global_var(tensor_array_constructor_name)
+        setattr(self.prelude, tensor_array_constructor_name, 
tensor_array_constructor_var)
+        tensor_nil_var = self.get_var('tensor_nil')
+        tensor_type_var = self.get_var('tensor_t')
+        n = Var("x", scalar_type('int32'))
+        body = If(equal(n, const(0)),
+                  self.prelude.nil(),
+                  self.prelude.cons(tensor_nil_var(),
+                                    tensor_array_constructor_var(subtract(n, 
const(1)))))
+        self.prelude.mod[tensor_array_constructor_var] = \
+            Function([n], body, self.prelude.l(tensor_type_var()), [])
+
+    def define_tensor_take(self):
+        """Defines a function to return a range of tensor_t on axis 0.
+            tensor_take(t, lower, upper) :
+            tensor_t -> Tensor[(), int32] -> Tensor[(), int32] -> tensor_t
+        """
+        ndim = len(self.shape)
+        if ndim == 0:
+            return
+
+        take_name = self.get_name("tensor_take")
+        take_var = self._create_global_var(take_name)
+        setattr(self.prelude, take_name, take_var)
+
+        output_shape = [Any(),] + list(self.shape[1:])
+        tensor_type_var, tensor_constructor = \
+            self._get_adt_by_shape(output_shape)
+
+        t = Var('tensor', self.get_var('tensor_t')())
+        lower = Var('lower', scalar_type('int32'))
+        upper = Var('upper', scalar_type('int32'))
+        tvar = Var('t')
+        case = Clause(PatternConstructor(self.get_var('tensor_constructor'), 
[PatternVar(tvar)]),
+                      tensor_constructor(op.take(tvar,
+                                                 op.arange(lower, upper, 
dtype='int32'),
+                                                 axis=0)))
+        self.prelude.mod[take_var] = \
+            Function([t, lower, upper],
+                     Match(t, [case], False), tensor_type_var(), [])
+
+    def define_tensor_concatenate(self):
+        """Defines a function to concatenate two tensor_t on axis 0.
+        tensor_concatenate(t) : tensor_t -> tensor_t -> tensor_t
+        """
+        ndim = len(self.shape)
+        if ndim == 0:
+            return
+
+        concat_name = self.get_name("tensor_concatenate")
+        concat_var = self._create_global_var(concat_name)
+        setattr(self.prelude, concat_name, concat_var)
+        output_shape = [Any(),] + list(self.shape[1:])
+        tensor_type_var, tensor_constructor = \
+            self._get_adt_by_shape(output_shape)
+
+        origin_tensor_constructor = self.get_var('tensor_constructor')
 
 Review comment:
   The reason is that input shape and output shape can be different, and we 
need different constructor. Similar idea in take operator.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to