pingsutw commented on a change in pull request #487:
URL: https://github.com/apache/submarine/pull/487#discussion_r553259128
##########
File path: submarine-sdk/pysubmarine/submarine/ml/tensorflow/layers/core.py
##########
@@ -181,3 +182,72 @@ def fm_layer(inputs, **kwargs):
square_sum = tf.reduce_sum(tf.square(inputs), 1)
fm_out = 0.5 * tf.reduce_sum(tf.subtract(sum_square, square_sum), 1)
return fm_out
+
+
+class NoMask(tf.keras.layers.Layer):
+ def __init__(self, **kwargs):
+ super(NoMask, self).__init__(**kwargs)
+
+ def build(self, input_shape):
+ # Be sure to call this somewhere!
+ super(NoMask, self).build(input_shape)
+
+ def call(self, x, mask=None, **kwargs):
+ return x
+
+ def compute_mask(self, inputs, mask):
+ return None
+
+
+class KMaxPooling(Layer):
+ """K Max pooling that selects the k biggest value along the specific axis.
+ Input shape
+ - nD tensor with shape: ``(batch_size, ..., input_dim)``.
+ Output shape
+ - nD tensor with shape: ``(batch_size, ..., output_dim)``.
+ Arguments
+ - **k**: positive integer, number of top elements to look for along
the ``axis`` dimension.
+ - **axis**: positive integer, the dimension to look for elements.
+ """
+
+ def __init__(self, k=1, axis=-1, **kwargs):
+
+ self.k = k
+ self.axis = axis
+ super(KMaxPooling, self).__init__(**kwargs)
+
+ def build(self, input_shape):
+
+ if self.axis < 1 or self.axis > len(input_shape):
+ raise ValueError("axis must be 1~%d,now is %d" %
+ (len(input_shape), self.axis))
+
+ if self.k < 1 or self.k > input_shape[self.axis]:
+ raise ValueError("k must be in 1 ~ %d,now k is %d" %
+ (input_shape[self.axis], self.k))
+ self.dims = len(input_shape)
+ # Be sure to call this somewhere!
Review comment:
```suggestion
```
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]