lebeg commented on a change in pull request #13224: [WIP] Quantize/digitize 
operator
URL: https://github.com/apache/incubator-mxnet/pull/13224#discussion_r235914956
 
 

 ##########
 File path: src/operator/tensor/digitize_op.h
 ##########
 @@ -0,0 +1,184 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2018 by Contributors
+ * \file digitize_op.h
+ * \brief Quantize operator a la numpy.digitize.
+ */
+#ifndef MXNET_OPERATOR_TENSOR_DIGITIZE_H_
+#define MXNET_OPERATOR_TENSOR_DIGITIZE_H_
+
+#include <mxnet/operator_util.h>
+#include "../mshadow_op.h"
+#include "../mxnet_op.h"
+#include "../operator_common.h"
+#include "../elemwise_op_common.h"
+#include <mxnet/base.h>
+
+namespace mxnet {
+namespace op {
+
+struct DigitizeParam : public dmlc::Parameter<DigitizeParam> {
+  bool right;
+
+  DMLC_DECLARE_PARAMETER(DigitizeParam) {
+    DMLC_DECLARE_FIELD(right)
+        .set_default(false)
+        .describe("Whether the intervals include the right or the left bin 
edge.");
+  }
+};
+
+class DigitizeOp {
+public:
+  bool InferShape(const nnvm::NodeAttrs &attrs,
+                  std::vector<TShape> *in_attrs,
+                  std::vector<TShape> *out_attrs) {
+    using namespace mshadow;
+
+    CHECK_EQ(in_attrs->size(), 2); // Size 2: data and bins
+    CHECK_EQ(out_attrs->size(), 1); // Only one output tensor
+
+    const auto &bin_size = in_attrs->at(1).Size();
+    CHECK_LE(bin_size, 2); // Size <= 2 for bins
+
+    auto &input_shape = (*in_attrs)[0];
+    auto &output_shape = (*out_attrs)[0];
+
+    SHAPE_ASSIGN_CHECK(*out_attrs, 0, input_shape); // First arg is a shape 
array
+    SHAPE_ASSIGN_CHECK(*in_attrs, 0, output_shape);
+
+    // If bins has two dimensions, the first one corresponds to the batch 
axis, so we need to verify
+    // # batches in X = # batches in bins
+    if (bin_size == 2) {
+      auto &input_batches = (*in_attrs)[0][0];
+      auto &bin_batches = (*in_attrs)[1][0];
+
+      CHECK_EQ(input_batches, bin_batches)
+        << "If bins has 2 dimensions, the first one should be the same as that 
of the input data";
+      //TODO: Reword the message above
+    }
+
+    return true;
+  }
+
+  struct ForwardKernel {
+    template<typename xpu>
+    MSHADOW_XINLINE static void Map(int i,
+                                    const OpContext &ctx,
+                                    const TBlob &input_data,
+                                    const TBlob &bins,
+                                    TBlob &out_data,
+                                    const bool right);
+  };
+
+  template<class ForwardIterator, typename DType>
+  void CheckMonotonic(ForwardIterator begin, ForwardIterator end) {
+    // adjacent_find here returns the begin element that's >= than the next 
one or the last element
+    CHECK_EQ(std::adjacent_find(begin, end, std::greater_equal<DType>()), end)
+      << "Bins vector must be strictly monotonically increasing";
+  }
+
+  // Based on http://www.cplusplus.com/reference/algorithm/is_sorted/
 
 Review comment:
   This can be safely removed

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to