GitHub user l1t1 closed a discussion: how to write a UDTF read_pg?
I ask the AI to write for me, it's code shows
```
thread 'main' panicked at
/usr/local/cargo/registry/src/mirrors.tuna.tsinghua.edu.cn-4dc01642fd091eda/postgres-0.19.10/src/config.rs:463:44:
Cannot start a runtime from within a runtime. This happens because a function
(like `block_on`) attempted to block the current thread while the thread is
being used to drive asynchronous tasks.
```
read_pg.rs
```rust
use std::sync::Arc;
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;
use arrow::array::{ArrayRef, StringArray};
use datafusion::common::{DataFusionError, Result};
use datafusion::catalog::{TableFunctionImpl, TableProvider};
use datafusion::datasource::memory::MemTable;
use datafusion_expr::Expr;
use datafusion::common::ScalarValue;
use tokio_postgres::{Client, NoTls, Statement};
use postgres_types::Type;
use tokio::runtime::Runtime;
/// 从PostgreSQL查询数据的表函数
#[derive(Debug)]
pub struct ReadPgFunction;
impl ReadPgFunction {
// 将辅助函数作为结构体的关联函数
fn is_simple_identifier(s: &str) -> bool {
!s.contains(char::is_whitespace) && !s.contains(';') &&
!s.to_lowercase().starts_with("select")
}
fn is_qualified_identifier(s: &str) -> bool {
let parts: Vec<&str> = s.split('.').collect();
(parts.len() == 2 || parts.len() == 3) && parts.iter().all(|p|
Self::is_simple_identifier(p))
}
fn sanitize_identifier(s: &str) -> String {
if s.contains('"') || s.contains('\'') || s.contains(';') {
s.replace('"', "").replace('\'', "").replace(';', "")
} else {
s.to_string()
}
}
fn pg_type_to_arrow(pg_type: &Type) -> DataType {
match *pg_type {
Type::BOOL => DataType::Boolean,
Type::INT2 => DataType::Int16,
Type::INT4 => DataType::Int32,
Type::INT8 => DataType::Int64,
Type::FLOAT4 => DataType::Float32,
Type::FLOAT8 => DataType::Float64,
Type::TEXT | Type::VARCHAR => DataType::Utf8,
Type::TIMESTAMP | Type::TIMESTAMPTZ => {
DataType::Timestamp(arrow::datatypes::TimeUnit::Microsecond,
None)
}
Type::DATE => DataType::Date32,
Type::TIME =>
DataType::Time64(arrow::datatypes::TimeUnit::Microsecond),
_ => DataType::Utf8,
}
}
}
impl TableFunctionImpl for ReadPgFunction {
fn call(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
// 创建新的运行时
let rt = Runtime::new().map_err(|e| {
DataFusionError::Execution(format!("创建Tokio运行时失败: {}", e))
})?;
rt.block_on(async {
// 检查参数数量
if exprs.len() != 2 {
return Err(DataFusionError::Execution(
"read_pg需要2个参数: 连接URI和查询语句/表名".to_string(),
));
}
// 解析连接URI
let uri = match &exprs[0] {
Expr::Literal(ScalarValue::Utf8(Some(uri)), _) => uri,
_ => {
return Err(DataFusionError::Execution(
"第一个参数必须是字符串类型的PostgreSQL连接URI".to_string(),
))
}
};
// 解析查询参数
let query = match &exprs[1] {
Expr::Literal(ScalarValue::Utf8(Some(query)), _) => query,
_ => {
return Err(DataFusionError::Execution(
"第二个参数必须是字符串类型的查询或表名".to_string(),
))
}
};
// 连接PostgreSQL (使用tokio-postgres替代postgres)
let (client, connection) = tokio_postgres::connect(uri, NoTls)
.await
.map_err(|e|
DataFusionError::Execution(format!("连接PostgreSQL失败: {}", e)))?;
// 处理连接任务
tokio::spawn(async move {
if let Err(e) = connection.await {
eprintln!("PostgreSQL连接错误: {}", e);
}
});
// 构建最终查询SQL
let final_query = if Self::is_simple_identifier(query) {
format!("SELECT * FROM {}", Self::sanitize_identifier(query))
} else if Self::is_qualified_identifier(query) {
format!("SELECT * FROM {}", query)
} else {
query.to_string()
};
// 准备查询以获取列信息
let stmt: Statement = client
.prepare(&final_query)
.await
.map_err(|e| DataFusionError::Execution(format!("准备查询失败: {}",
e)))?;
// 获取列信息
let columns = stmt.columns();
let mut fields = Vec::with_capacity(columns.len());
let mut col_types = Vec::with_capacity(columns.len());
for col in columns {
let field = Field::new(col.name(),
Self::pg_type_to_arrow(col.type_()), true);
col_types.push(col.type_().clone());
fields.push(field);
}
let schema = Arc::new(Schema::new(fields.clone()));
// 执行查询
let rows = client
.query(&stmt, &[])
.await
.map_err(|e| DataFusionError::Execution(format!("执行查询失败: {}",
e)))?;
if rows.is_empty() {
return Err(DataFusionError::Execution("查询结果为空".to_string()));
}
// 转换数据
let mut column_data: Vec<Vec<String>> = vec![vec![]; fields.len()];
for row in rows {
for i in 0..fields.len() {
let val: String = match &col_types[i] {
t if *t == Type::BOOL => row.get::<_,
Option<bool>>(i).map(|v| v.to_string()),
t if *t == Type::INT2 => row.get::<_,
Option<i16>>(i).map(|v| v.to_string()),
t if *t == Type::INT4 => row.get::<_,
Option<i32>>(i).map(|v| v.to_string()),
t if *t == Type::INT8 => row.get::<_,
Option<i64>>(i).map(|v| v.to_string()),
t if *t == Type::FLOAT4 => row.get::<_,
Option<f32>>(i).map(|v| v.to_string()),
t if *t == Type::FLOAT8 => row.get::<_,
Option<f64>>(i).map(|v| v.to_string()),
t if *t == Type::TEXT || *t == Type::VARCHAR =>
row.get::<_, Option<String>>(i),
_ => row.get::<_, Option<String>>(i),
}.unwrap_or_else(|| "NULL".to_string());
column_data[i].push(val);
}
}
// 创建RecordBatch
let arrow_columns = column_data
.into_iter()
.map(|col| Arc::new(StringArray::from(col)) as ArrayRef)
.collect::<Vec<_>>();
let batch = RecordBatch::try_new(schema.clone(),
arrow_columns).map_err(|e| {
DataFusionError::Execution(format!("创建RecordBatch失败: {}", e))
})?;
// 创建内存表
let provider = MemTable::try_new(schema,
vec![vec![batch]]).map_err(|e| {
DataFusionError::Execution(format!("创建内存表失败: {}", e))
})?;
// 显式转换为 dyn TableProvider
Ok(Arc::new(provider) as Arc<dyn TableProvider>)
//Ok(Arc::new(provider))
})
}
}
impl Default for ReadPgFunction {
fn default() -> Self {
Self
}
}
```
main.rs
```rust
mod read_pg;
use std::sync::Arc;
use datafusion::execution::context::SessionContext;
use datafusion::error::Result;
use read_pg::ReadPgFunction;
/// 注册所有自定义函数到DataFusion上下文
async fn register_custom_functions(ctx: &SessionContext) -> Result<()> {
// 注册函数
ctx.register_udtf("read_pg", Arc::new(ReadPgFunction::default()));
Ok(())
}
#[tokio::main]
async fn main() -> Result<()> {
let ctx = SessionContext::new();
// 注册自定义函数
register_custom_functions(&ctx).await?;
println!("所有自定义函数已注册成功!");
// 使用示例
let df4 = ctx.sql("SELECT * FROM
read_pg('postgres://lt:xyz123@localhost/postgres', 'mytable')").await?;
df4.show().await?;
let df5 = ctx.sql("SELECT * FROM
read_pg('postgres://lt:xyz123@localhost/postgres', 'SELECT * FROM mytable WHERE
id > 1')").await?;
df5.show().await?;
Ok(())
}
```
cargo.toml
```rust
[package]
name = "dfdf"
version = "0.1.0"
edition = "2024"
[dependencies]
datafusion = "48.0" # 使用最新稳定版
datafusion-expr = "48.0"
datafusion-common = "48.0"
arrow = "55.1"
tokio = { version = "1.0", features = ["full"] }
async-trait = "0.1"
num-bigint = "0.4"
calamine = "0.28" # Excel 文件支持
encoding_rs = "0.8" # 编码转换
postgres = { version = "0.19", features = ["with-uuid-1"] } # PostgreSQL 支持
csv = "1.2"
postgres-types = "0.2.9"
tokio-postgres = { version = "0.7", features = ["with-uuid-1"] }
futures = "0.3"
```
GitHub link: https://github.com/apache/datafusion/discussions/16564
----
This is an automatically sent email for [email protected].
To unsubscribe, please send an email to:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]