viirya commented on code in PR #8578:
URL: https://github.com/apache/arrow-datafusion/pull/8578#discussion_r1431998790


##########
datafusion-examples/examples/advanced_udf.rs:
##########
@@ -0,0 +1,243 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use datafusion::{
+    arrow::{
+        array::{ArrayRef, Float32Array, Float64Array},
+        datatypes::DataType,
+        record_batch::RecordBatch,
+    },
+    logical_expr::Volatility,
+};
+use std::any::Any;
+
+use arrow::array::{new_null_array, Array, AsArray};
+use arrow::compute;
+use arrow::datatypes::Float64Type;
+use datafusion::error::Result;
+use datafusion::prelude::*;
+use datafusion_common::{internal_err, ScalarValue};
+use datafusion_expr::{ColumnarValue, ScalarUDF, ScalarUDFImpl, Signature};
+use std::sync::Arc;
+
+/// This example shows how to use the full ScalarUDFImpl API to implement a 
user
+/// defined function. As in the `simple_udf.rs` example, this struct implements
+/// a function that takes two arguments and returns the first argument raised 
to
+/// the power of the second argument `a^b`.
+///
+/// To do so, we must implement the `ScalarUDFImpl` trait.
+struct PowUdf {
+    signature: Signature,
+    aliases: Vec<String>,
+}
+
+impl PowUdf {
+    /// Create a new instance of the `PowUdf` struct
+    fn new() -> Self {
+        Self {
+            signature: Signature::exact(
+                // this function will always take two arguments of type f64
+                vec![DataType::Float64, DataType::Float64],
+                // this function is deterministic and will always return the 
same
+                // result for the same input
+                Volatility::Immutable,
+            ),
+            // we will also add an alias of "my_pow"
+            aliases: vec!["my_pow".to_string()],
+        }
+    }
+}
+
+impl ScalarUDFImpl for PowUdf {
+    /// We implement as_any so that we can downcast the ScalarUDFImpl trait 
object
+    fn as_any(&self) -> &dyn Any {
+        self
+    }
+
+    /// Return the name of this function
+    fn name(&self) -> &str {
+        "pow"
+    }
+
+    /// Return the "signature" of this function -- namely what types of 
arguments it will take
+    fn signature(&self) -> &Signature {
+        &self.signature
+    }
+
+    /// What is the type of value that will be returned by this function? In
+    /// this case it will always be a constant value, but it could also be a
+    /// function of the input types.
+    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
+        Ok(DataType::Float64)
+    }
+
+    /// This is the function that actually calculates the results.
+    ///
+    /// This is the same way that functions built into DataFusion are invoked,
+    /// which permits important special cases when one or both of the arguments
+    /// are single values (constants). For example `pow(a, 2)`
+    ///
+    /// However, it also means the implementation is more complex than when
+    /// using `create_udf`.
+    fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
+        // DataFusion has arranged for the correct inputs to be passed to this
+        // function, but we check again to make sure
+        assert_eq!(args.len(), 2);
+        let (base, exp) = (&args[0], &args[1]);
+        assert_eq!(base.data_type(), DataType::Float64);
+        assert_eq!(exp.data_type(), DataType::Float64);
+
+        match (base, exp) {
+            // For demonstration purposes we also implement the scalar / scalar
+            // case here, but it is not typically required for high 
performance.
+            //
+            // For performance it is most important to optimize cases where at
+            // least one argument is an array. If all arguments are constants,
+            // the DataFusion expression simplification logic will often invoke
+            // this path once during planning, and simply use the result during
+            // execution.
+            (
+                ColumnarValue::Scalar(ScalarValue::Float64(base)),
+                ColumnarValue::Scalar(ScalarValue::Float64(exp)),
+            ) => {
+                // compute the output. Note DataFusion treats `None` as NULL.
+                let res = match (base, exp) {
+                    (Some(base), Some(exp)) => Some(base.powf(*exp)),
+                    // one or both arguments were NULL
+                    _ => None,
+                };
+                Ok(ColumnarValue::Scalar(ScalarValue::from(res)))
+            }
+            // special case if the exponent is a constant
+            (
+                ColumnarValue::Array(base_array),
+                ColumnarValue::Scalar(ScalarValue::Float64(exp)),
+            ) => {
+                let result_array = match exp {
+                    // a ^ null = null
+                    None => new_null_array(base_array.data_type(), 
base_array.len()),
+                    // a ^ exp
+                    Some(exp) => {
+                        // DataFusion has ensured both arguments are Float64:
+                        let base_array = 
base_array.as_primitive::<Float64Type>();
+                        // calculate the result for every row. The `unary` very
+                        // fast,  "vectorized" code and handles things like 
null
+                        // values for us.

Review Comment:
   Not sure if I read it correctly:
   ```suggestion
                           // calculate the result for every row. The `unary` 
is very
                           // fast "vectorized" code and handles things like 
null
                           // values for us.
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to