This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new e46924d80 feat: add `arrow_cast` function to support supports
arbitrary arrow types (#5166)
e46924d80 is described below
commit e46924d80fddbed0faf35edc85b3bba6f050b344
Author: Andrew Lamb <[email protected]>
AuthorDate: Wed Mar 8 23:32:59 2023 +0100
feat: add `arrow_cast` function to support supports arbitrary arrow types
(#5166)
* Add `arrow_cast` function
* prettier
* Update datafusion/sql/src/expr/arrow_cast.rs
Co-authored-by: Wei-Ting Kuo <[email protected]>
* Apply suggestions from code review
Co-authored-by: Wei-Ting Kuo <[email protected]>
* Clarify intent of tests
* Add more error tests
* More tests
* fix test
* reuse buffer to avoid an allocation per word
* add ticket link
* allow trailing whitespace, add tests for whitespace
---------
Co-authored-by: Wei-Ting Kuo <[email protected]>
---
.../sqllogictests/test_files/arrow_typeof.slt | 267 +++++++-
datafusion/proto/src/logical_plan/mod.rs | 5 +-
datafusion/sql/src/expr/arrow_cast.rs | 719 +++++++++++++++++++++
datafusion/sql/src/expr/function.rs | 39 +-
datafusion/sql/src/expr/mod.rs | 1 +
datafusion/sql/src/lib.rs | 1 +
datafusion/sql/tests/integration_test.rs | 11 +-
docs/source/user-guide/sql/data_types.md | 66 +-
8 files changed, 1056 insertions(+), 53 deletions(-)
diff --git a/datafusion/core/tests/sqllogictests/test_files/arrow_typeof.slt
b/datafusion/core/tests/sqllogictests/test_files/arrow_typeof.slt
index 8f1c00651..fee24740a 100644
--- a/datafusion/core/tests/sqllogictests/test_files/arrow_typeof.slt
+++ b/datafusion/core/tests/sqllogictests/test_files/arrow_typeof.slt
@@ -52,31 +52,242 @@ SELECT arrow_typeof(1.0::float)
Float32
# arrow_typeof_decimal
-# query T
-# SELECT arrow_typeof(1::Decimal)
-# ----
-# Decimal128(38, 10)
-
-# # arrow_typeof_timestamp
-# query T
-# SELECT arrow_typeof(now()::timestamp)
-# ----
-# Timestamp(Nanosecond, None)
-
-# # arrow_typeof_timestamp_utc
-# query T
-# SELECT arrow_typeof(now())
-# ----
-# Timestamp(Nanosecond, Some(\"+00:00\"))
-
-# # arrow_typeof_timestamp_date32(
-# query T
-# SELECT arrow_typeof(now()::date)
-# ----
-# Date32
-
-# # arrow_typeof_utf8
-# query T
-# SELECT arrow_typeof('1')
-# ----
-# Utf8
+query T
+SELECT arrow_typeof(1::Decimal)
+----
+Decimal128(38, 10)
+
+# arrow_typeof_timestamp
+query T
+SELECT arrow_typeof(now()::timestamp)
+----
+Timestamp(Nanosecond, None)
+
+# arrow_typeof_timestamp_utc
+query T
+SELECT arrow_typeof(now())
+----
+Timestamp(Nanosecond, Some("+00:00"))
+
+# arrow_typeof_timestamp_date32(
+query T
+SELECT arrow_typeof(now()::date)
+----
+Date32
+
+# arrow_typeof_utf8
+query T
+SELECT arrow_typeof('1')
+----
+Utf8
+
+
+#### arrow_cast (in some ways opposite of arrow_typeof)
+
+# Basic tests
+
+query I
+SELECT arrow_cast('1', 'Int16')
+----
+1
+
+# Basic error test
+query error Error during planning: arrow_cast needs 2 arguments, 1 provided
+SELECT arrow_cast('1')
+
+query error Error during planning: arrow_cast requires its second argument to
be a constant string, got Int64\(43\)
+SELECT arrow_cast('1', 43)
+
+query error Error unrecognized word: unknown
+SELECT arrow_cast('1', 'unknown')
+
+# Round Trip tests:
+query TTTTTTTTTTTTTTTTTTT
+SELECT
+ arrow_typeof(arrow_cast(1, 'Int8')) as col_i8,
+ arrow_typeof(arrow_cast(1, 'Int16')) as col_i16,
+ arrow_typeof(arrow_cast(1, 'Int32')) as col_i32,
+ arrow_typeof(arrow_cast(1, 'Int64')) as col_i64,
+ arrow_typeof(arrow_cast(1, 'UInt8')) as col_u8,
+ arrow_typeof(arrow_cast(1, 'UInt16')) as col_u16,
+ arrow_typeof(arrow_cast(1, 'UInt32')) as col_u32,
+ arrow_typeof(arrow_cast(1, 'UInt64')) as col_u64,
+ -- can't seem to cast to Float16 for some reason
+ -- arrow_typeof(arrow_cast(1, 'Float16')) as col_f16,
+ arrow_typeof(arrow_cast(1, 'Float32')) as col_f32,
+ arrow_typeof(arrow_cast(1, 'Float64')) as col_f64,
+ arrow_typeof(arrow_cast('foo', 'Utf8')) as col_utf8,
+ arrow_typeof(arrow_cast('foo', 'LargeUtf8')) as col_large_utf8,
+ arrow_typeof(arrow_cast('foo', 'Binary')) as col_binary,
+ arrow_typeof(arrow_cast('foo', 'LargeBinary')) as col_large_binary,
+ arrow_typeof(arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'),
'Timestamp(Second, None)')) as col_ts_s,
+ arrow_typeof(arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'),
'Timestamp(Millisecond, None)')) as col_ts_ms,
+ arrow_typeof(arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'),
'Timestamp(Microsecond, None)')) as col_ts_us,
+ arrow_typeof(arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'),
'Timestamp(Nanosecond, None)')) as col_ts_ns,
+ arrow_typeof(arrow_cast('foo', 'Dictionary(Int32, Utf8)')) as col_dict
+----
+Int8 Int16 Int32 Int64 UInt8 UInt16 UInt32 UInt64 Float32 Float64 Utf8
LargeUtf8 Binary LargeBinary Timestamp(Second, None) Timestamp(Millisecond,
None) Timestamp(Microsecond, None) Timestamp(Nanosecond, None)
Dictionary(Int32, Utf8)
+
+
+
+
+## Basic Types: Create a table
+
+statement ok
+create table foo as select
+ arrow_cast(1, 'Int8') as col_i8,
+ arrow_cast(1, 'Int16') as col_i16,
+ arrow_cast(1, 'Int32') as col_i32,
+ arrow_cast(1, 'Int64') as col_i64,
+ arrow_cast(1, 'UInt8') as col_u8,
+ arrow_cast(1, 'UInt16') as col_u16,
+ arrow_cast(1, 'UInt32') as col_u32,
+ arrow_cast(1, 'UInt64') as col_u64,
+ -- can't seem to cast to Float16 for some reason
+ -- arrow_cast(1.0, 'Float16') as col_f16,
+ arrow_cast(1.0, 'Float32') as col_f32,
+ arrow_cast(1.0, 'Float64') as col_f64
+;
+
+## Ensure each column in the table has the expected type
+
+query TTTTTTTTTT
+SELECT
+ arrow_typeof(col_i8),
+ arrow_typeof(col_i16),
+ arrow_typeof(col_i32),
+ arrow_typeof(col_i64),
+ arrow_typeof(col_u8),
+ arrow_typeof(col_u16),
+ arrow_typeof(col_u32),
+ arrow_typeof(col_u64),
+ -- arrow_typeof(col_f16),
+ arrow_typeof(col_f32),
+ arrow_typeof(col_f64)
+ FROM foo;
+----
+Int8 Int16 Int32 Int64 UInt8 UInt16 UInt32 UInt64 Float32 Float64
+
+
+statement ok
+drop table foo
+
+## Decimals: Create a table
+
+statement ok
+create table foo as select
+ arrow_cast(100, 'Decimal128(3,2)') as col_d128
+ -- Can't make a decimal 156:
+ -- This feature is not implemented: Can't create a scalar from array of type
"Decimal256(3, 2)"
+ --arrow_cast(100, 'Decimal256(3,2)') as col_d256
+;
+
+
+## Ensure each column in the table has the expected type
+
+query T
+SELECT
+ arrow_typeof(col_d128)
+ -- arrow_typeof(col_d256),
+ FROM foo;
+----
+Decimal128(3, 2)
+
+
+statement ok
+drop table foo
+
+## Strings, Binary: Create a table
+
+statement ok
+create table foo as select
+ arrow_cast('foo', 'Utf8') as col_utf8,
+ arrow_cast('foo', 'LargeUtf8') as col_large_utf8,
+ arrow_cast('foo', 'Binary') as col_binary,
+ arrow_cast('foo', 'LargeBinary') as col_large_binary
+;
+
+## Ensure each column in the table has the expected type
+
+query TTTT
+SELECT
+ arrow_typeof(col_utf8),
+ arrow_typeof(col_large_utf8),
+ arrow_typeof(col_binary),
+ arrow_typeof(col_large_binary)
+ FROM foo;
+----
+Utf8 LargeUtf8 Binary LargeBinary
+
+
+statement ok
+drop table foo
+
+
+## Timestamps: Create a table
+
+statement ok
+create table foo as select
+ arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'),
'Timestamp(Second, None)') as col_ts_s,
+ arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'),
'Timestamp(Millisecond, None)') as col_ts_ms,
+ arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'),
'Timestamp(Microsecond, None)') as col_ts_us,
+ arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'),
'Timestamp(Nanosecond, None)') as col_ts_ns
+;
+
+## Ensure each column in the table has the expected type
+
+query TTTT
+SELECT
+ arrow_typeof(col_ts_s),
+ arrow_typeof(col_ts_ms),
+ arrow_typeof(col_ts_us),
+ arrow_typeof(col_ts_ns)
+ FROM foo;
+----
+Timestamp(Second, None) Timestamp(Millisecond, None) Timestamp(Microsecond,
None) Timestamp(Nanosecond, None)
+
+
+statement ok
+drop table foo
+
+## Dictionaries
+
+statement ok
+create table foo as select
+ arrow_cast('foo', 'Dictionary(Int32, Utf8)') as col_dict_int32_utf8,
+ arrow_cast('foo', 'Dictionary(Int8, LargeUtf8)') as col_dict_int8_largeutf8
+;
+
+## Ensure each column in the table has the expected type
+
+query TT
+SELECT
+ arrow_typeof(col_dict_int32_utf8),
+ arrow_typeof(col_dict_int8_largeutf8)
+ FROM foo;
+----
+Dictionary(Int32, Utf8) Dictionary(Int8, LargeUtf8)
+
+
+statement ok
+drop table foo
+
+
+## Intervals:
+
+query error Cannot automatically convert Interval\(DayTime\) to
Interval\(MonthDayNano\)
+---
+select arrow_cast(interval '30 minutes', 'Interval(MonthDayNano)');
+
+query error DataFusion error: Error during planning: Cannot automatically
convert Utf8 to Interval\(MonthDayNano\)
+select arrow_cast('30 minutes', 'Interval(MonthDayNano)');
+
+
+## Duration
+
+query error Cannot automatically convert Interval\(DayTime\) to
Duration\(Second\)
+---
+select arrow_cast(interval '30 minutes', 'Duration(Second)');
+
+query error DataFusion error: Error during planning: Cannot automatically
convert Utf8 to Duration\(Second\)
+select arrow_cast('30 minutes', 'Duration(Second)');
diff --git a/datafusion/proto/src/logical_plan/mod.rs
b/datafusion/proto/src/logical_plan/mod.rs
index 706128259..802242b3e 100644
--- a/datafusion/proto/src/logical_plan/mod.rs
+++ b/datafusion/proto/src/logical_plan/mod.rs
@@ -2120,8 +2120,11 @@ mod roundtrip_tests {
DataType::Float16,
DataType::Float32,
DataType::Float64,
- // Add more timestamp tests
+ DataType::Timestamp(TimeUnit::Second, None),
DataType::Timestamp(TimeUnit::Millisecond, None),
+ DataType::Timestamp(TimeUnit::Microsecond, None),
+ DataType::Timestamp(TimeUnit::Nanosecond, None),
+ DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into())),
DataType::Date32,
DataType::Date64,
DataType::Time32(TimeUnit::Second),
diff --git a/datafusion/sql/src/expr/arrow_cast.rs
b/datafusion/sql/src/expr/arrow_cast.rs
new file mode 100644
index 000000000..bc1313e2c
--- /dev/null
+++ b/datafusion/sql/src/expr/arrow_cast.rs
@@ -0,0 +1,719 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Implementation of the `arrow_cast` function that allows
+//! casting to arbitrary arrow types (rather than SQL types)
+
+use std::{fmt::Display, iter::Peekable, str::Chars};
+
+use arrow_schema::{DataType, IntervalUnit, TimeUnit};
+use datafusion_common::{DFSchema, DataFusionError, Result, ScalarValue};
+
+use datafusion_expr::{Expr, ExprSchemable};
+
+pub const ARROW_CAST_NAME: &str = "arrow_cast";
+
+/// Create an [`Expr`] that evaluates the `arrow_cast` function
+///
+/// This function is not a [`BuiltInScalarFunction`] because the
+/// return type of [`BuiltInScalarFunction`] depends only on the
+/// *types* of the arguments. However, the type of `arrow_type` depends on
+/// the *value* of its second argument.
+///
+/// Use the `cast` function to cast to SQL type (which is then mapped
+/// to the corresponding arrow type). For example to cast to `int`
+/// (which is then mapped to the arrow type `Int32`)
+///
+/// ```sql
+/// select cast(column_x as int) ...
+/// ```
+///
+/// Use the `arrow_cast` functiont to cast to a specfic arrow type
+///
+/// For example
+/// ```sql
+/// select arrow_cast(column_x, 'Float64')
+/// ```
+pub fn create_arrow_cast(mut args: Vec<Expr>, schema: &DFSchema) ->
Result<Expr> {
+ if args.len() != 2 {
+ return Err(DataFusionError::Plan(format!(
+ "arrow_cast needs 2 arguments, {} provided",
+ args.len()
+ )));
+ }
+ let arg1 = args.pop().unwrap();
+ let arg0 = args.pop().unwrap();
+
+ // arg1 must be a stirng
+ let data_type_string = if let Expr::Literal(ScalarValue::Utf8(Some(v))) =
arg1 {
+ v
+ } else {
+ return Err(DataFusionError::Plan(format!(
+ "arrow_cast requires its second argument to be a constant string,
got {arg1}"
+ )));
+ };
+
+ // do the actual lookup to the appropriate data type
+ let data_type = parse_data_type(&data_type_string)?;
+
+ arg0.cast_to(&data_type, schema)
+}
+
+/// Parses `str` into a `DataType`.
+///
+/// `parse_data_type` is the the reverse of [`DataType`]'s `Display`
+/// impl, and maintains the invariant that
+/// `parse_data_type(data_type.to_string()) == data_type`
+///
+/// Example:
+/// ```
+/// # use datafusion_sql::parse_data_type;
+/// # use arrow_schema::DataType;
+/// let display_value = "Int32";
+///
+/// // "Int32" is the Display value of `DataType`
+/// assert_eq!(display_value, &format!("{}", DataType::Int32));
+///
+/// // parse_data_type coverts "Int32" back to `DataType`:
+/// let data_type = parse_data_type(display_value).unwrap();
+/// assert_eq!(data_type, DataType::Int32);
+/// ```
+///
+/// Remove if added to arrow: https://github.com/apache/arrow-rs/issues/3821
+pub fn parse_data_type(val: &str) -> Result<DataType> {
+ Parser::new(val).parse()
+}
+
+fn make_error(val: &str, msg: &str) -> DataFusionError {
+ DataFusionError::Plan(
+ format!("Unsupported type '{val}'. Must be a supported arrow type name
such as 'Int32' or 'Timestamp(Nanosecond, None)'. Error {msg}" )
+ )
+}
+
+fn make_error_expected(val: &str, expected: &Token, actual: &Token) ->
DataFusionError {
+ make_error(val, &format!("Expected '{expected}', got '{actual}'"))
+}
+
+#[derive(Debug)]
+/// Implementation of `parse_data_type`, modeled after
<https://github.com/sqlparser-rs/sqlparser-rs>
+struct Parser<'a> {
+ val: &'a str,
+ tokenizer: Tokenizer<'a>,
+}
+
+impl<'a> Parser<'a> {
+ fn new(val: &'a str) -> Self {
+ Self {
+ val,
+ tokenizer: Tokenizer::new(val),
+ }
+ }
+
+ fn parse(mut self) -> Result<DataType> {
+ let data_type = self.parse_next_type()?;
+ // ensure that there is no trailing content
+ if self.tokenizer.next().is_some() {
+ return Err(make_error(
+ self.val,
+ &format!("checking trailing content after parsing
'{data_type}'"),
+ ));
+ } else {
+ Ok(data_type)
+ }
+ }
+
+ /// parses the next full DataType
+ fn parse_next_type(&mut self) -> Result<DataType> {
+ match self.next_token()? {
+ Token::SimpleType(data_type) => Ok(data_type),
+ Token::Timestamp => self.parse_timestamp(),
+ Token::Time32 => self.parse_time32(),
+ Token::Time64 => self.parse_time64(),
+ Token::Duration => self.parse_duration(),
+ Token::Interval => self.parse_interval(),
+ Token::FixedSizeBinary => self.parse_fixed_size_binary(),
+ Token::Decimal128 => self.parse_decimal_128(),
+ Token::Decimal256 => self.parse_decimal_256(),
+ Token::Dictionary => self.parse_dictionary(),
+ tok => Err(make_error(
+ self.val,
+ &format!("finding next type, got unexpected '{tok}'"),
+ )),
+ }
+ }
+
+ /// Parses the next timeunit
+ fn parse_time_unit(&mut self, context: &str) -> Result<TimeUnit> {
+ match self.next_token()? {
+ Token::TimeUnit(time_unit) => Ok(time_unit),
+ tok => Err(make_error(
+ self.val,
+ &format!("finding TimeUnit for {context}, got {tok}"),
+ )),
+ }
+ }
+
+ /// Parses the next integer value
+ fn parse_i64(&mut self, context: &str) -> Result<i64> {
+ match self.next_token()? {
+ Token::Integer(v) => Ok(v),
+ tok => Err(make_error(
+ self.val,
+ &format!("finding i64 for {context}, got '{tok}'"),
+ )),
+ }
+ }
+
+ /// Parses the next i32 integer value
+ fn parse_i32(&mut self, context: &str) -> Result<i32> {
+ let length = self.parse_i64(context)?;
+ length.try_into().map_err(|e| {
+ make_error(
+ self.val,
+ &format!("converting {length} into i32 for {context}: {e}"),
+ )
+ })
+ }
+
+ /// Parses the next i8 integer value
+ fn parse_i8(&mut self, context: &str) -> Result<i8> {
+ let length = self.parse_i64(context)?;
+ length.try_into().map_err(|e| {
+ make_error(
+ self.val,
+ &format!("converting {length} into i8 for {context}: {e}"),
+ )
+ })
+ }
+
+ /// Parses the next u8 integer value
+ fn parse_u8(&mut self, context: &str) -> Result<u8> {
+ let length = self.parse_i64(context)?;
+ length.try_into().map_err(|e| {
+ make_error(
+ self.val,
+ &format!("converting {length} into u8 for {context}: {e}"),
+ )
+ })
+ }
+
+ /// Parses the next timestamp (called after `Timestamp` has been consumed)
+ fn parse_timestamp(&mut self) -> Result<DataType> {
+ self.expect_token(Token::LParen)?;
+ let time_unit = self.parse_time_unit("Timestamp")?;
+ self.expect_token(Token::Comma)?;
+ // TODO Support timezones other than None
+ self.expect_token(Token::None)?;
+ let timezone = None;
+
+ self.expect_token(Token::RParen)?;
+ Ok(DataType::Timestamp(time_unit, timezone))
+ }
+
+ /// Parses the next Time32 (called after `Time32` has been consumed)
+ fn parse_time32(&mut self) -> Result<DataType> {
+ self.expect_token(Token::LParen)?;
+ let time_unit = self.parse_time_unit("Time32")?;
+ self.expect_token(Token::RParen)?;
+ Ok(DataType::Time32(time_unit))
+ }
+
+ /// Parses the next Time64 (called after `Time64` has been consumed)
+ fn parse_time64(&mut self) -> Result<DataType> {
+ self.expect_token(Token::LParen)?;
+ let time_unit = self.parse_time_unit("Time64")?;
+ self.expect_token(Token::RParen)?;
+ Ok(DataType::Time64(time_unit))
+ }
+
+ /// Parses the next Duration (called after `Duration` has been consumed)
+ fn parse_duration(&mut self) -> Result<DataType> {
+ self.expect_token(Token::LParen)?;
+ let time_unit = self.parse_time_unit("Duration")?;
+ self.expect_token(Token::RParen)?;
+ Ok(DataType::Duration(time_unit))
+ }
+
+ /// Parses the next Interval (called after `Interval` has been consumed)
+ fn parse_interval(&mut self) -> Result<DataType> {
+ self.expect_token(Token::LParen)?;
+ let interval_unit = match self.next_token()? {
+ Token::IntervalUnit(interval_unit) => interval_unit,
+ tok => {
+ return Err(make_error(
+ self.val,
+ &format!("finding IntervalUnit for Interval, got {tok}"),
+ ))
+ }
+ };
+ self.expect_token(Token::RParen)?;
+ Ok(DataType::Interval(interval_unit))
+ }
+
+ /// Parses the next FixedSizeBinary (called after `FixedSizeBinary` has
been consumed)
+ fn parse_fixed_size_binary(&mut self) -> Result<DataType> {
+ self.expect_token(Token::LParen)?;
+ let length = self.parse_i32("FixedSizeBinary")?;
+ self.expect_token(Token::RParen)?;
+ Ok(DataType::FixedSizeBinary(length))
+ }
+
+ /// Parses the next Decimal128 (called after `Decimal128` has been
consumed)
+ fn parse_decimal_128(&mut self) -> Result<DataType> {
+ self.expect_token(Token::LParen)?;
+ let precision = self.parse_u8("Decimal128")?;
+ self.expect_token(Token::Comma)?;
+ let scale = self.parse_i8("Decimal128")?;
+ self.expect_token(Token::RParen)?;
+ Ok(DataType::Decimal128(precision, scale))
+ }
+
+ /// Parses the next Decimal256 (called after `Decimal256` has been
consumed)
+ fn parse_decimal_256(&mut self) -> Result<DataType> {
+ self.expect_token(Token::LParen)?;
+ let precision = self.parse_u8("Decimal256")?;
+ self.expect_token(Token::Comma)?;
+ let scale = self.parse_i8("Decimal256")?;
+ self.expect_token(Token::RParen)?;
+ Ok(DataType::Decimal256(precision, scale))
+ }
+
+ /// Parses the next Dictionary (called after `Dictionary` has been
consumed)
+ fn parse_dictionary(&mut self) -> Result<DataType> {
+ self.expect_token(Token::LParen)?;
+ let key_type = self.parse_next_type()?;
+ self.expect_token(Token::Comma)?;
+ let value_type = self.parse_next_type()?;
+ self.expect_token(Token::RParen)?;
+ Ok(DataType::Dictionary(
+ Box::new(key_type),
+ Box::new(value_type),
+ ))
+ }
+
+ /// return the next token, or an error if there are none left
+ fn next_token(&mut self) -> Result<Token> {
+ match self.tokenizer.next() {
+ None => Err(make_error(self.val, "finding next token")),
+ Some(token) => token,
+ }
+ }
+
+ /// consume the next token, returning OK(()) if it matches tok, and Err if
not
+ fn expect_token(&mut self, tok: Token) -> Result<()> {
+ let next_token = self.next_token()?;
+ if next_token == tok {
+ Ok(())
+ } else {
+ Err(make_error_expected(self.val, &tok, &next_token))
+ }
+ }
+}
+
+/// returns true if this character is a separator
+fn is_separator(c: char) -> bool {
+ c == '(' || c == ')' || c == ',' || c == ' '
+}
+
+#[derive(Debug)]
+/// Splits a strings like Dictionary(Int32, Int64) into tokens sutable for
parsing
+///
+/// For example the string "Timestamp(Nanosecond, None)" would be parsed into:
+///
+/// * Token::Timestamp
+/// * Token::Lparen
+/// * Token::IntervalUnit(IntervalUnit::Nanosecond)
+/// * Token::Comma,
+/// * Token::None,
+/// * Token::Rparen,
+struct Tokenizer<'a> {
+ val: &'a str,
+ chars: Peekable<Chars<'a>>,
+ // temporary buffer for parsing words
+ word: String,
+}
+
+impl<'a> Tokenizer<'a> {
+ fn new(val: &'a str) -> Self {
+ Self {
+ val,
+ chars: val.chars().peekable(),
+ word: String::new(),
+ }
+ }
+
+ /// returns the next char, without consuming it
+ fn peek_next_char(&mut self) -> Option<char> {
+ self.chars.peek().copied()
+ }
+
+ /// returns the next char, and consuming it
+ fn next_char(&mut self) -> Option<char> {
+ self.chars.next()
+ }
+
+ /// parse the characters in val starting at pos, until the next
+ /// `,`, `(`, or `)` or end of line
+ fn parse_word(&mut self) -> Result<Token> {
+ // reset temp space
+ self.word.clear();
+ loop {
+ match self.peek_next_char() {
+ None => break,
+ Some(c) if is_separator(c) => break,
+ Some(c) => {
+ self.next_char();
+ self.word.push(c);
+ }
+ }
+ }
+
+ // if it started with a number, try parsing it as an integer
+ if let Some(c) = self.word.chars().next() {
+ if c == '-' || c.is_numeric() {
+ let val: i64 = self.word.parse().map_err(|e| {
+ make_error(
+ self.val,
+ &format!("parsing {} as integer: {e}", self.word),
+ )
+ })?;
+ return Ok(Token::Integer(val));
+ }
+ }
+
+ // figure out what the word was
+ let token = match self.word.as_str() {
+ "Null" => Token::SimpleType(DataType::Null),
+ "Boolean" => Token::SimpleType(DataType::Boolean),
+
+ "Int8" => Token::SimpleType(DataType::Int8),
+ "Int16" => Token::SimpleType(DataType::Int16),
+ "Int32" => Token::SimpleType(DataType::Int32),
+ "Int64" => Token::SimpleType(DataType::Int64),
+
+ "UInt8" => Token::SimpleType(DataType::UInt8),
+ "UInt16" => Token::SimpleType(DataType::UInt16),
+ "UInt32" => Token::SimpleType(DataType::UInt32),
+ "UInt64" => Token::SimpleType(DataType::UInt64),
+
+ "Utf8" => Token::SimpleType(DataType::Utf8),
+ "LargeUtf8" => Token::SimpleType(DataType::LargeUtf8),
+ "Binary" => Token::SimpleType(DataType::Binary),
+ "LargeBinary" => Token::SimpleType(DataType::LargeBinary),
+
+ "Float16" => Token::SimpleType(DataType::Float16),
+ "Float32" => Token::SimpleType(DataType::Float32),
+ "Float64" => Token::SimpleType(DataType::Float64),
+
+ "Date32" => Token::SimpleType(DataType::Date32),
+ "Date64" => Token::SimpleType(DataType::Date64),
+
+ "Second" => Token::TimeUnit(TimeUnit::Second),
+ "Millisecond" => Token::TimeUnit(TimeUnit::Millisecond),
+ "Microsecond" => Token::TimeUnit(TimeUnit::Microsecond),
+ "Nanosecond" => Token::TimeUnit(TimeUnit::Nanosecond),
+
+ "Timestamp" => Token::Timestamp,
+ "Time32" => Token::Time32,
+ "Time64" => Token::Time64,
+ "Duration" => Token::Duration,
+ "Interval" => Token::Interval,
+ "Dictionary" => Token::Dictionary,
+
+ "FixedSizeBinary" => Token::FixedSizeBinary,
+ "Decimal128" => Token::Decimal128,
+ "Decimal256" => Token::Decimal256,
+
+ "YearMonth" => Token::IntervalUnit(IntervalUnit::YearMonth),
+ "DayTime" => Token::IntervalUnit(IntervalUnit::DayTime),
+ "MonthDayNano" => Token::IntervalUnit(IntervalUnit::MonthDayNano),
+
+ "None" => Token::None,
+
+ _ => {
+ return Err(make_error(
+ self.val,
+ &format!("unrecognized word: {}", self.word),
+ ))
+ }
+ };
+ Ok(token)
+ }
+}
+
+impl<'a> Iterator for Tokenizer<'a> {
+ type Item = Result<Token>;
+
+ fn next(&mut self) -> Option<Self::Item> {
+ loop {
+ match self.peek_next_char()? {
+ ' ' => {
+ // skip whitespace
+ self.next_char();
+ continue;
+ }
+ '(' => {
+ self.next_char();
+ return Some(Ok(Token::LParen));
+ }
+ ')' => {
+ self.next_char();
+ return Some(Ok(Token::RParen));
+ }
+ ',' => {
+ self.next_char();
+ return Some(Ok(Token::Comma));
+ }
+ _ => return Some(self.parse_word()),
+ }
+ }
+ }
+}
+
+/// Grammar is
+///
+#[derive(Debug, PartialEq)]
+enum Token {
+ // Null, or Int32
+ SimpleType(DataType),
+ Timestamp,
+ Time32,
+ Time64,
+ Duration,
+ Interval,
+ FixedSizeBinary,
+ Decimal128,
+ Decimal256,
+ Dictionary,
+ TimeUnit(TimeUnit),
+ IntervalUnit(IntervalUnit),
+ LParen,
+ RParen,
+ Comma,
+ None,
+ Integer(i64),
+}
+
+impl Display for Token {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ match self {
+ Token::SimpleType(t) => write!(f, "{t}"),
+ Token::Timestamp => write!(f, "Timestamp"),
+ Token::Time32 => write!(f, "Time32"),
+ Token::Time64 => write!(f, "Time64"),
+ Token::Duration => write!(f, "Duration"),
+ Token::Interval => write!(f, "Interval"),
+ Token::TimeUnit(u) => write!(f, "TimeUnit({u:?})"),
+ Token::IntervalUnit(u) => write!(f, "IntervalUnit({u:?})"),
+ Token::LParen => write!(f, "("),
+ Token::RParen => write!(f, ")"),
+ Token::Comma => write!(f, ","),
+ Token::None => write!(f, "None"),
+ Token::FixedSizeBinary => write!(f, "FixedSizeBinary"),
+ Token::Decimal128 => write!(f, "Decimal128"),
+ Token::Decimal256 => write!(f, "Decimal256"),
+ Token::Dictionary => write!(f, "Dictionary"),
+ Token::Integer(v) => write!(f, "Integer({v})"),
+ }
+ }
+}
+
+#[cfg(test)]
+mod test {
+ use arrow_schema::{IntervalUnit, TimeUnit};
+
+ use super::*;
+
+ #[test]
+ fn test_parse_data_type() {
+ // this ensures types can be parsed correctly from their string
representations
+ for dt in list_datatypes() {
+ round_trip(dt)
+ }
+ }
+
+ /// convert data_type to a string, and then parse it as a type
+ /// verifying it is the same
+ fn round_trip(data_type: DataType) {
+ let data_type_string = data_type.to_string();
+ println!("Input '{data_type_string}' ({data_type:?})");
+ let parsed_type = parse_data_type(&data_type_string).unwrap();
+ assert_eq!(
+ data_type, parsed_type,
+ "Mismatch parsing {data_type_string}"
+ );
+ }
+
+ fn list_datatypes() -> Vec<DataType> {
+ vec![
+ // ---------
+ // Non Nested types
+ // ---------
+ DataType::Null,
+ DataType::Boolean,
+ DataType::Int8,
+ DataType::Int16,
+ DataType::Int32,
+ DataType::Int64,
+ DataType::UInt8,
+ DataType::UInt16,
+ DataType::UInt32,
+ DataType::UInt64,
+ DataType::Float16,
+ DataType::Float32,
+ DataType::Float64,
+ DataType::Timestamp(TimeUnit::Second, None),
+ DataType::Timestamp(TimeUnit::Millisecond, None),
+ DataType::Timestamp(TimeUnit::Microsecond, None),
+ DataType::Timestamp(TimeUnit::Nanosecond, None),
+ // TODO support timezones
+ //DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into())),
+ DataType::Date32,
+ DataType::Date64,
+ DataType::Time32(TimeUnit::Second),
+ DataType::Time32(TimeUnit::Millisecond),
+ DataType::Time32(TimeUnit::Microsecond),
+ DataType::Time32(TimeUnit::Nanosecond),
+ DataType::Time64(TimeUnit::Second),
+ DataType::Time64(TimeUnit::Millisecond),
+ DataType::Time64(TimeUnit::Microsecond),
+ DataType::Time64(TimeUnit::Nanosecond),
+ DataType::Duration(TimeUnit::Second),
+ DataType::Duration(TimeUnit::Millisecond),
+ DataType::Duration(TimeUnit::Microsecond),
+ DataType::Duration(TimeUnit::Nanosecond),
+ DataType::Interval(IntervalUnit::YearMonth),
+ DataType::Interval(IntervalUnit::DayTime),
+ DataType::Interval(IntervalUnit::MonthDayNano),
+ DataType::Binary,
+ DataType::FixedSizeBinary(0),
+ DataType::FixedSizeBinary(1234),
+ DataType::FixedSizeBinary(-432),
+ DataType::LargeBinary,
+ DataType::Utf8,
+ DataType::LargeUtf8,
+ DataType::Decimal128(7, 12),
+ DataType::Decimal256(6, 13),
+ // ---------
+ // Nested types
+ // ---------
+ DataType::Dictionary(Box::new(DataType::Int32),
Box::new(DataType::Utf8)),
+ DataType::Dictionary(Box::new(DataType::Int8),
Box::new(DataType::Utf8)),
+ DataType::Dictionary(
+ Box::new(DataType::Int8),
+ Box::new(DataType::Timestamp(TimeUnit::Nanosecond, None)),
+ ),
+ DataType::Dictionary(
+ Box::new(DataType::Int8),
+ Box::new(DataType::FixedSizeBinary(23)),
+ ),
+ DataType::Dictionary(
+ Box::new(DataType::Int8),
+ Box::new(
+ // nested dictionaries are probably a bad idea but they
are possible
+ DataType::Dictionary(
+ Box::new(DataType::Int8),
+ Box::new(DataType::Utf8),
+ ),
+ ),
+ ),
+ // TODO support more structured types (List, LargeList, Struct,
Union, Map, RunEndEncoded, etc)
+ ]
+ }
+
+ #[test]
+ fn test_parse_data_type_whitespace_tolerance() {
+ // (string to parse, expected DataType)
+ let cases = [
+ ("Int8", DataType::Int8),
+ (
+ "Timestamp (Nanosecond, None)",
+ DataType::Timestamp(TimeUnit::Nanosecond, None),
+ ),
+ (
+ "Timestamp (Nanosecond, None) ",
+ DataType::Timestamp(TimeUnit::Nanosecond, None),
+ ),
+ (
+ " Timestamp (Nanosecond, None
)",
+ DataType::Timestamp(TimeUnit::Nanosecond, None),
+ ),
+ (
+ "Timestamp (Nanosecond, None ) ",
+ DataType::Timestamp(TimeUnit::Nanosecond, None),
+ ),
+ ];
+
+ for (data_type_string, expected_data_type) in cases {
+ println!("Parsing '{data_type_string}', expecting
'{expected_data_type:?}'");
+ let parsed_data_type = parse_data_type(data_type_string).unwrap();
+ assert_eq!(parsed_data_type, expected_data_type);
+ }
+ }
+
+ #[test]
+ fn parse_data_type_errors() {
+ // (string to parse, expected error message)
+ let cases = [
+ ("", "Unsupported type ''"),
+ ("", "Error finding next token"),
+ ("null", "Unsupported type 'null'"),
+ ("Nu", "Unsupported type 'Nu'"),
+ // TODO support timezones
+ (
+ r#"Timestamp(Nanosecond, Some("UTC"))"#,
+ "Error unrecognized word: Some",
+ ),
+ ("Timestamp(Nanosecond, ", "Error finding next token"),
+ (
+ "Float32 Float32",
+ "trailing content after parsing 'Float32'",
+ ),
+ ("Int32, ", "trailing content after parsing 'Int32'"),
+ ("Int32(3), ", "trailing content after parsing 'Int32'"),
+ ("FixedSizeBinary(Int32), ", "Error finding i64 for
FixedSizeBinary, got 'Int32'"),
+ ("FixedSizeBinary(3.0), ", "Error parsing 3.0 as integer: invalid
digit found in string"),
+ // too large for i32
+ ("FixedSizeBinary(4000000000), ", "Error converting 4000000000
into i32 for FixedSizeBinary: out of range integral type conversion attempted"),
+ // can't have negative precision
+ ("Decimal128(-3, 5)", "Error converting -3 into u8 for Decimal128:
out of range integral type conversion attempted"),
+ ("Decimal256(-3, 5)", "Error converting -3 into u8 for Decimal256:
out of range integral type conversion attempted"),
+ ("Decimal128(3, 500)", "Error converting 500 into i8 for
Decimal128: out of range integral type conversion attempted"),
+ ("Decimal256(3, 500)", "Error converting 500 into i8 for
Decimal256: out of range integral type conversion attempted"),
+
+ ];
+
+ for (data_type_string, expected_message) in cases {
+ print!("Parsing '{data_type_string}', expecting
'{expected_message}'");
+ match parse_data_type(data_type_string) {
+ Ok(d) => panic!(
+ "Expected error while parsing '{data_type_string}', but
got '{d}'"
+ ),
+ Err(e) => {
+ let message = e.to_string();
+ assert!(
+ message.contains(expected_message),
+ "\n\ndid not find expected in actual.\n\nexpected:
{expected_message}\nactual:{message}\n"
+ );
+ // errors should also contain a help message
+ assert!(message.contains("Must be a supported arrow type
name such as 'Int32' or 'Timestamp(Nanosecond, None)'"));
+ }
+ }
+ println!(" Ok");
+ }
+ }
+}
diff --git a/datafusion/sql/src/expr/function.rs
b/datafusion/sql/src/expr/function.rs
index c5f23213a..68a5df054 100644
--- a/datafusion/sql/src/expr/function.rs
+++ b/datafusion/sql/src/expr/function.rs
@@ -29,6 +29,8 @@ use sqlparser::ast::{
};
use std::str::FromStr;
+use super::arrow_cast::ARROW_CAST_NAME;
+
impl<'a, S: ContextProvider> SqlToRel<'a, S> {
pub(super) fn sql_function_to_expr(
&self,
@@ -110,24 +112,29 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
};
// finally, user-defined functions (UDF) and UDAF
- match self.schema_provider.get_function_meta(&name) {
- Some(fm) => {
- let args = self.function_args_to_expr(function.args, schema)?;
+ if let Some(fm) = self.schema_provider.get_function_meta(&name) {
+ let args = self.function_args_to_expr(function.args, schema)?;
+ return Ok(Expr::ScalarUDF { fun: fm, args });
+ }
- Ok(Expr::ScalarUDF { fun: fm, args })
- }
- None => match self.schema_provider.get_aggregate_meta(&name) {
- Some(fm) => {
- let args = self.function_args_to_expr(function.args,
schema)?;
- Ok(Expr::AggregateUDF {
- fun: fm,
- args,
- filter: None,
- })
- }
- _ => Err(DataFusionError::Plan(format!("Invalid function
'{name}'"))),
- },
+ // User defined aggregate functions
+ if let Some(fm) = self.schema_provider.get_aggregate_meta(&name) {
+ let args = self.function_args_to_expr(function.args, schema)?;
+ return Ok(Expr::AggregateUDF {
+ fun: fm,
+ args,
+ filter: None,
+ });
}
+
+ // Special case arrow_cast (as its type is dependent on its argument
value)
+ if name == ARROW_CAST_NAME {
+ let args = self.function_args_to_expr(function.args, schema)?;
+ return super::arrow_cast::create_arrow_cast(args, schema);
+ }
+
+ // Could not find the relevant function, so return an error
+ Err(DataFusionError::Plan(format!("Invalid function '{name}'")))
}
pub(super) fn sql_named_function_to_expr(
diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs
index f22692451..ad05fbcc1 100644
--- a/datafusion/sql/src/expr/mod.rs
+++ b/datafusion/sql/src/expr/mod.rs
@@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.
+pub(crate) mod arrow_cast;
mod binary_op;
mod function;
mod grouping_set;
diff --git a/datafusion/sql/src/lib.rs b/datafusion/sql/src/lib.rs
index efe239f45..c0c1a4ac9 100644
--- a/datafusion/sql/src/lib.rs
+++ b/datafusion/sql/src/lib.rs
@@ -30,4 +30,5 @@ pub mod utils;
mod values;
pub use datafusion_common::{ResolvedTableReference, TableReference};
+pub use expr::arrow_cast::parse_data_type;
pub use sqlparser;
diff --git a/datafusion/sql/tests/integration_test.rs
b/datafusion/sql/tests/integration_test.rs
index 71f5bf05e..660959907 100644
--- a/datafusion/sql/tests/integration_test.rs
+++ b/datafusion/sql/tests/integration_test.rs
@@ -2311,6 +2311,15 @@ fn approx_median_window() {
quick_test(sql, expected);
}
+#[test]
+fn select_arrow_cast() {
+ let sql = "SELECT arrow_cast(1234, 'Float64'), arrow_cast('foo',
'LargeUtf8')";
+ let expected = "\
+ Projection: CAST(Int64(1234) AS Float64), CAST(Utf8(\"foo\") AS LargeUtf8)\
+ \n EmptyRelation";
+ quick_test(sql, expected);
+}
+
#[test]
fn select_typed_date_string() {
let sql = "SELECT date '2020-12-10' AS date";
@@ -2534,7 +2543,7 @@ impl ContextProvider for MockContextProvider {
}
fn get_function_meta(&self, _name: &str) -> Option<Arc<ScalarUDF>> {
- unimplemented!()
+ None
}
fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>> {
diff --git a/docs/source/user-guide/sql/data_types.md
b/docs/source/user-guide/sql/data_types.md
index 968dcda53..9f0ca8f89 100644
--- a/docs/source/user-guide/sql/data_types.md
+++ b/docs/source/user-guide/sql/data_types.md
@@ -37,6 +37,18 @@ the `arrow_typeof` function. For example:
+-------------------------------------+
```
+You can cast a SQL expression to a specific Arrow type using the `arrow_cast`
function
+For example, to cast the output of `now()` to a `Timestamp` with second
precision rather:
+
+```sql
+❯ select arrow_cast(now(), 'Timestamp(Second, None)');
++---------------------+
+| now() |
++---------------------+
+| 2023-03-03T17:19:21 |
++---------------------+
+```
+
## Character Types
| SQL DataType | Arrow DataType |
@@ -65,12 +77,12 @@ the `arrow_typeof` function. For example:
## Date/Time Types
-| SQL DataType | Arrow DataType
|
-| ------------ |
:----------------------------------------------------------------------- |
-| `DATE` | `Date32`
|
-| `TIME` | `Time64(TimeUnit::Nanosecond)`
|
-| `TIMESTAMP` | `Timestamp(TimeUnit::Nanosecond, None)`
|
-| `INTERVAL` | `Interval(IntervalUnit::YearMonth)` or
`Interval(IntervalUnit::DayTime)` |
+| SQL DataType | Arrow DataType |
+| ------------ | :---------------------------------------------- |
+| `DATE` | `Date32` |
+| `TIME` | `Time64(Nanosecond)` |
+| `TIMESTAMP` | `Timestamp(Nanosecond, None)` |
+| `INTERVAL` | `Interval(IntervalUnit)` or `Interval(DayTime)` |
## Boolean Types
@@ -84,7 +96,7 @@ the `arrow_typeof` function. For example:
| ------------ | :------------- |
| `BYTEA` | `Binary` |
-## Unsupported Types
+## Unsupported SQL Types
| SQL Data Type | Arrow DataType |
| ------------- | :------------------ |
@@ -100,3 +112,43 @@ the `arrow_typeof` function. For example:
| `ENUM` | _Not yet supported_ |
| `SET` | _Not yet supported_ |
| `DATETIME` | _Not yet supported_ |
+
+## Supported Arrow Types
+
+The following types are supported by the `arrow_typeof` function:
+
+| Arrow Type |
+| ----------------------------------------------------------- |
+| `Null` |
+| `Boolean` |
+| `Int8` |
+| `Int16` |
+| `Int32` |
+| `Int64` |
+| `UInt8` |
+| `UInt16` |
+| `UInt32` |
+| `UInt64` |
+| `Float16` |
+| `Float32` |
+| `Float64` |
+| `Utf8` |
+| `LargeUtf8` |
+| `Binary` |
+| `Timestamp(Second, None)` |
+| `Timestamp(Millisecond, None)` |
+| `Timestamp(Microsecond, None)` |
+| `Timestamp(Nanosecond, None)` |
+| `Time32` |
+| `Time64` |
+| `Duration(Second)` |
+| `Duration(Millisecond)` |
+| `Duration(Microsecond)` |
+| `Duration(Nanosecond)` |
+| `Interval(YearMonth)` |
+| `Interval(DayTime)` |
+| `Interval(MonthDayNano)` |
+| `Interval(MonthDayNano)` |
+| `FixedSizeBinary(<len>)` (e.g. `FixedSizeBinary(16)`) |
+| `Decimal128(<precision>, <scale>)` e.g. `Decimal128(3, 10)` |
+| `Decimal256(<precision>, <scale>)` e.g. `Decimal256(3, 10)` |