yongwww commented on a change in pull request #5963:
URL: https://github.com/apache/incubator-tvm/pull/5963#discussion_r451332088



##########
File path: python/tvm/relay/frontend/tensorflow.py
##########
@@ -1990,6 +1990,66 @@ def _impl(inputs, attr, params, mod):
         return  _res
     return _impl
 
+def _LSTMBlockCell():
+    def _impl(inputs, attr, params, mod):
+        """LSTM Block cell.
+        Calculations and return values are described in:
+        https://github.com/tensorflow/tensorflow/blob/
+        r1.8/tensorflow/contrib/rnn/python/ops/lstm_ops.py#L41-L114
+
+        Parameters
+        ----------
+        inputs : relay.Expr
+            Input data
+        in_state_c: list of relay.Expr
+            Cell state input values for all the layers
+        in_state_h: list of relay.Expr
+            Hidden state input values for all the layers
+        attrs : dict
+            Dict of operator attributes
+        params : dict
+            List of pretrained weights and bias
+
+        Returns
+        -------
+        relay.Expr.TupleWapper
+            [dummy, cs, dummy, dummy, dummy, dummy, h]
+            Only cs and h which are useful are returned
+        """
+        in_data = inputs[0]
+        in_state_c = inputs[1]
+        in_state_h = inputs[2]
+        in_weight = inputs[3]
+        in_bias = inputs[7]
+        forget_bias = attr.pop('forget_bias')
+        input_shape = _infer_shape(inputs[0], mod)
+        weight_shape = _infer_shape(inputs[3], mod)
+        batch_size, input_size = input_shape[0], input_shape[1]
+        num_hidden_layers = weight_shape[1]
+
+        in_data = _op.reshape(in_data,
+                              newshape=(batch_size, input_size))
+        ixh = _op.concatenate([in_data, in_state_h], axis=1)
+        in_weight = _op.transpose(in_weight, axes=None)
+        gates = _op.nn.dense(ixh, in_weight,
+                             units=num_hidden_layers)
+        gates_bias = _op.add(gates, in_bias)
+        gate_list = _op.split(gates_bias, indices_or_sections=4, axis=1)
+        in_gate = _op.sigmoid(gate_list[0])
+        in_transform = _op.tanh(gate_list[1])
+        forget_gate = _op.add(gate_list[2], tvm.relay.const(forget_bias, 
attr['T'].name))
+        forget_gate = _op.sigmoid(forget_gate)
+        out_gate = _op.sigmoid(gate_list[3])
+        next_c = _op.add(_op.multiply(forget_gate, in_state_c),
+                         _op.multiply(in_gate, in_transform))
+        next_h = out_gate * _op.tanh(next_c)
+        # Return dummy for those unused values
+        dummy = tvm.relay.const(0)
+        return tvm.relay.TupleWrapper(
+            tvm.relay.Tuple([dummy, next_c, dummy, dummy, dummy, dummy, 
next_h]), 7)

Review comment:
       ```
   def lstm_block_cell(x, cs_prev, h_prev, w, wci, wcf, wco, b, forget_bias=1, 
cell_clip=3, use_peephole=False, name=None):
     r"""Computes the LSTM cell forward propagation for 1 time step.
   
     This implementation uses 1 weight matrix and 1 bias vector, and there's an
     optional peephole connection.
   
     This kernel op implements the following mathematical equations:
   
     xh = [x, h_prev]
     [i, f, ci, o] = xh * w + b
     f = f + forget_bias
   
     if not use_peephole:
       wci = wcf = wco = 0
   
     i = sigmoid(cs_prev * wci + i)
     f = sigmoid(cs_prev * wcf + f)
     ci = tanh(ci)
   
     cs = ci .* i + cs_prev .* f
     cs = clip(cs, cell_clip)
   
     o = sigmoid(cs * wco + o)
     co = tanh(cs)
     h = co .* o
     ...
     Returns:
       A tuple of `Tensor` objects (i, cs, f, o, ci, co, h).  
    ...
   ```




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to