This is an automated email from the ASF dual-hosted git repository.
npr pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new d14e778 ARROW-8216: [C++][Compute] Filter out nulls by default
d14e778 is described below
commit d14e778dd831db972b6326b8917ac7e822182276
Author: Benjamin Kietzman <[email protected]>
AuthorDate: Thu Apr 2 15:41:16 2020 -0700
ARROW-8216: [C++][Compute] Filter out nulls by default
In C++ and R, Filter's default is now to drop from output when a selection
slot is null. In Python, `Array.filter` now has a keyword argument
`drop_nulls=True`. In R, since the expected behavior is the opposite, the
argument is `keep_na = TRUE`. Old behavior is preserved in c_glib; I'm not
familiar enough with that binding to add an optional parameter in an idiomatic
fashion (@mrkn ?)
Closes #6732 from bkietz/8216-Filtering-returns-all-mis
Lead-authored-by: Benjamin Kietzman <[email protected]>
Co-authored-by: Kenta Murata <[email protected]>
Co-authored-by: Neal Richardson <[email protected]>
Signed-off-by: Neal Richardson <[email protected]>
---
c_glib/arrow-glib/compute.cpp | 318 ++++++++++++++++++++---
c_glib/arrow-glib/compute.h | 37 +++
c_glib/arrow-glib/compute.hpp | 3 +
c_glib/test/test-compare.rb | 7 +
c_glib/test/test-count.rb | 7 +
c_glib/test/test-filter.rb | 102 +++++++-
cpp/src/arrow/array/diff_test.cc | 52 ++--
cpp/src/arrow/compute/kernel.h | 13 +
cpp/src/arrow/compute/kernels/filter.cc | 261 +++++++++----------
cpp/src/arrow/compute/kernels/filter.h | 129 ++--------
cpp/src/arrow/compute/kernels/filter_test.cc | 339 ++++++++++++++++---------
cpp/src/arrow/dataset/filter.cc | 6 +-
python/pyarrow/array.pxi | 27 +-
python/pyarrow/includes/libarrow.pxd | 12 +-
python/pyarrow/tests/test_compute.py | 9 +-
r/R/array.R | 12 +-
r/R/arrowExports.R | 24 +-
r/R/chunked-array.R | 8 +-
r/R/dplyr.R | 2 +-
r/R/record-batch.R | 6 +-
r/R/table.R | 8 +-
r/man/ChunkedArray.Rd | 2 +-
r/man/RecordBatch.Rd | 2 +-
r/man/Table.Rd | 2 +-
r/man/array.Rd | 2 +-
r/src/arrowExports.cpp | 66 ++---
r/src/compute.cpp | 80 ++++--
r/tests/testthat/test-dplyr.R | 35 ++-
ruby/red-arrow/lib/arrow/generic-filterable.rb | 10 +-
ruby/red-arrow/lib/arrow/table.rb | 4 +-
ruby/red-arrow/test/test-array.rb | 8 +-
ruby/red-arrow/test/test-chunked-array.rb | 8 +-
ruby/red-arrow/test/test-table.rb | 12 +-
33 files changed, 1073 insertions(+), 540 deletions(-)
diff --git a/c_glib/arrow-glib/compute.cpp b/c_glib/arrow-glib/compute.cpp
index b31855b..3557078 100644
--- a/c_glib/arrow-glib/compute.cpp
+++ b/c_glib/arrow-glib/compute.cpp
@@ -122,6 +122,14 @@ G_DEFINE_TYPE_WITH_PRIVATE(GArrowCastOptions,
GARROW_CAST_OPTIONS(object)))
static void
+garrow_cast_options_finalize(GObject *object)
+{
+ auto priv = GARROW_CAST_OPTIONS_GET_PRIVATE(object);
+ priv->options.~CastOptions();
+ G_OBJECT_CLASS(garrow_cast_options_parent_class)->finalize(object);
+}
+
+static void
garrow_cast_options_set_property(GObject *object,
guint prop_id,
const GValue *value,
@@ -178,6 +186,8 @@ garrow_cast_options_get_property(GObject *object,
static void
garrow_cast_options_init(GArrowCastOptions *object)
{
+ auto priv = GARROW_CAST_OPTIONS_GET_PRIVATE(object);
+ new(&priv->options) arrow::compute::CastOptions;
}
static void
@@ -185,6 +195,7 @@ garrow_cast_options_class_init(GArrowCastOptionsClass
*klass)
{
auto gobject_class = G_OBJECT_CLASS(klass);
+ gobject_class->finalize = garrow_cast_options_finalize;
gobject_class->set_property = garrow_cast_options_set_property;
gobject_class->get_property = garrow_cast_options_get_property;
@@ -279,6 +290,14 @@ G_DEFINE_TYPE_WITH_PRIVATE(GArrowCountOptions,
GARROW_COUNT_OPTIONS(object)))
static void
+garrow_count_options_finalize(GObject *object)
+{
+ auto priv = GARROW_COUNT_OPTIONS_GET_PRIVATE(object);
+ priv->options.~CountOptions();
+ G_OBJECT_CLASS(garrow_count_options_parent_class)->finalize(object);
+}
+
+static void
garrow_count_options_set_property(GObject *object,
guint prop_id,
const GValue *value,
@@ -318,6 +337,8 @@ garrow_count_options_get_property(GObject *object,
static void
garrow_count_options_init(GArrowCountOptions *object)
{
+ auto priv = GARROW_COUNT_OPTIONS_GET_PRIVATE(object);
+ new(&priv->options)
arrow::compute::CountOptions(arrow::compute::CountOptions::COUNT_ALL);
}
static void
@@ -325,6 +346,7 @@ garrow_count_options_class_init(GArrowCountOptionsClass
*klass)
{
auto gobject_class = G_OBJECT_CLASS(klass);
+ gobject_class->finalize = garrow_count_options_finalize;
gobject_class->set_property = garrow_count_options_set_property;
gobject_class->get_property = garrow_count_options_get_property;
@@ -360,6 +382,119 @@ garrow_count_options_new(void)
}
+typedef struct GArrowFilterOptionsPrivate_ {
+ arrow::compute::FilterOptions options;
+} GArrowFilterOptionsPrivate;
+
+enum {
+ PROP_NULL_SELECTION_BEHAVIOR = 1,
+};
+
+G_DEFINE_TYPE_WITH_PRIVATE(GArrowFilterOptions,
+ garrow_filter_options,
+ G_TYPE_OBJECT)
+
+#define GARROW_FILTER_OPTIONS_GET_PRIVATE(object) \
+ static_cast<GArrowFilterOptionsPrivate *>( \
+ garrow_filter_options_get_instance_private( \
+ GARROW_FILTER_OPTIONS(object)))
+
+static void
+garrow_filter_options_finalize(GObject *object)
+{
+ auto priv = GARROW_FILTER_OPTIONS_GET_PRIVATE(object);
+ priv->options.~FilterOptions();
+ G_OBJECT_CLASS(garrow_filter_options_parent_class)->finalize(object);
+}
+
+static void
+garrow_filter_options_set_property(GObject *object,
+ guint prop_id,
+ const GValue *value,
+ GParamSpec *pspec)
+{
+ auto priv = GARROW_FILTER_OPTIONS_GET_PRIVATE(object);
+
+ switch (prop_id) {
+ case PROP_NULL_SELECTION_BEHAVIOR:
+ priv->options.null_selection_behavior =
+
static_cast<arrow::compute::FilterOptions::NullSelectionBehavior>(g_value_get_enum(value));
+ break;
+ default:
+ G_OBJECT_WARN_INVALID_PROPERTY_ID(object, prop_id, pspec);
+ break;
+ }
+}
+
+static void
+garrow_filter_options_get_property(GObject *object,
+ guint prop_id,
+ GValue *value,
+ GParamSpec *pspec)
+{
+ auto priv = GARROW_FILTER_OPTIONS_GET_PRIVATE(object);
+
+ switch (prop_id) {
+ case PROP_NULL_SELECTION_BEHAVIOR:
+ g_value_set_enum(value, priv->options.null_selection_behavior);
+ break;
+ default:
+ G_OBJECT_WARN_INVALID_PROPERTY_ID(object, prop_id, pspec);
+ break;
+ }
+}
+
+static void
+garrow_filter_options_init(GArrowFilterOptions *object)
+{
+ auto priv = GARROW_FILTER_OPTIONS_GET_PRIVATE(object);
+ new(&priv->options) arrow::compute::FilterOptions;
+}
+
+static void
+garrow_filter_options_class_init(GArrowFilterOptionsClass *klass)
+{
+ auto gobject_class = G_OBJECT_CLASS(klass);
+
+ gobject_class->finalize = garrow_filter_options_finalize;
+ gobject_class->set_property = garrow_filter_options_set_property;
+ gobject_class->get_property = garrow_filter_options_get_property;
+
+ arrow::compute::FilterOptions default_options;
+
+ GParamSpec *spec;
+ /**
+ * GArrowFilterOptions:null_selection_behavior:
+ *
+ * How to handle filtered values.
+ *
+ * Since: 1.0.0
+ */
+ spec = g_param_spec_enum("null_selection_behavior",
+ "Null selection behavior",
+ "How to handle filtered values",
+ GARROW_TYPE_FILTER_NULL_SELECTION_BEHAVIOR,
+ static_cast<GArrowFilterNullSelectionBehavior>(
+ default_options.null_selection_behavior),
+ static_cast<GParamFlags>(G_PARAM_READWRITE));
+ g_object_class_install_property(gobject_class, PROP_NULL_SELECTION_BEHAVIOR,
spec);
+}
+
+/**
+ * garrow_filter_options_new:
+ *
+ * Returns: A newly created #GArrowFilterOptions.
+ *
+ * Since: 1.0.0
+ */
+GArrowFilterOptions *
+garrow_filter_options_new(void)
+{
+ auto filter_options = g_object_new(GARROW_TYPE_FILTER_OPTIONS, NULL);
+ return GARROW_FILTER_OPTIONS(filter_options);
+}
+
+
typedef struct GArrowTakeOptionsPrivate_ {
arrow::compute::TakeOptions options;
} GArrowTakeOptionsPrivate;
@@ -374,13 +509,26 @@ G_DEFINE_TYPE_WITH_PRIVATE(GArrowTakeOptions,
GARROW_TAKE_OPTIONS(object)))
static void
+garrow_take_options_finalize(GObject *object)
+{
+ auto priv = GARROW_TAKE_OPTIONS_GET_PRIVATE(object);
+ priv->options.~TakeOptions();
+ G_OBJECT_CLASS(garrow_take_options_parent_class)->finalize(object);
+}
+
+static void
garrow_take_options_init(GArrowTakeOptions *object)
{
+ auto priv = GARROW_TAKE_OPTIONS_GET_PRIVATE(object);
+ new(&priv->options) arrow::compute::TakeOptions;
}
static void
garrow_take_options_class_init(GArrowTakeOptionsClass *klass)
{
+ auto gobject_class = G_OBJECT_CLASS(klass);
+
+ gobject_class->finalize = garrow_take_options_finalize;
}
/**
@@ -416,6 +564,14 @@ G_DEFINE_TYPE_WITH_PRIVATE(GArrowCompareOptions,
GARROW_COMPARE_OPTIONS(object)))
static void
+garrow_compare_options_finalize(GObject *object)
+{
+ auto priv = GARROW_COMPARE_OPTIONS_GET_PRIVATE(object);
+ priv->options.~CompareOptions();
+ G_OBJECT_CLASS(garrow_compare_options_parent_class)->finalize(object);
+}
+
+static void
garrow_compare_options_set_property(GObject *object,
guint prop_id,
const GValue *value,
@@ -455,6 +611,8 @@ garrow_compare_options_get_property(GObject *object,
static void
garrow_compare_options_init(GArrowCompareOptions *object)
{
+ auto priv = GARROW_COMPARE_OPTIONS_GET_PRIVATE(object);
+ new(&priv->options) arrow::compute::CompareOptions(arrow::compute::EQUAL);
}
static void
@@ -462,6 +620,7 @@ garrow_compare_options_class_init(GArrowCompareOptionsClass
*klass)
{
auto gobject_class = G_OBJECT_CLASS(klass);
+ gobject_class->finalize = garrow_compare_options_finalize;
gobject_class->set_property = garrow_compare_options_set_property;
gobject_class->get_property = garrow_compare_options_get_property;
@@ -1679,6 +1838,7 @@ garrow_double_array_compare(GArrowDoubleArray *array,
* garrow_array_filter:
* @array: A #GArrowArray.
* @filter: The values indicates which values should be filtered out.
+ * @options: (nullable): A #GArrowFilterOptions.
* @error: (nullable): Return location for a #GError or %NULL.
*
* Returns: (nullable) (transfer full): The #GArrowArray filterd
@@ -1690,20 +1850,32 @@ garrow_double_array_compare(GArrowDoubleArray *array,
GArrowArray *
garrow_array_filter(GArrowArray *array,
GArrowBooleanArray *filter,
+ GArrowFilterOptions *options,
GError **error)
{
auto arrow_array = garrow_array_get_raw(array);
- auto arrow_array_raw = arrow_array.get();
auto arrow_filter = garrow_array_get_raw(GARROW_ARRAY(filter));
- auto arrow_filter_raw = arrow_filter.get();
auto memory_pool = arrow::default_memory_pool();
arrow::compute::FunctionContext context(memory_pool);
- std::shared_ptr<arrow::Array> arrow_filtered_array;
- auto status = arrow::compute::Filter(&context,
- *arrow_array_raw,
- *arrow_filter_raw,
- &arrow_filtered_array);
+ arrow::compute::Datum arrow_filtered;
+ arrow::Status status;
+ if (options) {
+ auto arrow_options = garrow_filter_options_get_raw(options);
+ status = arrow::compute::Filter(&context,
+ arrow_array,
+ arrow_filter,
+ *arrow_options,
+ &arrow_filtered);
+ } else {
+ arrow::compute::FilterOptions arrow_options;
+ status = arrow::compute::Filter(&context,
+ arrow_array,
+ arrow_filter,
+ arrow_options,
+ &arrow_filtered);
+ }
if (garrow_error_check(error, status, "[array][filter]")) {
+ auto arrow_filtered_array = arrow_filtered.make_array();
return garrow_array_new_raw(&arrow_filtered_array);
} else {
return NULL;
@@ -1815,6 +1987,7 @@ garrow_array_sort_to_indices(GArrowArray *array,
* garrow_table_filter:
* @table: A #GArrowTable.
* @filter: The values indicates which values should be filtered out.
+ * @options: (nullable): A #GArrowFilterOptions.
* @error: (nullable): Return location for a #GError or %NULL.
*
* Returns: (nullable) (transfer full): The #GArrowTable filterd
@@ -1826,18 +1999,32 @@ garrow_array_sort_to_indices(GArrowArray *array,
GArrowTable *
garrow_table_filter(GArrowTable *table,
GArrowBooleanArray *filter,
+ GArrowFilterOptions *options,
GError **error)
{
auto arrow_table = garrow_table_get_raw(table);
auto arrow_filter = garrow_array_get_raw(GARROW_ARRAY(filter));
auto memory_pool = arrow::default_memory_pool();
arrow::compute::FunctionContext context(memory_pool);
- std::shared_ptr<arrow::Table> arrow_filtered_table;
- auto status = arrow::compute::Filter(&context,
- *arrow_table,
- *arrow_filter,
- &arrow_filtered_table);
+ arrow::compute::Datum arrow_filtered;
+ arrow::Status status;
+ if (options) {
+ auto arrow_options = garrow_filter_options_get_raw(options);
+ status = arrow::compute::Filter(&context,
+ arrow_table,
+ arrow_filter,
+ *arrow_options,
+ &arrow_filtered);
+ } else {
+ arrow::compute::FilterOptions arrow_options;
+ status = arrow::compute::Filter(&context,
+ arrow_table,
+ arrow_filter,
+ arrow_options,
+ &arrow_filtered);
+ }
if (garrow_error_check(error, status, "[table][filter]")) {
+ auto arrow_filtered_table = arrow_filtered.table();
return garrow_table_new_raw(&arrow_filtered_table);
} else {
return NULL;
@@ -1848,6 +2035,7 @@ garrow_table_filter(GArrowTable *table,
* garrow_table_filter_chunked_array:
* @table: A #GArrowTable.
* @filter: The values indicates which values should be filtered out.
+ * @options: (nullable): A #GArrowFilterOptions.
* @error: (nullable): Return location for a #GError or %NULL.
*
* Returns: (nullable) (transfer full): The #GArrowTable filterd
@@ -1859,18 +2047,32 @@ garrow_table_filter(GArrowTable *table,
GArrowTable *
garrow_table_filter_chunked_array(GArrowTable *table,
GArrowChunkedArray *filter,
+ GArrowFilterOptions *options,
GError **error)
{
auto arrow_table = garrow_table_get_raw(table);
auto arrow_filter = garrow_chunked_array_get_raw(filter);
auto memory_pool = arrow::default_memory_pool();
arrow::compute::FunctionContext context(memory_pool);
- std::shared_ptr<arrow::Table> arrow_filtered_table;
- auto status = arrow::compute::Filter(&context,
- *arrow_table,
- *arrow_filter,
- &arrow_filtered_table);
+ arrow::compute::Datum arrow_filtered;
+ arrow::Status status;
+ if (options) {
+ auto arrow_options = garrow_filter_options_get_raw(options);
+ status = arrow::compute::Filter(&context,
+ arrow_table,
+ arrow_filter,
+ *arrow_options,
+ &arrow_filtered);
+ } else {
+ arrow::compute::FilterOptions arrow_options;
+ status = arrow::compute::Filter(&context,
+ arrow_table,
+ arrow_filter,
+ arrow_options,
+ &arrow_filtered);
+ }
if (garrow_error_check(error, status, "[table][filter][chunked-array]")) {
+ auto arrow_filtered_table = arrow_filtered.table();
return garrow_table_new_raw(&arrow_filtered_table);
} else {
return NULL;
@@ -1881,6 +2083,7 @@ garrow_table_filter_chunked_array(GArrowTable *table,
* garrow_chunked_array_filter:
* @chunked_array: A #GArrowChunkedArray.
* @filter: The values indicates which values should be filtered out.
+ * @options: (nullable): A #GArrowFilterOptions.
* @error: (nullable): Return location for a #GError or %NULL.
*
* Returns: (nullable) (transfer full): The #GArrowChunkedArray filterd
@@ -1892,6 +2095,7 @@ garrow_table_filter_chunked_array(GArrowTable *table,
GArrowChunkedArray *
garrow_chunked_array_filter(GArrowChunkedArray *chunked_array,
GArrowBooleanArray *filter,
+ GArrowFilterOptions *options,
GError **error)
{
auto arrow_chunked_array =
@@ -1899,12 +2103,25 @@ garrow_chunked_array_filter(GArrowChunkedArray
*chunked_array,
auto arrow_filter = garrow_array_get_raw(GARROW_ARRAY(filter));
auto memory_pool = arrow::default_memory_pool();
arrow::compute::FunctionContext context(memory_pool);
- std::shared_ptr<arrow::ChunkedArray> arrow_filtered_chunked_array;
- auto status = arrow::compute::Filter(&context,
- *arrow_chunked_array,
- *arrow_filter,
- &arrow_filtered_chunked_array);
+ arrow::compute::Datum arrow_filtered;
+ arrow::Status status;
+ if (options) {
+ auto arrow_options = garrow_filter_options_get_raw(options);
+ status = arrow::compute::Filter(&context,
+ arrow_chunked_array,
+ arrow_filter,
+ *arrow_options,
+ &arrow_filtered);
+ } else {
+ arrow::compute::FilterOptions arrow_options;
+ status = arrow::compute::Filter(&context,
+ arrow_chunked_array,
+ arrow_filter,
+ arrow_options,
+ &arrow_filtered);
+ }
if (garrow_error_check(error, status, "[chunked-array][filter]")) {
+ auto arrow_filtered_chunked_array = arrow_filtered.chunked_array();
return garrow_chunked_array_new_raw(&arrow_filtered_chunked_array);
} else {
return NULL;
@@ -1915,6 +2132,7 @@ garrow_chunked_array_filter(GArrowChunkedArray
*chunked_array,
* garrow_chunked_array_filter_chunked_array:
* @chunked_array: A #GArrowChunkedArray.
* @filter: The values indicates which values should be filtered out.
+ * @options: (nullable): A #GArrowFilterOptions.
* @error: (nullable): Return location for a #GError or %NULL.
*
* Returns: (nullable) (transfer full): The #GArrowChunkedArray filterd
@@ -1926,6 +2144,7 @@ garrow_chunked_array_filter(GArrowChunkedArray
*chunked_array,
GArrowChunkedArray *
garrow_chunked_array_filter_chunked_array(GArrowChunkedArray *chunked_array,
GArrowChunkedArray *filter,
+ GArrowFilterOptions *options,
GError **error)
{
auto arrow_chunked_array =
@@ -1933,12 +2152,25 @@
garrow_chunked_array_filter_chunked_array(GArrowChunkedArray *chunked_array,
auto arrow_filter = garrow_chunked_array_get_raw(filter);
auto memory_pool = arrow::default_memory_pool();
arrow::compute::FunctionContext context(memory_pool);
- std::shared_ptr<arrow::ChunkedArray> arrow_filtered_chunked_array;
- auto status = arrow::compute::Filter(&context,
- *arrow_chunked_array,
- *arrow_filter,
- &arrow_filtered_chunked_array);
+ arrow::compute::Datum arrow_filtered;
+ arrow::Status status;
+ if (options) {
+ auto arrow_options = garrow_filter_options_get_raw(options);
+ status = arrow::compute::Filter(&context,
+ arrow_chunked_array,
+ arrow_filter,
+ *arrow_options,
+ &arrow_filtered);
+ } else {
+ arrow::compute::FilterOptions arrow_options;
+ status = arrow::compute::Filter(&context,
+ arrow_chunked_array,
+ arrow_filter,
+ arrow_options,
+ &arrow_filtered);
+ }
if (garrow_error_check(error, status,
"[chunked-array][filter][chunked-array]")) {
+ auto arrow_filtered_chunked_array = arrow_filtered.chunked_array();
return garrow_chunked_array_new_raw(&arrow_filtered_chunked_array);
} else {
return NULL;
@@ -1949,6 +2181,7 @@
garrow_chunked_array_filter_chunked_array(GArrowChunkedArray *chunked_array,
* garrow_record_batch_filter:
* @record_batch: A #GArrowRecordBatch.
* @filter: The values indicates which values should be filtered out.
+ * @options: (nullable): A #GArrowFilterOptions.
* @error: (nullable): Return location for a #GError or %NULL.
*
* Returns: (nullable) (transfer full): The #GArrowRecordBatch filterd
@@ -1960,6 +2193,7 @@
garrow_chunked_array_filter_chunked_array(GArrowChunkedArray *chunked_array,
GArrowRecordBatch *
garrow_record_batch_filter(GArrowRecordBatch *record_batch,
GArrowBooleanArray *filter,
+ GArrowFilterOptions *options,
GError **error)
{
auto arrow_record_batch =
@@ -1967,12 +2201,25 @@ garrow_record_batch_filter(GArrowRecordBatch
*record_batch,
auto arrow_filter = garrow_array_get_raw(GARROW_ARRAY(filter));
auto memory_pool = arrow::default_memory_pool();
arrow::compute::FunctionContext context(memory_pool);
- std::shared_ptr<arrow::RecordBatch> arrow_filtered_record_batch;
- auto status = arrow::compute::Filter(&context,
- *arrow_record_batch,
- *arrow_filter,
- &arrow_filtered_record_batch);
+ arrow::compute::Datum arrow_filtered;
+ arrow::Status status;
+ if (options) {
+ auto arrow_options = garrow_filter_options_get_raw(options);
+ status = arrow::compute::Filter(&context,
+ arrow_record_batch,
+ arrow_filter,
+ *arrow_options,
+ &arrow_filtered);
+ } else {
+ arrow::compute::FilterOptions arrow_options;
+ status = arrow::compute::Filter(&context,
+ arrow_record_batch,
+ arrow_filter,
+ arrow_options,
+ &arrow_filtered);
+ }
if (garrow_error_check(error, status, "[record-batch][filter]")) {
+ auto arrow_filtered_record_batch = arrow_filtered.record_batch();
return garrow_record_batch_new_raw(&arrow_filtered_record_batch);
} else {
return NULL;
@@ -2018,6 +2265,13 @@ garrow_count_options_get_raw(GArrowCountOptions
*count_options)
return &(priv->options);
}
+arrow::compute::FilterOptions *
+garrow_filter_options_get_raw(GArrowFilterOptions *filter_options)
+{
+ auto priv = GARROW_FILTER_OPTIONS_GET_PRIVATE(filter_options);
+ return &(priv->options);
+}
+
arrow::compute::TakeOptions *
garrow_take_options_get_raw(GArrowTakeOptions *take_options)
{
diff --git a/c_glib/arrow-glib/compute.h b/c_glib/arrow-glib/compute.h
index 9354bf2..82175c1 100644
--- a/c_glib/arrow-glib/compute.h
+++ b/c_glib/arrow-glib/compute.h
@@ -68,6 +68,37 @@ GArrowCountOptions *
garrow_count_options_new(void);
+/**
+ * GArrowFilterNullSelectionBehavior:
+ * @GARROW_FILTER_NULL_SELECTION_DROP:
+ * Filtered value will be removed in the output.
+ * @GARROW_FILTER_NULL_SELECTION_EMIT_NULL:
+ * Filtered value will be null in the output.
+ *
+ * They are corresponding to
+ * `arrow::compute::FilterOptions::NullSelectionBehavior` values.
+ */
+typedef enum {
+ GARROW_FILTER_NULL_SELECTION_DROP,
+ GARROW_FILTER_NULL_SELECTION_EMIT_NULL,
+} GArrowFilterNullSelectionBehavior;
+
+#define GARROW_TYPE_FILTER_OPTIONS (garrow_filter_options_get_type())
+G_DECLARE_DERIVABLE_TYPE(GArrowFilterOptions,
+ garrow_filter_options,
+ GARROW,
+ FILTER_OPTIONS,
+ GObject)
+struct _GArrowFilterOptionsClass
+{
+ GObjectClass parent_class;
+};
+
+GARROW_AVAILABLE_IN_1_0
+GArrowFilterOptions *
+garrow_filter_options_new(void);
+
+
#define GARROW_TYPE_TAKE_OPTIONS (garrow_take_options_get_type())
G_DECLARE_DERIVABLE_TYPE(GArrowTakeOptions,
garrow_take_options,
@@ -291,6 +322,7 @@ GARROW_AVAILABLE_IN_0_15
GArrowArray *
garrow_array_filter(GArrowArray *array,
GArrowBooleanArray *filter,
+ GArrowFilterOptions *options,
GError **error);
GARROW_AVAILABLE_IN_0_15
GArrowBooleanArray *
@@ -310,26 +342,31 @@ GARROW_AVAILABLE_IN_1_0
GArrowTable *
garrow_table_filter(GArrowTable *table,
GArrowBooleanArray *filter,
+ GArrowFilterOptions *options,
GError **error);
GARROW_AVAILABLE_IN_1_0
GArrowTable *
garrow_table_filter_chunked_array(GArrowTable *table,
GArrowChunkedArray *filter,
+ GArrowFilterOptions *options,
GError **error);
GARROW_AVAILABLE_IN_1_0
GArrowChunkedArray *
garrow_chunked_array_filter(GArrowChunkedArray *chunked_array,
GArrowBooleanArray *filter,
+ GArrowFilterOptions *options,
GError **error);
GARROW_AVAILABLE_IN_1_0
GArrowChunkedArray *
garrow_chunked_array_filter_chunked_array(GArrowChunkedArray *chunked_array,
GArrowChunkedArray *filter,
+ GArrowFilterOptions *options,
GError **error);
GARROW_AVAILABLE_IN_1_0
GArrowRecordBatch *
garrow_record_batch_filter(GArrowRecordBatch *record_batch,
GArrowBooleanArray *filter,
+ GArrowFilterOptions *options,
GError **error);
G_END_DECLS
diff --git a/c_glib/arrow-glib/compute.hpp b/c_glib/arrow-glib/compute.hpp
index 1251225..fe5022e 100644
--- a/c_glib/arrow-glib/compute.hpp
+++ b/c_glib/arrow-glib/compute.hpp
@@ -31,6 +31,9 @@ garrow_count_options_new_raw(arrow::compute::CountOptions
*arrow_count_options);
arrow::compute::CountOptions *
garrow_count_options_get_raw(GArrowCountOptions *count_options);
+arrow::compute::FilterOptions *
+garrow_filter_options_get_raw(GArrowFilterOptions *filter_options);
+
arrow::compute::TakeOptions *
garrow_take_options_get_raw(GArrowTakeOptions *take_options);
diff --git a/c_glib/test/test-compare.rb b/c_glib/test/test-compare.rb
index dcf18c8..2ffe398 100644
--- a/c_glib/test/test-compare.rb
+++ b/c_glib/test/test-compare.rb
@@ -22,6 +22,13 @@ class TestCompare < Test::Unit::TestCase
@options = Arrow::CompareOptions.new
end
+ sub_test_case("CompareOptions") do
+ def test_default_operator
+ assert_equal(Arrow::CompareOperator::EQUAL,
+ @options.operator)
+ end
+ end
+
sub_test_case("operator") do
def test_equal
@options.operator = :equal
diff --git a/c_glib/test/test-count.rb b/c_glib/test/test-count.rb
index a94853a..36390f8 100644
--- a/c_glib/test/test-count.rb
+++ b/c_glib/test/test-count.rb
@@ -19,6 +19,13 @@ class TestCount < Test::Unit::TestCase
include Helper::Buildable
include Helper::Omittable
+ sub_test_case("CountOptions") do
+ def test_default_mode
+ assert_equal(Arrow::CountMode::ALL,
+ Arrow::CountOptions.new.mode)
+ end
+ end
+
sub_test_case("mode") do
def test_default
assert_equal(2, build_int32_array([1, nil, 3]).count)
diff --git a/c_glib/test/test-filter.rb b/c_glib/test/test-filter.rb
index b099847..5ed0359 100644
--- a/c_glib/test/test-filter.rb
+++ b/c_glib/test/test-filter.rb
@@ -18,11 +18,28 @@
class TestFilter < Test::Unit::TestCase
include Helper::Buildable
+ sub_test_case("FilterOptions") do
+ def test_default_null_selection_behavior
+ assert_equal(Arrow::FilterNullSelectionBehavior::DROP,
+ Arrow::FilterOptions.new.null_selection_behavior)
+ end
+ end
+
sub_test_case("Array") do
+ def setup
+ @filter = build_boolean_array([false, true, true, nil])
+ end
+
def test_filter
- filter = build_boolean_array([false, true, true, nil])
+ assert_equal(build_int16_array([1, 0]),
+ build_int16_array([0, 1, 0, 2]).filter(@filter))
+ end
+
+ def test_filter_emit_null
+ options = Arrow::FilterOptions.new
+ options.null_selection_behavior = :emit_null
assert_equal(build_int16_array([1, 0, nil]),
- build_int16_array([0, 1, 0, 2]).filter(filter))
+ build_int16_array([0, 1, 0, 2]).filter(@filter, options))
end
def test_invalid_array_length
@@ -50,12 +67,25 @@ class TestFilter < Test::Unit::TestCase
def test_filter
filter = build_boolean_array([false, true, nil])
arrays = [
+ build_boolean_array([false]),
+ build_boolean_array([true]),
+ ]
+ filtered_table = Arrow::Table.new(@schema, arrays)
+ assert_equal(filtered_table,
+ @table.filter(filter))
+ end
+
+ def test_filter_emit_null
+ filter = build_boolean_array([false, true, nil])
+ arrays = [
build_boolean_array([false, nil]),
build_boolean_array([true, nil]),
]
filtered_table = Arrow::Table.new(@schema, arrays)
+ options = Arrow::FilterOptions.new
+ options.null_selection_behavior = :emit_null
assert_equal(filtered_table,
- @table.filter(filter))
+ @table.filter(filter, options))
end
def test_filter_chunked_array
@@ -65,12 +95,29 @@ class TestFilter < Test::Unit::TestCase
]
filter = Arrow::ChunkedArray.new(chunks)
arrays = [
+ build_boolean_array([false]),
+ build_boolean_array([true]),
+ ]
+ filtered_table = Arrow::Table.new(@schema, arrays)
+ assert_equal(filtered_table,
+ @table.filter_chunked_array(filter))
+ end
+
+ def test_filter_chunked_array_emit_null
+ chunks = [
+ build_boolean_array([false]),
+ build_boolean_array([true, nil]),
+ ]
+ filter = Arrow::ChunkedArray.new(chunks)
+ arrays = [
build_boolean_array([false, nil]),
build_boolean_array([true, nil]),
]
filtered_table = Arrow::Table.new(@schema, arrays)
+ options = Arrow::FilterOptions.new
+ options.null_selection_behavior = :emit_null
assert_equal(filtered_table,
- @table.filter_chunked_array(filter))
+ @table.filter_chunked_array(filter, options))
end
def test_invalid_array_length
@@ -94,13 +141,25 @@ class TestFilter < Test::Unit::TestCase
filter = build_boolean_array([false, true, nil])
chunks = [
build_boolean_array([false]),
- build_boolean_array([nil]),
]
filtered_chunked_array = Arrow::ChunkedArray.new(chunks)
assert_equal(filtered_chunked_array,
@chunked_array.filter(filter))
end
+ def test_filter_emit_null
+ filter = build_boolean_array([false, true, nil])
+ chunks = [
+ build_boolean_array([false]),
+ build_boolean_array([nil]),
+ ]
+ filtered_chunked_array = Arrow::ChunkedArray.new(chunks)
+ options = Arrow::FilterOptions.new
+ options.null_selection_behavior = :emit_null
+ assert_equal(filtered_chunked_array,
+ @chunked_array.filter(filter, options))
+ end
+
def test_filter_chunked_array
chunks = [
build_boolean_array([false]),
@@ -109,13 +168,29 @@ class TestFilter < Test::Unit::TestCase
filter = Arrow::ChunkedArray.new(chunks)
filtered_chunks = [
build_boolean_array([false]),
- build_boolean_array([nil]),
]
filtered_chunked_array = Arrow::ChunkedArray.new(filtered_chunks)
assert_equal(filtered_chunked_array,
@chunked_array.filter_chunked_array(filter))
end
+ def test_filter_chunked_array_emit_null
+ chunks = [
+ build_boolean_array([false]),
+ build_boolean_array([true, nil]),
+ ]
+ filter = Arrow::ChunkedArray.new(chunks)
+ filtered_chunks = [
+ build_boolean_array([false]),
+ build_boolean_array([nil]),
+ ]
+ filtered_chunked_array = Arrow::ChunkedArray.new(filtered_chunks)
+ options = Arrow::FilterOptions.new
+ options.null_selection_behavior = :emit_null
+ assert_equal(filtered_chunked_array,
+ @chunked_array.filter_chunked_array(filter, options))
+ end
+
def test_invalid_array_length
filter = build_boolean_array([false, true, true, false])
assert_raise(Arrow::Error::Invalid) do
@@ -141,12 +216,25 @@ class TestFilter < Test::Unit::TestCase
def test_filter
filter = build_boolean_array([false, true, nil])
columns = [
+ build_boolean_array([false]),
+ build_boolean_array([true]),
+ ]
+ filtered_record_batch = Arrow::RecordBatch.new(@schema, 1, columns)
+ assert_equal(filtered_record_batch,
+ @record_batch.filter(filter))
+ end
+
+ def test_filter_emit_null
+ filter = build_boolean_array([false, true, nil])
+ columns = [
build_boolean_array([false, nil]),
build_boolean_array([true, nil]),
]
filtered_record_batch = Arrow::RecordBatch.new(@schema, 2, columns)
+ options = Arrow::FilterOptions.new
+ options.null_selection_behavior = :emit_null
assert_equal(filtered_record_batch,
- @record_batch.filter(filter))
+ @record_batch.filter(filter, options))
end
def test_invalid_array_length
diff --git a/cpp/src/arrow/array/diff_test.cc b/cpp/src/arrow/array/diff_test.cc
index a452ec6..2e71333 100644
--- a/cpp/src/arrow/array/diff_test.cc
+++ b/cpp/src/arrow/array/diff_test.cc
@@ -117,6 +117,22 @@ class DiffTest : public ::testing::Test {
ASSERT_ARRAYS_EQUAL(*ArrayFromJSON(int64(), run_lengths_json),
*run_lengths_);
}
+ void BaseAndTargetFromRandomFilter(std::shared_ptr<Array> values,
+ double filter_probability) {
+ compute::Datum out_datum, base_filter, target_filter;
+ do {
+ base_filter = this->rng_.Boolean(values->length(), filter_probability,
0.0);
+ target_filter = this->rng_.Boolean(values->length(), filter_probability,
0.0);
+ } while (base_filter.Equals(target_filter));
+
+ ASSERT_OK(compute::Filter(&ctx_, values, base_filter, {}, &out_datum));
+ base_ = out_datum.make_array();
+
+ ASSERT_OK(compute::Filter(&ctx_, values, target_filter, {}, &out_datum));
+ target_ = out_datum.make_array();
+ }
+
+ compute::FunctionContext ctx_;
random::RandomArrayGenerator rng_;
std::shared_ptr<StructArray> edits_;
std::shared_ptr<Array> base_, target_;
@@ -210,15 +226,10 @@ TYPED_TEST(DiffTestWithNumeric, Basics) {
}
TEST_F(DiffTest, CompareRandomInt64) {
- compute::FunctionContext ctx;
for (auto null_probability : {0.0, 0.25}) {
auto values = this->rng_.Int64(1 << 10, 0, 127, null_probability);
for (const double filter_probability : {0.99, 0.75, 0.5}) {
- auto filter_1 = this->rng_.Boolean(values->length(), filter_probability,
0.0);
- auto filter_2 = this->rng_.Boolean(values->length(), filter_probability,
0.0);
-
- ASSERT_OK(compute::Filter(&ctx, *values, *filter_1, &this->base_));
- ASSERT_OK(compute::Filter(&ctx, *values, *filter_2, &this->target_));
+ this->BaseAndTargetFromRandomFilter(values, filter_probability);
std::stringstream formatted;
this->DoDiffAndFormat(&formatted);
@@ -231,15 +242,10 @@ TEST_F(DiffTest, CompareRandomInt64) {
}
TEST_F(DiffTest, CompareRandomStrings) {
- compute::FunctionContext ctx;
for (auto null_probability : {0.0, 0.25}) {
auto values = this->rng_.StringWithRepeats(1 << 10, 1 << 8, 0, 32,
null_probability);
for (const double filter_probability : {0.99, 0.75, 0.5}) {
- auto filter_1 = this->rng_.Boolean(values->length(), filter_probability,
0.0);
- auto filter_2 = this->rng_.Boolean(values->length(), filter_probability,
0.0);
-
- ASSERT_OK(compute::Filter(&ctx, *values, *filter_1, &this->base_));
- ASSERT_OK(compute::Filter(&ctx, *values, *filter_2, &this->target_));
+ this->BaseAndTargetFromRandomFilter(values, filter_probability);
std::stringstream formatted;
this->DoDiffAndFormat(&formatted);
@@ -614,21 +620,15 @@ TEST_F(DiffTest, CompareRandomStruct) {
auto int32_values = this->rng_.Int32(length, 0, 127, null_probability);
auto utf8_values = this->rng_.String(length, 0, 16, null_probability);
for (const double filter_probability : {0.9999, 0.75}) {
- std::shared_ptr<Array> int32_base, int32_target, utf8_base, utf8_target;
- ASSERT_OK(compute::Filter(&ctx, *int32_values,
- *this->rng_.Boolean(length,
filter_probability, 0.0),
- &int32_base));
- ASSERT_OK(compute::Filter(&ctx, *utf8_values,
- *this->rng_.Boolean(length,
filter_probability, 0.0),
- &utf8_base));
- MakeSameLength(&int32_base, &utf8_base);
+ this->BaseAndTargetFromRandomFilter(int32_values, filter_probability);
+ auto int32_base = this->base_;
+ auto int32_target = this->base_;
- ASSERT_OK(compute::Filter(&ctx, *int32_values,
- *this->rng_.Boolean(length,
filter_probability, 0.0),
- &int32_target));
- ASSERT_OK(compute::Filter(&ctx, *utf8_values,
- *this->rng_.Boolean(length,
filter_probability, 0.0),
- &utf8_target));
+ this->BaseAndTargetFromRandomFilter(utf8_values, filter_probability);
+ auto utf8_base = this->base_;
+ auto utf8_target = this->base_;
+
+ MakeSameLength(&int32_base, &utf8_base);
MakeSameLength(&int32_target, &utf8_target);
auto type = struct_({field("i", int32()), field("s", utf8())});
diff --git a/cpp/src/arrow/compute/kernel.h b/cpp/src/arrow/compute/kernel.h
index 1169eb4..16dca69 100644
--- a/cpp/src/arrow/compute/kernel.h
+++ b/cpp/src/arrow/compute/kernel.h
@@ -219,6 +219,19 @@ struct ARROW_EXPORT Datum {
return kUnknownLength;
}
+ /// \brief The array chunks of the variant, if any
+ ///
+ /// \return empty if not arraylike
+ ArrayVector chunks() const {
+ if (!this->is_arraylike()) {
+ return {};
+ }
+ if (this->is_array()) {
+ return {this->make_array()};
+ }
+ return this->chunked_array()->chunks();
+ }
+
bool Equals(const Datum& other) const {
if (this->kind() != other.kind()) return false;
diff --git a/cpp/src/arrow/compute/kernels/filter.cc
b/cpp/src/arrow/compute/kernels/filter.cc
index 1b2ba31..4c4919a 100644
--- a/cpp/src/arrow/compute/kernels/filter.cc
+++ b/cpp/src/arrow/compute/kernels/filter.cc
@@ -38,6 +38,7 @@ using internal::checked_pointer_cast;
// IndexSequence which yields the indices of positions in a BooleanArray
// which are either null or true
+template <FilterOptions::NullSelectionBehavior NullSelectionBehavior>
class FilterIndexSequence {
public:
// constexpr so we'll never instantiate bounds checking
@@ -50,6 +51,14 @@ class FilterIndexSequence {
: filter_(&filter), out_length_(out_length) {}
std::pair<int64_t, bool> Next() {
+ if (NullSelectionBehavior == FilterOptions::DROP) {
+ // skip until an index is found at which the filter is true
+ while (filter_->IsNull(index_) || !filter_->Value(index_)) {
+ ++index_;
+ }
+ return std::make_pair(index_++, true);
+ }
+
// skip until an index is found at which the filter is either null or true
while (filter_->IsValid(index_) && !filter_->Value(index_)) {
++index_;
@@ -60,194 +69,188 @@ class FilterIndexSequence {
int64_t length() const { return out_length_; }
- int64_t null_count() const { return filter_->null_count(); }
+ int64_t null_count() const {
+ if (NullSelectionBehavior == FilterOptions::DROP) {
+ return 0;
+ }
+ return filter_->null_count();
+ }
private:
const BooleanArray* filter_ = nullptr;
int64_t index_ = 0, out_length_ = -1;
};
-// TODO(bkietz) this can be optimized
-static int64_t OutputSize(const BooleanArray& filter) {
+static int64_t OutputSize(FilterOptions options, const BooleanArray& filter) {
+ // TODO(bkietz) this can be optimized. Use Bitmap::VisitWords
int64_t size = 0;
- for (auto i = 0; i < filter.length(); ++i) {
- if (filter.IsNull(i) || filter.Value(i)) {
- ++size;
+ if (options.null_selection_behavior == FilterOptions::EMIT_NULL) {
+ for (auto i = 0; i < filter.length(); ++i) {
+ if (filter.IsNull(i) || filter.Value(i)) {
+ ++size;
+ }
+ }
+ } else {
+ for (auto i = 0; i < filter.length(); ++i) {
+ if (filter.IsValid(i) && filter.Value(i)) {
+ ++size;
+ }
}
}
return size;
}
-static Result<std::shared_ptr<BooleanArray>> GetFilterArray(const Datum&
filter) {
- auto filter_type = filter.type();
- if (filter_type->id() != Type::BOOL) {
- return Status::TypeError("filter array must be of boolean type, got ",
*filter_type);
+static Status CheckFilterType(const std::shared_ptr<DataType>& type) {
+ if (type->id() != Type::BOOL) {
+ return Status::TypeError("filter array must be of boolean type, got ",
*type);
}
- return checked_pointer_cast<BooleanArray>(filter.make_array());
+ return Status::OK();
+}
+
+static Status CheckFilterValuesLengths(int64_t values, int64_t filter) {
+ if (values != filter) {
+ return Status::Invalid("filter and value array must have identical
lengths");
+ }
+ return Status::OK();
}
+template <typename IndexSequence>
class FilterKernelImpl : public FilterKernel {
public:
- FilterKernelImpl(const std::shared_ptr<DataType>& type,
- std::unique_ptr<Taker<FilterIndexSequence>> taker)
- : FilterKernel(type), taker_(std::move(taker)) {}
+ FilterKernelImpl(std::shared_ptr<DataType> type,
+ std::unique_ptr<Taker<IndexSequence>> taker, FilterOptions
options)
+ : FilterKernel(std::move(type), options), taker_(std::move(taker)) {}
Status Filter(FunctionContext* ctx, const Array& values, const BooleanArray&
filter,
int64_t out_length, std::shared_ptr<Array>* out) override {
- if (values.length() != filter.length()) {
- return Status::Invalid("filter and value array must have identical
lengths");
- }
+ RETURN_NOT_OK(CheckFilterValuesLengths(values.length(), filter.length()));
+
RETURN_NOT_OK(taker_->SetContext(ctx));
- RETURN_NOT_OK(taker_->Take(values, FilterIndexSequence(filter,
out_length)));
+ RETURN_NOT_OK(taker_->Take(values, IndexSequence(filter, out_length)));
return taker_->Finish(out);
}
- std::unique_ptr<Taker<FilterIndexSequence>> taker_;
+ static Status Make(std::shared_ptr<DataType> value_type, FilterOptions
options,
+ std::unique_ptr<FilterKernel>* out) {
+ std::unique_ptr<Taker<IndexSequence>> taker;
+ RETURN_NOT_OK(Taker<IndexSequence>::Make(value_type, &taker));
+
+ out->reset(new FilterKernelImpl(std::move(value_type), std::move(taker),
options));
+ return Status::OK();
+ }
+
+ std::unique_ptr<Taker<IndexSequence>> taker_;
};
-Status FilterKernel::Make(const std::shared_ptr<DataType>& value_type,
+Status FilterKernel::Make(std::shared_ptr<DataType> value_type, FilterOptions
options,
std::unique_ptr<FilterKernel>* out) {
- std::unique_ptr<Taker<FilterIndexSequence>> taker;
- RETURN_NOT_OK(Taker<FilterIndexSequence>::Make(value_type, &taker));
-
- out->reset(new FilterKernelImpl(value_type, std::move(taker)));
- return Status::OK();
+ if (options.null_selection_behavior == FilterOptions::EMIT_NULL) {
+ return
FilterKernelImpl<FilterIndexSequence<FilterOptions::EMIT_NULL>>::Make(
+ std::move(value_type), options, out);
+ }
+ return FilterKernelImpl<FilterIndexSequence<FilterOptions::DROP>>::Make(
+ std::move(value_type), options, out);
}
Status FilterKernel::Call(FunctionContext* ctx, const Datum& values, const
Datum& filter,
Datum* out) {
- if (!values.is_array() || !filter.is_array()) {
- return Status::Invalid("FilterKernel::Call expects array values and
filter");
+ if (!values.is_arraylike() || !filter.is_arraylike()) {
+ return Status::Invalid("FilterKernel::Call expects array-like values and
filter");
}
- auto values_array = values.make_array();
- ARROW_ASSIGN_OR_RAISE(auto filter_array, GetFilterArray(filter));
- std::shared_ptr<Array> out_array;
- RETURN_NOT_OK(this->Filter(ctx, *values_array, *filter_array,
OutputSize(*filter_array),
- &out_array));
- *out = out_array;
- return Status::OK();
-}
+ RETURN_NOT_OK(CheckFilterType(filter.type()));
+ RETURN_NOT_OK(CheckFilterValuesLengths(values.length(), filter.length()));
-Status Filter(FunctionContext* ctx, const Array& values, const Array& filter,
- std::shared_ptr<Array>* out) {
- Datum out_datum;
- RETURN_NOT_OK(Filter(ctx, Datum(values.data()), Datum(filter.data()),
&out_datum));
- *out = out_datum.make_array();
+ auto chunks = internal::RechunkArraysConsistently({values.chunks(),
filter.chunks()});
+ auto value_chunks = std::move(chunks[0]);
+ auto filter_chunks = std::move(chunks[1]);
+
+ for (size_t i = 0; i < value_chunks.size(); ++i) {
+ auto filter_chunk = checked_pointer_cast<BooleanArray>(filter_chunks[i]);
+ RETURN_NOT_OK(this->Filter(ctx, *value_chunks[i], *filter_chunk,
+ OutputSize(options_, *filter_chunk),
&value_chunks[i]));
+ }
+
+ if (values.is_array() && filter.is_array()) {
+ *out = std::move(value_chunks[0]);
+ } else {
+ // drop empty chunks
+ value_chunks.erase(
+ std::remove_if(value_chunks.begin(), value_chunks.end(),
+ [](const std::shared_ptr<Array>& a) { return
a->length() == 0; }),
+ value_chunks.end());
+
+ *out = std::make_shared<ChunkedArray>(std::move(value_chunks),
values.type());
+ }
return Status::OK();
}
-Status Filter(FunctionContext* ctx, const Datum& values, const Datum& filter,
- Datum* out) {
- std::unique_ptr<FilterKernel> kernel;
- RETURN_NOT_OK(FilterKernel::Make(values.type(), &kernel));
- return kernel->Call(ctx, values, filter, out);
+Status FilterTable(FunctionContext* ctx, const Table& table, const Datum&
filter,
+ FilterOptions options, std::shared_ptr<Table>* out) {
+ auto new_columns = table.columns();
+
+ for (auto& column : new_columns) {
+ Datum out_column;
+ RETURN_NOT_OK(Filter(ctx, Datum(column), filter, options, &out_column));
+ column = out_column.chunked_array();
+ }
+
+ *out = Table::Make(table.schema(), std::move(new_columns));
+ return Status::OK();
}
-Status Filter(FunctionContext* ctx, const RecordBatch& batch, const Array&
filter,
- std::shared_ptr<RecordBatch>* out) {
- ARROW_ASSIGN_OR_RAISE(auto filter_array,
GetFilterArray(Datum(filter.data())));
+Status FilterRecordBatch(FunctionContext* ctx, const RecordBatch& batch,
+ const Array& filter, FilterOptions options,
+ std::shared_ptr<RecordBatch>* out) {
+ RETURN_NOT_OK(CheckFilterType(filter.type()));
+ const auto& filter_array = checked_cast<const BooleanArray&>(filter);
std::vector<std::unique_ptr<FilterKernel>> kernels(batch.num_columns());
for (int i = 0; i < batch.num_columns(); ++i) {
- RETURN_NOT_OK(FilterKernel::Make(batch.schema()->field(i)->type(),
&kernels[i]));
+ RETURN_NOT_OK(
+ FilterKernel::Make(batch.schema()->field(i)->type(), options,
&kernels[i]));
}
std::vector<std::shared_ptr<Array>> columns(batch.num_columns());
- auto out_length = OutputSize(*filter_array);
+ auto out_length = OutputSize(options, filter_array);
for (int i = 0; i < batch.num_columns(); ++i) {
- RETURN_NOT_OK(kernels[i]->Filter(ctx, *batch.column(i), *filter_array,
out_length,
- &columns[i]));
+ RETURN_NOT_OK(
+ kernels[i]->Filter(ctx, *batch.column(i), filter_array, out_length,
&columns[i]));
}
*out = RecordBatch::Make(batch.schema(), out_length, columns);
return Status::OK();
}
-Status Filter(FunctionContext* ctx, const ChunkedArray& values, const Array&
filter,
- std::shared_ptr<ChunkedArray>* out) {
- if (values.length() != filter.length()) {
- return Status::Invalid("filter and value array must have identical
lengths");
- }
- auto num_chunks = values.num_chunks();
- std::vector<std::shared_ptr<Array>> new_chunks(num_chunks);
- std::shared_ptr<Array> current_chunk;
- int64_t offset = 0;
- int64_t len;
+Status Filter(FunctionContext* ctx, const Datum& values, const Datum& filter,
+ FilterOptions options, Datum* out) {
+ if (values.kind() == Datum::RECORD_BATCH) {
+ if (!filter.is_array()) {
+ return Status::Invalid("Cannot filter a RecordBatch with a filter of
kind ",
+ filter.kind());
+ }
- for (int i = 0; i < num_chunks; i++) {
- current_chunk = values.chunk(i);
- len = current_chunk->length();
+ auto values_batch = values.record_batch();
+ auto filter_array = filter.make_array();
+ std::shared_ptr<RecordBatch> out_batch;
RETURN_NOT_OK(
- Filter(ctx, *current_chunk, *filter.Slice(offset, len),
&new_chunks[i]));
- offset += len;
+ FilterRecordBatch(ctx, *values_batch, *filter_array, options,
&out_batch));
+ *out = std::move(out_batch);
+ return Status::OK();
}
- *out = std::make_shared<ChunkedArray>(std::move(new_chunks));
- return Status::OK();
-}
+ if (values.kind() == Datum::TABLE) {
+ auto values_table = values.table();
-Status Filter(FunctionContext* ctx, const ChunkedArray& values,
- const ChunkedArray& filter, std::shared_ptr<ChunkedArray>* out) {
- if (values.length() != filter.length()) {
- return Status::Invalid("filter and value array must have identical
lengths");
- }
- auto num_chunks = values.num_chunks();
- std::vector<std::shared_ptr<Array>> new_chunks(num_chunks);
- std::shared_ptr<Array> current_chunk;
- std::shared_ptr<ChunkedArray> current_chunked_filter;
- std::shared_ptr<Array> current_filter;
- int64_t offset = 0;
- int64_t len;
-
- for (int i = 0; i < num_chunks; i++) {
- current_chunk = values.chunk(i);
- len = current_chunk->length();
- if (len > 0) {
- current_chunked_filter = filter.Slice(offset, len);
- if (current_chunked_filter->num_chunks() == 1) {
- current_filter = current_chunked_filter->chunk(0);
- } else {
- // Concatenate the chunks of the filter so we have an Array
- RETURN_NOT_OK(Concatenate(current_chunked_filter->chunks(),
default_memory_pool(),
- ¤t_filter));
- }
- RETURN_NOT_OK(Filter(ctx, *current_chunk, *current_filter,
&new_chunks[i]));
- offset += len;
- } else {
- // Put a zero length array there, which we know our current chunk to be
- new_chunks[i] = current_chunk;
- }
+ std::shared_ptr<Table> out_table;
+ RETURN_NOT_OK(FilterTable(ctx, *values_table, filter, options,
&out_table));
+ *out = std::move(out_table);
+ return Status::OK();
}
- *out = std::make_shared<ChunkedArray>(std::move(new_chunks));
- return Status::OK();
-}
-
-Status Filter(FunctionContext* ctx, const Table& table, const Array& filter,
- std::shared_ptr<Table>* out) {
- auto ncols = table.num_columns();
-
- std::vector<std::shared_ptr<ChunkedArray>> columns(ncols);
-
- for (int j = 0; j < ncols; j++) {
- RETURN_NOT_OK(Filter(ctx, *table.column(j), filter, &columns[j]));
- }
- *out = Table::Make(table.schema(), columns);
- return Status::OK();
-}
-
-Status Filter(FunctionContext* ctx, const Table& table, const ChunkedArray&
filter,
- std::shared_ptr<Table>* out) {
- auto ncols = table.num_columns();
-
- std::vector<std::shared_ptr<ChunkedArray>> columns(ncols);
-
- for (int j = 0; j < ncols; j++) {
- RETURN_NOT_OK(Filter(ctx, *table.column(j), filter, &columns[j]));
- }
- *out = Table::Make(table.schema(), columns);
- return Status::OK();
+ std::unique_ptr<FilterKernel> kernel;
+ RETURN_NOT_OK(FilterKernel::Make(values.type(), options, &kernel));
+ return kernel->Call(ctx, values, filter, out);
}
} // namespace compute
diff --git a/cpp/src/arrow/compute/kernels/filter.h
b/cpp/src/arrow/compute/kernels/filter.h
index bc7f75d..65098b3 100644
--- a/cpp/src/arrow/compute/kernels/filter.h
+++ b/cpp/src/arrow/compute/kernels/filter.h
@@ -18,6 +18,7 @@
#pragma once
#include <memory>
+#include <utility>
#include "arrow/compute/kernel.h"
#include "arrow/record_batch.h"
@@ -32,120 +33,42 @@ namespace compute {
class FunctionContext;
-/// \brief Filter an array with a boolean selection filter
-///
-/// The output array will be populated with values from the input at positions
-/// where the selection filter is not 0. Nulls in the filter will result in
nulls
-/// in the output.
-///
-/// For example given values = ["a", "b", "c", null, "e", "f"] and
-/// filter = [0, 1, 1, 0, null, 1], the output will be
-/// = ["b", "c", null, "f"]
-///
-/// \param[in] ctx the FunctionContext
-/// \param[in] values array to filter
-/// \param[in] filter indicates which values should be filtered out
-/// \param[out] out resulting array
-ARROW_EXPORT
-Status Filter(FunctionContext* ctx, const Array& values, const Array& filter,
- std::shared_ptr<Array>* out);
+struct FilterOptions {
+ /// Configure the action taken when a slot of the selection mask is null
+ enum NullSelectionBehavior {
+ /// the corresponding filtered value will be removed in the output
+ DROP,
+ /// the corresponding filtered value will be null in the output
+ EMIT_NULL,
+ };
-/// \brief Filter a chunked array with a boolean selection filter
-///
-/// The output chunked array will be populated with values from the input at
positions
-/// where the selection filter is not 0. Nulls in the filter will result in
nulls
-/// in the output.
-///
-/// For example given values = ["a", "b", "c", null, "e", "f"] and
-/// filter = [0, 1, 1, 0, null, 1], the output will be
-/// = ["b", "c", null, "f"]
-///
-/// \param[in] ctx the FunctionContext
-/// \param[in] values chunked array to filter
-/// \param[in] filter indicates which values should be filtered out
-/// \param[out] out resulting chunked array
-/// NOTE: Experimental API
-ARROW_EXPORT
-Status Filter(FunctionContext* ctx, const ChunkedArray& values, const Array&
filter,
- std::shared_ptr<ChunkedArray>* out);
+ NullSelectionBehavior null_selection_behavior = DROP;
+};
-/// \brief Filter a chunked array with a boolean selection filter
+/// \brief Filter with a boolean selection filter
///
-/// The output chunked array will be populated with values from the input at
positions
-/// where the selection filter is not 0. Nulls in the filter will result in
nulls
-/// in the output.
+/// The output will be populated with values from the input at positions
+/// where the selection filter is not 0. Nulls in the filter will be handled
+/// based on options.null_selection_behavior.
///
/// For example given values = ["a", "b", "c", null, "e", "f"] and
/// filter = [0, 1, 1, 0, null, 1], the output will be
-/// = ["b", "c", null, "f"]
+/// (null_selection_behavior == DROP) = ["b", "c", "f"]
+/// (null_selection_behavior == EMIT_NULL) = ["b", "c", null, "f"]
///
/// \param[in] ctx the FunctionContext
-/// \param[in] values chunked array to filter
-/// \param[in] filter indicates which values should be filtered out
-/// \param[out] out resulting chunked array
-/// NOTE: Experimental API
-ARROW_EXPORT
-Status Filter(FunctionContext* ctx, const ChunkedArray& values,
- const ChunkedArray& filter, std::shared_ptr<ChunkedArray>* out);
-
-/// \brief Filter a record batch with a boolean selection filter
-///
-/// The output record batch's columns will be populated with values from
corresponding
-/// columns of the input at positions where the selection filter is not 0.
Nulls in the
-/// filter will result in nulls in the output.
-///
-/// \param[in] ctx the FunctionContext
-/// \param[in] batch record batch to filter
-/// \param[in] filter indicates which values should be filtered out
-/// \param[out] out resulting record batch
-/// NOTE: Experimental API
-ARROW_EXPORT
-Status Filter(FunctionContext* ctx, const RecordBatch& batch, const Array&
filter,
- std::shared_ptr<RecordBatch>* out);
-
-/// \brief Filter a table with a boolean selection filter
-///
-/// The output table's columns will be populated with values from corresponding
-/// columns of the input at positions where the selection filter is not 0.
Nulls in the
-/// filter will result in nulls in each column of the output.
-///
-/// \param[in] ctx the FunctionContext
-/// \param[in] table table to filter
-/// \param[in] filter indicates which values should be filtered out
-/// \param[out] out resulting table
-/// NOTE: Experimental API
-ARROW_EXPORT
-Status Filter(FunctionContext* ctx, const Table& table, const Array& filter,
- std::shared_ptr<Table>* out);
-
-/// \brief Filter a table with a boolean selection filter
-///
-/// The output record batch's columns will be populated with values from
corresponding
-/// columns of the input at positions where the selection filter is not 0.
Nulls in the
-/// filter will result in nulls in the output.
-///
-/// \param[in] ctx the FunctionContext
-/// \param[in] table record batch to filter
-/// \param[in] filter indicates which values should be filtered out
-/// \param[out] out resulting record batch
-/// NOTE: Experimental API
-ARROW_EXPORT
-Status Filter(FunctionContext* ctx, const Table& table, const ChunkedArray&
filter,
- std::shared_ptr<Table>* out);
-
-/// \brief Filter an array with a boolean selection filter
-///
-/// \param[in] ctx the FunctionContext
-/// \param[in] values datum to filter
+/// \param[in] values array to filter
/// \param[in] filter indicates which values should be filtered out
-/// \param[out] out resulting datum
+/// \param[in] options configures null_selection_behavior
+/// \param[out] out resulting array
ARROW_EXPORT
-Status Filter(FunctionContext* ctx, const Datum& values, const Datum& filter,
Datum* out);
+Status Filter(FunctionContext* ctx, const Datum& values, const Datum& filter,
+ FilterOptions options, Datum* out);
/// \brief BinaryKernel implementing Filter operation
class ARROW_EXPORT FilterKernel : public BinaryKernel {
public:
- explicit FilterKernel(const std::shared_ptr<DataType>& type) : type_(type) {}
+ const FilterOptions& options() const { return options_; }
/// \brief BinaryKernel interface
///
@@ -161,7 +84,7 @@ class ARROW_EXPORT FilterKernel : public BinaryKernel {
/// \param[in] value_type constructed FilterKernel will support filtering
/// values of this type
/// \param[out] out created kernel
- static Status Make(const std::shared_ptr<DataType>& value_type,
+ static Status Make(std::shared_ptr<DataType> value_type, FilterOptions
options,
std::unique_ptr<FilterKernel>* out);
/// \brief single-array implementation
@@ -170,7 +93,11 @@ class ARROW_EXPORT FilterKernel : public BinaryKernel {
std::shared_ptr<Array>* out) = 0;
protected:
+ explicit FilterKernel(std::shared_ptr<DataType> type, FilterOptions options)
+ : type_(std::move(type)), options_(options) {}
+
std::shared_ptr<DataType> type_;
+ FilterOptions options_;
};
} // namespace compute
diff --git a/cpp/src/arrow/compute/kernels/filter_test.cc
b/cpp/src/arrow/compute/kernels/filter_test.cc
index b270f3b..a75d740 100644
--- a/cpp/src/arrow/compute/kernels/filter_test.cc
+++ b/cpp/src/arrow/compute/kernels/filter_test.cc
@@ -38,63 +38,102 @@ using util::string_view;
constexpr auto kSeed = 0x0ff1ce;
+std::shared_ptr<Array> CoalesceNullToFalse(FunctionContext* ctx,
+ std::shared_ptr<Array> filter) {
+ if (filter->null_count() == 0) {
+ return filter;
+ }
+ const auto& data = *filter->data();
+ auto is_true = std::make_shared<BooleanArray>(data.length, data.buffers[1]);
+ auto is_valid = std::make_shared<BooleanArray>(data.length, data.buffers[0]);
+ Datum out_datum;
+ ARROW_EXPECT_OK(arrow::compute::And(ctx, is_true, is_valid, &out_datum));
+ return out_datum.make_array();
+}
+
template <typename ArrowType>
class TestFilterKernel : public ComputeFixture, public TestBase {
protected:
- void AssertFilterArrays(const std::shared_ptr<Array>& values,
- const std::shared_ptr<Array>& filter,
- const std::shared_ptr<Array>& expected) {
- std::shared_ptr<Array> actual;
- ASSERT_OK(arrow::compute::Filter(&this->ctx_, *values, *filter, &actual));
- ASSERT_OK(actual->ValidateFull());
- AssertArraysEqual(*expected, *actual);
+ TestFilterKernel() {
+ emit_null_.null_selection_behavior = FilterOptions::EMIT_NULL;
+ drop_.null_selection_behavior = FilterOptions::DROP;
}
- void AssertFilter(const std::shared_ptr<DataType>& type, const std::string&
values,
- const std::string& filter, const std::string& expected) {
- std::shared_ptr<Array> actual;
- ASSERT_OK(this->Filter(type, values, filter, &actual));
+ void AssertFilter(std::shared_ptr<Array> values, std::shared_ptr<Array>
filter,
+ std::shared_ptr<Array> expected) {
+ // test with EMIT_NULL
+ Datum out_datum;
+ ASSERT_OK(
+ arrow::compute::Filter(&this->ctx_, values, filter, emit_null_,
&out_datum));
+ auto actual = out_datum.make_array();
ASSERT_OK(actual->ValidateFull());
- AssertArraysEqual(*ArrayFromJSON(type, expected), *actual);
+ AssertArraysEqual(*expected, *actual);
+
+ // test with DROP using EMIT_NULL and a coalesced filter
+ auto coalesced_filter = CoalesceNullToFalse(&this->ctx_, filter);
+ ASSERT_OK(arrow::compute::Filter(&this->ctx_, values, coalesced_filter,
emit_null_,
+ &out_datum));
+ expected = out_datum.make_array();
+ ASSERT_OK(arrow::compute::Filter(&this->ctx_, values, filter, drop_,
&out_datum));
+ actual = out_datum.make_array();
+ AssertArraysEqual(*expected, *actual);
}
- Status Filter(const std::shared_ptr<DataType>& type, const std::string&
values,
- const std::string& filter, std::shared_ptr<Array>* out) {
- return arrow::compute::Filter(&this->ctx_, *ArrayFromJSON(type, values),
- *ArrayFromJSON(boolean(), filter), out);
+ void AssertFilter(std::shared_ptr<DataType> type, const std::string& values,
+ const std::string& filter, const std::string& expected) {
+ AssertFilter(ArrayFromJSON(type, values), ArrayFromJSON(boolean(), filter),
+ ArrayFromJSON(type, expected));
}
void ValidateFilter(const std::shared_ptr<Array>& values,
const std::shared_ptr<Array>& filter_boxed) {
- std::shared_ptr<Array> filtered;
- ASSERT_OK(arrow::compute::Filter(&this->ctx_, *values, *filter_boxed,
&filtered));
- ASSERT_OK(filtered->ValidateFull());
+ Datum out_datum;
+ ASSERT_OK(arrow::compute::Filter(&this->ctx_, values, filter_boxed,
emit_null_,
+ &out_datum));
+ auto filtered_emit_null = out_datum.make_array();
+ ASSERT_OK(filtered_emit_null->ValidateFull());
+
+ ASSERT_OK(
+ arrow::compute::Filter(&this->ctx_, values, filter_boxed, drop_,
&out_datum));
+ auto filtered_drop = out_datum.make_array();
+ ASSERT_OK(filtered_drop->ValidateFull());
auto filter = checked_pointer_cast<BooleanArray>(filter_boxed);
- int64_t values_i = 0, filtered_i = 0;
- for (; values_i < values->length(); ++values_i, ++filtered_i) {
+ int64_t values_i = 0, emit_null_i = 0, drop_i = 0;
+ for (; values_i < values->length(); ++values_i, ++emit_null_i, ++drop_i) {
if (filter->IsNull(values_i)) {
- ASSERT_LT(filtered_i, filtered->length());
- ASSERT_TRUE(filtered->IsNull(filtered_i));
+ ASSERT_LT(emit_null_i, filtered_emit_null->length());
+ ASSERT_TRUE(filtered_emit_null->IsNull(emit_null_i));
+ // this element was (null) filtered out; don't examine filtered_drop
+ --drop_i;
continue;
}
if (!filter->Value(values_i)) {
- // this element was filtered out; don't examine filtered
- --filtered_i;
+ // this element was filtered out; don't examine filtered_emit_null
+ --emit_null_i;
+ --drop_i;
continue;
}
- ASSERT_LT(filtered_i, filtered->length());
- ASSERT_TRUE(values->RangeEquals(values_i, values_i + 1, filtered_i,
filtered));
+ ASSERT_LT(emit_null_i, filtered_emit_null->length());
+ ASSERT_LT(drop_i, filtered_drop->length());
+ ASSERT_TRUE(
+ values->RangeEquals(values_i, values_i + 1, emit_null_i,
filtered_emit_null));
+ ASSERT_TRUE(values->RangeEquals(values_i, values_i + 1, drop_i,
filtered_drop));
}
- ASSERT_EQ(filtered_i, filtered->length());
+ ASSERT_EQ(emit_null_i, filtered_emit_null->length());
+ ASSERT_EQ(drop_i, filtered_drop->length());
}
+
+ FilterOptions emit_null_, drop_;
};
class TestFilterKernelWithNull : public TestFilterKernel<NullType> {
protected:
void AssertFilter(const std::string& values, const std::string& filter,
const std::string& expected) {
- TestFilterKernel<NullType>::AssertFilter(null(), values, filter, expected);
+ TestFilterKernel<NullType>::AssertFilter(ArrayFromJSON(null(), values),
+ ArrayFromJSON(boolean(), filter),
+ ArrayFromJSON(null(), expected));
}
};
@@ -109,7 +148,9 @@ class TestFilterKernelWithBoolean : public
TestFilterKernel<BooleanType> {
protected:
void AssertFilter(const std::string& values, const std::string& filter,
const std::string& expected) {
- TestFilterKernel<BooleanType>::AssertFilter(boolean(), values, filter,
expected);
+ TestFilterKernel<BooleanType>::AssertFilter(ArrayFromJSON(boolean(),
values),
+ ArrayFromJSON(boolean(),
filter),
+ ArrayFromJSON(boolean(),
expected));
}
};
@@ -124,11 +165,6 @@ TEST_F(TestFilterKernelWithBoolean, FilterBoolean) {
template <typename ArrowType>
class TestFilterKernelWithNumeric : public TestFilterKernel<ArrowType> {
protected:
- void AssertFilter(const std::string& values, const std::string& filter,
- const std::string& expected) {
- TestFilterKernel<ArrowType>::AssertFilter(type_singleton(), values,
filter, expected);
- }
-
std::shared_ptr<DataType> type_singleton() {
return TypeTraits<ArrowType>::type_singleton();
}
@@ -136,27 +172,34 @@ class TestFilterKernelWithNumeric : public
TestFilterKernel<ArrowType> {
TYPED_TEST_SUITE(TestFilterKernelWithNumeric, NumericArrowTypes);
TYPED_TEST(TestFilterKernelWithNumeric, FilterNumeric) {
- this->AssertFilter("[]", "[]", "[]");
-
- this->AssertFilter("[9]", "[0]", "[]");
- this->AssertFilter("[9]", "[1]", "[9]");
- this->AssertFilter("[9]", "[null]", "[null]");
- this->AssertFilter("[null]", "[0]", "[]");
- this->AssertFilter("[null]", "[1]", "[null]");
- this->AssertFilter("[null]", "[null]", "[null]");
-
- this->AssertFilter("[7, 8, 9]", "[0, 1, 0]", "[8]");
- this->AssertFilter("[7, 8, 9]", "[1, 0, 1]", "[7, 9]");
- this->AssertFilter("[null, 8, 9]", "[0, 1, 0]", "[8]");
- this->AssertFilter("[7, 8, 9]", "[null, 1, 0]", "[null, 8]");
- this->AssertFilter("[7, 8, 9]", "[1, null, 1]", "[7, null, 9]");
-
- this->AssertFilterArrays(ArrayFromJSON(this->type_singleton(), "[7, 8, 9]"),
- ArrayFromJSON(boolean(), "[0, 1, 1, 1, 0,
1]")->Slice(3, 3),
- ArrayFromJSON(this->type_singleton(), "[7, 9]"));
-
- std::shared_ptr<Array> arr;
- ASSERT_RAISES(Invalid, this->Filter(this->type_singleton(), "[7, 8, 9]",
"[]", &arr));
+ auto type = this->type_singleton();
+ this->AssertFilter(type, "[]", "[]", "[]");
+
+ this->AssertFilter(type, "[9]", "[0]", "[]");
+ this->AssertFilter(type, "[9]", "[1]", "[9]");
+ this->AssertFilter(type, "[9]", "[null]", "[null]");
+ this->AssertFilter(type, "[null]", "[0]", "[]");
+ this->AssertFilter(type, "[null]", "[1]", "[null]");
+ this->AssertFilter(type, "[null]", "[null]", "[null]");
+
+ this->AssertFilter(type, "[7, 8, 9]", "[0, 1, 0]", "[8]");
+ this->AssertFilter(type, "[7, 8, 9]", "[1, 0, 1]", "[7, 9]");
+ this->AssertFilter(type, "[null, 8, 9]", "[0, 1, 0]", "[8]");
+ this->AssertFilter(type, "[7, 8, 9]", "[null, 1, 0]", "[null, 8]");
+ this->AssertFilter(type, "[7, 8, 9]", "[1, null, 1]", "[7, null, 9]");
+
+ this->AssertFilter(ArrayFromJSON(type, "[7, 8, 9]"),
+ ArrayFromJSON(boolean(), "[0, 1, 1, 1, 0, 1]")->Slice(3,
3),
+ ArrayFromJSON(type, "[7, 9]"));
+
+ Datum out_datum;
+ ASSERT_RAISES(Invalid,
+ arrow::compute::Filter(&this->ctx_, ArrayFromJSON(type, "[7,
8, 9]"),
+ ArrayFromJSON(boolean(), "[]"),
this->emit_null_,
+ &out_datum));
+ ASSERT_RAISES(Invalid, arrow::compute::Filter(
+ &this->ctx_, ArrayFromJSON(type, "[7, 8, 9]"),
+ ArrayFromJSON(boolean(), "[]"), this->drop_,
&out_datum));
}
TYPED_TEST(TestFilterKernelWithNumeric, FilterRandomNumeric) {
@@ -233,11 +276,11 @@ TYPED_TEST(TestFilterKernelWithNumeric,
CompareScalarAndFilterRandomNumeric) {
CType c_fifty = 50;
auto fifty = std::make_shared<ScalarType>(c_fifty);
for (auto op : {EQUAL, NOT_EQUAL, GREATER, LESS_EQUAL}) {
- auto options = CompareOptions(op);
Datum selection, filtered;
- ASSERT_OK(arrow::compute::Compare(&this->ctx_, Datum(array),
Datum(fifty), options,
- &selection));
- ASSERT_OK(arrow::compute::Filter(&this->ctx_, Datum(array), selection,
&filtered));
+ ASSERT_OK(arrow::compute::Compare(&this->ctx_, Datum(array),
Datum(fifty),
+ CompareOptions(op), &selection));
+ ASSERT_OK(
+ arrow::compute::Filter(&this->ctx_, Datum(array), selection, {},
&filtered));
auto filtered_array = filtered.make_array();
ASSERT_OK(filtered_array->ValidateFull());
auto expected =
@@ -253,16 +296,16 @@ TYPED_TEST(TestFilterKernelWithNumeric,
CompareArrayAndFilterRandomNumeric) {
auto rand = random::RandomArrayGenerator(kSeed);
for (size_t i = 3; i < 13; i++) {
const int64_t length = static_cast<int64_t>(1ULL << i);
- auto lhs =
- checked_pointer_cast<ArrayType>(rand.Numeric<TypeParam>(length, 0,
100, 0));
- auto rhs =
- checked_pointer_cast<ArrayType>(rand.Numeric<TypeParam>(length, 0,
100, 0));
+ auto lhs = checked_pointer_cast<ArrayType>(
+ rand.Numeric<TypeParam>(length, 0, 100, /*null_probability=*/0.0));
+ auto rhs = checked_pointer_cast<ArrayType>(
+ rand.Numeric<TypeParam>(length, 0, 100, /*null_probability=*/0.0));
for (auto op : {EQUAL, NOT_EQUAL, GREATER, LESS_EQUAL}) {
- auto options = CompareOptions(op);
Datum selection, filtered;
- ASSERT_OK(arrow::compute::Compare(&this->ctx_, Datum(lhs), Datum(rhs),
options,
- &selection));
- ASSERT_OK(arrow::compute::Filter(&this->ctx_, Datum(lhs), selection,
&filtered));
+ ASSERT_OK(arrow::compute::Compare(&this->ctx_, Datum(lhs), Datum(rhs),
+ CompareOptions(op), &selection));
+ ASSERT_OK(
+ arrow::compute::Filter(&this->ctx_, Datum(lhs), selection, {},
&filtered));
auto filtered_array = filtered.make_array();
ASSERT_OK(filtered_array->ValidateFull());
auto expected = CompareAndFilter<TypeParam>(lhs->raw_values(),
lhs->length(),
@@ -280,8 +323,8 @@ TYPED_TEST(TestFilterKernelWithNumeric,
ScalarInRangeAndFilterRandomNumeric) {
auto rand = random::RandomArrayGenerator(kSeed);
for (size_t i = 3; i < 13; i++) {
const int64_t length = static_cast<int64_t>(1ULL << i);
- auto array =
- checked_pointer_cast<ArrayType>(rand.Numeric<TypeParam>(length, 0,
100, 0));
+ auto array = checked_pointer_cast<ArrayType>(
+ rand.Numeric<TypeParam>(length, 0, 100, /*null_probability=*/0.0));
CType c_fifty = 50, c_hundred = 100;
auto fifty = std::make_shared<ScalarType>(c_fifty);
auto hundred = std::make_shared<ScalarType>(c_hundred);
@@ -292,7 +335,8 @@ TYPED_TEST(TestFilterKernelWithNumeric,
ScalarInRangeAndFilterRandomNumeric) {
CompareOptions(LESS),
&less_than_hundred));
ASSERT_OK(arrow::compute::And(&this->ctx_, greater_than_fifty,
less_than_hundred,
&selection));
- ASSERT_OK(arrow::compute::Filter(&this->ctx_, Datum(array), selection,
&filtered));
+ ASSERT_OK(
+ arrow::compute::Filter(&this->ctx_, Datum(array), selection, {},
&filtered));
auto filtered_array = filtered.make_array();
ASSERT_OK(filtered_array->ValidateFull());
auto expected = CompareAndFilter<TypeParam>(
@@ -314,8 +358,11 @@ class TestFilterKernelWithString : public
TestFilterKernel<TypeClass> {
void AssertFilter(const std::string& values, const std::string& filter,
const std::string& expected) {
- TestFilterKernel<TypeClass>::AssertFilter(value_type(), values, filter,
expected);
+ TestFilterKernel<TypeClass>::AssertFilter(ArrayFromJSON(value_type(),
values),
+ ArrayFromJSON(boolean(), filter),
+ ArrayFromJSON(value_type(),
expected));
}
+
void AssertFilterDictionary(const std::string& dictionary_values,
const std::string& dictionary_filter,
const std::string& filter,
@@ -328,7 +375,7 @@ class TestFilterKernelWithString : public
TestFilterKernel<TypeClass> {
ASSERT_OK(DictionaryArray::FromArrays(type, ArrayFromJSON(int8(),
expected_filter),
dict, &expected));
auto take_filter = ArrayFromJSON(boolean(), filter);
- this->AssertFilterArrays(values, take_filter, expected);
+ TestFilterKernel<TypeClass>::AssertFilter(values, take_filter, expected);
}
};
@@ -347,7 +394,9 @@ TYPED_TEST(TestFilterKernelWithString, FilterDictionary) {
this->AssertFilterDictionary(dict, "[3, 4, 2]", "[null, 1, 0]", "[null, 4]");
}
-class TestFilterKernelWithList : public TestFilterKernel<ListType> {};
+class TestFilterKernelWithList : public TestFilterKernel<ListType> {
+ public:
+};
TEST_F(TestFilterKernelWithList, FilterListInt32) {
std::string list_json = "[[], [1,2], null, [3]]";
@@ -482,19 +531,24 @@ TEST_F(TestFilterKernelWithUnion, FilterUnion) {
class TestFilterKernelWithRecordBatch : public TestFilterKernel<RecordBatch> {
public:
void AssertFilter(const std::shared_ptr<Schema>& schm, const std::string&
batch_json,
- const std::string& selection, const std::string&
expected_batch) {
+ const std::string& selection, FilterOptions options,
+ const std::string& expected_batch) {
std::shared_ptr<RecordBatch> actual;
- ASSERT_OK(this->Filter(schm, batch_json, selection, &actual));
+ ASSERT_OK(this->Filter(schm, batch_json, selection, options, &actual));
ASSERT_OK(actual->ValidateFull());
ASSERT_BATCHES_EQUAL(*RecordBatchFromJSON(schm, expected_batch), *actual);
}
Status Filter(const std::shared_ptr<Schema>& schm, const std::string&
batch_json,
- const std::string& selection, std::shared_ptr<RecordBatch>*
out) {
+ const std::string& selection, FilterOptions options,
+ std::shared_ptr<RecordBatch>* out) {
auto batch = RecordBatchFromJSON(schm, batch_json);
- return arrow::compute::Filter(&this->ctx_, *batch,
- *ArrayFromJSON(boolean(), selection), out);
+ Datum out_datum;
+ RETURN_NOT_OK(arrow::compute::Filter(
+ &this->ctx_, batch, ArrayFromJSON(boolean(), selection), options,
&out_datum));
+ *out = out_datum.record_batch();
+ return Status::OK();
}
};
@@ -502,23 +556,31 @@ TEST_F(TestFilterKernelWithRecordBatch,
FilterRecordBatch) {
std::vector<std::shared_ptr<Field>> fields = {field("a", int32()),
field("b", utf8())};
auto schm = schema(fields);
- auto struct_json = R"([
+ auto batch_json = R"([
{"a": null, "b": "yo"},
{"a": 1, "b": ""},
{"a": 2, "b": "hello"},
{"a": 4, "b": "eh"}
])";
- this->AssertFilter(schm, struct_json, "[0, 0, 0, 0]", "[]");
- this->AssertFilter(schm, struct_json, "[0, 1, 1, null]", R"([
+ for (auto options : {this->emit_null_, this->drop_}) {
+ this->AssertFilter(schm, batch_json, "[0, 0, 0, 0]", options, "[]");
+ this->AssertFilter(schm, batch_json, "[1, 1, 1, 1]", options, batch_json);
+ this->AssertFilter(schm, batch_json, "[1, 0, 1, 0]", options, R"([
+ {"a": null, "b": "yo"},
+ {"a": 2, "b": "hello"}
+ ])");
+ }
+
+ this->AssertFilter(schm, batch_json, "[0, 1, 1, null]", this->drop_, R"([
+ {"a": 1, "b": ""},
+ {"a": 2, "b": "hello"}
+ ])");
+
+ this->AssertFilter(schm, batch_json, "[0, 1, 1, null]", this->emit_null_,
R"([
{"a": 1, "b": ""},
{"a": 2, "b": "hello"},
{"a": null, "b": null}
])");
- this->AssertFilter(schm, struct_json, "[1, 1, 1, 1]", struct_json);
- this->AssertFilter(schm, struct_json, "[1, 0, 1, 0]", R"([
- {"a": null, "b": "yo"},
- {"a": 2, "b": "hello"}
- ])");
}
class TestFilterKernelWithChunkedArray : public TestFilterKernel<ChunkedArray>
{
@@ -545,26 +607,34 @@ class TestFilterKernelWithChunkedArray : public
TestFilterKernel<ChunkedArray> {
Status FilterWithArray(const std::shared_ptr<DataType>& type,
const std::vector<std::string>& values,
const std::string& filter,
std::shared_ptr<ChunkedArray>* out) {
- return arrow::compute::Filter(&this->ctx_, *ChunkedArrayFromJSON(type,
values),
- *ArrayFromJSON(boolean(), filter), out);
+ Datum out_datum;
+ RETURN_NOT_OK(arrow::compute::Filter(&this->ctx_,
ChunkedArrayFromJSON(type, values),
+ ArrayFromJSON(boolean(), filter), {},
+ &out_datum));
+ *out = out_datum.chunked_array();
+ return Status::OK();
}
Status FilterWithChunkedArray(const std::shared_ptr<DataType>& type,
const std::vector<std::string>& values,
const std::vector<std::string>& filter,
std::shared_ptr<ChunkedArray>* out) {
- return arrow::compute::Filter(&this->ctx_, *ChunkedArrayFromJSON(type,
values),
- *ChunkedArrayFromJSON(boolean(), filter),
out);
+ Datum out_datum;
+ RETURN_NOT_OK(arrow::compute::Filter(&this->ctx_,
ChunkedArrayFromJSON(type, values),
+ ChunkedArrayFromJSON(boolean(),
filter), {},
+ &out_datum));
+ *out = out_datum.chunked_array();
+ return Status::OK();
}
};
TEST_F(TestFilterKernelWithChunkedArray, FilterChunkedArray) {
- this->AssertFilter(int8(), {"[]"}, "[]", {"[]"});
- this->AssertChunkedFilter(int8(), {"[]"}, {"[]"}, {"[]"});
+ this->AssertFilter(int8(), {"[]"}, "[]", {});
+ this->AssertChunkedFilter(int8(), {"[]"}, {"[]"}, {});
- this->AssertFilter(int8(), {"[7]", "[8, 9]"}, "[0, 1, 0]", {"[]", "[8]"});
- this->AssertChunkedFilter(int8(), {"[7]", "[8, 9]"}, {"[0]", "[1, 0]"},
{"[]", "[8]"});
- this->AssertChunkedFilter(int8(), {"[7]", "[8, 9]"}, {"[0, 1]", "[0]"},
{"[8]", "[]"});
+ this->AssertFilter(int8(), {"[7]", "[8, 9]"}, "[0, 1, 0]", {"[8]"});
+ this->AssertChunkedFilter(int8(), {"[7]", "[8, 9]"}, {"[0]", "[1, 0]"},
{"[8]"});
+ this->AssertChunkedFilter(int8(), {"[7]", "[8, 9]"}, {"[0, 1]", "[0]"},
{"[8]"});
std::shared_ptr<ChunkedArray> arr;
ASSERT_RAISES(
@@ -577,38 +647,49 @@ class TestFilterKernelWithTable : public
TestFilterKernel<Table> {
public:
void AssertFilter(const std::shared_ptr<Schema>& schm,
const std::vector<std::string>& table_json, const
std::string& filter,
+ FilterOptions options,
const std::vector<std::string>& expected_table) {
std::shared_ptr<Table> actual;
- ASSERT_OK(this->FilterWithArray(schm, table_json, filter, &actual));
+ ASSERT_OK(this->FilterWithArray(schm, table_json, filter, options,
&actual));
ASSERT_OK(actual->ValidateFull());
ASSERT_TABLES_EQUAL(*TableFromJSON(schm, expected_table), *actual);
}
void AssertChunkedFilter(const std::shared_ptr<Schema>& schm,
const std::vector<std::string>& table_json,
- const std::vector<std::string>& filter,
+ const std::vector<std::string>& filter,
FilterOptions options,
const std::vector<std::string>& expected_table) {
std::shared_ptr<Table> actual;
- ASSERT_OK(this->FilterWithChunkedArray(schm, table_json, filter, &actual));
+ ASSERT_OK(this->FilterWithChunkedArray(schm, table_json, filter, options,
&actual));
ASSERT_OK(actual->ValidateFull());
- ASSERT_TABLES_EQUAL(*TableFromJSON(schm, expected_table), *actual);
+ AssertTablesEqual(*TableFromJSON(schm, expected_table), *actual,
+ /*same_chunk_layout=*/false);
}
Status FilterWithArray(const std::shared_ptr<Schema>& schm,
const std::vector<std::string>& values,
- const std::string& filter, std::shared_ptr<Table>*
out) {
- return arrow::compute::Filter(&this->ctx_, *TableFromJSON(schm, values),
- *ArrayFromJSON(boolean(), filter), out);
+ const std::string& filter, FilterOptions options,
+ std::shared_ptr<Table>* out) {
+ Datum out_datum;
+ RETURN_NOT_OK(arrow::compute::Filter(&this->ctx_, TableFromJSON(schm,
values),
+ ArrayFromJSON(boolean(), filter),
options,
+ &out_datum));
+ *out = out_datum.table();
+ return Status::OK();
}
Status FilterWithChunkedArray(const std::shared_ptr<Schema>& schm,
const std::vector<std::string>& values,
const std::vector<std::string>& filter,
- std::shared_ptr<Table>* out) {
- return arrow::compute::Filter(&this->ctx_, *TableFromJSON(schm, values),
- *ChunkedArrayFromJSON(boolean(), filter),
out);
+ FilterOptions options, std::shared_ptr<Table>*
out) {
+ Datum out_datum;
+ RETURN_NOT_OK(arrow::compute::Filter(&this->ctx_, TableFromJSON(schm,
values),
+ ChunkedArrayFromJSON(boolean(),
filter), options,
+ &out_datum));
+ *out = out_datum.table();
+ return Status::OK();
}
};
@@ -616,19 +697,39 @@ TEST_F(TestFilterKernelWithTable, FilterTable) {
std::vector<std::shared_ptr<Field>> fields = {field("a", int32()),
field("b", utf8())};
auto schm = schema(fields);
- std::vector<std::string> table_json = {
- "[{\"a\": null, \"b\": \"yo\"},{\"a\": 1, \"b\": \"\"}]",
- "[{\"a\": 2, \"b\": \"hello\"},{\"a\": 4, \"b\": \"eh\"}]"};
- this->AssertFilter(schm, table_json, "[0, 0, 0, 0]", {"[]", "[]"});
- this->AssertChunkedFilter(schm, table_json, {"[0]", "[0, 0, 0]"}, {"[]",
"[]"});
-
- std::vector<std::string> expected2 = {
- "[{\"a\": 1, \"b\": \"\"}]",
- "[{\"a\": 2, \"b\": \"hello\"},{\"a\": null, \"b\": null}]"};
- this->AssertFilter(schm, table_json, "[0, 1, 1, null]", expected2);
- this->AssertChunkedFilter(schm, table_json, {"[0, 1, 1]", "[null]"},
expected2);
- this->AssertFilter(schm, table_json, "[1, 1, 1, 1]", table_json);
- this->AssertChunkedFilter(schm, table_json, {"[1]", "[1, 1, 1]"},
table_json);
+ std::vector<std::string> table_json = {R"([
+ {"a": null, "b": "yo"},
+ {"a": 1, "b": ""}
+ ])",
+ R"([
+ {"a": 2, "b": "hello"},
+ {"a": 4, "b": "eh"}
+ ])"};
+ for (auto options : {this->emit_null_, this->drop_}) {
+ this->AssertFilter(schm, table_json, "[0, 0, 0, 0]", options, {});
+ this->AssertChunkedFilter(schm, table_json, {"[0]", "[0, 0, 0]"}, options,
{});
+ this->AssertFilter(schm, table_json, "[1, 1, 1, 1]", options, table_json);
+ this->AssertChunkedFilter(schm, table_json, {"[1]", "[1, 1, 1]"}, options,
+ table_json);
+ }
+
+ std::vector<std::string> expected_emit_null = {R"([
+ {"a": 1, "b": ""}
+ ])",
+ R"([
+ {"a": 2, "b": "hello"},
+ {"a": null, "b": null}
+ ])"};
+ this->AssertFilter(schm, table_json, "[0, 1, 1, null]", this->emit_null_,
+ expected_emit_null);
+ this->AssertChunkedFilter(schm, table_json, {"[0, 1, 1]", "[null]"},
this->emit_null_,
+ expected_emit_null);
+
+ std::vector<std::string> expected_drop = {R"([{"a": 1, "b": ""}])",
+ R"([{"a": 2, "b": "hello"}])"};
+ this->AssertFilter(schm, table_json, "[0, 1, 1, null]", this->drop_,
expected_drop);
+ this->AssertChunkedFilter(schm, table_json, {"[0, 1, 1]", "[null]"},
this->drop_,
+ expected_drop);
}
} // namespace compute
diff --git a/cpp/src/arrow/dataset/filter.cc b/cpp/src/arrow/dataset/filter.cc
index ca5c3a7..7ec27c2 100644
--- a/cpp/src/arrow/dataset/filter.cc
+++ b/cpp/src/arrow/dataset/filter.cc
@@ -1242,10 +1242,10 @@ Result<std::shared_ptr<RecordBatch>>
TreeEvaluator::Filter(
MemoryPool* pool) const {
if (selection.is_array()) {
auto selection_array = selection.make_array();
- std::shared_ptr<RecordBatch> filtered;
+ compute::Datum filtered;
compute::FunctionContext ctx{pool};
- RETURN_NOT_OK(compute::Filter(&ctx, *batch, *selection_array, &filtered));
- return std::move(filtered);
+ RETURN_NOT_OK(compute::Filter(&ctx, batch, selection_array, {},
&filtered));
+ return filtered.record_batch();
}
if (!selection.is_scalar() || selection.type()->id() != Type::BOOL) {
diff --git a/python/pyarrow/array.pxi b/python/pyarrow/array.pxi
index ced55b5..6773a70 100644
--- a/python/pyarrow/array.pxi
+++ b/python/pyarrow/array.pxi
@@ -1002,7 +1002,7 @@ cdef class Array(_PandasConvertible):
return wrap_datum(out)
- def filter(self, Array mask):
+ def filter(self, Array mask, null_selection_behavior='drop'):
"""
Filter the array with a boolean mask.
@@ -1010,6 +1010,12 @@ cdef class Array(_PandasConvertible):
----------
mask : Array
The boolean mask indicating which values to extract.
+ null_selection_behavior : str, default 'drop'
+ Configure the behavior on encountering a null slot in the mask.
+ Allowed values are 'drop' and 'emit_null'.
+
+ 'drop': nulls will be treated as equivalent to False.
+ 'emit_null': nulls will result in a null in the output.
Returns
-------
@@ -1025,16 +1031,33 @@ cdef class Array(_PandasConvertible):
<pyarrow.lib.StringArray object at 0x7fa826df9200>
[
"a",
+ "e"
+ ]
+ >>> arr.filter(mask, null_selection_behavior='emit_null')
+ <pyarrow.lib.StringArray object at 0x7fa826df9200>
+ [
+ "a",
null,
"e"
]
"""
cdef:
cdef CDatum out
+ CFilterOptions options
+
+ if null_selection_behavior == 'drop':
+ options.null_selection_behavior = \
+ CFilterNullSelectionBehavior_DROP
+ elif null_selection_behavior == 'emit_null':
+ options.null_selection_behavior = \
+ CFilterNullSelectionBehavior_EMIT_NULL
+ else:
+ raise ValueError('"' + null_selection_behavior + '" is not a ' +
+ 'valid null_selection_behavior')
with nogil:
check_status(FilterKernel(_context(), CDatum(self.sp_array),
- CDatum(mask.sp_array), &out))
+ CDatum(mask.sp_array), options, &out))
return wrap_datum(out)
diff --git a/python/pyarrow/includes/libarrow.pxd
b/python/pyarrow/includes/libarrow.pxd
index 7be418c..06c9f90 100644
--- a/python/pyarrow/includes/libarrow.pxd
+++ b/python/pyarrow/includes/libarrow.pxd
@@ -1465,6 +1465,16 @@ cdef extern from "arrow/compute/api.h" namespace
"arrow::compute" nogil:
cdef cppclass CTakeOptions" arrow::compute::TakeOptions":
pass
+ enum CFilterNullSelectionBehavior \
+ "arrow::compute::FilterOptions::NullSelectionBehavior":
+ CFilterNullSelectionBehavior_DROP \
+ "arrow::compute::FilterOptions::DROP"
+ CFilterNullSelectionBehavior_EMIT_NULL \
+ "arrow::compute::FilterOptions::EMIT_NULL"
+
+ cdef cppclass CFilterOptions" arrow::compute::FilterOptions":
+ CFilterNullSelectionBehavior null_selection_behavior
+
enum DatumType" arrow::compute::Datum::type":
DatumType_NONE" arrow::compute::Datum::NONE"
DatumType_SCALAR" arrow::compute::Datum::SCALAR"
@@ -1517,7 +1527,7 @@ cdef extern from "arrow/compute/api.h" namespace
"arrow::compute" nogil:
# Filter clashes with gandiva.pyx::Filter
CStatus FilterKernel" arrow::compute::Filter"(
CFunctionContext* context, const CDatum& values,
- const CDatum& filter, CDatum* out)
+ const CDatum& filter, CFilterOptions, CDatum* out)
enum CCompareOperator "arrow::compute::CompareOperator":
CCompareOperator_EQUAL "arrow::compute::CompareOperator::EQUAL"
diff --git a/python/pyarrow/tests/test_compute.py
b/python/pyarrow/tests/test_compute.py
index b887631..07ab323 100644
--- a/python/pyarrow/tests/test_compute.py
+++ b/python/pyarrow/tests/test_compute.py
@@ -149,11 +149,12 @@ def test_filter(ty, values):
arr = pa.array(values, type=ty)
mask = pa.array([True, False, False, True, None])
- result = arr.filter(mask)
+ result = arr.filter(mask, null_selection_behavior='drop')
result.validate()
-
- expected = pa.array([values[0], values[3], None], type=ty)
- assert result.equals(expected)
+ assert result.equals(pa.array([values[0], values[3]], type=ty))
+ result = arr.filter(mask, null_selection_behavior='emit_null')
+ result.validate()
+ assert result.equals(pa.array([values[0], values[3], None], type=ty))
# non-boolean dtype
mask = pa.array([0, 1, 0, 1, 0])
diff --git a/r/R/array.R b/r/R/array.R
index 63f221a..ceb2994 100644
--- a/r/R/array.R
+++ b/r/R/array.R
@@ -70,7 +70,7 @@
#' until the end of the array.
#' - `$Take(i)`: return an `Array` with values at positions given by integers
#' (R vector or Array Array) `i`.
-#' - `$Filter(i)`: return an `Array` with values at positions where logical
+#' - `$Filter(i, keep_na = TRUE)`: return an `Array` with values at positions
where logical
#' vector (or Arrow boolean Array) `i` is `TRUE`.
#' - `$RangeEquals(other, start_idx, end_idx, other_start_idx)` :
#' - `$cast(target_type, safe = TRUE, options = cast_options(safe))`: Alter the
@@ -133,12 +133,12 @@ Array <- R6Class("Array",
assert_is(i, "Array")
Array$create(Array__Take(self, i))
},
- Filter = function(i) {
+ Filter = function(i, keep_na = TRUE) {
if (is.logical(i)) {
i <- Array$create(i)
}
assert_is(i, "Array")
- Array$create(Array__Filter(self, i))
+ Array$create(Array__Filter(self, i, keep_na))
},
RangeEquals = function(other, start_idx, end_idx, other_start_idx = 0L) {
assert_is(other, "Array")
@@ -243,7 +243,7 @@ is.na.Array <- function(x) {
#' @export
as.vector.Array <- function(x, mode) x$as_vector()
-filter_rows <- function(x, i, ...) {
+filter_rows <- function(x, i, keep_na = TRUE, ...) {
# General purpose function for [ row subsetting with R semantics
# Based on the input for `i`, calls x$Filter, x$Slice, or x$Take
nrows <- x$num_rows %||% x$length() # Depends on whether Array or Table-like
@@ -257,7 +257,7 @@ filter_rows <- function(x, i, ...) {
x
} else {
i <- rep_len(i, nrows) # For R recycling behavior; consider
vctrs::vec_recycle()
- x$Filter(i)
+ x$Filter(i, keep_na)
}
} else if (is.numeric(i)) {
if (all(i < 0)) {
@@ -275,7 +275,7 @@ filter_rows <- function(x, i, ...) {
# NOTE: this doesn't do the - 1 offset
x$Take(i)
} else if (is.Array(i, "bool")) {
- x$Filter(i)
+ x$Filter(i, keep_na)
} else {
# Unsupported cases
if (is.Array(i)) {
diff --git a/r/R/arrowExports.R b/r/R/arrowExports.R
index af381b2..889d1b3 100644
--- a/r/R/arrowExports.R
+++ b/r/R/arrowExports.R
@@ -300,28 +300,28 @@ Table__TakeChunked <- function(table, indices){
.Call(`_arrow_Table__TakeChunked` , table, indices)
}
-Array__Filter <- function(values, filter){
- .Call(`_arrow_Array__Filter` , values, filter)
+Array__Filter <- function(values, filter, keep_na){
+ .Call(`_arrow_Array__Filter` , values, filter, keep_na)
}
-RecordBatch__Filter <- function(batch, filter){
- .Call(`_arrow_RecordBatch__Filter` , batch, filter)
+RecordBatch__Filter <- function(batch, filter, keep_na){
+ .Call(`_arrow_RecordBatch__Filter` , batch, filter, keep_na)
}
-ChunkedArray__Filter <- function(values, filter){
- .Call(`_arrow_ChunkedArray__Filter` , values, filter)
+ChunkedArray__Filter <- function(values, filter, keep_na){
+ .Call(`_arrow_ChunkedArray__Filter` , values, filter, keep_na)
}
-ChunkedArray__FilterChunked <- function(values, filter){
- .Call(`_arrow_ChunkedArray__FilterChunked` , values, filter)
+ChunkedArray__FilterChunked <- function(values, filter, keep_na){
+ .Call(`_arrow_ChunkedArray__FilterChunked` , values, filter, keep_na)
}
-Table__Filter <- function(table, filter){
- .Call(`_arrow_Table__Filter` , table, filter)
+Table__Filter <- function(table, filter, keep_na){
+ .Call(`_arrow_Table__Filter` , table, filter, keep_na)
}
-Table__FilterChunked <- function(table, filter){
- .Call(`_arrow_Table__FilterChunked` , table, filter)
+Table__FilterChunked <- function(table, filter, keep_na){
+ .Call(`_arrow_Table__FilterChunked` , table, filter, keep_na)
}
csv___ReadOptions__initialize <- function(options){
diff --git a/r/R/chunked-array.R b/r/R/chunked-array.R
index 2620d9e..f352705 100644
--- a/r/R/chunked-array.R
+++ b/r/R/chunked-array.R
@@ -39,7 +39,7 @@
#' - `$Take(i)`: return a `ChunkedArray` with values at positions given by
#' integers `i`. If `i` is an Arrow `Array` or `ChunkedArray`, it will be
#' coerced to an R vector before taking.
-#' - `$Filter(i)`: return a `ChunkedArray` with values at positions where
+#' - `$Filter(i, keep_na = TRUE)`: return a `ChunkedArray` with values at
positions where
#' logical vector or Arrow boolean-type `(Chunked)Array` `i` is `TRUE`.
#' - `$cast(target_type, safe = TRUE, options = cast_options(safe))`: Alter the
#' data in the array to change its type.
@@ -81,15 +81,15 @@ ChunkedArray <- R6Class("ChunkedArray", inherit =
ArrowObject,
assert_is(i, "Array")
return(shared_ptr(ChunkedArray, ChunkedArray__Take(self, i)))
},
- Filter = function(i) {
+ Filter = function(i, keep_na = TRUE) {
if (is.logical(i)) {
i <- Array$create(i)
}
if (inherits(i, "ChunkedArray")) {
- return(shared_ptr(ChunkedArray, ChunkedArray__FilterChunked(self, i)))
+ return(shared_ptr(ChunkedArray, ChunkedArray__FilterChunked(self, i,
keep_na)))
}
assert_is(i, "Array")
- shared_ptr(ChunkedArray, ChunkedArray__Filter(self, i))
+ shared_ptr(ChunkedArray, ChunkedArray__Filter(self, i, keep_na))
},
cast = function(target_type, safe = TRUE, options = cast_options(safe)) {
assert_is(options, "CastOptions")
diff --git a/r/R/dplyr.R b/r/R/dplyr.R
index 78995bb..3f57742 100644
--- a/r/R/dplyr.R
+++ b/r/R/dplyr.R
@@ -248,7 +248,7 @@ collect.arrow_dplyr_query <- function(x, ...) {
df <- as.data.frame(scanner_builder$Finish()$ToTable())
} else {
# This is a Table/RecordBatch. See record-batch.R for the [ method
- df <- as.data.frame(x$.data[x$filtered_rows, colnames])
+ df <- as.data.frame(x$.data[x$filtered_rows, colnames, keep_na = FALSE])
}
# In case variables were renamed, apply those names
names(df) <- names(colnames)
diff --git a/r/R/record-batch.R b/r/R/record-batch.R
index 7dde3ae..8b10cb6 100644
--- a/r/R/record-batch.R
+++ b/r/R/record-batch.R
@@ -61,7 +61,7 @@
#' of the table if `NULL`, the default.
#' - `$Take(i)`: return an `RecordBatch` with rows at positions given by
#' integers (R vector or Array Array) `i`.
-#' - `$Filter(i)`: return an `RecordBatch` with rows at positions where logical
+#' - `$Filter(i, keep_na = TRUE)`: return an `RecordBatch` with rows at
positions where logical
#' vector (or Arrow boolean Array) `i` is `TRUE`.
#' - `$serialize()`: Returns a raw vector suitable for interprocess
communication
#' - `$cast(target_schema, safe = TRUE, options = cast_options(safe))`: Alter
@@ -121,12 +121,12 @@ RecordBatch <- R6Class("RecordBatch", inherit =
ArrowObject,
assert_is(i, "Array")
shared_ptr(RecordBatch, RecordBatch__Take(self, i))
},
- Filter = function(i) {
+ Filter = function(i, keep_na = TRUE) {
if (is.logical(i)) {
i <- Array$create(i)
}
assert_is(i, "Array")
- shared_ptr(RecordBatch, RecordBatch__Filter(self, i))
+ shared_ptr(RecordBatch, RecordBatch__Filter(self, i, keep_na))
},
serialize = function() ipc___SerializeRecordBatch__Raw(self),
ToString = function() ToString_tabular(self),
diff --git a/r/R/table.R b/r/R/table.R
index 25c0a23..136de36 100644
--- a/r/R/table.R
+++ b/r/R/table.R
@@ -69,7 +69,7 @@
#' - `$Take(i)`: return an `Table` with rows at positions given by
#' integers `i`. If `i` is an Arrow `Array` or `ChunkedArray`, it will be
#' coerced to an R vector before taking.
-#' - `$Filter(i)`: return an `Table` with rows at positions where logical
+#' - `$Filter(i, keep_na = TRUE)`: return an `Table` with rows at positions
where logical
#' vector or Arrow boolean-type `(Chunked)Array` `i` is `TRUE`.
#' - `$serialize(output_stream, ...)`: Write the table to the given
#' [OutputStream]
@@ -150,15 +150,15 @@ Table <- R6Class("Table", inherit = ArrowObject,
assert_is(i, "Array")
shared_ptr(Table, Table__Take(self, i))
},
- Filter = function(i) {
+ Filter = function(i, keep_na = TRUE) {
if (is.logical(i)) {
i <- Array$create(i)
}
if (inherits(i, "ChunkedArray")) {
- return(shared_ptr(Table, Table__FilterChunked(self, i)))
+ return(shared_ptr(Table, Table__FilterChunked(self, i, keep_na)))
}
assert_is(i, "Array")
- shared_ptr(Table, Table__Filter(self, i))
+ shared_ptr(Table, Table__Filter(self, i, keep_na))
},
Equals = function(other, check_metadata = TRUE, ...) {
diff --git a/r/man/ChunkedArray.Rd b/r/man/ChunkedArray.Rd
index c26ff3f..533931a 100644
--- a/r/man/ChunkedArray.Rd
+++ b/r/man/ChunkedArray.Rd
@@ -36,7 +36,7 @@ until the end of the array.
\item \verb{$Take(i)}: return a \code{ChunkedArray} with values at positions
given by
integers \code{i}. If \code{i} is an Arrow \code{Array} or
\code{ChunkedArray}, it will be
coerced to an R vector before taking.
-\item \verb{$Filter(i)}: return a \code{ChunkedArray} with values at positions
where
+\item \verb{$Filter(i, keep_na = TRUE)}: return a \code{ChunkedArray} with
values at positions where
logical vector or Arrow boolean-type \verb{(Chunked)Array} \code{i} is
\code{TRUE}.
\item \verb{$cast(target_type, safe = TRUE, options = cast_options(safe))}:
Alter the
data in the array to change its type.
diff --git a/r/man/RecordBatch.Rd b/r/man/RecordBatch.Rd
index af68398..db3282e 100644
--- a/r/man/RecordBatch.Rd
+++ b/r/man/RecordBatch.Rd
@@ -61,7 +61,7 @@ indicated integer offset and going for the given length, or
to the end
of the table if \code{NULL}, the default.
\item \verb{$Take(i)}: return an \code{RecordBatch} with rows at positions
given by
integers (R vector or Array Array) \code{i}.
-\item \verb{$Filter(i)}: return an \code{RecordBatch} with rows at positions
where logical
+\item \verb{$Filter(i, keep_na = TRUE)}: return an \code{RecordBatch} with
rows at positions where logical
vector (or Arrow boolean Array) \code{i} is \code{TRUE}.
\item \verb{$serialize()}: Returns a raw vector suitable for interprocess
communication
\item \verb{$cast(target_schema, safe = TRUE, options = cast_options(safe))}:
Alter
diff --git a/r/man/Table.Rd b/r/man/Table.Rd
index 4a4bddc..52bb1e2 100644
--- a/r/man/Table.Rd
+++ b/r/man/Table.Rd
@@ -60,7 +60,7 @@ of the table if \code{NULL}, the default.
\item \verb{$Take(i)}: return an \code{Table} with rows at positions given by
integers \code{i}. If \code{i} is an Arrow \code{Array} or
\code{ChunkedArray}, it will be
coerced to an R vector before taking.
-\item \verb{$Filter(i)}: return an \code{Table} with rows at positions where
logical
+\item \verb{$Filter(i, keep_na = TRUE)}: return an \code{Table} with rows at
positions where logical
vector or Arrow boolean-type \verb{(Chunked)Array} \code{i} is \code{TRUE}.
\item \verb{$serialize(output_stream, ...)}: Write the table to the given
\link{OutputStream}
diff --git a/r/man/array.Rd b/r/man/array.Rd
index 733ab4c..7f0e007 100644
--- a/r/man/array.Rd
+++ b/r/man/array.Rd
@@ -65,7 +65,7 @@ with the indicated offset and length. If length is
\code{NULL}, the slice goes
until the end of the array.
\item \verb{$Take(i)}: return an \code{Array} with values at positions given
by integers
(R vector or Array Array) \code{i}.
-\item \verb{$Filter(i)}: return an \code{Array} with values at positions where
logical
+\item \verb{$Filter(i, keep_na = TRUE)}: return an \code{Array} with values at
positions where logical
vector (or Arrow boolean Array) \code{i} is \code{TRUE}.
\item \verb{$RangeEquals(other, start_idx, end_idx, other_start_idx)} :
\item \verb{$cast(target_type, safe = TRUE, options = cast_options(safe))}:
Alter the
diff --git a/r/src/arrowExports.cpp b/r/src/arrowExports.cpp
index 8d109ee..29dc9fa 100644
--- a/r/src/arrowExports.cpp
+++ b/r/src/arrowExports.cpp
@@ -1183,96 +1183,102 @@ RcppExport SEXP _arrow_Table__TakeChunked(SEXP
table_sexp, SEXP indices_sexp){
// compute.cpp
#if defined(ARROW_R_WITH_ARROW)
-std::shared_ptr<arrow::Array> Array__Filter(const
std::shared_ptr<arrow::Array>& values, const std::shared_ptr<arrow::Array>&
filter);
-RcppExport SEXP _arrow_Array__Filter(SEXP values_sexp, SEXP filter_sexp){
+std::shared_ptr<arrow::Array> Array__Filter(const
std::shared_ptr<arrow::Array>& values, const std::shared_ptr<arrow::Array>&
filter, bool keep_na);
+RcppExport SEXP _arrow_Array__Filter(SEXP values_sexp, SEXP filter_sexp, SEXP
keep_na_sexp){
BEGIN_RCPP
Rcpp::traits::input_parameter<const
std::shared_ptr<arrow::Array>&>::type values(values_sexp);
Rcpp::traits::input_parameter<const
std::shared_ptr<arrow::Array>&>::type filter(filter_sexp);
- return Rcpp::wrap(Array__Filter(values, filter));
+ Rcpp::traits::input_parameter<bool>::type keep_na(keep_na_sexp);
+ return Rcpp::wrap(Array__Filter(values, filter, keep_na));
END_RCPP
}
#else
-RcppExport SEXP _arrow_Array__Filter(SEXP values_sexp, SEXP filter_sexp){
+RcppExport SEXP _arrow_Array__Filter(SEXP values_sexp, SEXP filter_sexp, SEXP
keep_na_sexp){
Rf_error("Cannot call Array__Filter(). Please use
arrow::install_arrow() to install required runtime libraries. ");
}
#endif
// compute.cpp
#if defined(ARROW_R_WITH_ARROW)
-std::shared_ptr<arrow::RecordBatch> RecordBatch__Filter(const
std::shared_ptr<arrow::RecordBatch>& batch, const
std::shared_ptr<arrow::Array>& filter);
-RcppExport SEXP _arrow_RecordBatch__Filter(SEXP batch_sexp, SEXP filter_sexp){
+std::shared_ptr<arrow::RecordBatch> RecordBatch__Filter(const
std::shared_ptr<arrow::RecordBatch>& batch, const
std::shared_ptr<arrow::Array>& filter, bool keep_na);
+RcppExport SEXP _arrow_RecordBatch__Filter(SEXP batch_sexp, SEXP filter_sexp,
SEXP keep_na_sexp){
BEGIN_RCPP
Rcpp::traits::input_parameter<const
std::shared_ptr<arrow::RecordBatch>&>::type batch(batch_sexp);
Rcpp::traits::input_parameter<const
std::shared_ptr<arrow::Array>&>::type filter(filter_sexp);
- return Rcpp::wrap(RecordBatch__Filter(batch, filter));
+ Rcpp::traits::input_parameter<bool>::type keep_na(keep_na_sexp);
+ return Rcpp::wrap(RecordBatch__Filter(batch, filter, keep_na));
END_RCPP
}
#else
-RcppExport SEXP _arrow_RecordBatch__Filter(SEXP batch_sexp, SEXP filter_sexp){
+RcppExport SEXP _arrow_RecordBatch__Filter(SEXP batch_sexp, SEXP filter_sexp,
SEXP keep_na_sexp){
Rf_error("Cannot call RecordBatch__Filter(). Please use
arrow::install_arrow() to install required runtime libraries. ");
}
#endif
// compute.cpp
#if defined(ARROW_R_WITH_ARROW)
-std::shared_ptr<arrow::ChunkedArray> ChunkedArray__Filter(const
std::shared_ptr<arrow::ChunkedArray>& values, const
std::shared_ptr<arrow::Array>& filter);
-RcppExport SEXP _arrow_ChunkedArray__Filter(SEXP values_sexp, SEXP
filter_sexp){
+std::shared_ptr<arrow::ChunkedArray> ChunkedArray__Filter(const
std::shared_ptr<arrow::ChunkedArray>& values, const
std::shared_ptr<arrow::Array>& filter, bool keep_na);
+RcppExport SEXP _arrow_ChunkedArray__Filter(SEXP values_sexp, SEXP
filter_sexp, SEXP keep_na_sexp){
BEGIN_RCPP
Rcpp::traits::input_parameter<const
std::shared_ptr<arrow::ChunkedArray>&>::type values(values_sexp);
Rcpp::traits::input_parameter<const
std::shared_ptr<arrow::Array>&>::type filter(filter_sexp);
- return Rcpp::wrap(ChunkedArray__Filter(values, filter));
+ Rcpp::traits::input_parameter<bool>::type keep_na(keep_na_sexp);
+ return Rcpp::wrap(ChunkedArray__Filter(values, filter, keep_na));
END_RCPP
}
#else
-RcppExport SEXP _arrow_ChunkedArray__Filter(SEXP values_sexp, SEXP
filter_sexp){
+RcppExport SEXP _arrow_ChunkedArray__Filter(SEXP values_sexp, SEXP
filter_sexp, SEXP keep_na_sexp){
Rf_error("Cannot call ChunkedArray__Filter(). Please use
arrow::install_arrow() to install required runtime libraries. ");
}
#endif
// compute.cpp
#if defined(ARROW_R_WITH_ARROW)
-std::shared_ptr<arrow::ChunkedArray> ChunkedArray__FilterChunked(const
std::shared_ptr<arrow::ChunkedArray>& values, const
std::shared_ptr<arrow::ChunkedArray>& filter);
-RcppExport SEXP _arrow_ChunkedArray__FilterChunked(SEXP values_sexp, SEXP
filter_sexp){
+std::shared_ptr<arrow::ChunkedArray> ChunkedArray__FilterChunked(const
std::shared_ptr<arrow::ChunkedArray>& values, const
std::shared_ptr<arrow::ChunkedArray>& filter, bool keep_na);
+RcppExport SEXP _arrow_ChunkedArray__FilterChunked(SEXP values_sexp, SEXP
filter_sexp, SEXP keep_na_sexp){
BEGIN_RCPP
Rcpp::traits::input_parameter<const
std::shared_ptr<arrow::ChunkedArray>&>::type values(values_sexp);
Rcpp::traits::input_parameter<const
std::shared_ptr<arrow::ChunkedArray>&>::type filter(filter_sexp);
- return Rcpp::wrap(ChunkedArray__FilterChunked(values, filter));
+ Rcpp::traits::input_parameter<bool>::type keep_na(keep_na_sexp);
+ return Rcpp::wrap(ChunkedArray__FilterChunked(values, filter, keep_na));
END_RCPP
}
#else
-RcppExport SEXP _arrow_ChunkedArray__FilterChunked(SEXP values_sexp, SEXP
filter_sexp){
+RcppExport SEXP _arrow_ChunkedArray__FilterChunked(SEXP values_sexp, SEXP
filter_sexp, SEXP keep_na_sexp){
Rf_error("Cannot call ChunkedArray__FilterChunked(). Please use
arrow::install_arrow() to install required runtime libraries. ");
}
#endif
// compute.cpp
#if defined(ARROW_R_WITH_ARROW)
-std::shared_ptr<arrow::Table> Table__Filter(const
std::shared_ptr<arrow::Table>& table, const std::shared_ptr<arrow::Array>&
filter);
-RcppExport SEXP _arrow_Table__Filter(SEXP table_sexp, SEXP filter_sexp){
+std::shared_ptr<arrow::Table> Table__Filter(const
std::shared_ptr<arrow::Table>& table, const std::shared_ptr<arrow::Array>&
filter, bool keep_na);
+RcppExport SEXP _arrow_Table__Filter(SEXP table_sexp, SEXP filter_sexp, SEXP
keep_na_sexp){
BEGIN_RCPP
Rcpp::traits::input_parameter<const
std::shared_ptr<arrow::Table>&>::type table(table_sexp);
Rcpp::traits::input_parameter<const
std::shared_ptr<arrow::Array>&>::type filter(filter_sexp);
- return Rcpp::wrap(Table__Filter(table, filter));
+ Rcpp::traits::input_parameter<bool>::type keep_na(keep_na_sexp);
+ return Rcpp::wrap(Table__Filter(table, filter, keep_na));
END_RCPP
}
#else
-RcppExport SEXP _arrow_Table__Filter(SEXP table_sexp, SEXP filter_sexp){
+RcppExport SEXP _arrow_Table__Filter(SEXP table_sexp, SEXP filter_sexp, SEXP
keep_na_sexp){
Rf_error("Cannot call Table__Filter(). Please use
arrow::install_arrow() to install required runtime libraries. ");
}
#endif
// compute.cpp
#if defined(ARROW_R_WITH_ARROW)
-std::shared_ptr<arrow::Table> Table__FilterChunked(const
std::shared_ptr<arrow::Table>& table, const
std::shared_ptr<arrow::ChunkedArray>& filter);
-RcppExport SEXP _arrow_Table__FilterChunked(SEXP table_sexp, SEXP filter_sexp){
+std::shared_ptr<arrow::Table> Table__FilterChunked(const
std::shared_ptr<arrow::Table>& table, const
std::shared_ptr<arrow::ChunkedArray>& filter, bool keep_na);
+RcppExport SEXP _arrow_Table__FilterChunked(SEXP table_sexp, SEXP filter_sexp,
SEXP keep_na_sexp){
BEGIN_RCPP
Rcpp::traits::input_parameter<const
std::shared_ptr<arrow::Table>&>::type table(table_sexp);
Rcpp::traits::input_parameter<const
std::shared_ptr<arrow::ChunkedArray>&>::type filter(filter_sexp);
- return Rcpp::wrap(Table__FilterChunked(table, filter));
+ Rcpp::traits::input_parameter<bool>::type keep_na(keep_na_sexp);
+ return Rcpp::wrap(Table__FilterChunked(table, filter, keep_na));
END_RCPP
}
#else
-RcppExport SEXP _arrow_Table__FilterChunked(SEXP table_sexp, SEXP filter_sexp){
+RcppExport SEXP _arrow_Table__FilterChunked(SEXP table_sexp, SEXP filter_sexp,
SEXP keep_na_sexp){
Rf_error("Cannot call Table__FilterChunked(). Please use
arrow::install_arrow() to install required runtime libraries. ");
}
#endif
@@ -5817,12 +5823,12 @@ static const R_CallMethodDef CallEntries[] = {
{ "_arrow_ChunkedArray__TakeChunked", (DL_FUNC)
&_arrow_ChunkedArray__TakeChunked, 2},
{ "_arrow_Table__Take", (DL_FUNC) &_arrow_Table__Take, 2},
{ "_arrow_Table__TakeChunked", (DL_FUNC)
&_arrow_Table__TakeChunked, 2},
- { "_arrow_Array__Filter", (DL_FUNC) &_arrow_Array__Filter, 2},
- { "_arrow_RecordBatch__Filter", (DL_FUNC)
&_arrow_RecordBatch__Filter, 2},
- { "_arrow_ChunkedArray__Filter", (DL_FUNC)
&_arrow_ChunkedArray__Filter, 2},
- { "_arrow_ChunkedArray__FilterChunked", (DL_FUNC)
&_arrow_ChunkedArray__FilterChunked, 2},
- { "_arrow_Table__Filter", (DL_FUNC) &_arrow_Table__Filter, 2},
- { "_arrow_Table__FilterChunked", (DL_FUNC)
&_arrow_Table__FilterChunked, 2},
+ { "_arrow_Array__Filter", (DL_FUNC) &_arrow_Array__Filter, 3},
+ { "_arrow_RecordBatch__Filter", (DL_FUNC)
&_arrow_RecordBatch__Filter, 3},
+ { "_arrow_ChunkedArray__Filter", (DL_FUNC)
&_arrow_ChunkedArray__Filter, 3},
+ { "_arrow_ChunkedArray__FilterChunked", (DL_FUNC)
&_arrow_ChunkedArray__FilterChunked, 3},
+ { "_arrow_Table__Filter", (DL_FUNC) &_arrow_Table__Filter, 3},
+ { "_arrow_Table__FilterChunked", (DL_FUNC)
&_arrow_Table__FilterChunked, 3},
{ "_arrow_csv___ReadOptions__initialize", (DL_FUNC)
&_arrow_csv___ReadOptions__initialize, 1},
{ "_arrow_csv___ParseOptions__initialize", (DL_FUNC)
&_arrow_csv___ParseOptions__initialize, 1},
{ "_arrow_csv___ConvertOptions__initialize", (DL_FUNC)
&_arrow_csv___ConvertOptions__initialize, 1},
diff --git a/r/src/compute.cpp b/r/src/compute.cpp
index 6f78bcb..eecbda9 100644
--- a/r/src/compute.cpp
+++ b/r/src/compute.cpp
@@ -164,59 +164,91 @@ std::shared_ptr<arrow::Table> Table__TakeChunked(
// [[arrow::export]]
std::shared_ptr<arrow::Array> Array__Filter(const
std::shared_ptr<arrow::Array>& values,
- const
std::shared_ptr<arrow::Array>& filter) {
- std::shared_ptr<arrow::Array> out;
+ const
std::shared_ptr<arrow::Array>& filter,
+ bool keep_na) {
arrow::compute::FunctionContext context;
- STOP_IF_NOT_OK(arrow::compute::Filter(&context, *values, *filter, &out));
- return out;
+ arrow::compute::Datum out;
+ // Use the EMIT_NULL filter option to match R's behavior in [
+ arrow::compute::FilterOptions options;
+ if (keep_na) {
+ options.null_selection_behavior = arrow::compute::FilterOptions::EMIT_NULL;
+ }
+ STOP_IF_NOT_OK(arrow::compute::Filter(&context, values, filter, {}, &out));
+ return out.make_array();
}
// [[arrow::export]]
std::shared_ptr<arrow::RecordBatch> RecordBatch__Filter(
const std::shared_ptr<arrow::RecordBatch>& batch,
- const std::shared_ptr<arrow::Array>& filter) {
- std::shared_ptr<arrow::RecordBatch> out;
+ const std::shared_ptr<arrow::Array>& filter, bool keep_na) {
arrow::compute::FunctionContext context;
- STOP_IF_NOT_OK(arrow::compute::Filter(&context, *batch, *filter, &out));
- return out;
+ arrow::compute::Datum out;
+ // Use the EMIT_NULL filter option to match R's behavior in [
+ arrow::compute::FilterOptions options;
+ if (keep_na) {
+ options.null_selection_behavior = arrow::compute::FilterOptions::EMIT_NULL;
+ }
+ STOP_IF_NOT_OK(arrow::compute::Filter(&context, batch, filter, options,
&out));
+ return out.record_batch();
}
// [[arrow::export]]
std::shared_ptr<arrow::ChunkedArray> ChunkedArray__Filter(
const std::shared_ptr<arrow::ChunkedArray>& values,
- const std::shared_ptr<arrow::Array>& filter) {
- std::shared_ptr<arrow::ChunkedArray> out;
+ const std::shared_ptr<arrow::Array>& filter, bool keep_na) {
arrow::compute::FunctionContext context;
- STOP_IF_NOT_OK(arrow::compute::Filter(&context, *values, *filter, &out));
- return out;
+ arrow::compute::Datum out;
+ // Use the EMIT_NULL filter option to match R's behavior in [
+ arrow::compute::FilterOptions options;
+ if (keep_na) {
+ options.null_selection_behavior = arrow::compute::FilterOptions::EMIT_NULL;
+ }
+ STOP_IF_NOT_OK(arrow::compute::Filter(&context, values, filter, options,
&out));
+ return out.chunked_array();
}
// [[arrow::export]]
std::shared_ptr<arrow::ChunkedArray> ChunkedArray__FilterChunked(
const std::shared_ptr<arrow::ChunkedArray>& values,
- const std::shared_ptr<arrow::ChunkedArray>& filter) {
- std::shared_ptr<arrow::ChunkedArray> out;
+ const std::shared_ptr<arrow::ChunkedArray>& filter, bool keep_na) {
arrow::compute::FunctionContext context;
- STOP_IF_NOT_OK(arrow::compute::Filter(&context, *values, *filter, &out));
- return out;
+ arrow::compute::Datum out;
+ // Use the EMIT_NULL filter option to match R's behavior in [
+ arrow::compute::FilterOptions options;
+ if (keep_na) {
+ options.null_selection_behavior = arrow::compute::FilterOptions::EMIT_NULL;
+ }
+ STOP_IF_NOT_OK(arrow::compute::Filter(&context, values, filter, options,
&out));
+ return out.chunked_array();
}
// [[arrow::export]]
std::shared_ptr<arrow::Table> Table__Filter(const
std::shared_ptr<arrow::Table>& table,
- const
std::shared_ptr<arrow::Array>& filter) {
- std::shared_ptr<arrow::Table> out;
+ const
std::shared_ptr<arrow::Array>& filter,
+ bool keep_na) {
arrow::compute::FunctionContext context;
- STOP_IF_NOT_OK(arrow::compute::Filter(&context, *table, *filter, &out));
- return out;
+ arrow::compute::Datum out;
+ // Use the EMIT_NULL filter option to match R's behavior in [
+ arrow::compute::FilterOptions options;
+ if (keep_na) {
+ options.null_selection_behavior = arrow::compute::FilterOptions::EMIT_NULL;
+ }
+ STOP_IF_NOT_OK(arrow::compute::Filter(&context, table, filter, options,
&out));
+ return out.table();
}
// [[arrow::export]]
std::shared_ptr<arrow::Table> Table__FilterChunked(
const std::shared_ptr<arrow::Table>& table,
- const std::shared_ptr<arrow::ChunkedArray>& filter) {
- std::shared_ptr<arrow::Table> out;
+ const std::shared_ptr<arrow::ChunkedArray>& filter, bool keep_na) {
arrow::compute::FunctionContext context;
- STOP_IF_NOT_OK(arrow::compute::Filter(&context, *table, *filter, &out));
- return out;
+ arrow::compute::Datum out;
+ // Use the EMIT_NULL filter option to match R's behavior in [
+ arrow::compute::FilterOptions options;
+ if (keep_na) {
+ options.null_selection_behavior = arrow::compute::FilterOptions::EMIT_NULL;
+ }
+ STOP_IF_NOT_OK(arrow::compute::Filter(&context, table, filter, options,
&out));
+ return out.table();
}
#endif
diff --git a/r/tests/testthat/test-dplyr.R b/r/tests/testthat/test-dplyr.R
index 2d3f6cc..e77bfa3 100644
--- a/r/tests/testthat/test-dplyr.R
+++ b/r/tests/testthat/test-dplyr.R
@@ -98,7 +98,7 @@ test_that("basic select/filter/collect", {
expect_identical(collect(batch), tbl)
})
-test_that("More complex select/filter", {
+test_that("filter() on is.na()", {
expect_dplyr_equal(
input %>%
filter(is.na(lgl)) %>%
@@ -108,19 +108,28 @@ test_that("More complex select/filter", {
)
})
-# ARROW-7360
-# test_that("filtering with expression", {
-# char_sym <- "b"
-# expect_dplyr_equal(
-# input %>%
-# filter(chr == char_sym) %>%
-# select(string = chr, int) %>%
-# collect(),
-# tbl
-# )
-# })
+test_that("filter() with NAs in selection", {
+ expect_dplyr_equal(
+ input %>%
+ filter(lgl) %>%
+ select(chr, int, lgl) %>%
+ collect(),
+ tbl
+ )
+})
-test_that("filter() on is.na()", {
+test_that("filtering with expression", {
+ char_sym <- "b"
+ expect_dplyr_equal(
+ input %>%
+ filter(chr == char_sym) %>%
+ select(string = chr, int) %>%
+ collect(),
+ tbl
+ )
+})
+
+test_that("More complex select/filter", {
expect_dplyr_equal(
input %>%
filter(dbl > 2, chr == "d" | chr == "f") %>%
diff --git a/ruby/red-arrow/lib/arrow/generic-filterable.rb
b/ruby/red-arrow/lib/arrow/generic-filterable.rb
index 4fd5c87..50a7914 100644
--- a/ruby/red-arrow/lib/arrow/generic-filterable.rb
+++ b/ruby/red-arrow/lib/arrow/generic-filterable.rb
@@ -24,19 +24,19 @@ module Arrow
end
end
- def filter_generic(filter)
+ def filter_generic(filter, options=nil)
case filter
when ::Array
- filter_raw(BooleanArray.new(filter))
+ filter_raw(BooleanArray.new(filter), options)
when ChunkedArray
if respond_to?(:filter_chunked_array)
- filter_chunked_array(filter)
+ filter_chunked_array(filter, options)
else
# TODO: Implement this in C++
- filter_raw(filter.pack)
+ filter_raw(filter.pack, options)
end
else
- filter_raw(filter)
+ filter_raw(filter, options)
end
end
end
diff --git a/ruby/red-arrow/lib/arrow/table.rb
b/ruby/red-arrow/lib/arrow/table.rb
index eb37586..fbf5d3c 100644
--- a/ruby/red-arrow/lib/arrow/table.rb
+++ b/ruby/red-arrow/lib/arrow/table.rb
@@ -304,6 +304,8 @@ module Arrow
end
end
+ filter_options = Arrow::FilterOptions.new
+ filter_options.null_selection_behavior = :emit_null
sliced_tables = []
slicers.each do |slicer|
slicer = slicer.evaluate if slicer.respond_to?(:evaluate)
@@ -325,7 +327,7 @@ module Arrow
to += n_rows if to < 0
sliced_tables << slice_by_range(from, to)
when ::Array, BooleanArray, ChunkedArray
- sliced_tables << filter(slicer)
+ sliced_tables << filter(slicer, filter_options)
else
message = "slicer must be Integer, Range, (from, to), " +
"Arrow::ChunkedArray of Arrow::BooleanArray, " +
diff --git a/ruby/red-arrow/test/test-array.rb
b/ruby/red-arrow/test/test-array.rb
index c7b2213..b2c9d5f 100644
--- a/ruby/red-arrow/test/test-array.rb
+++ b/ruby/red-arrow/test/test-array.rb
@@ -76,20 +76,22 @@ class ArrayTest < Test::Unit::TestCase
def setup
values = [true, false, false, true]
@array = Arrow::BooleanArray.new(values)
+ @options = Arrow::FilterOptions.new
+ @options.null_selection_behavior = :emit_null
end
test("Array: boolean") do
filter = [nil, true, true, false]
filtered_array = Arrow::BooleanArray.new([nil, false, false])
assert_equal(filtered_array,
- @array.filter(filter))
+ @array.filter(filter, @options))
end
test("Arrow::BooleanArray") do
filter = Arrow::BooleanArray.new([nil, true, true, false])
filtered_array = Arrow::BooleanArray.new([nil, false, false])
assert_equal(filtered_array,
- @array.filter(filter))
+ @array.filter(filter, @options))
end
test("Arrow::ChunkedArray") do
@@ -100,7 +102,7 @@ class ArrayTest < Test::Unit::TestCase
filter = Arrow::ChunkedArray.new(chunks)
filtered_array = Arrow::BooleanArray.new([nil, false, false])
assert_equal(filtered_array,
- @array.filter(filter))
+ @array.filter(filter, @options))
end
end
diff --git a/ruby/red-arrow/test/test-chunked-array.rb
b/ruby/red-arrow/test/test-chunked-array.rb
index b1e273c..3785e98 100644
--- a/ruby/red-arrow/test/test-chunked-array.rb
+++ b/ruby/red-arrow/test/test-chunked-array.rb
@@ -92,6 +92,8 @@ class ChunkedArrayTest < Test::Unit::TestCase
Arrow::BooleanArray.new([false, true, false]),
]
@chunked_array = Arrow::ChunkedArray.new(arrays)
+ @options = Arrow::FilterOptions.new
+ @options.null_selection_behavior = :emit_null
end
test("Array: boolean") do
@@ -102,7 +104,7 @@ class ChunkedArrayTest < Test::Unit::TestCase
]
filtered_chunked_array = Arrow::ChunkedArray.new(chunks)
assert_equal(filtered_chunked_array,
- @chunked_array.filter(filter))
+ @chunked_array.filter(filter, @options))
end
test("Arrow::BooleanArray") do
@@ -113,7 +115,7 @@ class ChunkedArrayTest < Test::Unit::TestCase
]
filtered_chunked_array = Arrow::ChunkedArray.new(chunks)
assert_equal(filtered_chunked_array,
- @chunked_array.filter(filter))
+ @chunked_array.filter(filter, @options))
end
test("Arrow::ChunkedArray") do
@@ -128,7 +130,7 @@ class ChunkedArrayTest < Test::Unit::TestCase
]
filtered_chunked_array = Arrow::ChunkedArray.new(filtered_chunks)
assert_equal(filtered_chunked_array,
- @chunked_array.filter(filter))
+ @chunked_array.filter(filter, @options))
end
end
diff --git a/ruby/red-arrow/test/test-table.rb
b/ruby/red-arrow/test/test-table.rb
index eebccdc..ef16625 100644
--- a/ruby/red-arrow/test/test-table.rb
+++ b/ruby/red-arrow/test/test-table.rb
@@ -698,9 +698,15 @@ visible: false
end
sub_test_case("#filter") do
+ def setup
+ super
+ @options = Arrow::FilterOptions.new
+ @options.null_selection_behavior = :emit_null
+ end
+
test("Array: boolean") do
filter = [nil, true, true, false, true, false, true, true]
- assert_equal(<<-TABLE, @table.filter(filter).to_s)
+ assert_equal(<<-TABLE, @table.filter(filter, @options).to_s)
count visible
0
1 2 false
@@ -714,7 +720,7 @@ visible: false
test("Arrow::BooleanArray") do
array = [nil, true, true, false, true, false, true, true]
filter = Arrow::BooleanArray.new(array)
- assert_equal(<<-TABLE, @table.filter(filter).to_s)
+ assert_equal(<<-TABLE, @table.filter(filter, @options).to_s)
count visible
0
1 2 false
@@ -732,7 +738,7 @@ visible: false
Arrow::BooleanArray.new([true, true]),
]
filter = Arrow::ChunkedArray.new(filter_chunks)
- assert_equal(<<-TABLE, @table.filter(filter).to_s)
+ assert_equal(<<-TABLE, @table.filter(filter, @options).to_s)
count visible
0
1 2 false