thisisnic commented on code in PR #14361:
URL: https://github.com/apache/arrow/pull/14361#discussion_r993708753


##########
r/R/dplyr-slice.R:
##########
@@ -0,0 +1,158 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+# The following S3 methods are registered on load if dplyr is present
+
+slice_head.arrow_dplyr_query <- function(.data, ..., n, prop) {
+  if (length(group_vars(.data)) > 0) {
+    arrow_not_supported("Slicing grouped data")
+  }
+  rlang::check_dots_empty()
+
+  if (missing(n)) {
+    n <- prop_to_n(.data, prop)
+  }
+
+  head(.data, n)
+}
+slice_head.Dataset <- slice_head.ArrowTabular <- slice_head.RecordBatchReader 
<- slice_head.arrow_dplyr_query
+
+slice_tail.arrow_dplyr_query <- function(.data, ..., n, prop) {
+  if (length(group_vars(.data)) > 0) {
+    arrow_not_supported("Slicing grouped data")
+  }
+  rlang::check_dots_empty()
+
+  if (missing(n)) {
+    n <- prop_to_n(.data, prop)
+  }
+
+  tail(.data, n)
+}
+slice_tail.Dataset <- slice_tail.ArrowTabular <- slice_tail.RecordBatchReader 
<- slice_tail.arrow_dplyr_query
+
+slice_min.arrow_dplyr_query <- function(.data, order_by, ..., n, prop, 
with_ties = TRUE) {
+  if (length(group_vars(.data)) > 0) {
+    arrow_not_supported("Slicing grouped data")
+  }
+  if (with_ties) {
+    arrow_not_supported("with_ties = TRUE")
+  }
+  rlang::check_dots_empty()
+
+  if (missing(n)) {
+    n <- prop_to_n(.data, prop)
+  }
+
+  head(dplyr::arrange(.data, {{ order_by }}), n)
+}
+slice_min.Dataset <- slice_min.ArrowTabular <- slice_min.RecordBatchReader <- 
slice_min.arrow_dplyr_query
+
+slice_max.arrow_dplyr_query <- function(.data, order_by, ..., n, prop, 
with_ties = TRUE) {
+  if (length(group_vars(.data)) > 0) {
+    arrow_not_supported("Slicing grouped data")
+  }
+  if (with_ties) {
+    arrow_not_supported("with_ties = TRUE")
+  }
+  rlang::check_dots_empty()
+
+  if (missing(n)) {
+    n <- prop_to_n(.data, prop)
+  }
+
+  sorted <- dplyr::arrange(.data, {{ order_by }})
+  # Invert the sort order of the things in ... so they're descending
+  # TODO: handle possibility that .data was already sorted and we don't want
+  # to invert those sorts? Does that matter? Or no because there's no promise
+  # of order of which TopK elements you get if there are ties?
+  sorted$arrange_desc <- !sorted$arrange_desc
+  head(sorted, n)
+}
+slice_max.Dataset <- slice_max.ArrowTabular <- slice_max.RecordBatchReader <- 
slice_max.arrow_dplyr_query
+
+slice_sample.arrow_dplyr_query <- function(.data,
+                                           ...,
+                                           n,
+                                           prop,
+                                           weight_by = NULL,
+                                           replace = FALSE) {
+  if (length(group_vars(.data)) > 0) {
+    arrow_not_supported("Slicing grouped data")
+  }
+  if (replace) {
+    arrow_not_supported("Sampling with replacement")
+  }
+  if (!missing(weight_by)) {
+    # You could do this by multiplying the random() column * weight_by
+    # but you'd need to calculate sum(weight_by) in order to normalize
+    arrow_not_supported("weight_by")
+  }
+  rlang::check_dots_empty()
+
+  # If we want n rows sampled, we have to convert n to prop, oversample some
+  # just to make sure we get enough, then head(n)
+  sampling_n <- missing(prop)
+  if (missing(prop)) {
+    prop <- min(n_to_prop(.data, n) + .05, 1)
+  }
+  validate_prop(prop)
+
+  if (prop < 1) {
+    .data <- as_adq(.data)
+    # TODO(ARROW-17974): expr <- Expression$create("random") < prop
+    # HACK: use our UDF to generate random. It needs an input column because
+    # nullary functions don't work, and that column has to be typed. We've
+    # chosen boolean() type because it's compact and can always be created:
+    # pick any column and do is.na, that will be boolean.
+    # TODO: get an actual FieldRef because the first col could be derived
+    ref <- Expression$create("is_null", .data$selected_columns[[1]])
+    expr <- Expression$create("_random_along", ref) < prop
+    .data <- set_filters(.data, expr)
+  }
+  if (sampling_n) {
+    .data <- head(.data, n)
+  }
+
+  .data
+}
+slice_sample.Dataset <- slice_sample.ArrowTabular <- 
slice_sample.RecordBatchReader <- slice_sample.arrow_dplyr_query
+
+
+prop_to_n <- function(.data, prop) {
+  nrows <- nrow(.data)
+  if (is.na(nrows)) {
+    arrow_not_supported("Slicing with `prop` when `nrow()` requires evaluating 
the query")
+  }
+  validate_prop(prop)
+  nrows * prop
+}
+
+validate_prop <- function(prop) {
+  if (!is.numeric(prop) || length(prop) != 1 || is.na(prop) || prop < 0 || 
prop > 1) {
+    stop("`prop` must be a single numeric value in [0, 1]", call. = FALSE)
+  }
+}
+
+n_to_prop <- function(.data, n) {
+  nrows <- nrow(.data)
+  if (is.na(nrows)) {
+    arrow_not_supported("slice_sample() with `n` when `nrow()` requires 
evaluating the query")

Review Comment:
   Out of interest, is this only the case of joins, or are there other 
circumstances in which we may not know `nrow()`?  Like, I guess, aggregations 
too?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to