bmmeijers commented on issue #18121:
URL: https://github.com/apache/datafusion/issues/18121#issuecomment-3414714495

   As the pastebin will disappear in 5 days, ad verbatim what I placed there...
   
   ```rustlang
   use datafusion::arrow::array::{ArrayRef, Int64Array, StringArray};
   use datafusion::arrow::datatypes::{DataType, Field, Schema};
   use datafusion::arrow::record_batch::RecordBatch;
   use datafusion::catalog::{TableFunctionImpl, TableProvider};
   use datafusion::common::{Result, ScalarValue, plan_err};
   use datafusion::datasource::memory::MemTable;
   use datafusion::logical_expr::Expr;
   use datafusion::prelude::SessionContext;
   use std::sync::Arc;
   
   #[derive(Debug)]
   pub struct TransformFunction {}
   
   impl TableFunctionImpl for TransformFunction {
       fn call(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
           if exprs.len() != 3 {
               return plan_err!(
                   "Expected exactly three arguments: a, b, c, but got {}",
                   exprs.len()
               );
           }
   
           println!("{:?}", exprs);
   
           let extract_int64 = |expr: &Expr, arg_name: &str| -> Result<i64> {
               match expr {
                   Expr::Literal(ScalarValue::Int64(Some(val)), _) => Ok(*val),
                   // Expr::Column()
                   _ => plan_err!("Argument {} must be an Int64 literal", 
arg_name),
               }
           };
   
           let a = extract_int64(&exprs[0], "a")?;
           let b = extract_int64(&exprs[1], "b")?;
           let c = extract_int64(&exprs[2], "c")?;
   
           // Compute output columns: x = a + b, y = b * c
           let x = a + b;
           let y = b * c;
   
           // Define output schema
           let schema = Arc::new(Schema::new(vec![
               Field::new("x", DataType::Int64, false),
               Field::new("y", DataType::Int64, false),
           ]));
   
           // Create output arrays
           let x_array = Arc::new(Int64Array::from(vec![x])) as ArrayRef;
           let y_array = Arc::new(Int64Array::from(vec![y])) as ArrayRef;
   
           // Create a single RecordBatch
           let batch = RecordBatch::try_new(schema.clone(), vec![x_array, 
y_array])?;
   
           // Wrap in a MemTable
           let provider = MemTable::try_new(schema, vec![vec![batch]])?;
   
           Ok(Arc::new(provider))
       }
   }
   
   // --- Usage Example ---
   
   // /// Registers the TransformFunction as a TableUDF in the SessionContext.
   fn register_udtf(ctx: &mut SessionContext) -> Result<()> {
       // 1. Create the implementation instance
       let udtf = Arc::new(TransformFunction {});
       ctx.register_udtf("my_transform", udtf);
   
       Ok(())
   }
   
   /// Creates a small in-memory table for demonstration.
   fn create_dummy_table(ctx: &mut SessionContext) -> Result<()> {
       let schema = Arc::new(Schema::new(vec![
           Field::new("id", DataType::Utf8, false),
           Field::new("a", DataType::Int64, false),
           Field::new("b", DataType::Int64, false),
           Field::new("c", DataType::Int64, false),
       ]));
   
       let batch = RecordBatch::try_new(
           schema.clone(),
           vec![
               Arc::new(StringArray::from(vec!["r1", "r2"])) as ArrayRef,
               Arc::new(Int64Array::from(vec![10, 20])) as ArrayRef,
               Arc::new(Int64Array::from(vec![5, 6])) as ArrayRef,
               Arc::new(Int64Array::from(vec![2, 3])) as ArrayRef,
           ],
       )?;
   
       let provider = MemTable::try_new(schema, vec![vec![batch]])?;
       ctx.register_table("my_table", Arc::new(provider))?;
       Ok(())
   }
   
   #[tokio::main]
   async fn main() -> Result<()> {
       let mut ctx = SessionContext::new();
   
       // 1. Register the custom UDTF
       register_udtf(&mut ctx)?;
   
       // 2. Register a dummy table
       create_dummy_table(&mut ctx)?;
   
       // 3. Define and execute the SQL query
       let sql = r#"
           SELECT 
               t1.id, 
               t2.x AS a_plus_b, 
               t2.y AS b_times_c
           FROM 
               my_table AS t1,
               LATERAL my_transform(1, 2, 3) AS t2(x, y)
       "#;
   
       // let sql = r#"
       //     SELECT 
       //         t1.id, 
       //         t2.x AS a_plus_b, 
       //         t2.y AS b_times_c
       //     FROM 
       //         my_table AS t1,
       //         LATERAL my_transform(t1.a, t1.b, t1.c) AS t2(x, y)
       // "#;
   
   
       println!("Executing SQL:\n{}", sql);
   
       let df = ctx.sql(sql).await?;
   
       println!("\nQuery Result:");
       df.show().await?;
   
       Ok(())
   }
   ```


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to