nealrichardson commented on code in PR #13397: URL: https://github.com/apache/arrow/pull/13397#discussion_r905374682
########## r/tests/testthat/test-compute.R: ########## @@ -0,0 +1,107 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +test_that("list_compute_functions() works", { + expect_type(list_compute_functions(), "character") + expect_true(all(!grepl("^hash_", list_compute_functions()))) +}) + + +test_that("arrow_base_scalar_function() works", { + # check in/out type as schema/data type + fun <- arrow_base_scalar_function(schema(.y = int32()), int64(), function(x, y) y[[1]]) Review Comment: The function has 2 args but the schema you provided only has 1, shouldn't this error? ########## r/R/compute.R: ########## @@ -307,3 +307,158 @@ cast_options <- function(safe = TRUE, ...) { ) modifyList(opts, list(...)) } + +#' Register user-defined functions +#' +#' These functions support calling R code from query engine execution +#' (i.e., a [dplyr::mutate()] or [dplyr::filter()] on a [Table] or [Dataset]). +#' Use [arrow_scalar_function()] to define an R function that accepts and +#' returns R objects; use [arrow_base_scalar_function()] to define a +#' lower-level function that operates directly on Arrow objects. +#' +#' @param name The function name to be used in the dplyr bindings +#' @param scalar_function An object created with [arrow_scalar_function()] +#' or [arrow_base_scalar_function()]. +#' @param registry_name The function name to be used in the Arrow C++ +#' compute function registry. This may be different from `name`. +#' @param in_type A [DataType] of the input type or a [schema()] +#' for functions with more than one argument. This signature will be used +#' to determine if this function is appropriate for a given set of arguments. +#' If this function is appropriate for more than one signature, pass a +#' `list()` of the above. +#' @param out_type A [DataType] of the output type or a function accepting +#' a single argument (`types`), which is a `list()` of [DataType]s. If a +#' function it must return a [DataType]. +#' @param fun An R function or rlang-style lambda expression. This function +#' will be called with R objects as arguments and must return an object +#' that can be converted to an [Array] using [as_arrow_array()]. Function +#' authors must take care to return an array castable to the output data +#' type specified by `out_type`. +#' @param base_fun An R function or rlang-style lambda expression. This +#' function will be called with exactly two arguments: `kernel_context`, +#' which is a `list()` of objects giving information about the +#' execution context and `args`, which is a list of [Array] or [Scalar] +#' objects corresponding to the input arguments. +#' +#' @return +#' - `register_scalar_function()`: `NULL`, invisibly +#' - `arrow_scalar_function()`: returns an object of class +#' "arrow_base_scalar_function" that can be passed to +#' `register_scalar_function()`. +#' @export +#' +#' @examplesIf .Machine$sizeof.pointer >= 8 +#' fun_wrapper <- arrow_scalar_function( +#' schema(x = float64(), y = float64(), z = float64()), +#' float64(), +#' function(x, y, z) x + y + z +#' ) +#' register_scalar_function("example_add3", fun_wrapper) +#' +#' call_function( +#' "example_add3", +#' Scalar$create(1), +#' Scalar$create(2), +#' Array$create(3) +#' ) +#' +#' # use arrow_base_scalar_function() for a lower-level interface +#' base_fun_wrapper <- arrow_base_scalar_function( +#' schema(x = float64(), y = float64(), z = float64()), +#' float64(), +#' function(context, args) { +#' args[[1]] + args[[2]] + args[[3]] +#' } +#' ) +#' register_scalar_function("example_add3", base_fun_wrapper) +#' +#' call_function( +#' "example_add3", +#' Scalar$create(1), +#' Scalar$create(2), +#' Array$create(3) +#' ) +#' +register_scalar_function <- function(name, scalar_function, registry_name = name) { + assert_that( + is.string(name), + is.string(registry_name), + inherits(scalar_function, "arrow_base_scalar_function") + ) + + # register with Arrow C++ + RegisterScalarUDF(registry_name, scalar_function) + + # register with dplyr bindings + register_binding( + name, + function(...) build_expr(registry_name, ...) + ) + + # recreate dplyr binding cache + create_binding_cache() + + invisible(NULL) +} + +#' @rdname register_scalar_function +#' @export +arrow_scalar_function <- function(in_type, out_type, fun) { + fun <- rlang::as_function(fun) + base_fun <- function(context, args) { + args <- lapply(args, as.vector) + result <- do.call(fun, args) + as_arrow_array(result, type = context$output_type) + } + + arrow_base_scalar_function(in_type, out_type, base_fun) +} + +#' @rdname register_scalar_function +#' @export +arrow_base_scalar_function <- function(in_type, out_type, base_fun) { Review Comment: What does "base" mean here? I'm not sure this is the most evocative name ########## r/R/compute.R: ########## @@ -307,3 +307,158 @@ cast_options <- function(safe = TRUE, ...) { ) modifyList(opts, list(...)) } + +#' Register user-defined functions +#' +#' These functions support calling R code from query engine execution +#' (i.e., a [dplyr::mutate()] or [dplyr::filter()] on a [Table] or [Dataset]). +#' Use [arrow_scalar_function()] to define an R function that accepts and +#' returns R objects; use [arrow_base_scalar_function()] to define a +#' lower-level function that operates directly on Arrow objects. +#' +#' @param name The function name to be used in the dplyr bindings +#' @param scalar_function An object created with [arrow_scalar_function()] +#' or [arrow_base_scalar_function()]. +#' @param registry_name The function name to be used in the Arrow C++ +#' compute function registry. This may be different from `name`. +#' @param in_type A [DataType] of the input type or a [schema()] +#' for functions with more than one argument. This signature will be used +#' to determine if this function is appropriate for a given set of arguments. +#' If this function is appropriate for more than one signature, pass a +#' `list()` of the above. +#' @param out_type A [DataType] of the output type or a function accepting +#' a single argument (`types`), which is a `list()` of [DataType]s. If a +#' function it must return a [DataType]. +#' @param fun An R function or rlang-style lambda expression. This function +#' will be called with R objects as arguments and must return an object +#' that can be converted to an [Array] using [as_arrow_array()]. Function +#' authors must take care to return an array castable to the output data +#' type specified by `out_type`. +#' @param base_fun An R function or rlang-style lambda expression. This +#' function will be called with exactly two arguments: `kernel_context`, +#' which is a `list()` of objects giving information about the +#' execution context and `args`, which is a list of [Array] or [Scalar] +#' objects corresponding to the input arguments. +#' +#' @return +#' - `register_scalar_function()`: `NULL`, invisibly +#' - `arrow_scalar_function()`: returns an object of class +#' "arrow_base_scalar_function" that can be passed to +#' `register_scalar_function()`. +#' @export +#' +#' @examplesIf .Machine$sizeof.pointer >= 8 +#' fun_wrapper <- arrow_scalar_function( +#' schema(x = float64(), y = float64(), z = float64()), +#' float64(), +#' function(x, y, z) x + y + z +#' ) +#' register_scalar_function("example_add3", fun_wrapper) +#' +#' call_function( +#' "example_add3", +#' Scalar$create(1), +#' Scalar$create(2), +#' Array$create(3) +#' ) +#' +#' # use arrow_base_scalar_function() for a lower-level interface +#' base_fun_wrapper <- arrow_base_scalar_function( +#' schema(x = float64(), y = float64(), z = float64()), +#' float64(), +#' function(context, args) { +#' args[[1]] + args[[2]] + args[[3]] +#' } +#' ) +#' register_scalar_function("example_add3", base_fun_wrapper) +#' +#' call_function( +#' "example_add3", +#' Scalar$create(1), +#' Scalar$create(2), +#' Array$create(3) +#' ) +#' +register_scalar_function <- function(name, scalar_function, registry_name = name) { + assert_that( + is.string(name), + is.string(registry_name), + inherits(scalar_function, "arrow_base_scalar_function") + ) + + # register with Arrow C++ + RegisterScalarUDF(registry_name, scalar_function) + + # register with dplyr bindings + register_binding( + name, + function(...) build_expr(registry_name, ...) + ) + + # recreate dplyr binding cache Review Comment: Why do you have to recreate it? You just registered it. ########## r/R/compute.R: ########## @@ -307,3 +307,158 @@ cast_options <- function(safe = TRUE, ...) { ) modifyList(opts, list(...)) } + +#' Register user-defined functions +#' +#' These functions support calling R code from query engine execution +#' (i.e., a [dplyr::mutate()] or [dplyr::filter()] on a [Table] or [Dataset]). +#' Use [arrow_scalar_function()] to define an R function that accepts and +#' returns R objects; use [arrow_base_scalar_function()] to define a +#' lower-level function that operates directly on Arrow objects. +#' +#' @param name The function name to be used in the dplyr bindings +#' @param scalar_function An object created with [arrow_scalar_function()] +#' or [arrow_base_scalar_function()]. +#' @param registry_name The function name to be used in the Arrow C++ +#' compute function registry. This may be different from `name`. +#' @param in_type A [DataType] of the input type or a [schema()] +#' for functions with more than one argument. This signature will be used +#' to determine if this function is appropriate for a given set of arguments. +#' If this function is appropriate for more than one signature, pass a +#' `list()` of the above. +#' @param out_type A [DataType] of the output type or a function accepting +#' a single argument (`types`), which is a `list()` of [DataType]s. If a +#' function it must return a [DataType]. +#' @param fun An R function or rlang-style lambda expression. This function +#' will be called with R objects as arguments and must return an object +#' that can be converted to an [Array] using [as_arrow_array()]. Function +#' authors must take care to return an array castable to the output data +#' type specified by `out_type`. +#' @param base_fun An R function or rlang-style lambda expression. This +#' function will be called with exactly two arguments: `kernel_context`, +#' which is a `list()` of objects giving information about the +#' execution context and `args`, which is a list of [Array] or [Scalar] +#' objects corresponding to the input arguments. +#' +#' @return +#' - `register_scalar_function()`: `NULL`, invisibly +#' - `arrow_scalar_function()`: returns an object of class +#' "arrow_base_scalar_function" that can be passed to +#' `register_scalar_function()`. +#' @export +#' +#' @examplesIf .Machine$sizeof.pointer >= 8 +#' fun_wrapper <- arrow_scalar_function( +#' schema(x = float64(), y = float64(), z = float64()), +#' float64(), +#' function(x, y, z) x + y + z +#' ) +#' register_scalar_function("example_add3", fun_wrapper) +#' +#' call_function( +#' "example_add3", +#' Scalar$create(1), +#' Scalar$create(2), +#' Array$create(3) +#' ) +#' +#' # use arrow_base_scalar_function() for a lower-level interface +#' base_fun_wrapper <- arrow_base_scalar_function( +#' schema(x = float64(), y = float64(), z = float64()), +#' float64(), +#' function(context, args) { +#' args[[1]] + args[[2]] + args[[3]] +#' } +#' ) +#' register_scalar_function("example_add3", base_fun_wrapper) +#' +#' call_function( +#' "example_add3", +#' Scalar$create(1), +#' Scalar$create(2), +#' Array$create(3) +#' ) +#' +register_scalar_function <- function(name, scalar_function, registry_name = name) { Review Comment: What's the use case where name and registry_name should be different? ########## r/R/query-engine.R: ########## @@ -190,7 +190,7 @@ ExecPlan <- R6Class("ExecPlan", } node }, - Run = function(node) { + Run = function(node, as_table = FALSE) { Review Comment: Why do we need this argument? You can always consume a RBR into a Table. ########## r/R/compute.R: ########## @@ -307,3 +307,158 @@ cast_options <- function(safe = TRUE, ...) { ) modifyList(opts, list(...)) } + +#' Register user-defined functions +#' +#' These functions support calling R code from query engine execution +#' (i.e., a [dplyr::mutate()] or [dplyr::filter()] on a [Table] or [Dataset]). +#' Use [arrow_scalar_function()] to define an R function that accepts and +#' returns R objects; use [arrow_base_scalar_function()] to define a +#' lower-level function that operates directly on Arrow objects. +#' +#' @param name The function name to be used in the dplyr bindings +#' @param scalar_function An object created with [arrow_scalar_function()] +#' or [arrow_base_scalar_function()]. +#' @param registry_name The function name to be used in the Arrow C++ +#' compute function registry. This may be different from `name`. +#' @param in_type A [DataType] of the input type or a [schema()] +#' for functions with more than one argument. This signature will be used +#' to determine if this function is appropriate for a given set of arguments. +#' If this function is appropriate for more than one signature, pass a +#' `list()` of the above. +#' @param out_type A [DataType] of the output type or a function accepting +#' a single argument (`types`), which is a `list()` of [DataType]s. If a +#' function it must return a [DataType]. +#' @param fun An R function or rlang-style lambda expression. This function +#' will be called with R objects as arguments and must return an object +#' that can be converted to an [Array] using [as_arrow_array()]. Function +#' authors must take care to return an array castable to the output data +#' type specified by `out_type`. +#' @param base_fun An R function or rlang-style lambda expression. This +#' function will be called with exactly two arguments: `kernel_context`, +#' which is a `list()` of objects giving information about the +#' execution context and `args`, which is a list of [Array] or [Scalar] +#' objects corresponding to the input arguments. +#' +#' @return +#' - `register_scalar_function()`: `NULL`, invisibly +#' - `arrow_scalar_function()`: returns an object of class +#' "arrow_base_scalar_function" that can be passed to +#' `register_scalar_function()`. +#' @export +#' +#' @examplesIf .Machine$sizeof.pointer >= 8 +#' fun_wrapper <- arrow_scalar_function( +#' schema(x = float64(), y = float64(), z = float64()), +#' float64(), +#' function(x, y, z) x + y + z +#' ) +#' register_scalar_function("example_add3", fun_wrapper) +#' +#' call_function( +#' "example_add3", +#' Scalar$create(1), +#' Scalar$create(2), +#' Array$create(3) +#' ) +#' +#' # use arrow_base_scalar_function() for a lower-level interface +#' base_fun_wrapper <- arrow_base_scalar_function( +#' schema(x = float64(), y = float64(), z = float64()), +#' float64(), +#' function(context, args) { +#' args[[1]] + args[[2]] + args[[3]] +#' } +#' ) +#' register_scalar_function("example_add3", base_fun_wrapper) +#' +#' call_function( +#' "example_add3", +#' Scalar$create(1), +#' Scalar$create(2), +#' Array$create(3) +#' ) +#' +register_scalar_function <- function(name, scalar_function, registry_name = name) { + assert_that( + is.string(name), + is.string(registry_name), + inherits(scalar_function, "arrow_base_scalar_function") + ) + + # register with Arrow C++ + RegisterScalarUDF(registry_name, scalar_function) + + # register with dplyr bindings + register_binding( + name, + function(...) build_expr(registry_name, ...) + ) + + # recreate dplyr binding cache + create_binding_cache() + + invisible(NULL) +} + +#' @rdname register_scalar_function +#' @export +arrow_scalar_function <- function(in_type, out_type, fun) { + fun <- rlang::as_function(fun) Review Comment: nit: why `rlang::` here? We generally importFrom the namespace and don't use `::` ########## r/R/compute.R: ########## @@ -307,3 +307,158 @@ cast_options <- function(safe = TRUE, ...) { ) modifyList(opts, list(...)) } + +#' Register user-defined functions +#' +#' These functions support calling R code from query engine execution +#' (i.e., a [dplyr::mutate()] or [dplyr::filter()] on a [Table] or [Dataset]). +#' Use [arrow_scalar_function()] to define an R function that accepts and +#' returns R objects; use [arrow_base_scalar_function()] to define a +#' lower-level function that operates directly on Arrow objects. +#' +#' @param name The function name to be used in the dplyr bindings +#' @param scalar_function An object created with [arrow_scalar_function()] +#' or [arrow_base_scalar_function()]. +#' @param registry_name The function name to be used in the Arrow C++ +#' compute function registry. This may be different from `name`. +#' @param in_type A [DataType] of the input type or a [schema()] +#' for functions with more than one argument. This signature will be used +#' to determine if this function is appropriate for a given set of arguments. +#' If this function is appropriate for more than one signature, pass a +#' `list()` of the above. +#' @param out_type A [DataType] of the output type or a function accepting +#' a single argument (`types`), which is a `list()` of [DataType]s. If a +#' function it must return a [DataType]. +#' @param fun An R function or rlang-style lambda expression. This function +#' will be called with R objects as arguments and must return an object +#' that can be converted to an [Array] using [as_arrow_array()]. Function +#' authors must take care to return an array castable to the output data +#' type specified by `out_type`. +#' @param base_fun An R function or rlang-style lambda expression. This +#' function will be called with exactly two arguments: `kernel_context`, +#' which is a `list()` of objects giving information about the +#' execution context and `args`, which is a list of [Array] or [Scalar] +#' objects corresponding to the input arguments. +#' +#' @return +#' - `register_scalar_function()`: `NULL`, invisibly +#' - `arrow_scalar_function()`: returns an object of class +#' "arrow_base_scalar_function" that can be passed to +#' `register_scalar_function()`. +#' @export +#' +#' @examplesIf .Machine$sizeof.pointer >= 8 +#' fun_wrapper <- arrow_scalar_function( +#' schema(x = float64(), y = float64(), z = float64()), +#' float64(), +#' function(x, y, z) x + y + z +#' ) +#' register_scalar_function("example_add3", fun_wrapper) +#' +#' call_function( +#' "example_add3", +#' Scalar$create(1), +#' Scalar$create(2), +#' Array$create(3) +#' ) +#' +#' # use arrow_base_scalar_function() for a lower-level interface +#' base_fun_wrapper <- arrow_base_scalar_function( +#' schema(x = float64(), y = float64(), z = float64()), +#' float64(), +#' function(context, args) { +#' args[[1]] + args[[2]] + args[[3]] +#' } +#' ) +#' register_scalar_function("example_add3", base_fun_wrapper) +#' +#' call_function( +#' "example_add3", +#' Scalar$create(1), +#' Scalar$create(2), +#' Array$create(3) +#' ) +#' +register_scalar_function <- function(name, scalar_function, registry_name = name) { + assert_that( + is.string(name), + is.string(registry_name), + inherits(scalar_function, "arrow_base_scalar_function") + ) + + # register with Arrow C++ + RegisterScalarUDF(registry_name, scalar_function) + + # register with dplyr bindings + register_binding( + name, + function(...) build_expr(registry_name, ...) + ) + + # recreate dplyr binding cache + create_binding_cache() + + invisible(NULL) +} + +#' @rdname register_scalar_function +#' @export +arrow_scalar_function <- function(in_type, out_type, fun) { + fun <- rlang::as_function(fun) + base_fun <- function(context, args) { + args <- lapply(args, as.vector) Review Comment: Could you annotate these functions with some code comments where it's not obvious what's happening? It took me a while to figure out why you were calling as.vector on the args. ########## r/tests/testthat/test-compute.R: ########## @@ -0,0 +1,107 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +test_that("list_compute_functions() works", { + expect_type(list_compute_functions(), "character") + expect_true(all(!grepl("^hash_", list_compute_functions()))) +}) + + +test_that("arrow_base_scalar_function() works", { + # check in/out type as schema/data type + fun <- arrow_base_scalar_function(schema(.y = int32()), int64(), function(x, y) y[[1]]) + expect_equal(attr(fun, "in_type")[[1]], schema(.y = int32())) + expect_equal(attr(fun, "out_type")[[1]](), int64()) + + # check in/out type as data type/data type + fun <- arrow_base_scalar_function(int32(), int64(), function(x, y) y[[1]]) + expect_equal(attr(fun, "in_type")[[1]][[1]], field("", int32())) + expect_equal(attr(fun, "out_type")[[1]](), int64()) + + # check in/out type as field/data type + fun <- arrow_base_scalar_function( + field("a_name", int32()), + int64(), + function(x, y) y[[1]] + ) + expect_equal(attr(fun, "in_type")[[1]], schema(a_name = int32())) + expect_equal(attr(fun, "out_type")[[1]](), int64()) + + # check in/out type as lists + fun <- arrow_base_scalar_function( + list(int32(), int64()), + list(int64(), int32()), + function(x, y) y[[1]] + ) + + expect_equal(attr(fun, "in_type")[[1]][[1]], field("", int32())) + expect_equal(attr(fun, "in_type")[[2]][[1]], field("", int64())) + expect_equal(attr(fun, "out_type")[[1]](), int64()) + expect_equal(attr(fun, "out_type")[[2]](), int32()) + + expect_snapshot_error(arrow_base_scalar_function(int32(), int32(), identity)) + expect_snapshot_error(arrow_base_scalar_function(int32(), int32(), NULL)) +}) + +test_that("arrow_scalar_function() returns a base scalar function", { + base_fun <- arrow_scalar_function( + list(float64(), float64()), + float64(), + function(x, y) { + x + y Review Comment: Can we have a test function be something that is clearly not in Arrow? ########## r/R/compute.R: ########## @@ -307,3 +307,158 @@ cast_options <- function(safe = TRUE, ...) { ) modifyList(opts, list(...)) } + +#' Register user-defined functions +#' +#' These functions support calling R code from query engine execution +#' (i.e., a [dplyr::mutate()] or [dplyr::filter()] on a [Table] or [Dataset]). +#' Use [arrow_scalar_function()] to define an R function that accepts and +#' returns R objects; use [arrow_base_scalar_function()] to define a +#' lower-level function that operates directly on Arrow objects. +#' +#' @param name The function name to be used in the dplyr bindings +#' @param scalar_function An object created with [arrow_scalar_function()] +#' or [arrow_base_scalar_function()]. +#' @param registry_name The function name to be used in the Arrow C++ +#' compute function registry. This may be different from `name`. +#' @param in_type A [DataType] of the input type or a [schema()] +#' for functions with more than one argument. This signature will be used +#' to determine if this function is appropriate for a given set of arguments. +#' If this function is appropriate for more than one signature, pass a +#' `list()` of the above. +#' @param out_type A [DataType] of the output type or a function accepting +#' a single argument (`types`), which is a `list()` of [DataType]s. If a +#' function it must return a [DataType]. +#' @param fun An R function or rlang-style lambda expression. This function +#' will be called with R objects as arguments and must return an object +#' that can be converted to an [Array] using [as_arrow_array()]. Function +#' authors must take care to return an array castable to the output data +#' type specified by `out_type`. +#' @param base_fun An R function or rlang-style lambda expression. This +#' function will be called with exactly two arguments: `kernel_context`, +#' which is a `list()` of objects giving information about the +#' execution context and `args`, which is a list of [Array] or [Scalar] +#' objects corresponding to the input arguments. +#' +#' @return +#' - `register_scalar_function()`: `NULL`, invisibly +#' - `arrow_scalar_function()`: returns an object of class +#' "arrow_base_scalar_function" that can be passed to +#' `register_scalar_function()`. +#' @export +#' +#' @examplesIf .Machine$sizeof.pointer >= 8 +#' fun_wrapper <- arrow_scalar_function( +#' schema(x = float64(), y = float64(), z = float64()), +#' float64(), +#' function(x, y, z) x + y + z +#' ) +#' register_scalar_function("example_add3", fun_wrapper) +#' +#' call_function( +#' "example_add3", +#' Scalar$create(1), +#' Scalar$create(2), +#' Array$create(3) +#' ) +#' +#' # use arrow_base_scalar_function() for a lower-level interface +#' base_fun_wrapper <- arrow_base_scalar_function( +#' schema(x = float64(), y = float64(), z = float64()), +#' float64(), +#' function(context, args) { +#' args[[1]] + args[[2]] + args[[3]] +#' } +#' ) +#' register_scalar_function("example_add3", base_fun_wrapper) +#' +#' call_function( +#' "example_add3", +#' Scalar$create(1), +#' Scalar$create(2), +#' Array$create(3) +#' ) +#' +register_scalar_function <- function(name, scalar_function, registry_name = name) { + assert_that( + is.string(name), + is.string(registry_name), + inherits(scalar_function, "arrow_base_scalar_function") + ) + + # register with Arrow C++ Review Comment: Can you note why you have to register twice? ########## r/tests/testthat/test-compute.R: ########## @@ -0,0 +1,107 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +test_that("list_compute_functions() works", { + expect_type(list_compute_functions(), "character") + expect_true(all(!grepl("^hash_", list_compute_functions()))) +}) + + +test_that("arrow_base_scalar_function() works", { + # check in/out type as schema/data type + fun <- arrow_base_scalar_function(schema(.y = int32()), int64(), function(x, y) y[[1]]) + expect_equal(attr(fun, "in_type")[[1]], schema(.y = int32())) + expect_equal(attr(fun, "out_type")[[1]](), int64()) + + # check in/out type as data type/data type + fun <- arrow_base_scalar_function(int32(), int64(), function(x, y) y[[1]]) + expect_equal(attr(fun, "in_type")[[1]][[1]], field("", int32())) + expect_equal(attr(fun, "out_type")[[1]](), int64()) + + # check in/out type as field/data type + fun <- arrow_base_scalar_function( + field("a_name", int32()), + int64(), + function(x, y) y[[1]] + ) + expect_equal(attr(fun, "in_type")[[1]], schema(a_name = int32())) + expect_equal(attr(fun, "out_type")[[1]](), int64()) + + # check in/out type as lists + fun <- arrow_base_scalar_function( + list(int32(), int64()), + list(int64(), int32()), + function(x, y) y[[1]] + ) + + expect_equal(attr(fun, "in_type")[[1]][[1]], field("", int32())) + expect_equal(attr(fun, "in_type")[[2]][[1]], field("", int64())) + expect_equal(attr(fun, "out_type")[[1]](), int64()) + expect_equal(attr(fun, "out_type")[[2]](), int32()) + + expect_snapshot_error(arrow_base_scalar_function(int32(), int32(), identity)) + expect_snapshot_error(arrow_base_scalar_function(int32(), int32(), NULL)) +}) + +test_that("arrow_scalar_function() returns a base scalar function", { + base_fun <- arrow_scalar_function( + list(float64(), float64()), + float64(), + function(x, y) { + x + y + } + ) + + expect_s3_class(base_fun, "arrow_base_scalar_function") + expect_equal( + base_fun(list(), list(Scalar$create(2), Array$create(3))), Review Comment: Why is the first argument to `base_fun` list()`? Would a more interesting test be to have an array with more than one element, to demonstrate the vectorization? ########## r/tests/testthat/test-compute.R: ########## @@ -0,0 +1,107 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +test_that("list_compute_functions() works", { + expect_type(list_compute_functions(), "character") + expect_true(all(!grepl("^hash_", list_compute_functions()))) +}) + + +test_that("arrow_base_scalar_function() works", { + # check in/out type as schema/data type + fun <- arrow_base_scalar_function(schema(.y = int32()), int64(), function(x, y) y[[1]]) + expect_equal(attr(fun, "in_type")[[1]], schema(.y = int32())) + expect_equal(attr(fun, "out_type")[[1]](), int64()) + + # check in/out type as data type/data type + fun <- arrow_base_scalar_function(int32(), int64(), function(x, y) y[[1]]) + expect_equal(attr(fun, "in_type")[[1]][[1]], field("", int32())) + expect_equal(attr(fun, "out_type")[[1]](), int64()) + + # check in/out type as field/data type + fun <- arrow_base_scalar_function( + field("a_name", int32()), + int64(), + function(x, y) y[[1]] + ) + expect_equal(attr(fun, "in_type")[[1]], schema(a_name = int32())) + expect_equal(attr(fun, "out_type")[[1]](), int64()) + + # check in/out type as lists + fun <- arrow_base_scalar_function( + list(int32(), int64()), + list(int64(), int32()), + function(x, y) y[[1]] + ) + + expect_equal(attr(fun, "in_type")[[1]][[1]], field("", int32())) + expect_equal(attr(fun, "in_type")[[2]][[1]], field("", int64())) + expect_equal(attr(fun, "out_type")[[1]](), int64()) + expect_equal(attr(fun, "out_type")[[2]](), int32()) + + expect_snapshot_error(arrow_base_scalar_function(int32(), int32(), identity)) + expect_snapshot_error(arrow_base_scalar_function(int32(), int32(), NULL)) +}) + +test_that("arrow_scalar_function() returns a base scalar function", { + base_fun <- arrow_scalar_function( + list(float64(), float64()), + float64(), + function(x, y) { + x + y + } + ) + + expect_s3_class(base_fun, "arrow_base_scalar_function") + expect_equal( + base_fun(list(), list(Scalar$create(2), Array$create(3))), + Array$create(5) + ) +}) + +test_that("register_scalar_function() adds a compute function to the registry", { + skip_if_not_available("dataset") + + fun <- arrow_base_scalar_function( + int32(), int64(), + function(context, args) { + args[[1]] + 1L + } + ) + + register_scalar_function("my_test_scalar_function", fun) + + expect_true("my_test_scalar_function" %in% names(arrow:::.cache$functions)) + expect_true("my_test_scalar_function" %in% list_compute_functions()) + + expect_equal( + call_function("my_test_scalar_function", Array$create(1L, int32())), + Array$create(2L, int64()) + ) + + expect_equal( + call_function("my_test_scalar_function", Scalar$create(1L, int32())), + Scalar$create(2L, int64()) + ) + + expect_identical( + record_batch(a = 1L) %>% + dplyr::mutate(b = my_test_scalar_function(a)) %>% + dplyr::collect(), + tibble::tibble(a = 1L, b = 2L) + ) +}) Review Comment: Would it be good to do a test with a dataset with multiple files? Particularly if we're concerned about R thread safety, we might need more than 1 row in a RecordBatch to confirm it's behaving. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
