This is an automated email from the ASF dual-hosted git repository.

wesm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new 310e5b1  ARROW-1757: [C++] Add DictionaryArray::FromArrays alternate 
ctor that can check or sanitized "untrusted" indices
310e5b1 is described below

commit 310e5b1683d91f76ba3ee7d6f2096fe07b03789a
Author: Panchen Xue <[email protected]>
AuthorDate: Tue Feb 6 11:17:47 2018 -0500

    ARROW-1757: [C++] Add DictionaryArray::FromArrays alternate ctor that can 
check or sanitized "untrusted" indices
    
    Add static member function DictionaryArray::FromArrays to create 
DictionaryArray from a given type and an indices array. This method calls 
DictionaryArray::SanityCheck to check if all indices are within valid range.
    
    Author: Panchen Xue <[email protected]>
    
    Closes #1535 from xuepanchen/ARROW-1757 and squashes the following commits:
    
    88da7381 [Panchen Xue] Rename ValidateArray method and move to array.cc
    93e26a64 [Panchen Xue] Modify code based on comments
    ebb6897c [Panchen Xue] ARROW-1757: [C++] Lint format and add comment
    e6801ea3 [Panchen Xue] ARROW-1757: [C++] Add test case for 
DictionaryArray::FromArrays()
    5f4944c6 [Panchen Xue] ARROW-1757: [C++] Add DictionaryArray::FromArray and 
SanityCheck methods
---
 cpp/src/arrow/array-test.cc | 31 +++++++++++++++++++
 cpp/src/arrow/array.cc      | 75 ++++++++++++++++++++++++++++++++++++++++++++-
 cpp/src/arrow/array.h       | 13 ++++++++
 3 files changed, 118 insertions(+), 1 deletion(-)

diff --git a/cpp/src/arrow/array-test.cc b/cpp/src/arrow/array-test.cc
index c53da85..fa64d46 100644
--- a/cpp/src/arrow/array-test.cc
+++ b/cpp/src/arrow/array-test.cc
@@ -2384,6 +2384,37 @@ TEST(TestDictionary, Validate) {
   // ASSERT_OK(ValidateArray(*arr3));
 }
 
+TEST(TestDictionary, FromArray) {
+  std::shared_ptr<Array> dict;
+  vector<string> dict_values = {"foo", "bar", "baz"};
+  ArrayFromVector<StringType, string>(dict_values, &dict);
+  std::shared_ptr<DataType> dict_type = dictionary(int16(), dict);
+
+  std::shared_ptr<Array> indices1;
+  vector<int16_t> indices_values1 = {1, 2, 0, 0, 2, 0};
+  ArrayFromVector<Int16Type, int16_t>(indices_values1, &indices1);
+
+  std::shared_ptr<Array> indices2;
+  vector<int16_t> indices_values2 = {1, 2, 0, 3, 2, 0};
+  ArrayFromVector<Int16Type, int16_t>(indices_values2, &indices2);
+
+  std::shared_ptr<Array> indices3;
+  vector<bool> is_valid3 = {true, true, false, true, true, true};
+  vector<int16_t> indices_values3 = {1, 2, -1, 0, 2, 0};
+  ArrayFromVector<Int16Type, int16_t>(is_valid3, indices_values3, &indices3);
+
+  std::shared_ptr<Array> indices4;
+  vector<bool> is_valid4 = {true, true, false, true, true, true};
+  vector<int16_t> indices_values4 = {1, 2, 1, 3, 2, 0};
+  ArrayFromVector<Int16Type, int16_t>(is_valid4, indices_values4, &indices4);
+
+  std::shared_ptr<Array> arr1, arr2, arr3, arr4;
+  ASSERT_OK(DictionaryArray::FromArrays(dict_type, indices1, &arr1));
+  ASSERT_RAISES(Invalid, DictionaryArray::FromArrays(dict_type, indices2, 
&arr2));
+  ASSERT_OK(DictionaryArray::FromArrays(dict_type, indices3, &arr3));
+  ASSERT_RAISES(Invalid, DictionaryArray::FromArrays(dict_type, indices4, 
&arr4));
+}
+
 // ----------------------------------------------------------------------
 // Struct tests
 
diff --git a/cpp/src/arrow/array.cc b/cpp/src/arrow/array.cc
index 3d72761..a8043d6 100644
--- a/cpp/src/arrow/array.cc
+++ b/cpp/src/arrow/array.cc
@@ -476,6 +476,39 @@ const Array* UnionArray::UnsafeChild(int i) const {
 // ----------------------------------------------------------------------
 // DictionaryArray
 
+/// \brief Perform validation check to determine if all dictionary indices
+/// are within valid range (0 <= index < upper_bound)
+///
+/// \param[in] indices array of dictionary indices
+/// \param[in] upper_bound upper bound of valid range for indices
+/// \return Status
+template <typename ArrowType>
+Status ValidateDictionaryIndices(const std::shared_ptr<Array>& indices,
+                                 const int64_t upper_bound) {
+  using ArrayType = typename TypeTraits<ArrowType>::ArrayType;
+  const auto& array = static_cast<const ArrayType&>(*indices);
+  const typename ArrowType::c_type* data = array.raw_values();
+  const int64_t size = array.length();
+
+  if (array.null_count() == 0) {
+    for (int64_t idx = 0; idx < size; ++idx) {
+      if (data[idx] < 0 || data[idx] >= upper_bound) {
+        return Status::Invalid("Dictionary has out-of-bound index [0, 
dict.length)");
+      }
+    }
+  } else {
+    for (int64_t idx = 0; idx < size; ++idx) {
+      if (!array.IsNull(idx)) {
+        if (data[idx] < 0 || data[idx] >= upper_bound) {
+          return Status::Invalid("Dictionary has out-of-bound index [0, 
dict.length)");
+        }
+      }
+    }
+  }
+
+  return Status::OK();
+}
+
 DictionaryArray::DictionaryArray(const std::shared_ptr<ArrayData>& data)
     : dict_type_(static_cast<const DictionaryType*>(data->type.get())) {
   DCHECK_EQ(data->type->id(), Type::DICTIONARY);
@@ -492,11 +525,51 @@ DictionaryArray::DictionaryArray(const 
std::shared_ptr<DataType>& type,
   SetData(data);
 }
 
+Status DictionaryArray::FromArrays(const std::shared_ptr<DataType>& type,
+                                   const std::shared_ptr<Array>& indices,
+                                   std::shared_ptr<Array>* out) {
+  if (indices->length() == 0) {
+    return Status::Invalid("Dictionary indices must have non-zero length");
+  }
+
+  DCHECK_EQ(type->id(), Type::DICTIONARY);
+  const auto& dict = static_cast<const DictionaryType&>(*type);
+  DCHECK_EQ(indices->type_id(), dict.index_type()->id());
+
+  int64_t upper_bound = dict.dictionary()->length();
+  Status is_valid;
+
+  switch (indices->type_id()) {
+    case Type::INT8:
+      is_valid = ValidateDictionaryIndices<Int8Type>(indices, upper_bound);
+      break;
+    case Type::INT16:
+      is_valid = ValidateDictionaryIndices<Int16Type>(indices, upper_bound);
+      break;
+    case Type::INT32:
+      is_valid = ValidateDictionaryIndices<Int32Type>(indices, upper_bound);
+      break;
+    case Type::INT64:
+      is_valid = ValidateDictionaryIndices<Int64Type>(indices, upper_bound);
+      break;
+    default:
+      std::stringstream ss;
+      ss << "Categorical index type not supported: " << 
indices->type()->ToString();
+      return Status::NotImplemented(ss.str());
+  }
+
+  if (!is_valid.ok()) {
+    return is_valid;
+  }
+
+  *out = std::make_shared<DictionaryArray>(type, indices);
+  return is_valid;
+}
+
 void DictionaryArray::SetData(const std::shared_ptr<ArrayData>& data) {
   this->Array::SetData(data);
   auto indices_data = data_->Copy();
   indices_data->type = dict_type_->index_type();
-  std::shared_ptr<Array> result;
   indices_ = MakeArray(indices_data);
 }
 
diff --git a/cpp/src/arrow/array.h b/cpp/src/arrow/array.h
index f0a7861..5b9ce9a 100644
--- a/cpp/src/arrow/array.h
+++ b/cpp/src/arrow/array.h
@@ -726,6 +726,19 @@ class ARROW_EXPORT DictionaryArray : public Array {
   DictionaryArray(const std::shared_ptr<DataType>& type,
                   const std::shared_ptr<Array>& indices);
 
+  /// \brief Construct DictionaryArray from dictonary data type and indices 
array
+  ///
+  /// This function does the validation of the indices and input type. It 
checks if
+  /// all indices are non-negative and smaller than the size of the dictionary
+  ///
+  /// \param[in] type a data type containing a dictionary
+  /// \param[in] indices an array of non-negative signed
+  /// integers smaller than the size of the dictionary
+  /// \param[out] out the resulting DictionaryArray instance
+  static Status FromArrays(const std::shared_ptr<DataType>& type,
+                           const std::shared_ptr<Array>& indices,
+                           std::shared_ptr<Array>* out);
+
   std::shared_ptr<Array> indices() const;
   std::shared_ptr<Array> dictionary() const;
 

-- 
To stop receiving notification emails like this one, please contact
[email protected].

Reply via email to