lixiaoquan commented on a change in pull request #5963:
URL: https://github.com/apache/incubator-tvm/pull/5963#discussion_r454194463



##########
File path: python/tvm/relay/frontend/tensorflow.py
##########
@@ -1990,6 +1990,66 @@ def _impl(inputs, attr, params, mod):
         return  _res
     return _impl
 
+def _LSTMBlockCell():
+    def _impl(inputs, attr, params, mod):
+        """LSTM Block cell.
+        Calculations and return values are described in:
+        https://github.com/tensorflow/tensorflow/blob/
+        r1.8/tensorflow/contrib/rnn/python/ops/lstm_ops.py#L41-L114
+
+        Parameters
+        ----------
+        inputs : relay.Expr
+            Input data
+        in_state_c: list of relay.Expr
+            Cell state input values for all the layers
+        in_state_h: list of relay.Expr
+            Hidden state input values for all the layers
+        attrs : dict
+            Dict of operator attributes
+        params : dict
+            List of pretrained weights and bias
+
+        Returns
+        -------
+        relay.Expr.TupleWapper
+            [dummy, cs, dummy, dummy, dummy, dummy, h]
+            Only cs and h which are useful are returned
+        """
+        in_data = inputs[0]
+        in_state_c = inputs[1]
+        in_state_h = inputs[2]
+        in_weight = inputs[3]
+        in_bias = inputs[7]
+        forget_bias = attr.pop('forget_bias')
+        input_shape = _infer_shape(inputs[0], mod)
+        weight_shape = _infer_shape(inputs[3], mod)
+        batch_size, input_size = input_shape[0], input_shape[1]
+        num_hidden_layers = weight_shape[1]
+
+        in_data = _op.reshape(in_data,
+                              newshape=(batch_size, input_size))
+        ixh = _op.concatenate([in_data, in_state_h], axis=1)
+        in_weight = _op.transpose(in_weight, axes=None)
+        gates = _op.nn.dense(ixh, in_weight,
+                             units=num_hidden_layers)
+        gates_bias = _op.add(gates, in_bias)
+        gate_list = _op.split(gates_bias, indices_or_sections=4, axis=1)
+        in_gate = _op.sigmoid(gate_list[0])
+        in_transform = _op.tanh(gate_list[1])
+        forget_gate = _op.add(gate_list[2], tvm.relay.const(forget_bias, 
attr['T'].name))
+        forget_gate = _op.sigmoid(forget_gate)
+        out_gate = _op.sigmoid(gate_list[3])
+        next_c = _op.add(_op.multiply(forget_gate, in_state_c),
+                         _op.multiply(in_gate, in_transform))
+        next_h = out_gate * _op.tanh(next_c)
+        # Return dummy for those unused values
+        dummy = tvm.relay.const(0)
+        return tvm.relay.TupleWrapper(
+            tvm.relay.Tuple([dummy, next_c, dummy, dummy, dummy, dummy, 
next_h]), 7)

Review comment:
       cc @yongwww  Could you take another look? Thanks




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to