haojin2 commented on a change in pull request #17311: Added 
beamsearch_set_finished Operator
URL: https://github.com/apache/incubator-mxnet/pull/17311#discussion_r367028141
 
 

 ##########
 File path: src/operator/contrib/beamsearch_set_finished.cc
 ##########
 @@ -0,0 +1,66 @@
+#include <mxnet/base.h>
+#include "./beamsearch_set_finished-inl.h"
+#include "../tensor/elemwise_unary_op.h"
+
+namespace mxnet {
+namespace op {
+
+DMLC_REGISTER_PARAMETER(BeamsearchSetFinishedParam);
+
+NNVM_REGISTER_OP(_contrib_beamsearch_set_finished)
+.describe(R"code(Sets finished beams of the beam to a mask value (aside from 
the score index) and forces beams at max length to output the EOS sequence.
+
+Returns an array of the same shape of the input data array and the same values 
except for the designated rows whose elements are to be masked or be replaced 
by beam scores/EOS probabilities.
+
+Example::
+
+    x = [[ -1.,  -2.,  -3.,  -4.],
+         [ -5.,  -6.,  -7.,  -8.],
+         [ -9., -10., -11., -12.],
+         [-13., -14., -15., -16.]]
+
+    scores = [[-17.],
+              [-18.],
+              [-19.],
+              [-20.]]
+
+    finished = [0, 1, 0, 1]
+
+    over_max = [0, 0, 1, 1]
+
+    beamsearch_set_finished(x, scores, finished, over_max, score_idx=0,
+                            eos_idx=2, mask_val=-1e15) = [[  -1.,   -2.,   
-3.,   -4.],
+                                                          [ -18., -1e15, 
-1e15, -1e15],
+                                                          [-1e15, -1e15,  
-11., -1e15],
+                                                          [ -20., -1e15, 
-1e15, -1e15]]
+
+.. Note::
+    This operator only supports forward propagation. DO NOT use it in training.
+
+)code")
+.set_attr_parser(ParamParser<BeamsearchSetFinishedParam>)
+.set_num_inputs(4)
+.set_num_outputs(1)
+.set_attr<nnvm::FListInputNames>("FListInputNames", [](const NodeAttrs& attrs) 
{
+    return std::vector<std::string>{"data", "scores", "finished", "over_max"}; 
})
+.set_attr<nnvm::FListOutputNames>("FListOutputNames", [](const NodeAttrs& 
attrs) {
+    return std::vector<std::string>{"output"}; })
+.set_attr<FInferShape>("FInferShape", BeamsearchSetFinishedShape)
+.set_attr<nnvm::FInferType>("FInferType", BeamsearchSetFinishedType)
+.set_attr<nnvm::FInplaceOption>("FInplaceOption", [](const NodeAttrs& attrs) {
+    return std::vector<std::pair<int, int>>{{0, 0}}; })
+.set_attr<FCompute>("FCompute<cpu>", BeamsearchSetFinishedForward<cpu>)
+.set_attr<nnvm::FGradient>("FGradient", 
ElemwiseGradUseNone{"beamsearch_noop_grad"})
+.add_argument("data", "NDArray-or-Symbol", "Input distribution of tokens")
+.add_argument("scores", "NDArray-or-Symbol", "Running scores for the 
sequences")
+.add_argument("finished", "NDArray-or-Symbol", "Finished beams")
+.add_argument("over_max", "NDArray-or-Symbol", "Beams at or exceeding maximum 
length")
+.add_arguments(BeamsearchSetFinishedParam::__FIELDS__());
+
+NNVM_REGISTER_OP(_contrib_beamsearch_noop_grad)
+.set_num_inputs(1)
+.set_num_outputs(4)
+.set_attr<FCompute>("FCompute<cpu>", NoopGrad<cpu>)
+.set_attr<nnvm::TIsBackward>("TIsBackward", true);
+}
 
 Review comment:
   one more blank line below.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to