bmmeijers commented on issue #18121:
URL: https://github.com/apache/datafusion/issues/18121#issuecomment-3414714495
As the pastebin will disappear in 5 days, ad verbatim what I placed there...
```rustlang
use datafusion::arrow::array::{ArrayRef, Int64Array, StringArray};
use datafusion::arrow::datatypes::{DataType, Field, Schema};
use datafusion::arrow::record_batch::RecordBatch;
use datafusion::catalog::{TableFunctionImpl, TableProvider};
use datafusion::common::{Result, ScalarValue, plan_err};
use datafusion::datasource::memory::MemTable;
use datafusion::logical_expr::Expr;
use datafusion::prelude::SessionContext;
use std::sync::Arc;
#[derive(Debug)]
pub struct TransformFunction {}
impl TableFunctionImpl for TransformFunction {
fn call(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
if exprs.len() != 3 {
return plan_err!(
"Expected exactly three arguments: a, b, c, but got {}",
exprs.len()
);
}
println!("{:?}", exprs);
let extract_int64 = |expr: &Expr, arg_name: &str| -> Result<i64> {
match expr {
Expr::Literal(ScalarValue::Int64(Some(val)), _) => Ok(*val),
// Expr::Column()
_ => plan_err!("Argument {} must be an Int64 literal",
arg_name),
}
};
let a = extract_int64(&exprs[0], "a")?;
let b = extract_int64(&exprs[1], "b")?;
let c = extract_int64(&exprs[2], "c")?;
// Compute output columns: x = a + b, y = b * c
let x = a + b;
let y = b * c;
// Define output schema
let schema = Arc::new(Schema::new(vec![
Field::new("x", DataType::Int64, false),
Field::new("y", DataType::Int64, false),
]));
// Create output arrays
let x_array = Arc::new(Int64Array::from(vec![x])) as ArrayRef;
let y_array = Arc::new(Int64Array::from(vec![y])) as ArrayRef;
// Create a single RecordBatch
let batch = RecordBatch::try_new(schema.clone(), vec![x_array,
y_array])?;
// Wrap in a MemTable
let provider = MemTable::try_new(schema, vec![vec![batch]])?;
Ok(Arc::new(provider))
}
}
// --- Usage Example ---
// /// Registers the TransformFunction as a TableUDF in the SessionContext.
fn register_udtf(ctx: &mut SessionContext) -> Result<()> {
// 1. Create the implementation instance
let udtf = Arc::new(TransformFunction {});
ctx.register_udtf("my_transform", udtf);
Ok(())
}
/// Creates a small in-memory table for demonstration.
fn create_dummy_table(ctx: &mut SessionContext) -> Result<()> {
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Utf8, false),
Field::new("a", DataType::Int64, false),
Field::new("b", DataType::Int64, false),
Field::new("c", DataType::Int64, false),
]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(StringArray::from(vec!["r1", "r2"])) as ArrayRef,
Arc::new(Int64Array::from(vec![10, 20])) as ArrayRef,
Arc::new(Int64Array::from(vec![5, 6])) as ArrayRef,
Arc::new(Int64Array::from(vec![2, 3])) as ArrayRef,
],
)?;
let provider = MemTable::try_new(schema, vec![vec![batch]])?;
ctx.register_table("my_table", Arc::new(provider))?;
Ok(())
}
#[tokio::main]
async fn main() -> Result<()> {
let mut ctx = SessionContext::new();
// 1. Register the custom UDTF
register_udtf(&mut ctx)?;
// 2. Register a dummy table
create_dummy_table(&mut ctx)?;
// 3. Define and execute the SQL query
let sql = r#"
SELECT
t1.id,
t2.x AS a_plus_b,
t2.y AS b_times_c
FROM
my_table AS t1,
LATERAL my_transform(1, 2, 3) AS t2(x, y)
"#;
// let sql = r#"
// SELECT
// t1.id,
// t2.x AS a_plus_b,
// t2.y AS b_times_c
// FROM
// my_table AS t1,
// LATERAL my_transform(t1.a, t1.b, t1.c) AS t2(x, y)
// "#;
println!("Executing SQL:\n{}", sql);
let df = ctx.sql(sql).await?;
println!("\nQuery Result:");
df.show().await?;
Ok(())
}
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]