thisisnic commented on code in PR #14361:
URL: https://github.com/apache/arrow/pull/14361#discussion_r993593900


##########
r/tests/testthat/test-dplyr-slice.R:
##########
@@ -0,0 +1,192 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+library(dplyr, warn.conflicts = FALSE)
+
+tbl <- example_data
+
+test_that("slice_head/tail, ungrouped", {
+  # head/tail are not deterministic in Arrow because data is unordered
+  # so we can't assert identical to dplyr, just assert right number of rows
+  tab <- arrow_table(tbl)
+  expect_equal(
+    tab %>%
+      slice_head(n = 5) %>%
+      nrow(),
+    5
+  )
+  expect_equal(
+    tab %>%
+      slice_tail(n = 5) %>%
+      nrow(),
+    5
+  )
+
+  expect_equal(
+    tab %>%
+      slice_head(prop = .25) %>%
+      nrow(),
+    2
+  )
+  expect_equal(
+    tab %>%
+      slice_tail(prop = .25) %>%
+      nrow(),
+    2
+  )
+})
+
+test_that("slice_min/max, ungrouped", {
+  # with_ties must be FALSE
+  tab <- arrow_table(tbl)
+  expect_error(
+    tab %>% slice_max(int, n = 5),
+    "with_ties = TRUE"
+  )
+  expect_error(
+    tab %>% slice_min(int, n = 5),
+    "with_ties = TRUE"
+  )
+  compare_dplyr_binding(
+    .input %>%
+      slice_max(int, n = 4, with_ties = FALSE) %>%
+      collect(),
+    tbl
+  )
+  compare_dplyr_binding(
+    .input %>%
+      slice_min(int, n = 4, with_ties = FALSE) %>%
+      collect(),
+    tbl
+  )
+
+  compare_dplyr_binding(
+    .input %>%
+      slice_max(int, prop = .25, with_ties = FALSE) %>%
+      collect(),
+    tbl
+  )
+  compare_dplyr_binding(
+    .input %>%
+      slice_min(int, prop = .25, with_ties = FALSE) %>%
+      collect(),
+    tbl
+  )
+})
+
+test_that("slice_sample, ungrouped", {
+  tab <- arrow_table(tbl)
+  expect_error(
+    tab %>% slice_sample(replace = TRUE),
+    "Sampling with replacement"
+  )
+  expect_error(
+    tab %>% slice_sample(weight_by = dbl),
+    "weight_by"
+  )
+
+  # Because this is random (and we only have 10 rows), try several times
+  for (i in 1:10) {
+    sampled_prop <- tab %>%
+      slice_sample(prop = .2) %>%
+      collect() %>%
+      nrow()
+    if (sampled_prop == 2) break
+  }
+  expect_equal(sampled_prop, 2)

Review Comment:
   I may be under-caffeinated today, but I'm super confused by this.  Wouldn't 
we always expect 2 rows if 0.2 * 10 is 2?  Can you add a bit more detail in the 
comment?



##########
r/R/dplyr-slice.R:
##########
@@ -0,0 +1,158 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+# The following S3 methods are registered on load if dplyr is present
+
+slice_head.arrow_dplyr_query <- function(.data, ..., n, prop) {
+  if (length(group_vars(.data)) > 0) {
+    arrow_not_supported("Slicing grouped data")
+  }
+  rlang::check_dots_empty()
+
+  if (missing(n)) {
+    n <- prop_to_n(.data, prop)
+  }
+
+  head(.data, n)
+}
+slice_head.Dataset <- slice_head.ArrowTabular <- slice_head.RecordBatchReader 
<- slice_head.arrow_dplyr_query
+
+slice_tail.arrow_dplyr_query <- function(.data, ..., n, prop) {
+  if (length(group_vars(.data)) > 0) {
+    arrow_not_supported("Slicing grouped data")
+  }
+  rlang::check_dots_empty()
+
+  if (missing(n)) {
+    n <- prop_to_n(.data, prop)
+  }
+
+  tail(.data, n)
+}
+slice_tail.Dataset <- slice_tail.ArrowTabular <- slice_tail.RecordBatchReader 
<- slice_tail.arrow_dplyr_query
+
+slice_min.arrow_dplyr_query <- function(.data, order_by, ..., n, prop, 
with_ties = TRUE) {
+  if (length(group_vars(.data)) > 0) {
+    arrow_not_supported("Slicing grouped data")
+  }
+  if (with_ties) {
+    arrow_not_supported("with_ties = TRUE")
+  }
+  rlang::check_dots_empty()
+
+  if (missing(n)) {
+    n <- prop_to_n(.data, prop)
+  }
+
+  head(dplyr::arrange(.data, {{ order_by }}), n)
+}
+slice_min.Dataset <- slice_min.ArrowTabular <- slice_min.RecordBatchReader <- 
slice_min.arrow_dplyr_query
+
+slice_max.arrow_dplyr_query <- function(.data, order_by, ..., n, prop, 
with_ties = TRUE) {
+  if (length(group_vars(.data)) > 0) {
+    arrow_not_supported("Slicing grouped data")
+  }
+  if (with_ties) {
+    arrow_not_supported("with_ties = TRUE")
+  }
+  rlang::check_dots_empty()
+
+  if (missing(n)) {
+    n <- prop_to_n(.data, prop)
+  }
+
+  sorted <- dplyr::arrange(.data, {{ order_by }})
+  # Invert the sort order of the things in ... so they're descending
+  # TODO: handle possibility that .data was already sorted and we don't want
+  # to invert those sorts? Does that matter? Or no because there's no promise
+  # of order of which TopK elements you get if there are ties?
+  sorted$arrange_desc <- !sorted$arrange_desc
+  head(sorted, n)
+}
+slice_max.Dataset <- slice_max.ArrowTabular <- slice_max.RecordBatchReader <- 
slice_max.arrow_dplyr_query
+
+slice_sample.arrow_dplyr_query <- function(.data,
+                                           ...,
+                                           n,
+                                           prop,
+                                           weight_by = NULL,
+                                           replace = FALSE) {
+  if (length(group_vars(.data)) > 0) {
+    arrow_not_supported("Slicing grouped data")
+  }
+  if (replace) {
+    arrow_not_supported("Sampling with replacement")
+  }
+  if (!missing(weight_by)) {
+    # You could do this by multiplying the random() column * weight_by
+    # but you'd need to calculate sum(weight_by) in order to normalize
+    arrow_not_supported("weight_by")
+  }
+  rlang::check_dots_empty()
+
+  # If we want n rows sampled, we have to convert n to prop, oversample some
+  # just to make sure we get enough, then head(n)
+  sampling_n <- missing(prop)
+  if (missing(prop)) {
+    prop <- min(n_to_prop(.data, n) + .05, 1)
+  }
+  validate_prop(prop)
+
+  if (prop < 1) {
+    .data <- as_adq(.data)
+    # TODO(ARROW-17974): expr <- Expression$create("random") < prop
+    # HACK: use our UDF to generate random. It needs an input column because
+    # nullary functions don't work, and that column has to be typed. We've
+    # chosen boolean() type because it's compact and can always be created:
+    # pick any column and do is.na, that will be boolean.
+    # TODO: get an actual FieldRef because the first col could be derived
+    ref <- Expression$create("is_null", .data$selected_columns[[1]])
+    expr <- Expression$create("_random_along", ref) < prop
+    .data <- set_filters(.data, expr)
+  }
+  if (sampling_n) {
+    .data <- head(.data, n)
+  }
+
+  .data
+}
+slice_sample.Dataset <- slice_sample.ArrowTabular <- 
slice_sample.RecordBatchReader <- slice_sample.arrow_dplyr_query
+
+
+prop_to_n <- function(.data, prop) {
+  nrows <- nrow(.data)
+  if (is.na(nrows)) {
+    arrow_not_supported("Slicing with `prop` when `nrow()` requires evaluating 
the query")
+  }
+  validate_prop(prop)
+  nrows * prop
+}
+
+validate_prop <- function(prop) {
+  if (!is.numeric(prop) || length(prop) != 1 || is.na(prop) || prop < 0 || 
prop > 1) {
+    stop("`prop` must be a single numeric value in [0, 1]", call. = FALSE)
+  }
+}
+
+n_to_prop <- function(.data, n) {
+  nrows <- nrow(.data)
+  if (is.na(nrows)) {
+    arrow_not_supported("slice_sample() with `n` when `nrow()` requires 
evaluating the query")

Review Comment:
   Can we rephrase this? It's not totally clear to me when `nrow()` requires 
evaluating the query, and I wonder if it'll be clear to a user why they're 
getting this error.  How about something like "when `nrow()` is used before 
`slice()`" (I know that's wrong, but...) or whatever the concrete external 
condition is that'll trigger this error.
   
   I guess we probably also want a test for whatever triggers this too?



##########
r/R/dplyr-slice.R:
##########
@@ -0,0 +1,158 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+# The following S3 methods are registered on load if dplyr is present
+
+slice_head.arrow_dplyr_query <- function(.data, ..., n, prop) {
+  if (length(group_vars(.data)) > 0) {
+    arrow_not_supported("Slicing grouped data")
+  }
+  rlang::check_dots_empty()
+
+  if (missing(n)) {
+    n <- prop_to_n(.data, prop)
+  }
+
+  head(.data, n)
+}
+slice_head.Dataset <- slice_head.ArrowTabular <- slice_head.RecordBatchReader 
<- slice_head.arrow_dplyr_query
+
+slice_tail.arrow_dplyr_query <- function(.data, ..., n, prop) {
+  if (length(group_vars(.data)) > 0) {
+    arrow_not_supported("Slicing grouped data")
+  }
+  rlang::check_dots_empty()
+
+  if (missing(n)) {
+    n <- prop_to_n(.data, prop)
+  }
+
+  tail(.data, n)
+}
+slice_tail.Dataset <- slice_tail.ArrowTabular <- slice_tail.RecordBatchReader 
<- slice_tail.arrow_dplyr_query
+
+slice_min.arrow_dplyr_query <- function(.data, order_by, ..., n, prop, 
with_ties = TRUE) {
+  if (length(group_vars(.data)) > 0) {
+    arrow_not_supported("Slicing grouped data")
+  }
+  if (with_ties) {
+    arrow_not_supported("with_ties = TRUE")
+  }
+  rlang::check_dots_empty()
+
+  if (missing(n)) {
+    n <- prop_to_n(.data, prop)
+  }
+
+  head(dplyr::arrange(.data, {{ order_by }}), n)
+}
+slice_min.Dataset <- slice_min.ArrowTabular <- slice_min.RecordBatchReader <- 
slice_min.arrow_dplyr_query
+
+slice_max.arrow_dplyr_query <- function(.data, order_by, ..., n, prop, 
with_ties = TRUE) {
+  if (length(group_vars(.data)) > 0) {
+    arrow_not_supported("Slicing grouped data")
+  }
+  if (with_ties) {
+    arrow_not_supported("with_ties = TRUE")
+  }
+  rlang::check_dots_empty()
+
+  if (missing(n)) {
+    n <- prop_to_n(.data, prop)
+  }
+
+  sorted <- dplyr::arrange(.data, {{ order_by }})
+  # Invert the sort order of the things in ... so they're descending
+  # TODO: handle possibility that .data was already sorted and we don't want
+  # to invert those sorts? Does that matter? Or no because there's no promise
+  # of order of which TopK elements you get if there are ties?
+  sorted$arrange_desc <- !sorted$arrange_desc
+  head(sorted, n)
+}
+slice_max.Dataset <- slice_max.ArrowTabular <- slice_max.RecordBatchReader <- 
slice_max.arrow_dplyr_query
+
+slice_sample.arrow_dplyr_query <- function(.data,
+                                           ...,
+                                           n,
+                                           prop,
+                                           weight_by = NULL,
+                                           replace = FALSE) {
+  if (length(group_vars(.data)) > 0) {
+    arrow_not_supported("Slicing grouped data")
+  }
+  if (replace) {
+    arrow_not_supported("Sampling with replacement")
+  }
+  if (!missing(weight_by)) {
+    # You could do this by multiplying the random() column * weight_by
+    # but you'd need to calculate sum(weight_by) in order to normalize
+    arrow_not_supported("weight_by")
+  }
+  rlang::check_dots_empty()
+
+  # If we want n rows sampled, we have to convert n to prop, oversample some
+  # just to make sure we get enough, then head(n)
+  sampling_n <- missing(prop)
+  if (missing(prop)) {
+    prop <- min(n_to_prop(.data, n) + .05, 1)
+  }
+  validate_prop(prop)
+
+  if (prop < 1) {
+    .data <- as_adq(.data)
+    # TODO(ARROW-17974): expr <- Expression$create("random") < prop
+    # HACK: use our UDF to generate random. It needs an input column because
+    # nullary functions don't work, and that column has to be typed. We've
+    # chosen boolean() type because it's compact and can always be created:
+    # pick any column and do is.na, that will be boolean.
+    # TODO: get an actual FieldRef because the first col could be derived
+    ref <- Expression$create("is_null", .data$selected_columns[[1]])
+    expr <- Expression$create("_random_along", ref) < prop
+    .data <- set_filters(.data, expr)
+  }
+  if (sampling_n) {
+    .data <- head(.data, n)
+  }
+
+  .data
+}
+slice_sample.Dataset <- slice_sample.ArrowTabular <- 
slice_sample.RecordBatchReader <- slice_sample.arrow_dplyr_query
+
+
+prop_to_n <- function(.data, prop) {
+  nrows <- nrow(.data)
+  if (is.na(nrows)) {
+    arrow_not_supported("Slicing with `prop` when `nrow()` requires evaluating 
the query")
+  }
+  validate_prop(prop)
+  nrows * prop
+}
+
+validate_prop <- function(prop) {
+  if (!is.numeric(prop) || length(prop) != 1 || is.na(prop) || prop < 0 || 
prop > 1) {
+    stop("`prop` must be a single numeric value in [0, 1]", call. = FALSE)

Review Comment:
   ```suggestion
       stop("`prop` must be a single numeric value between 0 and 1", call. = 
FALSE)
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to