This is an automated email from the ASF dual-hosted git repository.

kou pushed a commit to branch maint-6.0.x
in repository https://gitbox.apache.org/repos/asf/arrow.git

commit 76dfc385fa1a7b0dc8172ced95f311c21af1430f
Author: David Li <[email protected]>
AuthorDate: Sat Nov 6 09:48:09 2021 -0400

    ARROW-14519: [C++] Properly error if joining on unsupported type
    
    Instead of DCHECK, return a NotImplemented.
    
    Closes #11625 from lidavidm/arrow-14519
    
    Authored-by: David Li <[email protected]>
    Signed-off-by: David Li <[email protected]>
---
 cpp/src/arrow/compute/exec/hash_join.h            |  1 +
 cpp/src/arrow/compute/exec/hash_join_node.cc      | 33 ++++++++++++++++++-----
 cpp/src/arrow/compute/exec/hash_join_node_test.cc | 33 +++++++++++++++++++++++
 cpp/src/arrow/compute/exec/schema_util.h          | 32 +++-------------------
 4 files changed, 63 insertions(+), 36 deletions(-)

diff --git a/cpp/src/arrow/compute/exec/hash_join.h 
b/cpp/src/arrow/compute/exec/hash_join.h
index 11b36d9..6520e4a 100644
--- a/cpp/src/arrow/compute/exec/hash_join.h
+++ b/cpp/src/arrow/compute/exec/hash_join.h
@@ -66,6 +66,7 @@ class ARROW_EXPORT HashJoinSchema {
   SchemaProjectionMaps<HashJoinProjection> proj_maps[2];
 
  private:
+  static bool IsTypeSupported(const DataType& type);
   static Result<std::vector<FieldRef>> VectorDiff(const Schema& schema,
                                                   const std::vector<FieldRef>& 
a,
                                                   const std::vector<FieldRef>& 
b);
diff --git a/cpp/src/arrow/compute/exec/hash_join_node.cc 
b/cpp/src/arrow/compute/exec/hash_join_node.cc
index 583ac9a..4bccb76 100644
--- a/cpp/src/arrow/compute/exec/hash_join_node.cc
+++ b/cpp/src/arrow/compute/exec/hash_join_node.cc
@@ -34,6 +34,15 @@ using internal::checked_cast;
 
 namespace compute {
 
+// Check if a type is supported in a join (as either a key or non-key column)
+bool HashJoinSchema::IsTypeSupported(const DataType& type) {
+  const Type::type id = type.id();
+  if (id == Type::DICTIONARY) {
+    return IsTypeSupported(*checked_cast<const 
DictionaryType&>(type).value_type());
+  }
+  return is_fixed_width(id) || is_binary_like(id) || is_large_binary_like(id);
+}
+
 Result<std::vector<FieldRef>> HashJoinSchema::VectorDiff(const Schema& schema,
                                                          const 
std::vector<FieldRef>& a,
                                                          const 
std::vector<FieldRef>& b) {
@@ -141,8 +150,7 @@ Status HashJoinSchema::ValidateSchemas(JoinType join_type, 
const Schema& left_sc
   // 2. Same number of key fields on left and right
   // 3. At least one key field
   // 4. Equal data types for corresponding key fields
-  // 5. Dictionary type is not supported in a key field
-  // 6. Some other data types may not be allowed in a key field
+  // 5. Some data types may not be allowed in a key field or non-key field
   //
   if (left_keys.size() != right_keys.size()) {
     return Status::Invalid("Different number of key fields on left (", 
left_keys.size(),
@@ -164,11 +172,8 @@ Status HashJoinSchema::ValidateSchemas(JoinType join_type, 
const Schema& left_sc
     const FieldPath& match = result.ValueUnsafe();
     const std::shared_ptr<DataType>& type =
         (left_side ? left_schema.fields() : 
right_schema.fields())[match[0]]->type();
-    if ((type->id() != Type::BOOL && !is_fixed_width(type->id()) &&
-         !is_binary_like(type->id())) ||
-        is_large_binary_like(type->id())) {
-      return Status::Invalid("Data type ", type->ToString(),
-                             " is not supported in join key field");
+    if (!IsTypeSupported(*type)) {
+      return Status::Invalid("Data type ", *type, " is not supported in join 
key field");
     }
   }
   for (size_t i = 0; i < left_keys.size(); ++i) {
@@ -185,6 +190,20 @@ Status HashJoinSchema::ValidateSchemas(JoinType join_type, 
const Schema& left_sc
           right_ref.ToString(), " of type ", right_type->ToString());
     }
   }
+  for (const auto& field : left_schema.fields()) {
+    const auto& type = *field->type();
+    if (!IsTypeSupported(type)) {
+      return Status::Invalid("Data type ", type,
+                             " is not supported in join non-key field");
+    }
+  }
+  for (const auto& field : right_schema.fields()) {
+    const auto& type = *field->type();
+    if (!IsTypeSupported(type)) {
+      return Status::Invalid("Data type ", type,
+                             " is not supported in join non-key field");
+    }
+  }
 
   // Check for output fields:
   // 1. Output field refs must match exactly one input field
diff --git a/cpp/src/arrow/compute/exec/hash_join_node_test.cc 
b/cpp/src/arrow/compute/exec/hash_join_node_test.cc
index d20b456..9afddf3 100644
--- a/cpp/src/arrow/compute/exec/hash_join_node_test.cc
+++ b/cpp/src/arrow/compute/exec/hash_join_node_test.cc
@@ -1656,5 +1656,38 @@ TEST(HashJoin, DictNegative) {
   }
 }
 
+TEST(HashJoin, UnsupportedTypes) {
+  // ARROW-14519
+  const bool parallel = false;
+  const bool slow = false;
+
+  auto l_schema = schema({field("l_i32", int32()), field("l_list", 
list(int32()))});
+  auto l_schema_nolist = schema({field("l_i32", int32())});
+  auto r_schema = schema({field("r_i32", int32()), field("r_list", 
list(int32()))});
+  auto r_schema_nolist = schema({field("r_i32", int32())});
+
+  std::vector<std::pair<std::shared_ptr<Schema>, std::shared_ptr<Schema>>> 
cases{
+      {l_schema, r_schema}, {l_schema_nolist, r_schema}, {l_schema, 
r_schema_nolist}};
+  std::vector<FieldRef> l_keys{{"l_i32"}};
+  std::vector<FieldRef> r_keys{{"r_i32"}};
+
+  for (const auto& schemas : cases) {
+    BatchesWithSchema l_batches = GenerateBatchesFromString(schemas.first, 
{R"([])"});
+    BatchesWithSchema r_batches = GenerateBatchesFromString(schemas.second, 
{R"([])"});
+
+    ExecContext exec_ctx;
+    ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(&exec_ctx));
+
+    HashJoinNodeOptions join_options{JoinType::LEFT_SEMI, l_keys, r_keys};
+    Declaration join{"hashjoin", join_options};
+    join.inputs.emplace_back(Declaration{
+        "source", SourceNodeOptions{l_batches.schema, l_batches.gen(parallel, 
slow)}});
+    join.inputs.emplace_back(Declaration{
+        "source", SourceNodeOptions{r_batches.schema, r_batches.gen(parallel, 
slow)}});
+
+    ASSERT_RAISES(Invalid, join.AddToPlan(plan.get()));
+  }
+}
+
 }  // namespace compute
 }  // namespace arrow
diff --git a/cpp/src/arrow/compute/exec/schema_util.h 
b/cpp/src/arrow/compute/exec/schema_util.h
index 33f4270..279cbb8 100644
--- a/cpp/src/arrow/compute/exec/schema_util.h
+++ b/cpp/src/arrow/compute/exec/schema_util.h
@@ -62,7 +62,7 @@ class SchemaProjectionMaps {
               const std::vector<ProjectionIdEnum>& projection_handles,
               const std::vector<const std::vector<FieldRef>*>& projections) {
     ARROW_DCHECK(projection_handles.size() == projections.size());
-    RegisterSchema(full_schema_handle, schema);
+    ARROW_RETURN_NOT_OK(RegisterSchema(full_schema_handle, schema));
     for (size_t i = 0; i < projections.size(); ++i) {
       ARROW_RETURN_NOT_OK(
           RegisterProjectedSchema(projection_handles[i], *(projections[i]), 
schema));
@@ -76,11 +76,6 @@ class SchemaProjectionMaps {
     return static_cast<int>(schemas_[id].second.size());
   }
 
-  const KeyEncoder::KeyColumnMetadata& column_metadata(ProjectionIdEnum 
schema_handle,
-                                                       int field_id) const {
-    return field(schema_handle, field_id).column_metadata;
-  }
-
   const std::string& field_name(ProjectionIdEnum schema_handle, int field_id) 
const {
     return field(schema_handle, field_id).field_name;
   }
@@ -105,10 +100,9 @@ class SchemaProjectionMaps {
     int field_path;
     std::string field_name;
     std::shared_ptr<DataType> data_type;
-    KeyEncoder::KeyColumnMetadata column_metadata;
   };
 
-  void RegisterSchema(ProjectionIdEnum handle, const Schema& schema) {
+  Status RegisterSchema(ProjectionIdEnum handle, const Schema& schema) {
     std::vector<FieldInfo> out_fields;
     const FieldVector& in_fields = schema.fields();
     out_fields.resize(in_fields.size());
@@ -118,9 +112,9 @@ class SchemaProjectionMaps {
       out_fields[i].field_path = static_cast<int>(i);
       out_fields[i].field_name = name;
       out_fields[i].data_type = type;
-      out_fields[i].column_metadata = ColumnMetadataFromDataType(type);
     }
     schemas_.push_back(std::make_pair(handle, out_fields));
+    return Status::OK();
   }
 
   Status RegisterProjectedSchema(ProjectionIdEnum handle,
@@ -137,7 +131,6 @@ class SchemaProjectionMaps {
       out_fields[i].field_path = match[0];
       out_fields[i].field_name = name;
       out_fields[i].data_type = type;
-      out_fields[i].column_metadata = ColumnMetadataFromDataType(type);
     }
     schemas_.push_back(std::make_pair(handle, out_fields));
     return Status::OK();
@@ -153,25 +146,6 @@ class SchemaProjectionMaps {
     }
   }
 
-  KeyEncoder::KeyColumnMetadata ColumnMetadataFromDataType(
-      const std::shared_ptr<DataType>& type) {
-    if (type->id() == Type::DICTIONARY) {
-      auto bit_width = checked_cast<const FixedWidthType&>(*type).bit_width();
-      ARROW_DCHECK(bit_width % 8 == 0);
-      return KeyEncoder::KeyColumnMetadata(true, bit_width / 8);
-    } else if (type->id() == Type::BOOL) {
-      return KeyEncoder::KeyColumnMetadata(true, 0);
-    } else if (is_fixed_width(type->id())) {
-      return KeyEncoder::KeyColumnMetadata(
-          true, checked_cast<const FixedWidthType&>(*type).bit_width() / 8);
-    } else if (is_binary_like(type->id())) {
-      return KeyEncoder::KeyColumnMetadata(false, sizeof(uint32_t));
-    } else {
-      ARROW_DCHECK(false);
-      return KeyEncoder::KeyColumnMetadata(true, 0);
-    }
-  }
-
   int schema_id(ProjectionIdEnum schema_handle) const {
     for (size_t i = 0; i < schemas_.size(); ++i) {
       if (schemas_[i].first == schema_handle) {

Reply via email to