This is an automated email from the ASF dual-hosted git repository.

xudong963 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new b5b50ba5c Support `insert into` statement in sqllogictest (#4496)
b5b50ba5c is described below

commit b5b50ba5ca8942aa9494f0c6a46f1ea9fd6bff5d
Author: xudong.w <[email protected]>
AuthorDate: Sun Dec 4 21:17:51 2022 +0800

    Support `insert into` statement in sqllogictest (#4496)
    
    * Support insert statement in sqllogictest
    
    * fix clippy
---
 datafusion/core/Cargo.toml                         |  1 +
 datafusion/core/src/datasource/memory.rs           | 13 ++-
 datafusion/core/src/execution/context.rs           | 16 ++++
 datafusion/core/tests/sqllogictests/src/error.rs   | 83 +++++++++++++++++++
 .../core/tests/sqllogictests/src/insert/mod.rs     | 96 ++++++++++++++++++++++
 .../core/tests/sqllogictests/src/insert/util.rs    | 50 +++++++++++
 datafusion/core/tests/sqllogictests/src/main.rs    | 17 +++-
 .../core/tests/sqllogictests/test_files/insert.slt | 50 +++++++++++
 8 files changed, 319 insertions(+), 7 deletions(-)

diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml
index f3834c4cd..8f40bac70 100644
--- a/datafusion/core/Cargo.toml
+++ b/datafusion/core/Cargo.toml
@@ -110,6 +110,7 @@ doc-comment = "0.3"
 env_logger = "0.10"
 parquet-test-utils = { path = "../../parquet-test-utils" }
 rstest = "0.16.0"
+sqlparser = "0.27"
 test-utils = { path = "../../test-utils" }
 
 [[bench]]
diff --git a/datafusion/core/src/datasource/memory.rs 
b/datafusion/core/src/datasource/memory.rs
index 632ef8d28..a80c4b94d 100644
--- a/datafusion/core/src/datasource/memory.rs
+++ b/datafusion/core/src/datasource/memory.rs
@@ -26,6 +26,7 @@ use std::sync::Arc;
 use arrow::datatypes::SchemaRef;
 use arrow::record_batch::RecordBatch;
 use async_trait::async_trait;
+use parking_lot::RwLock;
 
 use crate::datasource::{TableProvider, TableType};
 use crate::error::{DataFusionError, Result};
@@ -40,7 +41,7 @@ use crate::physical_plan::{repartition::RepartitionExec, 
Partitioning};
 #[derive(Debug)]
 pub struct MemTable {
     schema: SchemaRef,
-    batches: Vec<Vec<RecordBatch>>,
+    batches: Arc<RwLock<Vec<Vec<RecordBatch>>>>,
 }
 
 impl MemTable {
@@ -53,7 +54,7 @@ impl MemTable {
         {
             Ok(Self {
                 schema,
-                batches: partitions,
+                batches: Arc::new(RwLock::new(partitions)),
             })
         } else {
             Err(DataFusionError::Plan(
@@ -117,6 +118,11 @@ impl MemTable {
         }
         MemTable::try_new(schema.clone(), data)
     }
+
+    /// Get record batches in MemTable
+    pub fn get_batches(&self) -> Arc<RwLock<Vec<Vec<RecordBatch>>>> {
+        self.batches.clone()
+    }
 }
 
 #[async_trait]
@@ -140,8 +146,9 @@ impl TableProvider for MemTable {
         _filters: &[Expr],
         _limit: Option<usize>,
     ) -> Result<Arc<dyn ExecutionPlan>> {
+        let batches = self.batches.read();
         Ok(Arc::new(MemoryExec::try_new(
-            &self.batches.clone(),
+            &(*batches).clone(),
             self.schema(),
             projection.cloned(),
         )?))
diff --git a/datafusion/core/src/execution/context.rs 
b/datafusion/core/src/execution/context.rs
index 6b6a4fb1f..9115b2efb 100644
--- a/datafusion/core/src/execution/context.rs
+++ b/datafusion/core/src/execution/context.rs
@@ -958,6 +958,22 @@ impl SessionContext {
         }
     }
 
+    /// Return a [`TabelProvider`] for the specified table.
+    pub fn table_provider<'a>(
+        &self,
+        table_ref: impl Into<TableReference<'a>>,
+    ) -> Result<Arc<dyn TableProvider>> {
+        let table_ref = table_ref.into();
+        let schema = self.state.read().schema_for_ref(table_ref)?;
+        match schema.table(table_ref.table()) {
+            Some(ref provider) => Ok(Arc::clone(provider)),
+            _ => Err(DataFusionError::Plan(format!(
+                "No table named '{}'",
+                table_ref.table()
+            ))),
+        }
+    }
+
     /// Returns the set of available tables in the default catalog and
     /// schema.
     ///
diff --git a/datafusion/core/tests/sqllogictests/src/error.rs 
b/datafusion/core/tests/sqllogictests/src/error.rs
new file mode 100644
index 000000000..8ac482141
--- /dev/null
+++ b/datafusion/core/tests/sqllogictests/src/error.rs
@@ -0,0 +1,83 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use datafusion_common::DataFusionError;
+use sqllogictest::TestError;
+use sqlparser::parser::ParserError;
+use std::error;
+use std::fmt::{Display, Formatter};
+
+pub type Result<T> = std::result::Result<T, DFSqlLogicTestError>;
+
+/// DataFusion sql-logicaltest error
+#[derive(Debug)]
+pub enum DFSqlLogicTestError {
+    /// Error from sqllogictest-rs
+    SqlLogicTest(TestError),
+    /// Error from datafusion
+    DataFusion(DataFusionError),
+    /// Error returned when SQL is syntactically incorrect.
+    Sql(ParserError),
+    /// Error returned on a branch that we know it is possible
+    /// but to which we still have no implementation for.
+    /// Often, these errors are tracked in our issue tracker.
+    NotImplemented(String),
+    /// Error returned from DFSqlLogicTest inner
+    Internal(String),
+}
+
+impl From<TestError> for DFSqlLogicTestError {
+    fn from(value: TestError) -> Self {
+        DFSqlLogicTestError::SqlLogicTest(value)
+    }
+}
+
+impl From<DataFusionError> for DFSqlLogicTestError {
+    fn from(value: DataFusionError) -> Self {
+        DFSqlLogicTestError::DataFusion(value)
+    }
+}
+
+impl From<ParserError> for DFSqlLogicTestError {
+    fn from(value: ParserError) -> Self {
+        DFSqlLogicTestError::Sql(value)
+    }
+}
+
+impl Display for DFSqlLogicTestError {
+    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+        match self {
+            DFSqlLogicTestError::SqlLogicTest(error) => write!(
+                f,
+                "SqlLogicTest error(from sqllogictest-rs crate): {}",
+                error
+            ),
+            DFSqlLogicTestError::DataFusion(error) => {
+                write!(f, "DataFusion error: {}", error)
+            }
+            DFSqlLogicTestError::Sql(error) => write!(f, "SQL Parser error: 
{}", error),
+            DFSqlLogicTestError::NotImplemented(error) => {
+                write!(f, "This feature is not implemented yet: {}", error)
+            }
+            DFSqlLogicTestError::Internal(error) => {
+                write!(f, "Internal error: {}", error)
+            }
+        }
+    }
+}
+
+impl error::Error for DFSqlLogicTestError {}
diff --git a/datafusion/core/tests/sqllogictests/src/insert/mod.rs 
b/datafusion/core/tests/sqllogictests/src/insert/mod.rs
new file mode 100644
index 000000000..100fa1184
--- /dev/null
+++ b/datafusion/core/tests/sqllogictests/src/insert/mod.rs
@@ -0,0 +1,96 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+mod util;
+
+use crate::error::{DFSqlLogicTestError, Result};
+use crate::insert::util::LogicTestContextProvider;
+use datafusion::datasource::MemTable;
+use datafusion::prelude::SessionContext;
+use datafusion_common::{DFSchema, DataFusionError};
+use datafusion_expr::Expr as DFExpr;
+use datafusion_sql::parser::{DFParser, Statement};
+use datafusion_sql::planner::SqlToRel;
+use sqlparser::ast::{Expr, SetExpr, Statement as SQLStatement};
+use std::collections::HashMap;
+
+pub async fn insert(ctx: &SessionContext, sql: String) -> Result<String> {
+    // First, use sqlparser to get table name and insert values
+    let mut table_name = "".to_string();
+    let mut insert_values: Vec<Vec<Expr>> = vec![];
+    if let Statement::Statement(statement) = &DFParser::parse_sql(&sql)?[0] {
+        if let SQLStatement::Insert {
+            table_name: name,
+            source,
+            ..
+        } = &**statement
+        {
+            // Todo: check columns match table schema
+            table_name = name.to_string();
+            match &*source.body {
+                SetExpr::Values(values) => {
+                    insert_values = values.0.clone();
+                }
+                _ => {
+                    return Err(DFSqlLogicTestError::NotImplemented(
+                        "Only support insert values".to_string(),
+                    ));
+                }
+            }
+        }
+    } else {
+        return Err(DFSqlLogicTestError::Internal(format!(
+            "{:?} not an insert statement",
+            sql
+        )));
+    }
+
+    // Second, get table by table name
+    // Here we assume table must be in memory table.
+    let table_provider = ctx.table_provider(table_name.as_str())?;
+    let table_batches = table_provider
+        .as_any()
+        .downcast_ref::<MemTable>()
+        .ok_or_else(|| {
+            DFSqlLogicTestError::NotImplemented(
+                "only support use memory table in logictest".to_string(),
+            )
+        })?
+        .get_batches();
+
+    // Third, transfer insert values to `RecordBatch`
+    // Attention: schema info can be ignored. (insert values don't contain 
schema info)
+    let sql_to_rel = SqlToRel::new(&LogicTestContextProvider {});
+    let mut insert_batches = Vec::with_capacity(insert_values.len());
+    for row in insert_values.into_iter() {
+        let logical_exprs = row
+            .into_iter()
+            .map(|expr| {
+                sql_to_rel.sql_to_rex(expr, &DFSchema::empty(), &mut 
HashMap::new())
+            })
+            .collect::<std::result::Result<Vec<DFExpr>, DataFusionError>>()?;
+        // Directly use `select` to get `RecordBatch`
+        let dataframe = ctx.read_empty()?;
+        insert_batches.push(dataframe.select(logical_exprs)?.collect().await?)
+    }
+
+    // Final, append the `RecordBatch` to memtable's batches
+    let mut table_batches = table_batches.write();
+    table_batches.extend(insert_batches);
+
+    Ok("".to_string())
+}
diff --git a/datafusion/core/tests/sqllogictests/src/insert/util.rs 
b/datafusion/core/tests/sqllogictests/src/insert/util.rs
new file mode 100644
index 000000000..03dbb7299
--- /dev/null
+++ b/datafusion/core/tests/sqllogictests/src/insert/util.rs
@@ -0,0 +1,50 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use arrow::datatypes::DataType;
+use datafusion_common::{ScalarValue, TableReference};
+use datafusion_expr::{AggregateUDF, ScalarUDF, TableSource};
+use datafusion_sql::planner::ContextProvider;
+use std::sync::Arc;
+
+pub struct LogicTestContextProvider {}
+
+// Only a mock, don't need to implement
+impl ContextProvider for LogicTestContextProvider {
+    fn get_table_provider(
+        &self,
+        _name: TableReference,
+    ) -> datafusion_common::Result<Arc<dyn TableSource>> {
+        todo!()
+    }
+
+    fn get_function_meta(&self, _name: &str) -> Option<Arc<ScalarUDF>> {
+        todo!()
+    }
+
+    fn get_aggregate_meta(&self, _name: &str) -> Option<Arc<AggregateUDF>> {
+        todo!()
+    }
+
+    fn get_variable_type(&self, _variable_names: &[String]) -> 
Option<DataType> {
+        todo!()
+    }
+
+    fn get_config_option(&self, _variable: &str) -> Option<ScalarValue> {
+        todo!()
+    }
+}
diff --git a/datafusion/core/tests/sqllogictests/src/main.rs 
b/datafusion/core/tests/sqllogictests/src/main.rs
index fc27773c0..0efac489f 100644
--- a/datafusion/core/tests/sqllogictests/src/main.rs
+++ b/datafusion/core/tests/sqllogictests/src/main.rs
@@ -22,9 +22,11 @@ use datafusion::prelude::{SessionConfig, SessionContext};
 use std::path::Path;
 use std::time::Duration;
 
-use sqllogictest::TestError;
-pub type Result<T> = std::result::Result<T, TestError>;
+use crate::error::{DFSqlLogicTestError, Result};
+use crate::insert::insert;
 
+mod error;
+mod insert;
 mod setup;
 mod utils;
 
@@ -37,7 +39,7 @@ pub struct DataFusion {
 
 #[async_trait]
 impl sqllogictest::AsyncDB for DataFusion {
-    type Error = TestError;
+    type Error = DFSqlLogicTestError;
 
     async fn run(&mut self, sql: &str) -> Result<String> {
         println!("[{}] Running query: \"{}\"", self.file_name, sql);
@@ -138,7 +140,14 @@ fn format_batches(batches: &[RecordBatch]) -> 
Result<String> {
 }
 
 async fn run_query(ctx: &SessionContext, sql: impl Into<String>) -> 
Result<String> {
-    let df = ctx.sql(&sql.into()).await.unwrap();
+    let sql = sql.into();
+    // Check if the sql is `insert`
+    if sql.trim_start().to_lowercase().starts_with("insert") {
+        // Process the insert statement
+        insert(ctx, sql).await?;
+        return Ok("".to_string());
+    }
+    let df = ctx.sql(sql.as_str()).await.unwrap();
     let results: Vec<RecordBatch> = df.collect().await.unwrap();
     let formatted_batches = format_batches(&results)?;
     Ok(formatted_batches)
diff --git a/datafusion/core/tests/sqllogictests/test_files/insert.slt 
b/datafusion/core/tests/sqllogictests/test_files/insert.slt
new file mode 100644
index 000000000..0927b3777
--- /dev/null
+++ b/datafusion/core/tests/sqllogictests/test_files/insert.slt
@@ -0,0 +1,50 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+
+#   http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+statement ok
+CREATE TABLE users AS VALUES(1,2),(2,3);
+
+query II rowsort
+select * from users;
+----
+1 2
+2 3
+
+statement ok
+insert into users values(2, 4);
+
+query II rowsort
+select * from users;
+----
+1 2
+2 3
+2 4
+
+statement ok
+insert into users values(1 + 10, 20);
+
+query II rowsort
+select * from users;
+----
+1 2
+2 3
+2 4
+11 20
+
+# Test insert into a undefined table
+statement error
+insert into user values(1, 20);

Reply via email to