GitHub user l1t1 closed a discussion: how to write a UDTF read_pg?

I ask the AI to write for me, it's code shows
```
thread 'main' panicked at 
/usr/local/cargo/registry/src/mirrors.tuna.tsinghua.edu.cn-4dc01642fd091eda/postgres-0.19.10/src/config.rs:463:44:
Cannot start a runtime from within a runtime. This happens because a function 
(like `block_on`) attempted to block the current thread while the thread is 
being used to drive asynchronous tasks.
```
read_pg.rs
```rust
use std::sync::Arc;
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;
use arrow::array::{ArrayRef, StringArray};
use datafusion::common::{DataFusionError, Result};
use datafusion::catalog::{TableFunctionImpl, TableProvider};
use datafusion::datasource::memory::MemTable;
use datafusion_expr::Expr;
use datafusion::common::ScalarValue;
use tokio_postgres::{Client, NoTls, Statement};
use postgres_types::Type;
use tokio::runtime::Runtime;

/// 从PostgreSQL查询数据的表函数
#[derive(Debug)]
pub struct ReadPgFunction;

impl ReadPgFunction {
    // 将辅助函数作为结构体的关联函数
    fn is_simple_identifier(s: &str) -> bool {
        !s.contains(char::is_whitespace) && !s.contains(';') && 
!s.to_lowercase().starts_with("select")
    }

    fn is_qualified_identifier(s: &str) -> bool {
        let parts: Vec<&str> = s.split('.').collect();
        (parts.len() == 2 || parts.len() == 3) && parts.iter().all(|p| 
Self::is_simple_identifier(p))
    }

    fn sanitize_identifier(s: &str) -> String {
        if s.contains('"') || s.contains('\'') || s.contains(';') {
            s.replace('"', "").replace('\'', "").replace(';', "")
        } else {
            s.to_string()
        }
    }

    fn pg_type_to_arrow(pg_type: &Type) -> DataType {
        match *pg_type {
            Type::BOOL => DataType::Boolean,
            Type::INT2 => DataType::Int16,
            Type::INT4 => DataType::Int32,
            Type::INT8 => DataType::Int64,
            Type::FLOAT4 => DataType::Float32,
            Type::FLOAT8 => DataType::Float64,
            Type::TEXT | Type::VARCHAR => DataType::Utf8,
            Type::TIMESTAMP | Type::TIMESTAMPTZ => {
                DataType::Timestamp(arrow::datatypes::TimeUnit::Microsecond, 
None)
            }
            Type::DATE => DataType::Date32,
            Type::TIME => 
DataType::Time64(arrow::datatypes::TimeUnit::Microsecond),
            _ => DataType::Utf8,
        }
    }
}

impl TableFunctionImpl for ReadPgFunction {
    fn call(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
        // 创建新的运行时
        let rt = Runtime::new().map_err(|e| {
            DataFusionError::Execution(format!("创建Tokio运行时失败: {}", e))
        })?;

        rt.block_on(async {
            // 检查参数数量
            if exprs.len() != 2 {
                return Err(DataFusionError::Execution(
                    "read_pg需要2个参数: 连接URI和查询语句/表名".to_string(),
                ));
            }

            // 解析连接URI
            let uri = match &exprs[0] {
                Expr::Literal(ScalarValue::Utf8(Some(uri)), _) => uri,
                _ => {
                    return Err(DataFusionError::Execution(
                        "第一个参数必须是字符串类型的PostgreSQL连接URI".to_string(),
                    ))
                }
            };

            // 解析查询参数
            let query = match &exprs[1] {
                Expr::Literal(ScalarValue::Utf8(Some(query)), _) => query,
                _ => {
                    return Err(DataFusionError::Execution(
                        "第二个参数必须是字符串类型的查询或表名".to_string(),
                    ))
                }
            };

            // 连接PostgreSQL (使用tokio-postgres替代postgres)
            let (client, connection) = tokio_postgres::connect(uri, NoTls)
                .await
                .map_err(|e| 
DataFusionError::Execution(format!("连接PostgreSQL失败: {}", e)))?;

            // 处理连接任务
            tokio::spawn(async move {
                if let Err(e) = connection.await {
                    eprintln!("PostgreSQL连接错误: {}", e);
                }
            });

            // 构建最终查询SQL
            let final_query = if Self::is_simple_identifier(query) {
                format!("SELECT * FROM {}", Self::sanitize_identifier(query))
            } else if Self::is_qualified_identifier(query) {
                format!("SELECT * FROM {}", query)
            } else {
                query.to_string()
            };

            // 准备查询以获取列信息
            let stmt: Statement = client
                .prepare(&final_query)
                .await
                .map_err(|e| DataFusionError::Execution(format!("准备查询失败: {}", 
e)))?;

            // 获取列信息
            let columns = stmt.columns();
            let mut fields = Vec::with_capacity(columns.len());
            let mut col_types = Vec::with_capacity(columns.len());

            for col in columns {
                let field = Field::new(col.name(), 
Self::pg_type_to_arrow(col.type_()), true);
                col_types.push(col.type_().clone());
                fields.push(field);
            }

            let schema = Arc::new(Schema::new(fields.clone()));

            // 执行查询
            let rows = client
                .query(&stmt, &[])
                .await
                .map_err(|e| DataFusionError::Execution(format!("执行查询失败: {}", 
e)))?;

            if rows.is_empty() {
                return Err(DataFusionError::Execution("查询结果为空".to_string()));
            }

            // 转换数据
            let mut column_data: Vec<Vec<String>> = vec![vec![]; fields.len()];
            
            for row in rows {
                for i in 0..fields.len() {
                    let val: String = match &col_types[i] {
                        t if *t == Type::BOOL => row.get::<_, 
Option<bool>>(i).map(|v| v.to_string()),
                        t if *t == Type::INT2 => row.get::<_, 
Option<i16>>(i).map(|v| v.to_string()),
                        t if *t == Type::INT4 => row.get::<_, 
Option<i32>>(i).map(|v| v.to_string()),
                        t if *t == Type::INT8 => row.get::<_, 
Option<i64>>(i).map(|v| v.to_string()),
                        t if *t == Type::FLOAT4 => row.get::<_, 
Option<f32>>(i).map(|v| v.to_string()),
                        t if *t == Type::FLOAT8 => row.get::<_, 
Option<f64>>(i).map(|v| v.to_string()),
                        t if *t == Type::TEXT || *t == Type::VARCHAR => 
row.get::<_, Option<String>>(i),
                        _ => row.get::<_, Option<String>>(i),
                    }.unwrap_or_else(|| "NULL".to_string());
                    
                    column_data[i].push(val);
                }
            }

            // 创建RecordBatch
            let arrow_columns = column_data
                .into_iter()
                .map(|col| Arc::new(StringArray::from(col)) as ArrayRef)
                .collect::<Vec<_>>();

            let batch = RecordBatch::try_new(schema.clone(), 
arrow_columns).map_err(|e| {
                DataFusionError::Execution(format!("创建RecordBatch失败: {}", e))
            })?;


            // 创建内存表
            let provider = MemTable::try_new(schema, 
vec![vec![batch]]).map_err(|e| {
                DataFusionError::Execution(format!("创建内存表失败: {}", e))
            })?;

            // 显式转换为 dyn TableProvider
            Ok(Arc::new(provider) as Arc<dyn TableProvider>)
            //Ok(Arc::new(provider))
 
 
        })
    }
}

impl Default for ReadPgFunction {
    fn default() -> Self {
        Self
    }
}
```
main.rs
```rust

mod read_pg;

use std::sync::Arc;
use datafusion::execution::context::SessionContext;
use datafusion::error::Result;


use read_pg::ReadPgFunction;

/// 注册所有自定义函数到DataFusion上下文
async fn register_custom_functions(ctx: &SessionContext) -> Result<()> {

    // 注册函数
    ctx.register_udtf("read_pg", Arc::new(ReadPgFunction::default()));    
    Ok(())
}

#[tokio::main]
async fn main() -> Result<()> {
    let ctx = SessionContext::new();
    
    // 注册自定义函数
    register_custom_functions(&ctx).await?;
    
    println!("所有自定义函数已注册成功!");

    // 使用示例
    let df4 = ctx.sql("SELECT * FROM 
read_pg('postgres://lt:xyz123@localhost/postgres', 'mytable')").await?;
    df4.show().await?;
    let df5 = ctx.sql("SELECT * FROM 
read_pg('postgres://lt:xyz123@localhost/postgres', 'SELECT * FROM mytable WHERE 
id > 1')").await?;
    df5.show().await?;    

    Ok(())
}
```
cargo.toml
```rust
[package]
name = "dfdf"
version = "0.1.0"
edition = "2024"

[dependencies]
datafusion = "48.0"  # 使用最新稳定版
datafusion-expr = "48.0"
datafusion-common = "48.0"
arrow = "55.1"
tokio = { version = "1.0", features = ["full"] }
async-trait = "0.1"
num-bigint = "0.4"
calamine = "0.28"    # Excel 文件支持
encoding_rs = "0.8"  # 编码转换
postgres = { version = "0.19", features = ["with-uuid-1"] }  # PostgreSQL 支持
csv = "1.2"
postgres-types = "0.2.9"

tokio-postgres = { version = "0.7", features = ["with-uuid-1"] }

futures = "0.3"
```

GitHub link: https://github.com/apache/datafusion/discussions/16564

----
This is an automatically sent email for [email protected].
To unsubscribe, please send an email to: 
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to