bkietz commented on a change in pull request #7026:
URL: https://github.com/apache/arrow/pull/7026#discussion_r417342700



##########
File path: cpp/src/arrow/dataset/filter.cc
##########
@@ -1261,7 +1264,212 @@ Result<std::shared_ptr<RecordBatch>> 
TreeEvaluator::Filter(
   return batch->Slice(0, 0);
 }
 
-std::shared_ptr<ScalarExpression> scalar(bool value) { return 
scalar(MakeScalar(value)); }
+std::shared_ptr<Expression> scalar(bool value) { return 
scalar(MakeScalar(value)); }
+
+struct SerializeImpl {
+  Result<std::shared_ptr<StructArray>> ToArray(const Expression& expr) const {
+    return VisitExpression(expr, *this);
+  }
+
+  Result<std::shared_ptr<StructArray>> TaggedWithChildren(const Expression& 
expr,
+                                                          ArrayVector 
children) const {
+    children.emplace_back();
+    ARROW_ASSIGN_OR_RAISE(children.back(),
+                          MakeArrayFromScalar(Int32Scalar(expr.type()), 1));
+
+    return StructArray::Make(children, 
std::vector<std::string>(children.size(), ""));
+  }
+
+  Result<std::shared_ptr<StructArray>> operator()(const FieldExpression& expr) 
const {
+    ARROW_ASSIGN_OR_RAISE(auto name, 
MakeArrayFromScalar(StringScalar(expr.name()), 1));
+    return TaggedWithChildren(expr, {name});
+  }
+
+  Result<std::shared_ptr<StructArray>> operator()(const ScalarExpression& 
expr) const {
+    ARROW_ASSIGN_OR_RAISE(auto value, MakeArrayFromScalar(*expr.value(), 1));
+    return TaggedWithChildren(expr, {value});
+  }
+
+  Result<std::shared_ptr<StructArray>> operator()(const UnaryExpression& expr) 
const {
+    ARROW_ASSIGN_OR_RAISE(auto operand, ToArray(*expr.operand()));
+    return TaggedWithChildren(expr, {operand});
+  }
+
+  Result<std::shared_ptr<StructArray>> operator()(const CastExpression& expr) 
const {
+    ARROW_ASSIGN_OR_RAISE(auto operand, ToArray(*expr.operand()));
+
+    std::shared_ptr<Array> is_like_expr, to;
+    if (const auto& to_type = expr.to_type()) {
+      ARROW_ASSIGN_OR_RAISE(is_like_expr, 
MakeArrayFromScalar(BooleanScalar(false), 1));
+      ARROW_ASSIGN_OR_RAISE(to, MakeArrayOfNull(to_type, 1));
+    }
+    if (const auto& like_expr = expr.like_expr()) {
+      ARROW_ASSIGN_OR_RAISE(is_like_expr, 
MakeArrayFromScalar(BooleanScalar(true), 1));
+      ARROW_ASSIGN_OR_RAISE(to, ToArray(*like_expr));
+    }
+
+    return TaggedWithChildren(expr, {operand, is_like_expr, to});
+  }
+
+  Result<std::shared_ptr<StructArray>> operator()(const BinaryExpression& 
expr) const {
+    ARROW_ASSIGN_OR_RAISE(auto left_operand, ToArray(*expr.left_operand()));
+    ARROW_ASSIGN_OR_RAISE(auto right_operand, ToArray(*expr.right_operand()));
+    return TaggedWithChildren(expr, {left_operand, right_operand});
+  }
+
+  Result<std::shared_ptr<StructArray>> operator()(
+      const ComparisonExpression& expr) const {
+    ARROW_ASSIGN_OR_RAISE(auto left_operand, ToArray(*expr.left_operand()));
+    ARROW_ASSIGN_OR_RAISE(auto right_operand, ToArray(*expr.right_operand()));
+    ARROW_ASSIGN_OR_RAISE(auto op, MakeArrayFromScalar(Int32Scalar(expr.op()), 
1));
+    return TaggedWithChildren(expr, {left_operand, right_operand, op});
+  }
+
+  Result<std::shared_ptr<StructArray>> operator()(const InExpression& expr) 
const {
+    ARROW_ASSIGN_OR_RAISE(auto operand, ToArray(*expr.operand()));
+
+    auto set_type = list(expr.set()->type());
+
+    ARROW_ASSIGN_OR_RAISE(auto set_offsets, AllocateBuffer(sizeof(int32_t) * 
2));
+    reinterpret_cast<int32_t*>(set_offsets->mutable_data())[0] = 0;
+    reinterpret_cast<int32_t*>(set_offsets->mutable_data())[1] =
+        static_cast<int32_t>(expr.set()->length());
+
+    auto set_values = expr.set();
+
+    auto set = std::make_shared<ListArray>(std::move(set_type), 1, 
std::move(set_offsets),
+                                           std::move(set_values));
+    return TaggedWithChildren(expr, {operand, set});
+  }
+
+  Result<std::shared_ptr<StructArray>> operator()(const Expression& expr) 
const {
+    return Status::NotImplemented("serialization of ", expr.ToString());
+  }
+};
+
+Result<std::shared_ptr<Buffer>> Expression::Serialize() const {
+  ARROW_ASSIGN_OR_RAISE(auto array, SerializeImpl{}.ToArray(*this));
+  ARROW_ASSIGN_OR_RAISE(auto batch, RecordBatch::FromStructArray(array));
+  ARROW_ASSIGN_OR_RAISE(auto stream, io::BufferOutputStream::Create());
+  ARROW_ASSIGN_OR_RAISE(auto writer, ipc::NewFileWriter(stream.get(), 
batch->schema()));
+  RETURN_NOT_OK(writer->WriteRecordBatch(*batch));
+  RETURN_NOT_OK(writer->Close());
+  return stream->Finish();
+}
+
+struct DeserializeImpl {
+  Result<std::shared_ptr<Expression>> FromArray(const Array& array) const {
+    if (array.type_id() != Type::STRUCT || array.length() != 1) {
+      return Status::Invalid("can only deserialize expressions from 
unit-length",
+                             " StructArray, got ", array);
+    }
+    const auto& struct_array = checked_cast<const StructArray&>(array);
+
+    ARROW_ASSIGN_OR_RAISE(auto expression_type, 
GetExpressionType(struct_array));
+    switch (expression_type) {
+      case ExpressionType::FIELD: {
+        ARROW_ASSIGN_OR_RAISE(auto name, GetView<StringType>(struct_array, 0));
+        return field_ref(name.to_string());
+      }
+
+      case ExpressionType::SCALAR: {
+        ARROW_ASSIGN_OR_RAISE(auto value,
+                              Scalar::FromArraySlot(*struct_array.field(0), 
0));
+        return scalar(std::move(value));
+      }
+
+      case ExpressionType::NOT: {
+        ARROW_ASSIGN_OR_RAISE(auto operand, FromArray(*struct_array.field(0)));
+        return not_(std::move(operand));
+      }
+
+      case ExpressionType::CAST: {
+        ARROW_ASSIGN_OR_RAISE(auto operand, FromArray(*struct_array.field(0)));
+        ARROW_ASSIGN_OR_RAISE(auto is_like_expr, 
GetView<BooleanType>(struct_array, 1));
+        if (is_like_expr) {
+          ARROW_ASSIGN_OR_RAISE(auto like_expr, 
FromArray(*struct_array.field(2)));
+          return operand->CastLike(std::move(like_expr)).Copy();
+        }
+        return operand->CastTo(struct_array.field(2)->type()).Copy();
+      }
+
+      case ExpressionType::AND: {
+        ARROW_ASSIGN_OR_RAISE(auto left_operand, 
FromArray(*struct_array.field(0)));
+        ARROW_ASSIGN_OR_RAISE(auto right_operand, 
FromArray(*struct_array.field(1)));
+        return and_(std::move(left_operand), std::move(right_operand));
+      }
+
+      case ExpressionType::OR: {
+        ARROW_ASSIGN_OR_RAISE(auto left_operand, 
FromArray(*struct_array.field(0)));
+        ARROW_ASSIGN_OR_RAISE(auto right_operand, 
FromArray(*struct_array.field(1)));
+        return or_(std::move(left_operand), std::move(right_operand));
+      }
+
+      case ExpressionType::COMPARISON: {
+        ARROW_ASSIGN_OR_RAISE(auto left_operand, 
FromArray(*struct_array.field(0)));
+        ARROW_ASSIGN_OR_RAISE(auto right_operand, 
FromArray(*struct_array.field(1)));
+        ARROW_ASSIGN_OR_RAISE(auto op, GetView<Int32Type>(struct_array, 2));
+        return std::make_shared<ComparisonExpression>(
+            static_cast<compute::CompareOperator>(op), std::move(left_operand),
+            std::move(right_operand));
+      }
+
+      case ExpressionType::IS_VALID: {
+        ARROW_ASSIGN_OR_RAISE(auto operand, FromArray(*struct_array.field(0)));
+        return std::make_shared<IsValidExpression>(std::move(operand));
+      }
+
+      case ExpressionType::IN: {
+        ARROW_ASSIGN_OR_RAISE(auto operand, FromArray(*struct_array.field(0)));
+        if (struct_array.field(1)->type_id() != Type::LIST) {
+          return Status::TypeError("expected field 1 of ", struct_array,
+                                   " to have list type");
+        }
+        auto set = checked_cast<const 
ListArray&>(*struct_array.field(1)).values();
+        return std::make_shared<InExpression>(std::move(operand), 
std::move(set));
+      }
+
+      default:
+        break;
+    }
+
+    return Status::Invalid("non-deserializable ExpressionType ", 
expression_type);
+  }
+
+  template <typename T, typename A = typename TypeTraits<T>::ArrayType>
+  static Result<decltype(std::declval<A>().GetView(0))> GetView(const 
StructArray& array,
+                                                                int index) {
+    if (index >= array.num_fields()) {
+      return Status::IndexError("expected ", array, " to have a child at index 
", index);
+    }
+
+    const auto& child = *array.field(index);
+    if (child.type_id() != T::type_id) {
+      return Status::TypeError("expected child ", index, " of ", array, " to 
have type ",
+                               T::type_id);
+    }
+
+    return checked_cast<const A&>(child).GetView(0);
+  }
+
+  static Result<ExpressionType::type> GetExpressionType(const StructArray& 
array) {
+    if (array.struct_type()->num_children() == 0) {

Review comment:
       Do you mean `< 1`?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to