calcaura commented on code in PR #20308: URL: https://github.com/apache/datafusion/pull/20308#discussion_r2807431087
########## datafusion/functions/src/regex/regexpextract.rs: ########## @@ -0,0 +1,551 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Regex expressions +use arrow::array::{Array, ArrayRef, Int32Array, StringArray, StringBuilder}; +use arrow::datatypes::DataType; +use arrow::error::ArrowError; +use datafusion_common::exec_err; +use datafusion_common::{DataFusionError, Result}; +use datafusion_expr::{ColumnarValue, Documentation, TypeSignature}; +use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use datafusion_macros::user_doc; +use regex::Regex; +use std::any::Any; +use std::sync::Arc; + +// See https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.regexp_extract.html +// See https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala#L863 + +#[user_doc( + doc_section(label = "Regular Expression Functions"), + description = "Extract the first string in the `str` that match the `regexp` expression and corresponding to the regex group index", + syntax_example = "regexp_extract(str, regexp[, idx])", + sql_example = r#"```sql + > SELECT regexp_extract('100-200', '(\\d+)-(\\d+)', 1); + +---------------------------------------------------------+ + | 100 | + +---------------------------------------------------------+ + > SELECT regexp_extract('100-200', '(\\d+)-(\\d+)', 2); + +---------------------------------------------------------+ + | 200 | + +---------------------------------------------------------+ +``` +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/builtin_functions/regexp.rs) +"#, + argument(name = "str", description = "Column or column name"), + argument( + name = "regexp", + description = r#"a string representing a regular expression. The regex string should be a + Java regular expression.<br><br> + Since Spark 2.0, string literals (including regex patterns) are unescaped in our SQL + parser, see the unescaping rules at <a href="https://spark.apache.org/docs/latest/sql-ref-literals.html#string-literal">String Literal</a>. + For example, to match "\abc", a regular expression for `regexp` can be "^\\abc$".<br><br> + There is a SQL config 'spark.sql.parser.escapedStringLiterals' that can be used to + fallback to the Spark 1.6 behavior regarding string literal parsing. For example, + if the config is enabled, the `regexp` that can match "\abc" is "^\abc$".<br><br> + It's recommended to use a raw string literal (with the `r` prefix) to avoid escaping + special characters in the pattern string if exists."# + ), + argument( + name = "idx", + description = r#"an integer expression that representing the group index. The regex maybe contains + multiple groups. `idx` indicates which regex group to extract. The group index should + be non-negative. The minimum value of `idx` is 0, which means matching the entire + regular expression. If `idx` is not specified, the default group index value is 1. + This parameter is optional; when omitted the function defaults to extracting the first + capture group (idx=1), matching Spark's behavior."# + ) +)] +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct RegexpExtractFunc { + signature: Signature, +} + +impl Default for RegexpExtractFunc { + fn default() -> Self { + Self::new() + } +} + +impl RegexpExtractFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::one_of( + vec![ + // Spark Catalyst Expression: RegExpExtract(subject, regexp, idx) + // where idx defaults to 1 when omitted. + // See: https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala + // + // 2-arg form: regexp_extract(str, regexp) — idx defaults to 1 + // Matches Spark's: def this(s: Expression, r: Expression) = this(s, r, Literal(1)) + TypeSignature::Exact(vec![Utf8, Utf8]), + // 3-arg form: regexp_extract(str, regexp, idx) + TypeSignature::Exact(vec![Utf8, Utf8, Int32]), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for RegexpExtractFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "regexp_extract" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> { + use DataType::*; + // Spark's RegExpExtract always returns StringType + match arg_types.len() { + 2 | 3 => match arg_types[0] { + Utf8 => Ok(Utf8), + _ => exec_err!("regexp_extract only supports Utf8 for arg0"), + }, + _ => exec_err!( + "regexp_extract expects 2 or 3 arguments, got {}", + arg_types.len() + ), + } + } + + fn invoke_with_args( + &self, + args: datafusion_expr::ScalarFunctionArgs, + ) -> Result<ColumnarValue> { + let args = &args.args; + + if args.len() != 2 && args.len() != 3 { + return exec_err!("regexp_extract expects 2 or 3 arguments"); + } + + // DataFusion passes either scalars or arrays. Convert to arrays. + let len = args + .iter() + .map(|v| match v { + ColumnarValue::Array(a) => a.len(), + ColumnarValue::Scalar(_) => 1, + }) + .max() + .unwrap_or(1); + + let a0 = args[0].to_array(len)?; + let a1 = args[1].to_array(len)?; + + // Spark Catalyst: def this(s, r) = this(s, r, Literal(1)) + // When idx is omitted, default to group index 1. + let a2 = if args.len() == 3 { + args[2].to_array(len)? + } else { + // Default idx = 1, matching Spark's behavior + Arc::new(Int32Array::from(vec![1; len])) as ArrayRef + }; + + let out: ArrayRef = regexp_extract(&[a0, a1, a2])?; + Ok(ColumnarValue::Array(out)) + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} + +/// Helper to build args for tests and external callers. +pub fn regexp_extract(args: &[ArrayRef]) -> Result<ArrayRef> { + if args.len() != 3 { Review Comment: Here's a small omission (either 2 or 3). If there's a desire to always have only 3 I can change it everywhere, but it'll make it diverge slightly from spark (where the group idx is optional and defaults to 1 when not specified). -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
