eric-haibin-lin closed pull request #9675: Add contrib.compute_accidental_hits 
operator for candidate sampling
URL: https://github.com/apache/incubator-mxnet/pull/9675
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/mshadow b/mshadow
index f5b67f380c..1a83023848 160000
--- a/mshadow
+++ b/mshadow
@@ -1 +1 @@
-Subproject commit f5b67f380cb0588be11e6f440f92f013139380ee
+Subproject commit 1a830238481578e480e202adb258c988c4b2528e
diff --git a/src/operator/contrib/compute_acc_hits-inl.h 
b/src/operator/contrib/compute_acc_hits-inl.h
new file mode 100644
index 0000000000..34cf0b8fc8
--- /dev/null
+++ b/src/operator/contrib/compute_acc_hits-inl.h
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *  Copyright (c) 2017 by Contributors
+ * \file compute_acc_hits-inl.h
+ * \brief implementation of compute_accidental_hits operator
+ */
+#ifndef MXNET_OPERATOR_CONTRIB_COMPUTE_ACC_HITS_INL_H_
+#define MXNET_OPERATOR_CONTRIB_COMPUTE_ACC_HITS_INL_H_
+
+#include <mxnet/operator_util.h>
+#include <vector>
+#include "../elemwise_op_common.h"
+#include "../operator_common.h"
+#include "../mxnet_op.h"
+
+namespace mxnet {
+namespace op {
+
+template<typename xpu>
+void AccidentalHitComputeCsrImpl(mshadow::Stream<xpu> *s,
+                                 const TBlob& label,
+                                 const TBlob& sample,
+                                 const OpReqType req,
+                                 const NDArray& output);
+
+template<typename xpu>
+void AccidentalHitComputeEx(const nnvm::NodeAttrs& attrs,
+                            const OpContext& ctx,
+                            const std::vector<NDArray>& inputs,
+                            const std::vector<OpReqType>& req,
+                            const std::vector<NDArray>& outputs) {
+  CHECK_EQ(inputs.size(), 2U);
+  CHECK_EQ(outputs.size(), 1U);
+  CHECK_EQ(req.size(), 1U);
+  mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
+  if (common::ContainsOnlyStorage(inputs, kDefaultStorage) &&
+      outputs[0].storage_type() == kCSRStorage) {
+    AccidentalHitComputeCsrImpl(s, inputs[0].data(), inputs[1].data(), req[0],
+                                outputs[0]);
+  } else {
+    LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
+  }
+}
+
+inline bool AccidentalHitShape(const nnvm::NodeAttrs& attrs,
+                               std::vector<TShape> *in_attrs,
+                               std::vector<TShape> *out_attrs) {
+  CHECK_EQ(in_attrs->size(), 2U);
+  CHECK_EQ(out_attrs->size(), 1U);
+  for (size_t i = 0; i < 2; ++i) {
+    CHECK_EQ(in_attrs->at(i).ndim(), 1);
+  }
+  TShape out_attr{in_attrs->at(0)[0], in_attrs->at(1)[0]};
+  SHAPE_ASSIGN_CHECK(*out_attrs, 0, out_attr);
+  return true;
+}
+
+inline bool AccidentalHitStorageType(const nnvm::NodeAttrs& attrs,
+                                     const int dev_mask,
+                                     DispatchMode* dispatch_mode,
+                                     std::vector<int>* in_attrs,
+                                     std::vector<int>* out_attrs) {
+  CHECK_EQ(in_attrs->size(), 2U);
+  CHECK_EQ(out_attrs->size(), 1U);
+  auto& out_stype = out_attrs->at(0);
+  bool dispatched = false;
+  if (!dispatched && common::ContainsOnlyStorage(*in_attrs, kDefaultStorage) &&
+      dev_mask == Context::kCPU) {
+    // dns, dns -> csr
+    dispatched = storage_type_assign(&out_stype, kCSRStorage, dispatch_mode,
+                                     DispatchMode::kFComputeEx);
+  }
+  return dispatched;
+}
+
+}  // namespace op
+}  // namespace mxnet
+#endif  // MXNET_OPERATOR_CONTRIB_COMPUTE_ACC_HITS_INL_H_
diff --git a/src/operator/contrib/compute_acc_hits.cc 
b/src/operator/contrib/compute_acc_hits.cc
new file mode 100644
index 0000000000..6e3944e872
--- /dev/null
+++ b/src/operator/contrib/compute_acc_hits.cc
@@ -0,0 +1,134 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *  Copyright (c) 2017 by Contributors
+ * \file compute_acc_hits.cc
+ * \brief
+ */
+#include "./compute_acc_hits-inl.h"
+
+namespace mxnet {
+namespace op {
+
+/* \brief the kernel to compute accidental hit on CPU
+ * \param i i-th        thread
+ * \param out_data      the output csr's data
+ * \param out_idx       the output csr's column indices
+ * \param label         the true classes
+ * \param out_indptr    the output csr's indptr
+ * \param map           the hash map that stores positions of sampled 
candidates
+ */
+struct accidental_hit {
+  template<typename IType, typename DType>
+  MSHADOW_XINLINE static void Map(int i, DType *out_data, IType *out_idx,
+                                  const DType* label, const IType* out_indptr,
+                                  const std::unordered_map<DType, 
std::list<IType>> *map) {
+    const auto it = map->find(label[i]);
+    const DType one = static_cast<DType>(1);
+    IType j = out_indptr[i];
+    if (it != map->end()) {
+      for (const IType idx : it->second) {
+        out_data[j] = one;
+        out_idx[j++] = idx;
+      }
+    }
+  }
+};
+
+template<>
+void AccidentalHitComputeCsrImpl<cpu>(mshadow::Stream<cpu> *s,
+                                      const TBlob& label,
+                                      const TBlob& sample,
+                                      const OpReqType req,
+                                      const NDArray& output) {
+  if (req == kNullOp) return;
+  CHECK_EQ(req, kWriteTo) << "Unexpected req for compute accidental hits 
operator";
+  using nnvm::dim_t;
+  using namespace csr;
+  using namespace mxnet_op;
+  const dim_t num_sample = sample.shape_.Size();
+  const dim_t num_label = label.shape_.Size();
+  MSHADOW_TYPE_SWITCH(label.type_flag_, DType, {
+    MSHADOW_IDX_TYPE_SWITCH(output.aux_type(kIdx), IType, {
+      std::unordered_map<DType, std::list<IType>> sample_map;
+      const DType *label_data = label.dptr<DType>();
+      const DType *sample_data = sample.dptr<DType>();
+      for (IType i = 0; i < num_sample; i++) {
+        sample_map[sample_data[i]].push_back(i);
+      }
+      output.CheckAndAllocAuxData(kIndPtr, mshadow::Shape1(num_label + 1));
+      IType *out_indptr = output.aux_data(kIndPtr).dptr<IType>();
+      out_indptr[0] = 0;
+      // compute the number of matches for each row
+      for (dim_t i = 1; i < num_label + 1; i++) {
+        IType count = 0;
+        const auto it = sample_map.find(label_data[i - 1]);
+        // accidental match found
+        if (it != sample_map.end()) {
+          count = it->second.size();
+        }
+        out_indptr[i] = out_indptr[i - 1] + count;
+      }
+      // allocate the memory based on nnz
+      const IType nnz = out_indptr[num_label];
+      output.CheckAndAllocData(mshadow::Shape1(nnz));
+      output.CheckAndAllocAuxData(kIdx, mshadow::Shape1(nnz));
+      DType *out_data = output.data().dptr<DType>();
+      IType *out_idx = output.aux_data(kIdx).dptr<IType>();
+      Kernel<accidental_hit, cpu>::Launch(s, num_label, out_data,
+             out_idx, label_data, out_indptr, &sample_map);
+    });
+  });
+}
+
+NNVM_REGISTER_OP(_contrib_compute_accidental_hits)
+.describe(R"code(Compute the indices in ``sampled_candidates`` which matches 
``true_classes``
+and return the mask for the matching positions.
+
+The operator is used for removing sampled classes which happen to match target 
classes
+(i.e. accidental hits) for sampled softmax and sampled logistic. The mask has 0
+for non-matching positions and 1 for matching ones.
+
+Both inputs are expected to be 1-D. For example, let's say ``true_classes`` 
has shape (M,),
+and ``sampled_candidates`` has (N,), then the resulting mask will have shape 
(M, N).
+mask[i][j] = 1 iff true_classes[i] == sampled_candidates[j].
+
+Example::
+
+   true = [1,5,11]
+   sampled = [5,8,1,5,24]
+   compute_accidental_hits(true, sampled) = [[0, 0, 1, 0, 0],
+                                             [1, 0, 0, 1, 0],
+                                             [0, 0, 0, 0, 0]]
+
+.. note:: `compute_accidental_hits` is only available on CPU and returns a 
compressed sparse row mask.
+
+)code" ADD_FILELINE)
+.set_num_inputs(2)
+.set_num_outputs(1)
+.set_attr<nnvm::FInferShape>("FInferShape", AccidentalHitShape)
+.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>)
+.set_attr<FInferStorageType>("FInferStorageType", AccidentalHitStorageType)
+.set_attr<FComputeEx>("FComputeEx<cpu>", AccidentalHitComputeEx<cpu>)
+.add_argument("true_classes", "NDArray-or-Symbol", "True Classes of 1-D 
shape.")
+.add_argument("sampled_candidates", "NDArray-or-Symbol", "Sampled Candidates 
of 1-D shape.");
+
+}  // namespace op
+}  // namespace mxnet
diff --git a/tests/python/unittest/test_sparse_ndarray.py 
b/tests/python/unittest/test_sparse_ndarray.py
index ab389b6d03..3863a202ae 100644
--- a/tests/python/unittest/test_sparse_ndarray.py
+++ b/tests/python/unittest/test_sparse_ndarray.py
@@ -843,6 +843,46 @@ def check_sparse_nd_norm(stype, shape, density):
         for density in densities:
             check_sparse_nd_norm(stype, shape, density)
 
+def test_sparse_nd_accidental_hit():
+    """ test contrib.accidental_hit on cpu """
+    def compute_hits(label, sample):
+        num_label = len(label)
+        num_sample = len(sample)
+        label_indices = {}
+        # record all true classes indices
+        for i in range(num_label):
+            if label[i] not in label_indices:
+                label_indices[label[i]] = []
+            label_indices[label[i]].append(i)
+        accidental_hits = []
+        for i in range(num_sample):
+            s = sample[i]
+            if s in label_indices:
+                hits = label_indices[s]
+                accidental_hits += [(h, i) for h in hits]
+        return accidental_hits
+
+    n = 20
+    mx.random.seed(1)
+    np.random.seed(1)
+    num_label = np.random.randint(1, 10)
+    num_sample = np.random.randint(1, 10)
+    label = np.random.randint(0, n, size=num_label)
+    sample = np.random.randint(0, n, size=num_sample)
+    accidental_hits = compute_hits(label, sample)
+    for dtype in [np.float16, np.int32, np.int64, np.float32, np.float64]:
+        label_nd = mx.nd.array(label, dtype=dtype)
+        sample_nd = mx.nd.array(sample, dtype=dtype)
+        hit_mask = mx.nd.contrib.compute_accidental_hits(true_classes=label_nd,
+                                                         
sampled_candidates=sample_nd)
+        hit_mask_dns = hit_mask.tostype('default')
+        # check total number of hits
+        assert(hit_mask_dns.sum() == len(accidental_hits))
+        # check individual hit
+        for p in accidental_hits:
+            hit_mask_dns[p] = 0
+        assert(hit_mask_dns.sum() == 0)
+
 if __name__ == '__main__':
     import nose
     nose.runmodule()


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to