This is an automated email from the ASF dual-hosted git repository.
houqp pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 9e75ff5 Introduce JIT code generation (#1849)
9e75ff5 is described below
commit 9e75ff5fade635c6962064e5afc59d714184208c
Author: Yijie Shen <[email protected]>
AuthorDate: Thu Feb 24 13:17:03 2022 +0800
Introduce JIT code generation (#1849)
---
Cargo.toml | 1 +
datafusion-common/Cargo.toml | 2 +
datafusion-common/src/error.rs | 23 +
{datafusion-common => datafusion-jit}/Cargo.toml | 21 +-
datafusion-jit/src/api.rs | 630 +++++++++++++++++++++
datafusion-jit/src/ast.rs | 359 ++++++++++++
datafusion-jit/src/jit.rs | 676 +++++++++++++++++++++++
datafusion-jit/src/lib.rs | 110 ++++
datafusion/Cargo.toml | 8 +
datafusion/benches/data_utils/mod.rs | 10 +-
datafusion/benches/jit.rs | 58 ++
datafusion/src/lib.rs | 2 +-
datafusion/src/row/mod.rs | 87 ++-
datafusion/src/row/reader.rs | 147 ++++-
datafusion/src/row/writer.rs | 392 ++++++++++---
15 files changed, 2426 insertions(+), 100 deletions(-)
diff --git a/Cargo.toml b/Cargo.toml
index beaa22d..65dd722 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -20,6 +20,7 @@ members = [
"datafusion",
"datafusion-common",
"datafusion-expr",
+ "datafusion-jit",
"datafusion-physical-expr",
"datafusion-cli",
"datafusion-examples",
diff --git a/datafusion-common/Cargo.toml b/datafusion-common/Cargo.toml
index a557313..30831b8 100644
--- a/datafusion-common/Cargo.toml
+++ b/datafusion-common/Cargo.toml
@@ -35,6 +35,7 @@ path = "src/lib.rs"
[features]
avro = ["avro-rs"]
pyarrow = ["pyo3"]
+jit = ["cranelift-module"]
[dependencies]
arrow = { version = "9.0.0", features = ["prettyprint"] }
@@ -43,3 +44,4 @@ avro-rs = { version = "0.13", features = ["snappy"], optional
= true }
pyo3 = { version = "0.15", optional = true }
sqlparser = "0.14"
ordered-float = "2.10"
+cranelift-module = { version = "0.81.1", optional = true }
diff --git a/datafusion-common/src/error.rs b/datafusion-common/src/error.rs
index 93978db..ec59a8a 100644
--- a/datafusion-common/src/error.rs
+++ b/datafusion-common/src/error.rs
@@ -25,6 +25,8 @@ use std::result;
use arrow::error::ArrowError;
#[cfg(feature = "avro")]
use avro_rs::Error as AvroError;
+#[cfg(feature = "jit")]
+use cranelift_module::ModuleError;
use parquet::errors::ParquetError;
use sqlparser::parser::ParserError;
@@ -69,6 +71,9 @@ pub enum DataFusionError {
/// Errors originating from outside DataFusion's core codebase.
/// For example, a custom S3Error from the crate datafusion-objectstore-s3
External(GenericError),
+ #[cfg(feature = "jit")]
+ /// Error occurs during code generation
+ JITError(ModuleError),
}
impl From<io::Error> for DataFusionError {
@@ -112,6 +117,13 @@ impl From<ParserError> for DataFusionError {
}
}
+#[cfg(feature = "jit")]
+impl From<ModuleError> for DataFusionError {
+ fn from(e: ModuleError) -> Self {
+ DataFusionError::JITError(e)
+ }
+}
+
impl From<GenericError> for DataFusionError {
fn from(err: GenericError) -> Self {
DataFusionError::External(err)
@@ -152,6 +164,10 @@ impl Display for DataFusionError {
DataFusionError::External(ref desc) => {
write!(f, "External error: {}", desc)
}
+ #[cfg(feature = "jit")]
+ DataFusionError::JITError(ref desc) => {
+ write!(f, "JIT error: {}", desc)
+ }
}
}
}
@@ -196,3 +212,10 @@ mod test {
Ok(())
}
}
+
+#[macro_export]
+macro_rules! internal_err {
+ ($($arg:tt)*) => {
+ Err(DataFusionError::Internal(format!($($arg)*)))
+ };
+}
diff --git a/datafusion-common/Cargo.toml b/datafusion-jit/Cargo.toml
similarity index 76%
copy from datafusion-common/Cargo.toml
copy to datafusion-jit/Cargo.toml
index a557313..aaca90a 100644
--- a/datafusion-common/Cargo.toml
+++ b/datafusion-jit/Cargo.toml
@@ -16,12 +16,12 @@
# under the License.
[package]
-name = "datafusion-common"
+name = "datafusion-jit"
description = "DataFusion is an in-memory query engine that uses Apache Arrow
as the memory model"
version = "7.0.0"
homepage = "https://github.com/apache/arrow-datafusion"
repository = "https://github.com/apache/arrow-datafusion"
-readme = "README.md"
+readme = "../README.md"
authors = ["Apache Arrow <[email protected]>"]
license = "Apache-2.0"
keywords = [ "arrow", "query", "sql" ]
@@ -29,17 +29,16 @@ edition = "2021"
rust-version = "1.58"
[lib]
-name = "datafusion_common"
+name = "datafusion_jit"
path = "src/lib.rs"
[features]
-avro = ["avro-rs"]
-pyarrow = ["pyo3"]
+jit = []
[dependencies]
-arrow = { version = "9.0.0", features = ["prettyprint"] }
-parquet = { version = "9.0.0", features = ["arrow"] }
-avro-rs = { version = "0.13", features = ["snappy"], optional = true }
-pyo3 = { version = "0.15", optional = true }
-sqlparser = "0.14"
-ordered-float = "2.10"
+datafusion-common = { path = "../datafusion-common", version = "7.0.0",
features = ["jit"] }
+cranelift = "0.81.1"
+cranelift-module = "0.81.1"
+cranelift-jit = "0.81.1"
+cranelift-native = "0.81.1"
+parking_lot = "0.12"
diff --git a/datafusion-jit/src/api.rs b/datafusion-jit/src/api.rs
new file mode 100644
index 0000000..d95f9cc
--- /dev/null
+++ b/datafusion-jit/src/api.rs
@@ -0,0 +1,630 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Constructing a function AST at runtime.
+
+use crate::ast::*;
+use crate::jit::JIT;
+use datafusion_common::internal_err;
+use datafusion_common::{DataFusionError, Result};
+use parking_lot::Mutex;
+use std::collections::HashMap;
+use std::collections::VecDeque;
+use std::fmt::{Debug, Display, Formatter};
+use std::sync::Arc;
+
+/// External Function signature
+struct ExternFuncSignature {
+ name: String,
+ /// pointer to the function
+ code: *const u8,
+ params: Vec<JITType>,
+ returns: Option<JITType>,
+}
+
+#[derive(Clone, Debug)]
+/// A function consisting of AST nodes that JIT can compile.
+pub struct GeneratedFunction {
+ pub(crate) name: String,
+ pub(crate) params: Vec<(String, JITType)>,
+ pub(crate) body: Vec<Stmt>,
+ pub(crate) ret: Option<(String, JITType)>,
+}
+
+#[derive(Default)]
+/// State of Assembler, keep tracking of generated function names
+/// and registered external functions.
+pub struct AssemblerState {
+ name_next_id: HashMap<String, u8>,
+ extern_funcs: HashMap<String, ExternFuncSignature>,
+}
+
+impl AssemblerState {
+ /// Create a fresh function name with prefix `name`.
+ pub fn fresh_name(&mut self, name: impl Into<String>) -> String {
+ let name = name.into();
+ if !self.name_next_id.contains_key(&name) {
+ self.name_next_id.insert(name.clone(), 0);
+ }
+
+ let id = self.name_next_id.get_mut(&name).unwrap();
+ let name = format!("{}_{}", &name, id);
+ *id += 1;
+ name
+ }
+}
+
+/// The very first step for constructing a function at runtime.
+pub struct Assembler {
+ state: Arc<Mutex<AssemblerState>>,
+}
+
+impl Default for Assembler {
+ fn default() -> Self {
+ Self {
+ state: Arc::new(Default::default()),
+ }
+ }
+}
+
+impl Assembler {
+ /// Register an external Rust function to make it accessible by runtime
generated functions.
+ /// Parameters and return types are used to impose type safety while
constructing an AST.
+ pub fn register_extern_fn(
+ &self,
+ name: impl Into<String>,
+ ptr: *const u8,
+ params: Vec<JITType>,
+ returns: Option<JITType>,
+ ) -> Result<()> {
+ let extern_funcs = &mut self.state.lock().extern_funcs;
+ let fn_name = name.into();
+ let old = extern_funcs.insert(
+ fn_name.clone(),
+ ExternFuncSignature {
+ name: fn_name,
+ code: ptr,
+ params,
+ returns,
+ },
+ );
+
+ match old {
+ None => Ok(()),
+ Some(old) => internal_err!("Extern function {} already exists",
old.name),
+ }
+ }
+
+ /// Create a new FunctionBuilder with `name` prefix
+ pub fn new_func_builder(&self, name: impl Into<String>) -> FunctionBuilder
{
+ let name = self.state.lock().fresh_name(name);
+ FunctionBuilder::new(name, self.state.clone())
+ }
+
+ /// Create JIT env which we could compile the AST of constructed function
+ /// into runnable code.
+ pub fn create_jit(&self) -> JIT {
+ let symbols = self
+ .state
+ .lock()
+ .extern_funcs
+ .values()
+ .map(|s| (s.name.clone(), s.code))
+ .collect::<Vec<_>>();
+ JIT::new(symbols)
+ }
+}
+
+/// Function builder API. Stores the state while
+/// we are constructing an AST for a function.
+pub struct FunctionBuilder {
+ name: String,
+ params: Vec<(String, JITType)>,
+ ret: Option<(String, JITType)>,
+ fields: VecDeque<HashMap<String, JITType>>,
+ assembler_state: Arc<Mutex<AssemblerState>>,
+}
+
+impl FunctionBuilder {
+ fn new(name: impl Into<String>, assembler_state:
Arc<Mutex<AssemblerState>>) -> Self {
+ let mut fields = VecDeque::new();
+ fields.push_back(HashMap::new());
+ Self {
+ name: name.into(),
+ params: Vec::new(),
+ ret: None,
+ fields,
+ assembler_state,
+ }
+ }
+
+ /// Add one more parameter to the function.
+ pub fn param(mut self, name: impl Into<String>, ty: JITType) -> Self {
+ let name = name.into();
+ assert!(!self.fields.back().unwrap().contains_key(&name));
+ self.params.push((name.clone(), ty));
+ self.fields.back_mut().unwrap().insert(name, ty);
+ self
+ }
+
+ /// Set return type for the function. Functions are of `void` type by
default if
+ /// you do not set the return type.
+ pub fn ret(mut self, name: impl Into<String>, ty: JITType) -> Self {
+ let name = name.into();
+ assert!(!self.fields.back().unwrap().contains_key(&name));
+ self.ret = Some((name.clone(), ty));
+ self.fields.back_mut().unwrap().insert(name, ty);
+ self
+ }
+
+ /// Enter the function body at start the building.
+ pub fn enter_block(&mut self) -> CodeBlock {
+ self.fields.push_back(HashMap::new());
+ CodeBlock {
+ fields: &mut self.fields,
+ state: &self.assembler_state,
+ stmts: vec![],
+ while_state: None,
+ if_state: None,
+ fn_state: Some(GeneratedFunction {
+ name: self.name.clone(),
+ params: self.params.clone(),
+ body: vec![],
+ ret: self.ret.clone(),
+ }),
+ }
+ }
+}
+
+/// Keep `while` condition expr as we are constructing while loop body.
+struct WhileState {
+ condition: Expr,
+}
+
+/// Keep `if-then-else` state, including condition expr, the already built
+/// then statements (if we are during building the else block).
+struct IfElseState {
+ condition: Expr,
+ then_stmts: Vec<Stmt>,
+ in_then: bool,
+}
+
+impl IfElseState {
+ /// Move the all current statements in the `then` block and move to `else`
block.
+ fn enter_else(&mut self, then_stmts: Vec<Stmt>) {
+ self.then_stmts = then_stmts;
+ self.in_then = false;
+ }
+}
+
+/// Code block consists of statements and acts as anonymous namespace scope
for items and variable declarations.
+pub struct CodeBlock<'a> {
+ /// A stack that containing all defined variables so far. The variables
defined
+ /// in the current block are at the top stack frame.
+ /// Fields provides a shadow semantics of the same name in outsider block,
and are
+ /// used to guarantee type safety while constructing AST.
+ fields: &'a mut VecDeque<HashMap<String, JITType>>,
+ /// The state of Assembler, used for type checking function calls.
+ state: &'a Arc<Mutex<AssemblerState>>,
+ /// Holding all statements for the current code block.
+ stmts: Vec<Stmt>,
+ while_state: Option<WhileState>,
+ if_state: Option<IfElseState>,
+ /// Keep track of function params and return types, only valid for
function main block.
+ fn_state: Option<GeneratedFunction>,
+}
+
+impl<'a> CodeBlock<'a> {
+ pub fn build(&mut self) -> GeneratedFunction {
+ assert!(
+ self.fn_state.is_some(),
+ "Can only call build on outermost function block"
+ );
+ let mut gen = self.fn_state.take().unwrap();
+ gen.body = self.stmts.drain(..).collect::<Vec<_>>();
+ gen
+ }
+
+ /// Leave the current block and returns the statements constructed.
+ fn leave(&mut self) -> Result<Stmt> {
+ self.fields.pop_back();
+ if let Some(ref mut while_state) = self.while_state {
+ let WhileState { condition } = while_state;
+ let stmts = self.stmts.drain(..).collect::<Vec<_>>();
+ return Ok(Stmt::WhileLoop(Box::new(condition.clone()), stmts));
+ }
+
+ if let Some(ref mut if_state) = self.if_state {
+ let IfElseState {
+ condition,
+ then_stmts,
+ in_then,
+ } = if_state;
+ return if *in_then {
+ assert!(then_stmts.is_empty());
+ let stmts = self.stmts.drain(..).collect::<Vec<_>>();
+ Ok(Stmt::IfElse(Box::new(condition.clone()), stmts,
Vec::new()))
+ } else {
+ assert!(!then_stmts.is_empty());
+ let then_stmts = then_stmts.drain(..).collect::<Vec<_>>();
+ let else_stmts = self.stmts.drain(..).collect::<Vec<_>>();
+ Ok(Stmt::IfElse(
+ Box::new(condition.clone()),
+ then_stmts,
+ else_stmts,
+ ))
+ };
+ }
+ unreachable!()
+ }
+
+ /// Enter else block. Try [if_block] first which is much easier to use.
+ fn enter_else(&mut self) {
+ self.fields.pop_back();
+ self.fields.push_back(HashMap::new());
+ assert!(self.if_state.is_some() &&
self.if_state.as_ref().unwrap().in_then);
+ let new_then = self.stmts.drain(..).collect::<Vec<_>>();
+ if let Some(s) = self.if_state.iter_mut().next() {
+ s.enter_else(new_then)
+ }
+ }
+
+ /// Declare variable `name` of a type.
+ pub fn declare(&mut self, name: impl Into<String>, ty: JITType) ->
Result<()> {
+ let name = name.into();
+ let typ = self.fields.back().unwrap().get(&name);
+ match typ {
+ Some(typ) => internal_err!(
+ "Variable {} of {} already exists in the current scope",
+ name,
+ typ
+ ),
+ None => {
+ self.fields.back_mut().unwrap().insert(name.clone(), ty);
+ self.stmts.push(Stmt::Declare(name, ty));
+ Ok(())
+ }
+ }
+ }
+
+ fn find_type(&self, name: impl Into<String>) -> Option<JITType> {
+ let name = name.into();
+ for scope in self.fields.iter().rev() {
+ let typ = scope.get(&name);
+ if let Some(typ) = typ {
+ return Some(*typ);
+ }
+ }
+ None
+ }
+
+ /// Assignment statement. Assign a expression value to a variable.
+ pub fn assign(&mut self, name: impl Into<String>, expr: Expr) ->
Result<()> {
+ let name = name.into();
+ let typ = self.find_type(&name);
+ match typ {
+ Some(typ) => {
+ if typ != expr.get_type() {
+ internal_err!(
+ "Variable {} of {} cannot be assigned to {}",
+ name,
+ typ,
+ expr.get_type()
+ )
+ } else {
+ self.stmts.push(Stmt::Assign(name, Box::new(expr)));
+ Ok(())
+ }
+ }
+ None => internal_err!("unknown identifier: {}", name),
+ }
+ }
+
+ /// Declare variable with initialization.
+ pub fn declare_as(&mut self, name: impl Into<String>, expr: Expr) ->
Result<()> {
+ let name = name.into();
+ let typ = self.fields.back().unwrap().get(&name);
+ match typ {
+ Some(typ) => {
+ internal_err!(
+ "Variable {} of {} already exists in the current scope",
+ name,
+ typ
+ )
+ }
+ None => {
+ self.fields
+ .back_mut()
+ .unwrap()
+ .insert(name.clone(), expr.get_type());
+ self.stmts
+ .push(Stmt::Declare(name.clone(), expr.get_type()));
+ self.stmts.push(Stmt::Assign(name, Box::new(expr)));
+ Ok(())
+ }
+ }
+ }
+
+ /// Call external function for side effect only.
+ pub fn call_stmt(&mut self, name: impl Into<String>, args: Vec<Expr>) ->
Result<()> {
+ self.stmts.push(Stmt::Call(name.into(), args));
+ Ok(())
+ }
+
+ /// Enter `while` loop block. Try [while_block] first which is much easier
to use.
+ fn while_loop(&mut self, cond: Expr) -> Result<CodeBlock> {
+ if cond.get_type() != BOOL {
+ internal_err!("while condition must be bool")
+ } else {
+ self.fields.push_back(HashMap::new());
+ Ok(CodeBlock {
+ fields: self.fields,
+ state: self.state,
+ stmts: vec![],
+ while_state: Some(WhileState { condition: cond }),
+ if_state: None,
+ fn_state: None,
+ })
+ }
+ }
+
+ /// Enter `if-then-else`'s then block. Try [if_block] first which is much
easier to use.
+ fn if_else(&mut self, cond: Expr) -> Result<CodeBlock> {
+ if cond.get_type() != BOOL {
+ internal_err!("if condition must be bool")
+ } else {
+ self.fields.push_back(HashMap::new());
+ Ok(CodeBlock {
+ fields: self.fields,
+ state: self.state,
+ stmts: vec![],
+ while_state: None,
+ if_state: Some(IfElseState {
+ condition: cond,
+ then_stmts: vec![],
+ in_then: true,
+ }),
+ fn_state: None,
+ })
+ }
+ }
+
+ /// Construct a `if-then-else` block with each part provided.
+ ///
+ /// E.g. if n == 0 { r = 0 } else { r = 1} could be write as:
+ /// x.if_block(
+ /// |cond| cond.eq(cond.id("n")?, cond.lit_i(0)),
+ /// |t| {
+ /// t.assign("r", t.lit_i(0))?;
+ /// Ok(())
+ /// },
+ /// |e| t.assign("r", t.lit_i(1))?;
+ /// Ok(())
+ /// },
+ /// )?;
+ pub fn if_block<C, T, E>(
+ &mut self,
+ mut cond: C,
+ mut then_blk: T,
+ mut else_blk: E,
+ ) -> Result<()>
+ where
+ C: FnMut(&mut CodeBlock) -> Result<Expr>,
+ T: FnMut(&mut CodeBlock) -> Result<()>,
+ E: FnMut(&mut CodeBlock) -> Result<()>,
+ {
+ let cond = cond(self)?;
+ let mut body = self.if_else(cond)?;
+ then_blk(&mut body)?;
+ body.enter_else();
+ else_blk(&mut body)?;
+ let if_else = body.leave()?;
+ self.stmts.push(if_else);
+ Ok(())
+ }
+
+ /// Construct a `while` block with each part provided.
+ ///
+ /// E.g. while n != 0 { n = n - 1;} could be write as:
+ /// x.while_block(
+ /// |cond| cond.ne(cond.id("n")?, cond.lit_i(0)),
+ /// |w| {
+ /// w.assign("n", w.sub(w.id("n")?, w.lit_i(1))?)?;
+ /// Ok(())
+ /// },
+ /// )?;
+ pub fn while_block<C, B>(&mut self, mut cond: C, mut body_blk: B) ->
Result<()>
+ where
+ C: FnMut(&mut CodeBlock) -> Result<Expr>,
+ B: FnMut(&mut CodeBlock) -> Result<()>,
+ {
+ let cond = cond(self)?;
+ let mut body = self.while_loop(cond)?;
+ body_blk(&mut body)?;
+ let while_stmt = body.leave()?;
+ self.stmts.push(while_stmt);
+ Ok(())
+ }
+
+ /// Create a literal `val` of `ty` type.
+ pub fn lit(&self, val: impl Into<String>, ty: JITType) -> Expr {
+ Expr::Literal(Literal::Parsing(val.into(), ty))
+ }
+
+ /// Shorthand to create i64 literal
+ pub fn lit_i(&self, val: impl Into<i64>) -> Expr {
+ Expr::Literal(Literal::Typed(TypedLit::Int(val.into())))
+ }
+
+ /// Shorthand to create f32 literal
+ pub fn lit_f(&self, val: f32) -> Expr {
+ Expr::Literal(Literal::Typed(TypedLit::Float(val)))
+ }
+
+ /// Shorthand to create f64 literal
+ pub fn lit_d(&self, val: f64) -> Expr {
+ Expr::Literal(Literal::Typed(TypedLit::Double(val)))
+ }
+
+ /// Shorthand to create boolean literal
+ pub fn lit_b(&self, val: bool) -> Expr {
+ Expr::Literal(Literal::Typed(TypedLit::Bool(val)))
+ }
+
+ /// Create a reference to an already defined variable.
+ pub fn id(&self, name: impl Into<String>) -> Result<Expr> {
+ let name = name.into();
+ match self.find_type(&name) {
+ None => internal_err!("unknown identifier: {}", name),
+ Some(typ) => Ok(Expr::Identifier(name, typ)),
+ }
+ }
+
+ /// Binary comparison expression: lhs == rhs
+ pub fn eq(&self, lhs: Expr, rhs: Expr) -> Result<Expr> {
+ if lhs.get_type() != rhs.get_type() {
+ internal_err!("cannot compare {} and {}", lhs.get_type(),
rhs.get_type())
+ } else {
+ Ok(Expr::Binary(BinaryExpr::Eq(Box::new(lhs), Box::new(rhs))))
+ }
+ }
+
+ /// Binary comparison expression: lhs != rhs
+ pub fn ne(&self, lhs: Expr, rhs: Expr) -> Result<Expr> {
+ if lhs.get_type() != rhs.get_type() {
+ internal_err!("cannot compare {} and {}", lhs.get_type(),
rhs.get_type())
+ } else {
+ Ok(Expr::Binary(BinaryExpr::Ne(Box::new(lhs), Box::new(rhs))))
+ }
+ }
+
+ /// Binary comparison expression: lhs < rhs
+ pub fn lt(&self, lhs: Expr, rhs: Expr) -> Result<Expr> {
+ if lhs.get_type() != rhs.get_type() {
+ internal_err!("cannot compare {} and {}", lhs.get_type(),
rhs.get_type())
+ } else {
+ Ok(Expr::Binary(BinaryExpr::Lt(Box::new(lhs), Box::new(rhs))))
+ }
+ }
+
+ /// Binary comparison expression: lhs <= rhs
+ pub fn le(&self, lhs: Expr, rhs: Expr) -> Result<Expr> {
+ if lhs.get_type() != rhs.get_type() {
+ internal_err!("cannot compare {} and {}", lhs.get_type(),
rhs.get_type())
+ } else {
+ Ok(Expr::Binary(BinaryExpr::Le(Box::new(lhs), Box::new(rhs))))
+ }
+ }
+
+ /// Binary comparison expression: lhs > rhs
+ pub fn gt(&self, lhs: Expr, rhs: Expr) -> Result<Expr> {
+ if lhs.get_type() != rhs.get_type() {
+ internal_err!("cannot compare {} and {}", lhs.get_type(),
rhs.get_type())
+ } else {
+ Ok(Expr::Binary(BinaryExpr::Gt(Box::new(lhs), Box::new(rhs))))
+ }
+ }
+
+ /// Binary comparison expression: lhs >= rhs
+ pub fn ge(&self, lhs: Expr, rhs: Expr) -> Result<Expr> {
+ if lhs.get_type() != rhs.get_type() {
+ internal_err!("cannot compare {} and {}", lhs.get_type(),
rhs.get_type())
+ } else {
+ Ok(Expr::Binary(BinaryExpr::Ge(Box::new(lhs), Box::new(rhs))))
+ }
+ }
+
+ /// Binary arithmetic expression: lhs + rhs
+ pub fn add(&self, lhs: Expr, rhs: Expr) -> Result<Expr> {
+ if lhs.get_type() != rhs.get_type() {
+ internal_err!("cannot add {} and {}", lhs.get_type(),
rhs.get_type())
+ } else {
+ Ok(Expr::Binary(BinaryExpr::Add(Box::new(lhs), Box::new(rhs))))
+ }
+ }
+
+ /// Binary arithmetic expression: lhs - rhs
+ pub fn sub(&self, lhs: Expr, rhs: Expr) -> Result<Expr> {
+ if lhs.get_type() != rhs.get_type() {
+ internal_err!("cannot subtract {} and {}", lhs.get_type(),
rhs.get_type())
+ } else {
+ Ok(Expr::Binary(BinaryExpr::Sub(Box::new(lhs), Box::new(rhs))))
+ }
+ }
+
+ /// Binary arithmetic expression: lhs * rhs
+ pub fn mul(&self, lhs: Expr, rhs: Expr) -> Result<Expr> {
+ if lhs.get_type() != rhs.get_type() {
+ internal_err!("cannot multiply {} and {}", lhs.get_type(),
rhs.get_type())
+ } else {
+ Ok(Expr::Binary(BinaryExpr::Mul(Box::new(lhs), Box::new(rhs))))
+ }
+ }
+
+ /// Binary arithmetic expression: lhs / rhs
+ pub fn div(&self, lhs: Expr, rhs: Expr) -> Result<Expr> {
+ if lhs.get_type() != rhs.get_type() {
+ internal_err!("cannot divide {} and {}", lhs.get_type(),
rhs.get_type())
+ } else {
+ Ok(Expr::Binary(BinaryExpr::Div(Box::new(lhs), Box::new(rhs))))
+ }
+ }
+
+ /// Call external function `name` with parameters
+ pub fn call(&self, name: impl Into<String>, params: Vec<Expr>) ->
Result<Expr> {
+ let fn_name = name.into();
+ if let Some(func) = self.state.lock().extern_funcs.get(&fn_name) {
+ for ((i, t1), t2) in
params.iter().enumerate().zip(func.params.iter()) {
+ if t1.get_type() != *t2 {
+ return internal_err!(
+ "Func {} need {} as arg{}, get {}",
+ &fn_name,
+ t2,
+ i,
+ t1.get_type()
+ );
+ }
+ }
+ Ok(Expr::Call(fn_name, params, func.returns.unwrap_or(NIL)))
+ } else {
+ internal_err!("No func with the name {} exist", fn_name)
+ }
+ }
+}
+
+impl Display for GeneratedFunction {
+ fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+ write!(f, "fn {}(", self.name)?;
+ for (i, (name, ty)) in self.params.iter().enumerate() {
+ if i != 0 {
+ write!(f, ", ")?;
+ }
+ write!(f, "{}: {}", name, ty)?;
+ }
+ write!(f, ") -> ")?;
+ if let Some((name, ty)) = &self.ret {
+ write!(f, "{}: {}", name, ty)?;
+ } else {
+ write!(f, "()")?;
+ }
+ writeln!(f, " {{")?;
+ for stmt in &self.body {
+ stmt.fmt_ident(4, f)?;
+ }
+ write!(f, "}}")
+ }
+}
diff --git a/datafusion-jit/src/ast.rs b/datafusion-jit/src/ast.rs
new file mode 100644
index 0000000..5d0e3bc
--- /dev/null
+++ b/datafusion-jit/src/ast.rs
@@ -0,0 +1,359 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use cranelift::codegen::ir;
+use std::fmt::{Display, Formatter};
+
+#[derive(Clone, Debug)]
+/// Statement
+pub enum Stmt {
+ /// if-then-else
+ IfElse(Box<Expr>, Vec<Stmt>, Vec<Stmt>),
+ /// while
+ WhileLoop(Box<Expr>, Vec<Stmt>),
+ /// assignment
+ Assign(String, Box<Expr>),
+ /// call function for side effect
+ Call(String, Vec<Expr>),
+ /// declare a new variable of type
+ Declare(String, JITType),
+}
+
+#[derive(Copy, Clone, Debug)]
+/// Shorthand typed literals
+pub enum TypedLit {
+ Bool(bool),
+ Int(i64),
+ Float(f32),
+ Double(f64),
+}
+
+#[derive(Clone, Debug)]
+/// Expression
+pub enum Expr {
+ /// literal
+ Literal(Literal),
+ /// variable
+ Identifier(String, JITType),
+ /// binary expression
+ Binary(BinaryExpr),
+ /// call function expression
+ Call(String, Vec<Expr>, JITType),
+}
+
+impl Expr {
+ pub fn get_type(&self) -> JITType {
+ match self {
+ Expr::Literal(lit) => lit.get_type(),
+ Expr::Identifier(_, ty) => *ty,
+ Expr::Binary(bin) => bin.get_type(),
+ Expr::Call(_, _, ty) => *ty,
+ }
+ }
+}
+
+impl Literal {
+ fn get_type(&self) -> JITType {
+ match self {
+ Literal::Parsing(_, ty) => *ty,
+ Literal::Typed(tl) => tl.get_type(),
+ }
+ }
+}
+
+impl TypedLit {
+ fn get_type(&self) -> JITType {
+ match self {
+ TypedLit::Bool(_) => BOOL,
+ TypedLit::Int(_) => I64,
+ TypedLit::Float(_) => F32,
+ TypedLit::Double(_) => F64,
+ }
+ }
+}
+
+impl BinaryExpr {
+ fn get_type(&self) -> JITType {
+ match self {
+ BinaryExpr::Eq(_, _) => BOOL,
+ BinaryExpr::Ne(_, _) => BOOL,
+ BinaryExpr::Lt(_, _) => BOOL,
+ BinaryExpr::Le(_, _) => BOOL,
+ BinaryExpr::Gt(_, _) => BOOL,
+ BinaryExpr::Ge(_, _) => BOOL,
+ BinaryExpr::Add(lhs, _) => lhs.get_type(),
+ BinaryExpr::Sub(lhs, _) => lhs.get_type(),
+ BinaryExpr::Mul(lhs, _) => lhs.get_type(),
+ BinaryExpr::Div(lhs, _) => lhs.get_type(),
+ }
+ }
+}
+
+#[derive(Clone, Debug)]
+/// Binary expression
+pub enum BinaryExpr {
+ /// ==
+ Eq(Box<Expr>, Box<Expr>),
+ /// !=
+ Ne(Box<Expr>, Box<Expr>),
+ /// <
+ Lt(Box<Expr>, Box<Expr>),
+ /// <=
+ Le(Box<Expr>, Box<Expr>),
+ /// >
+ Gt(Box<Expr>, Box<Expr>),
+ /// >=
+ Ge(Box<Expr>, Box<Expr>),
+ /// add
+ Add(Box<Expr>, Box<Expr>),
+ /// subtract
+ Sub(Box<Expr>, Box<Expr>),
+ /// multiply
+ Mul(Box<Expr>, Box<Expr>),
+ /// divide
+ Div(Box<Expr>, Box<Expr>),
+}
+
+#[derive(Clone, Debug)]
+/// Literal
+pub enum Literal {
+ /// Parsable literal with type
+ Parsing(String, JITType),
+ /// Shorthand literals of common types
+ Typed(TypedLit),
+}
+
+#[derive(Copy, Clone, PartialEq, Eq, Hash)]
+/// Type to be used in JIT
+pub struct JITType {
+ /// The cranelift type
+ pub(crate) native: ir::Type,
+ /// re-expose inner field of `ir::Type` out for easier pattern matching
+ pub(crate) code: u8,
+}
+
+/// null type as placeholder
+pub const NIL: JITType = JITType {
+ native: ir::types::INVALID,
+ code: 0,
+};
+/// bool
+pub const BOOL: JITType = JITType {
+ native: ir::types::B1,
+ code: 0x70,
+};
+/// integer of 1 byte
+pub const I8: JITType = JITType {
+ native: ir::types::I8,
+ code: 0x76,
+};
+/// integer of 2 bytes
+pub const I16: JITType = JITType {
+ native: ir::types::I16,
+ code: 0x77,
+};
+/// integer of 4 bytes
+pub const I32: JITType = JITType {
+ native: ir::types::I32,
+ code: 0x78,
+};
+/// integer of 8 bytes
+pub const I64: JITType = JITType {
+ native: ir::types::I64,
+ code: 0x79,
+};
+/// Ieee float of 32 bits
+pub const F32: JITType = JITType {
+ native: ir::types::F32,
+ code: 0x7b,
+};
+/// Ieee float of 64 bits
+pub const F64: JITType = JITType {
+ native: ir::types::F64,
+ code: 0x7c,
+};
+/// Pointer type of 32 bits
+pub const R32: JITType = JITType {
+ native: ir::types::R32,
+ code: 0x7e,
+};
+/// Pointer type of 64 bits
+pub const R64: JITType = JITType {
+ native: ir::types::R64,
+ code: 0x7f,
+};
+/// The pointer type to use based on our currently target.
+pub const PTR: JITType = if std::mem::size_of::<usize>() == 8 {
+ R64
+} else {
+ R32
+};
+
+impl Stmt {
+ /// print the statement with indentation
+ pub fn fmt_ident(&self, ident: usize, f: &mut Formatter) ->
std::fmt::Result {
+ let mut ident_str = String::new();
+ for _ in 0..ident {
+ ident_str.push(' ');
+ }
+ match self {
+ Stmt::IfElse(cond, then_stmts, else_stmts) => {
+ writeln!(f, "{}if {} {{", ident_str, cond)?;
+ for stmt in then_stmts {
+ stmt.fmt_ident(ident + 4, f)?;
+ }
+ writeln!(f, "{}}} else {{", ident_str)?;
+ for stmt in else_stmts {
+ stmt.fmt_ident(ident + 4, f)?;
+ }
+ writeln!(f, "{}}}", ident_str)
+ }
+ Stmt::WhileLoop(cond, stmts) => {
+ writeln!(f, "{}while {} {{", ident_str, cond)?;
+ for stmt in stmts {
+ stmt.fmt_ident(ident + 4, f)?;
+ }
+ writeln!(f, "{}}}", ident_str)
+ }
+ Stmt::Assign(name, expr) => {
+ writeln!(f, "{}{} = {};", ident_str, name, expr)
+ }
+ Stmt::Call(name, args) => {
+ writeln!(
+ f,
+ "{}{}({});",
+ ident_str,
+ name,
+ args.iter()
+ .map(|e| e.to_string())
+ .collect::<Vec<_>>()
+ .join(", ")
+ )
+ }
+ Stmt::Declare(name, ty) => {
+ writeln!(f, "{}let {}: {};", ident_str, name, ty)
+ }
+ }
+ }
+}
+
+impl Display for Stmt {
+ fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+ self.fmt_ident(0, f)?;
+ Ok(())
+ }
+}
+
+impl Display for Expr {
+ fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+ match self {
+ Expr::Literal(l) => write!(f, "{}", l),
+ Expr::Identifier(name, _) => write!(f, "{}", name),
+ Expr::Binary(be) => write!(f, "{}", be),
+ Expr::Call(name, exprs, _) => {
+ write!(
+ f,
+ "{}({})",
+ name,
+ exprs
+ .iter()
+ .map(|e| e.to_string())
+ .collect::<Vec<_>>()
+ .join(", ")
+ )
+ }
+ }
+ }
+}
+
+impl Display for Literal {
+ fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+ match self {
+ Literal::Parsing(str, _) => write!(f, "{}", str),
+ Literal::Typed(tl) => write!(f, "{}", tl),
+ }
+ }
+}
+
+impl Display for TypedLit {
+ fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+ match self {
+ TypedLit::Bool(b) => write!(f, "{}", b),
+ TypedLit::Int(i) => write!(f, "{}", i),
+ TypedLit::Float(fl) => write!(f, "{}", fl),
+ TypedLit::Double(d) => write!(f, "{}", d),
+ }
+ }
+}
+
+impl Display for BinaryExpr {
+ fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+ match self {
+ BinaryExpr::Eq(lhs, rhs) => write!(f, "{} == {}", lhs, rhs),
+ BinaryExpr::Ne(lhs, rhs) => write!(f, "{} != {}", lhs, rhs),
+ BinaryExpr::Lt(lhs, rhs) => write!(f, "{} < {}", lhs, rhs),
+ BinaryExpr::Le(lhs, rhs) => write!(f, "{} <= {}", lhs, rhs),
+ BinaryExpr::Gt(lhs, rhs) => write!(f, "{} > {}", lhs, rhs),
+ BinaryExpr::Ge(lhs, rhs) => write!(f, "{} >= {}", lhs, rhs),
+ BinaryExpr::Add(lhs, rhs) => write!(f, "{} + {}", lhs, rhs),
+ BinaryExpr::Sub(lhs, rhs) => write!(f, "{} - {}", lhs, rhs),
+ BinaryExpr::Mul(lhs, rhs) => write!(f, "{} * {}", lhs, rhs),
+ BinaryExpr::Div(lhs, rhs) => write!(f, "{} / {}", lhs, rhs),
+ }
+ }
+}
+
+impl std::fmt::Display for JITType {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ write!(f, "{:?}", self)
+ }
+}
+
+impl std::fmt::Debug for JITType {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ match self.code {
+ 0 => write!(f, "nil"),
+ 0x70 => write!(f, "bool"),
+ 0x76 => write!(f, "i8"),
+ 0x77 => write!(f, "i16"),
+ 0x78 => write!(f, "i32"),
+ 0x79 => write!(f, "i64"),
+ 0x7b => write!(f, "f32"),
+ 0x7c => write!(f, "f64"),
+ 0x7e => write!(f, "small_ptr"),
+ 0x7f => write!(f, "ptr"),
+ _ => write!(f, "unknown"),
+ }
+ }
+}
+
+impl From<&str> for JITType {
+ fn from(x: &str) -> Self {
+ match x {
+ "bool" => BOOL,
+ "i8" => I8,
+ "i16" => I16,
+ "i32" => I32,
+ "i64" => I64,
+ "f32" => F32,
+ "f64" => F64,
+ "small_ptr" => R32,
+ "ptr" => R64,
+ _ => panic!("unknown type: {}", x),
+ }
+ }
+}
diff --git a/datafusion-jit/src/jit.rs b/datafusion-jit/src/jit.rs
new file mode 100644
index 0000000..225366b
--- /dev/null
+++ b/datafusion-jit/src/jit.rs
@@ -0,0 +1,676 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::api::GeneratedFunction;
+use crate::ast::{BinaryExpr, Expr, JITType, Literal, Stmt, TypedLit, BOOL,
I64, NIL};
+use cranelift::prelude::*;
+use cranelift_jit::{JITBuilder, JITModule};
+use cranelift_module::{Linkage, Module};
+use datafusion_common::internal_err;
+use datafusion_common::{DataFusionError, Result};
+use std::collections::HashMap;
+
+/// The basic JIT class.
+#[allow(clippy::upper_case_acronyms)]
+pub struct JIT {
+ /// The function builder context, which is reused across multiple
+ /// FunctionBuilder instances.
+ builder_context: FunctionBuilderContext,
+
+ /// The main Cranelift context, which holds the state for codegen.
Cranelift
+ /// separates this from `Module` to allow for parallel compilation, with a
+ /// context per thread, though this is not the case now.
+ ctx: codegen::Context,
+
+ /// The module, with the jit backend, which manages the JIT'd
+ /// functions.
+ module: JITModule,
+}
+
+impl Default for JIT {
+ fn default() -> Self {
+ let builder =
JITBuilder::new(cranelift_module::default_libcall_names());
+ let module = JITModule::new(builder);
+ Self {
+ builder_context: FunctionBuilderContext::new(),
+ ctx: module.make_context(),
+ module,
+ }
+ }
+}
+
+impl JIT {
+ /// New while registering external functions
+ pub fn new<It, K>(symbols: It) -> Self
+ where
+ It: IntoIterator<Item = (K, *const u8)>,
+ K: Into<String>,
+ {
+ let mut flag_builder = settings::builder();
+ flag_builder.set("use_colocated_libcalls", "false").unwrap();
+ flag_builder.set("is_pic", "true").unwrap();
+ flag_builder.set("opt_level", "speed").unwrap();
+ flag_builder.set("enable_simd", "true").unwrap();
+ let isa_builder = cranelift_native::builder().unwrap_or_else(|msg| {
+ panic!("host machine is not supported: {}", msg);
+ });
+ let isa = isa_builder.finish(settings::Flags::new(flag_builder));
+ let mut builder =
+ JITBuilder::with_isa(isa,
cranelift_module::default_libcall_names());
+ builder.symbols(symbols);
+ let module = JITModule::new(builder);
+ Self {
+ builder_context: FunctionBuilderContext::new(),
+ ctx: module.make_context(),
+ module,
+ }
+ }
+
+ /// Compile the generated function into machine code.
+ pub fn compile(&mut self, func: GeneratedFunction) -> Result<*const u8> {
+ let GeneratedFunction {
+ name,
+ params,
+ body,
+ ret,
+ } = func;
+
+ // Translate the AST nodes into Cranelift IR.
+ self.translate(params, ret, body)?;
+
+ // Next, declare the function to jit. Functions must be declared
+ // before they can be called, or defined.
+ let id = self.module.declare_function(
+ &name,
+ Linkage::Export,
+ &self.ctx.func.signature,
+ )?;
+
+ // Define the function to jit. This finishes compilation, although
+ // there may be outstanding relocations to perform. Currently, jit
+ // cannot finish relocations until all functions to be called are
+ // defined. For now, we'll just finalize the function below.
+ self.module.define_function(id, &mut self.ctx)?;
+
+ // Now that compilation is finished, we can clear out the context
state.
+ self.module.clear_context(&mut self.ctx);
+
+ // Finalize the functions which we just defined, which resolves any
+ // outstanding relocations (patching in addresses, now that they're
+ // available).
+ self.module.finalize_definitions();
+
+ // We can now retrieve a pointer to the machine code.
+ let code = self.module.get_finalized_function(id);
+
+ Ok(code)
+ }
+
+ // Translate into Cranelift IR.
+ fn translate(
+ &mut self,
+ params: Vec<(String, JITType)>,
+ the_return: Option<(String, JITType)>,
+ stmts: Vec<Stmt>,
+ ) -> Result<()> {
+ for param in ¶ms {
+ self.ctx
+ .func
+ .signature
+ .params
+ .push(AbiParam::new(param.1.native));
+ }
+
+ let mut void_return: bool = false;
+
+ // We currently only supports one return value, though
+ // Cranelift is designed to support more.
+ match the_return {
+ None => void_return = true,
+ Some(ref ret) => {
+ self.ctx
+ .func
+ .signature
+ .returns
+ .push(AbiParam::new(ret.1.native));
+ }
+ }
+
+ // Create the builder to build a function.
+ let mut builder =
+ FunctionBuilder::new(&mut self.ctx.func, &mut
self.builder_context);
+
+ // Create the entry block, to start emitting code in.
+ let entry_block = builder.create_block();
+
+ // Since this is the entry block, add block parameters corresponding to
+ // the function's parameters.
+ builder.append_block_params_for_function_params(entry_block);
+
+ // Tell the builder to emit code in this block.
+ builder.switch_to_block(entry_block);
+
+ // And, tell the builder that this block will have no further
+ // predecessors. Since it's the entry block, it won't have any
+ // predecessors.
+ builder.seal_block(entry_block);
+
+ // Walk the AST and declare all variables.
+ let variables =
+ declare_variables(&mut builder, ¶ms, &the_return, &stmts,
entry_block);
+
+ // Now translate the statements of the function body.
+ let mut trans = FunctionTranslator {
+ builder,
+ variables,
+ module: &mut self.module,
+ };
+ for stmt in stmts {
+ trans.translate_stmt(stmt)?;
+ }
+
+ if !void_return {
+ // Set up the return variable of the function. Above, we declared a
+ // variable to hold the return value. Here, we just do a use of
that
+ // variable.
+ let return_variable = trans
+ .variables
+ .get(&the_return.as_ref().unwrap().0)
+ .unwrap();
+ let return_value = trans.builder.use_var(*return_variable);
+
+ // Emit the return instruction.
+ trans.builder.ins().return_(&[return_value]);
+ } else {
+ trans.builder.ins().return_(&[]);
+ }
+
+ // Tell the builder we're done with this function.
+ trans.builder.finalize();
+ Ok(())
+ }
+}
+
+/// A collection of state used for translating from AST nodes
+/// into Cranelift IR.
+struct FunctionTranslator<'a> {
+ builder: FunctionBuilder<'a>,
+ variables: HashMap<String, Variable>,
+ module: &'a mut JITModule,
+}
+
+impl<'a> FunctionTranslator<'a> {
+ fn translate_stmt(&mut self, stmt: Stmt) -> Result<()> {
+ match stmt {
+ Stmt::IfElse(condition, then_body, else_body) => {
+ self.translate_if_else(*condition, then_body, else_body)
+ }
+ Stmt::WhileLoop(condition, loop_body) => {
+ self.translate_while_loop(*condition, loop_body)
+ }
+ Stmt::Assign(name, expr) => self.translate_assign(name, *expr),
+ Stmt::Call(name, args) => {
+ self.translate_call_stmt(name, args, NIL)?;
+ Ok(())
+ }
+ Stmt::Declare(_, _) => Ok(()),
+ }
+ }
+
+ fn translate_typed_lit(&mut self, tl: TypedLit) -> Value {
+ match tl {
+ TypedLit::Bool(b) => self.builder.ins().bconst(BOOL.native, b),
+ TypedLit::Int(i) => self.builder.ins().iconst(I64.native, i),
+ TypedLit::Float(f) => self.builder.ins().f32const(f),
+ TypedLit::Double(d) => self.builder.ins().f64const(d),
+ }
+ }
+
+ /// When you write out instructions in Cranelift, you get back `Value`s.
You
+ /// can then use these references in other instructions.
+ fn translate_expr(&mut self, expr: Expr) -> Result<Value> {
+ match expr {
+ Expr::Literal(nl) => self.translate_literal(nl),
+ Expr::Identifier(name, _) => {
+ // `use_var` is used to read the value of a variable.
+ let variable = self.variables.get(&name).ok_or_else(|| {
+ DataFusionError::Internal("variable not
defined".to_owned())
+ })?;
+ Ok(self.builder.use_var(*variable))
+ }
+ Expr::Binary(b) => self.translate_binary_expr(b),
+ Expr::Call(name, args, ret) => self.translate_call_expr(name,
args, ret),
+ }
+ }
+
+ fn translate_literal(&mut self, expr: Literal) -> Result<Value> {
+ match expr {
+ Literal::Parsing(literal, ty) =>
self.translate_string_lit(literal, ty),
+ Literal::Typed(lt) => Ok(self.translate_typed_lit(lt)),
+ }
+ }
+
+ fn translate_binary_expr(&mut self, expr: BinaryExpr) -> Result<Value> {
+ match expr {
+ BinaryExpr::Eq(lhs, rhs) => {
+ let ty = lhs.get_type();
+ if ty.code >= 0x76 && ty.code <= 0x79 {
+ self.translate_icmp(IntCC::Equal, *lhs, *rhs)
+ } else if ty.code == 0x7b || ty.code == 0x7c {
+ self.translate_fcmp(FloatCC::Equal, *lhs, *rhs)
+ } else {
+ internal_err!("Unsupported type {} for equal comparison",
ty)
+ }
+ }
+ BinaryExpr::Ne(lhs, rhs) => {
+ let ty = lhs.get_type();
+ if ty.code >= 0x76 && ty.code <= 0x79 {
+ self.translate_icmp(IntCC::NotEqual, *lhs, *rhs)
+ } else if ty.code == 0x7b || ty.code == 0x7c {
+ self.translate_fcmp(FloatCC::NotEqual, *lhs, *rhs)
+ } else {
+ internal_err!("Unsupported type {} for not equal
comparison", ty)
+ }
+ }
+ BinaryExpr::Lt(lhs, rhs) => {
+ let ty = lhs.get_type();
+ if ty.code >= 0x76 && ty.code <= 0x79 {
+ self.translate_icmp(IntCC::SignedLessThan, *lhs, *rhs)
+ } else if ty.code == 0x7b || ty.code == 0x7c {
+ self.translate_fcmp(FloatCC::LessThan, *lhs, *rhs)
+ } else {
+ internal_err!("Unsupported type {} for less than
comparison", ty)
+ }
+ }
+ BinaryExpr::Le(lhs, rhs) => {
+ let ty = lhs.get_type();
+ if ty.code >= 0x76 && ty.code <= 0x79 {
+ self.translate_icmp(IntCC::SignedLessThanOrEqual, *lhs,
*rhs)
+ } else if ty.code == 0x7b || ty.code == 0x7c {
+ self.translate_fcmp(FloatCC::LessThanOrEqual, *lhs, *rhs)
+ } else {
+ internal_err!(
+ "Unsupported type {} for less than or equal
comparison",
+ ty
+ )
+ }
+ }
+ BinaryExpr::Gt(lhs, rhs) => {
+ let ty = lhs.get_type();
+ if ty.code >= 0x76 && ty.code <= 0x79 {
+ self.translate_icmp(IntCC::SignedGreaterThan, *lhs, *rhs)
+ } else if ty.code == 0x7b || ty.code == 0x7c {
+ self.translate_fcmp(FloatCC::GreaterThan, *lhs, *rhs)
+ } else {
+ internal_err!("Unsupported type {} for greater than
comparison", ty)
+ }
+ }
+ BinaryExpr::Ge(lhs, rhs) => {
+ let ty = lhs.get_type();
+ if ty.code >= 0x76 && ty.code <= 0x79 {
+ self.translate_icmp(IntCC::SignedGreaterThanOrEqual, *lhs,
*rhs)
+ } else if ty.code == 0x7b || ty.code == 0x7c {
+ self.translate_fcmp(FloatCC::GreaterThanOrEqual, *lhs,
*rhs)
+ } else {
+ internal_err!(
+ "Unsupported type {} for greater than or equal
comparison",
+ ty
+ )
+ }
+ }
+ BinaryExpr::Add(lhs, rhs) => {
+ let ty = lhs.get_type();
+ let lhs = self.translate_expr(*lhs)?;
+ let rhs = self.translate_expr(*rhs)?;
+ if ty.code >= 0x76 && ty.code <= 0x79 {
+ Ok(self.builder.ins().iadd(lhs, rhs))
+ } else if ty.code == 0x7b || ty.code == 0x7c {
+ Ok(self.builder.ins().fadd(lhs, rhs))
+ } else {
+ internal_err!("Unsupported type {} for add", ty)
+ }
+ }
+ BinaryExpr::Sub(lhs, rhs) => {
+ let ty = lhs.get_type();
+ let lhs = self.translate_expr(*lhs)?;
+ let rhs = self.translate_expr(*rhs)?;
+ if ty.code >= 0x76 && ty.code <= 0x79 {
+ Ok(self.builder.ins().isub(lhs, rhs))
+ } else if ty.code == 0x7b || ty.code == 0x7c {
+ Ok(self.builder.ins().fsub(lhs, rhs))
+ } else {
+ internal_err!("Unsupported type {} for sub", ty)
+ }
+ }
+ BinaryExpr::Mul(lhs, rhs) => {
+ let ty = lhs.get_type();
+ let lhs = self.translate_expr(*lhs)?;
+ let rhs = self.translate_expr(*rhs)?;
+ if ty.code >= 0x76 && ty.code <= 0x79 {
+ Ok(self.builder.ins().imul(lhs, rhs))
+ } else if ty.code == 0x7b || ty.code == 0x7c {
+ Ok(self.builder.ins().fmul(lhs, rhs))
+ } else {
+ internal_err!("Unsupported type {} for mul", ty)
+ }
+ }
+ BinaryExpr::Div(lhs, rhs) => {
+ let ty = lhs.get_type();
+ let lhs = self.translate_expr(*lhs)?;
+ let rhs = self.translate_expr(*rhs)?;
+ if ty.code >= 0x76 && ty.code <= 0x79 {
+ Ok(self.builder.ins().udiv(lhs, rhs))
+ } else if ty.code == 0x7b || ty.code == 0x7c {
+ Ok(self.builder.ins().fdiv(lhs, rhs))
+ } else {
+ internal_err!("Unsupported type {} for div", ty)
+ }
+ }
+ }
+ }
+
+ fn translate_string_lit(&mut self, lit: String, ty: JITType) ->
Result<Value> {
+ match ty.code {
+ 0x70 => {
+ let b = lit.parse::<bool>().unwrap();
+ Ok(self.builder.ins().bconst(ty.native, b))
+ }
+ 0x76 => {
+ let i = lit.parse::<i8>().unwrap();
+ Ok(self.builder.ins().iconst(ty.native, i as i64))
+ }
+ 0x77 => {
+ let i = lit.parse::<i16>().unwrap();
+ Ok(self.builder.ins().iconst(ty.native, i as i64))
+ }
+ 0x78 => {
+ let i = lit.parse::<i32>().unwrap();
+ Ok(self.builder.ins().iconst(ty.native, i as i64))
+ }
+ 0x79 => {
+ let i = lit.parse::<i64>().unwrap();
+ Ok(self.builder.ins().iconst(ty.native, i))
+ }
+ 0x7b => {
+ let f = lit.parse::<f32>().unwrap();
+ Ok(self.builder.ins().f32const(f))
+ }
+ 0x7c => {
+ let f = lit.parse::<f64>().unwrap();
+ Ok(self.builder.ins().f64const(f))
+ }
+ _ => internal_err!("Unsupported type {} for string literal", ty),
+ }
+ }
+
+ fn translate_assign(&mut self, name: String, expr: Expr) -> Result<()> {
+ // `def_var` is used to write the value of a variable. Note that
+ // variables can have multiple definitions. Cranelift will
+ // convert them into SSA form for itself automatically.
+ let new_value = self.translate_expr(expr)?;
+ let variable = self.variables.get(&*name).unwrap();
+ self.builder.def_var(*variable, new_value);
+ Ok(())
+ }
+
+ fn translate_icmp(&mut self, cmp: IntCC, lhs: Expr, rhs: Expr) ->
Result<Value> {
+ let lhs = self.translate_expr(lhs)?;
+ let rhs = self.translate_expr(rhs)?;
+ let c = self.builder.ins().icmp(cmp, lhs, rhs);
+ Ok(self.builder.ins().bint(I64.native, c))
+ }
+
+ fn translate_fcmp(&mut self, cmp: FloatCC, lhs: Expr, rhs: Expr) ->
Result<Value> {
+ let lhs = self.translate_expr(lhs)?;
+ let rhs = self.translate_expr(rhs)?;
+ let c = self.builder.ins().fcmp(cmp, lhs, rhs);
+ Ok(self.builder.ins().bint(I64.native, c))
+ }
+
+ fn translate_if_else(
+ &mut self,
+ condition: Expr,
+ then_body: Vec<Stmt>,
+ else_body: Vec<Stmt>,
+ ) -> Result<()> {
+ let condition_value = self.translate_expr(condition)?;
+
+ let then_block = self.builder.create_block();
+ let else_block = self.builder.create_block();
+ let merge_block = self.builder.create_block();
+
+ // Test the if condition and conditionally branch.
+ self.builder.ins().brz(condition_value, else_block, &[]);
+ // Fall through to then block.
+ self.builder.ins().jump(then_block, &[]);
+
+ self.builder.switch_to_block(then_block);
+ self.builder.seal_block(then_block);
+ for stmt in then_body {
+ self.translate_stmt(stmt)?;
+ }
+
+ // Jump to the merge block, passing it the block return value.
+ self.builder.ins().jump(merge_block, &[]);
+
+ self.builder.switch_to_block(else_block);
+ self.builder.seal_block(else_block);
+ for stmt in else_body {
+ self.translate_stmt(stmt)?;
+ }
+
+ // Jump to the merge block, passing it the block return value.
+ self.builder.ins().jump(merge_block, &[]);
+
+ // Switch to the merge block for subsequent statements.
+ self.builder.switch_to_block(merge_block);
+
+ // We've now seen all the predecessors of the merge block.
+ self.builder.seal_block(merge_block);
+ Ok(())
+ }
+
+ fn translate_while_loop(
+ &mut self,
+ condition: Expr,
+ loop_body: Vec<Stmt>,
+ ) -> Result<()> {
+ let header_block = self.builder.create_block();
+ let body_block = self.builder.create_block();
+ let exit_block = self.builder.create_block();
+
+ self.builder.ins().jump(header_block, &[]);
+ self.builder.switch_to_block(header_block);
+
+ let condition_value = self.translate_expr(condition)?;
+ self.builder.ins().brz(condition_value, exit_block, &[]);
+ self.builder.ins().jump(body_block, &[]);
+
+ self.builder.switch_to_block(body_block);
+ self.builder.seal_block(body_block);
+
+ for stmt in loop_body {
+ self.translate_stmt(stmt)?;
+ }
+ self.builder.ins().jump(header_block, &[]);
+
+ self.builder.switch_to_block(exit_block);
+
+ // We've reached the bottom of the loop, so there will be no
+ // more backedges to the header to exits to the bottom.
+ self.builder.seal_block(header_block);
+ self.builder.seal_block(exit_block);
+ Ok(())
+ }
+
+ fn translate_call_expr(
+ &mut self,
+ name: String,
+ args: Vec<Expr>,
+ ret: JITType,
+ ) -> Result<Value> {
+ let mut sig = self.module.make_signature();
+
+ // Add a parameter for each argument.
+ for arg in &args {
+ sig.params.push(AbiParam::new(arg.get_type().native));
+ }
+
+ if ret.code == 0 {
+ return internal_err!(
+ "Call function {}(..) has void type, it can not be an
expression",
+ &name
+ );
+ } else {
+ sig.returns.push(AbiParam::new(ret.native));
+ }
+
+ let callee = self
+ .module
+ .declare_function(&name, Linkage::Import, &sig)
+ .expect("problem declaring function");
+ let local_callee = self.module.declare_func_in_func(callee,
self.builder.func);
+
+ let mut arg_values = Vec::new();
+ for arg in args {
+ arg_values.push(self.translate_expr(arg)?)
+ }
+ let call = self.builder.ins().call(local_callee, &arg_values);
+ Ok(self.builder.inst_results(call)[0])
+ }
+
+ fn translate_call_stmt(
+ &mut self,
+ name: String,
+ args: Vec<Expr>,
+ ret: JITType,
+ ) -> Result<()> {
+ let mut sig = self.module.make_signature();
+
+ // Add a parameter for each argument.
+ for arg in &args {
+ sig.params.push(AbiParam::new(arg.get_type().native));
+ }
+
+ if ret.code != 0 {
+ sig.returns.push(AbiParam::new(ret.native));
+ }
+
+ let callee = self
+ .module
+ .declare_function(&name, Linkage::Import, &sig)
+ .expect("problem declaring function");
+ let local_callee = self.module.declare_func_in_func(callee,
self.builder.func);
+
+ let mut arg_values = Vec::new();
+ for arg in args {
+ arg_values.push(self.translate_expr(arg)?)
+ }
+ let _ = self.builder.ins().call(local_callee, &arg_values);
+ Ok(())
+ }
+}
+
+fn typed_zero(typ: JITType, builder: &mut FunctionBuilder) -> Value {
+ match typ.code {
+ 0x70 => builder.ins().bconst(typ.native, false),
+ 0x76 => builder.ins().iconst(typ.native, 0),
+ 0x77 => builder.ins().iconst(typ.native, 0),
+ 0x78 => builder.ins().iconst(typ.native, 0),
+ 0x79 => builder.ins().iconst(typ.native, 0),
+ 0x7b => builder.ins().f32const(0.0),
+ 0x7c => builder.ins().f64const(0.0),
+ 0x7e => builder.ins().null(typ.native),
+ 0x7f => builder.ins().null(typ.native),
+ _ => panic!("unsupported type"),
+ }
+}
+
+fn declare_variables(
+ builder: &mut FunctionBuilder,
+ params: &[(String, JITType)],
+ the_return: &Option<(String, JITType)>,
+ stmts: &[Stmt],
+ entry_block: Block,
+) -> HashMap<String, Variable> {
+ let mut variables = HashMap::new();
+ let mut index = 0;
+
+ for (i, name) in params.iter().enumerate() {
+ let val = builder.block_params(entry_block)[i];
+ let var = declare_variable(builder, &mut variables, &mut index,
&name.0, name.1);
+ builder.def_var(var, val);
+ }
+
+ if let Some(ret) = the_return {
+ let zero = typed_zero(ret.1, builder);
+ let return_variable =
+ declare_variable(builder, &mut variables, &mut index, &ret.0,
ret.1);
+ builder.def_var(return_variable, zero);
+ }
+
+ for stmt in stmts {
+ declare_variables_in_stmt(builder, &mut variables, &mut index, stmt);
+ }
+
+ variables
+}
+
+/// Recursively descend through the AST, translating all declarations.
+fn declare_variables_in_stmt(
+ builder: &mut FunctionBuilder,
+ variables: &mut HashMap<String, Variable>,
+ index: &mut usize,
+ stmt: &Stmt,
+) {
+ match *stmt {
+ Stmt::IfElse(_, ref then_body, ref else_body) => {
+ for stmt in then_body {
+ declare_variables_in_stmt(builder, variables, index, stmt);
+ }
+ for stmt in else_body {
+ declare_variables_in_stmt(builder, variables, index, stmt);
+ }
+ }
+ Stmt::WhileLoop(_, ref loop_body) => {
+ for stmt in loop_body {
+ declare_variables_in_stmt(builder, variables, index, stmt);
+ }
+ }
+ Stmt::Declare(ref name, typ) => {
+ declare_variable(builder, variables, index, name, typ);
+ }
+ _ => {}
+ }
+}
+
+/// Declare a single variable declaration.
+fn declare_variable(
+ builder: &mut FunctionBuilder,
+ variables: &mut HashMap<String, Variable>,
+ index: &mut usize,
+ name: &str,
+ typ: JITType,
+) -> Variable {
+ let var = Variable::new(*index);
+ if !variables.contains_key(name) {
+ variables.insert(name.into(), var);
+ builder.declare_var(var, typ.native);
+ *index += 1;
+ }
+ var
+}
diff --git a/datafusion-jit/src/lib.rs b/datafusion-jit/src/lib.rs
new file mode 100644
index 0000000..5642b5a
--- /dev/null
+++ b/datafusion-jit/src/lib.rs
@@ -0,0 +1,110 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Just-In-Time compilation to accelerate DataFusion physical plan execution.
+
+pub mod api;
+pub mod ast;
+pub mod jit;
+
+#[cfg(test)]
+mod tests {
+ use crate::api::{Assembler, GeneratedFunction};
+ use crate::ast::I64;
+ use crate::jit::JIT;
+ use datafusion_common::Result;
+
+ #[test]
+ fn iterative_fib() -> Result<()> {
+ let expected = r#"fn iterative_fib_0(n: i64) -> r: i64 {
+ if n == 0 {
+ r = 0;
+ } else {
+ n = n - 1;
+ let a: i64;
+ a = 0;
+ r = 1;
+ while n != 0 {
+ let t: i64;
+ t = r;
+ r = r + a;
+ a = t;
+ n = n - 1;
+ }
+ }
+}"#;
+ let assembler = Assembler::default();
+ let mut builder = assembler
+ .new_func_builder("iterative_fib")
+ .param("n", I64)
+ .ret("r", I64);
+ let mut fn_body = builder.enter_block();
+
+ fn_body.if_block(
+ |cond| cond.eq(cond.id("n")?, cond.lit_i(0)),
+ |t| {
+ t.assign("r", t.lit_i(0))?;
+ Ok(())
+ },
+ |e| {
+ e.assign("n", e.sub(e.id("n")?, e.lit_i(1))?)?;
+ e.declare_as("a", e.lit_i(0))?;
+ e.assign("r", e.lit_i(1))?;
+ e.while_block(
+ |cond| cond.ne(cond.id("n")?, cond.lit_i(0)),
+ |w| {
+ w.declare_as("t", w.id("r")?)?;
+ w.assign("r", w.add(w.id("r")?, w.id("a")?)?)?;
+ w.assign("a", w.id("t")?)?;
+ w.assign("n", w.sub(w.id("n")?, w.lit_i(1))?)?;
+ Ok(())
+ },
+ )?;
+ Ok(())
+ },
+ )?;
+
+ let gen_func = fn_body.build();
+ assert_eq!(format!("{}", &gen_func), expected);
+ let mut jit = assembler.create_jit();
+ assert_eq!(55, run_iterative_fib_code(&mut jit, gen_func, 10)?);
+ Ok(())
+ }
+
+ unsafe fn run_code<I, O>(
+ jit: &mut JIT,
+ code: GeneratedFunction,
+ input: I,
+ ) -> Result<O> {
+ // Pass the string to the JIT, and it returns a raw pointer to machine
code.
+ let code_ptr = jit.compile(code)?;
+ // Cast the raw pointer to a typed function pointer. This is unsafe,
because
+ // this is the critical point where you have to trust that the
generated code
+ // is safe to be called.
+ let code_fn = core::mem::transmute::<_, fn(I) -> O>(code_ptr);
+ // And now we can call it!
+ Ok(code_fn(input))
+ }
+
+ fn run_iterative_fib_code(
+ jit: &mut JIT,
+ code: GeneratedFunction,
+ input: isize,
+ ) -> Result<isize> {
+ unsafe { run_code(jit, code, input) }
+ }
+}
diff --git a/datafusion/Cargo.toml b/datafusion/Cargo.toml
index cbba899..23eb7ce 100644
--- a/datafusion/Cargo.toml
+++ b/datafusion/Cargo.toml
@@ -50,10 +50,13 @@ force_hash_collisions = []
avro = ["avro-rs", "num-traits", "datafusion-common/avro"]
# Used to enable row format experiment
row = []
+# Used to enable JIT code generation
+jit = ["datafusion-jit"]
[dependencies]
datafusion-common = { path = "../datafusion-common", version = "7.0.0" }
datafusion-expr = { path = "../datafusion-expr", version = "7.0.0" }
+datafusion-jit = { path = "../datafusion-jit", version = "7.0.0", optional =
true }
datafusion-physical-expr = { path = "../datafusion-physical-expr", version =
"7.0.0" }
ahash = { version = "0.7", default-features = false }
hashbrown = { version = "0.12", features = ["raw"] }
@@ -121,3 +124,8 @@ harness = false
[[bench]]
name = "parquet_query_sql"
harness = false
+
+[[bench]]
+name = "jit"
+harness = false
+required-features = ["row", "jit"]
diff --git a/datafusion/benches/data_utils/mod.rs
b/datafusion/benches/data_utils/mod.rs
index 6ebeeb7..71952b4 100644
--- a/datafusion/benches/data_utils/mod.rs
+++ b/datafusion/benches/data_utils/mod.rs
@@ -35,7 +35,8 @@ use std::sync::Arc;
/// create an in-memory table given the partition len, array len, and batch
size,
/// and the result table will be of array_len in total, and then partitioned,
and batched.
-pub(crate) fn create_table_provider(
+#[allow(dead_code)]
+pub fn create_table_provider(
partitions_len: usize,
array_len: usize,
batch_size: usize,
@@ -52,7 +53,8 @@ fn seedable_rng() -> StdRng {
StdRng::seed_from_u64(42)
}
-fn create_schema() -> Schema {
+/// Create test data schema
+pub fn create_schema() -> Schema {
Schema::new(vec![
Field::new("utf8", DataType::Utf8, false),
Field::new("f32", DataType::Float32, false),
@@ -138,7 +140,9 @@ fn create_record_batch(
.unwrap()
}
-fn create_record_batches(
+/// Create record batches of `partitions_len` partitions and `batch_size` for
each batch,
+/// with a total number of `array_len` records
+pub fn create_record_batches(
schema: SchemaRef,
array_len: usize,
partitions_len: usize,
diff --git a/datafusion/benches/jit.rs b/datafusion/benches/jit.rs
new file mode 100644
index 0000000..b198b15
--- /dev/null
+++ b/datafusion/benches/jit.rs
@@ -0,0 +1,58 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#[macro_use]
+extern crate criterion;
+extern crate arrow;
+extern crate datafusion;
+
+mod data_utils;
+use crate::criterion::Criterion;
+use crate::data_utils::{create_record_batches, create_schema};
+use datafusion::row::writer::{
+ bench_write_batch, bench_write_batch_jit, bench_write_batch_jit_dummy,
+};
+use std::sync::Arc;
+
+fn criterion_benchmark(c: &mut Criterion) {
+ let partitions_len = 8;
+ let array_len = 32768 * 1024; // 2^25
+ let batch_size = 2048; // 2^11
+
+ let schema = Arc::new(create_schema());
+ let batches =
+ create_record_batches(schema.clone(), array_len, partitions_len,
batch_size);
+
+ c.bench_function("row serializer", |b| {
+ b.iter(|| {
+ criterion::black_box(bench_write_batch(&batches,
schema.clone()).unwrap())
+ })
+ });
+
+ c.bench_function("row serializer jit", |b| {
+ b.iter(|| {
+ criterion::black_box(bench_write_batch_jit(&batches,
schema.clone()).unwrap())
+ })
+ });
+
+ c.bench_function("row serializer jit codegen only", |b| {
+ b.iter(|| bench_write_batch_jit_dummy(schema.clone()).unwrap())
+ });
+}
+
+criterion_group!(benches, criterion_benchmark);
+criterion_main!(benches);
diff --git a/datafusion/src/lib.rs b/datafusion/src/lib.rs
index a9630c0..0ce6e91 100644
--- a/datafusion/src/lib.rs
+++ b/datafusion/src/lib.rs
@@ -227,7 +227,7 @@ pub use parquet;
pub(crate) mod field_util;
#[cfg(feature = "row")]
-pub(crate) mod row;
+pub mod row;
pub mod from_slice;
diff --git a/datafusion/src/row/mod.rs b/datafusion/src/row/mod.rs
index 9875b84..5cd9885 100644
--- a/datafusion/src/row/mod.rs
+++ b/datafusion/src/row/mod.rs
@@ -17,7 +17,7 @@
//! An implementation of Row backed by raw bytes
//!
-//! Each tuple consists of up to three parts: [null bit set] [values] [var
length data]
+//! Each tuple consists of up to three parts: "`null bit set`" , "`values`"
and "`var length data`"
//!
//! The null bit set is used for null tracking and is aligned to 1-byte. It
stores
//! one bit per field.
@@ -52,8 +52,8 @@ use arrow::util::bit_util::{get_bit_raw,
round_upto_power_of_2};
use std::fmt::Write;
use std::sync::Arc;
-mod reader;
-mod writer;
+pub mod reader;
+pub mod writer;
const ALL_VALID_MASK: [u8; 8] = [1, 3, 7, 15, 31, 63, 127, 255];
@@ -189,6 +189,29 @@ fn supported(schema: &Arc<Schema>) -> bool {
.all(|f| supported_type(f.data_type()))
}
+#[cfg(feature = "jit")]
+#[macro_export]
+/// register external functions to the assembler
+macro_rules! reg_fn {
+ ($ASS:ident, $FN: path, $PARAM: expr, $RET: expr) => {
+ $ASS.register_extern_fn(fn_name($FN), $FN as *const u8, $PARAM, $RET)?;
+ };
+}
+
+#[cfg(feature = "jit")]
+fn fn_name<T>(f: T) -> &'static str {
+ fn type_name_of<T>(_: T) -> &'static str {
+ std::any::type_name::<T>()
+ }
+ let name = type_name_of(f);
+
+ // Find and cut the rest of the path
+ match &name.rfind(':') {
+ Some(pos) => &name[pos + 1..name.len()],
+ None => name,
+ }
+}
+
#[cfg(test)]
mod tests {
use super::*;
@@ -203,10 +226,16 @@ mod tests {
use crate::physical_plan::file_format::FileScanConfig;
use crate::physical_plan::{collect, ExecutionPlan};
use crate::row::reader::read_as_batch;
+ #[cfg(feature = "jit")]
+ use crate::row::reader::read_as_batch_jit;
use crate::row::writer::write_batch_unchecked;
+ #[cfg(feature = "jit")]
+ use crate::row::writer::write_batch_unchecked_jit;
use arrow::record_batch::RecordBatch;
use arrow::util::bit_util::{ceil, set_bit_raw, unset_bit_raw};
use arrow::{array::*, datatypes::*};
+ #[cfg(feature = "jit")]
+ use datafusion_jit::api::Assembler;
use rand::Rng;
use DataType::*;
@@ -300,7 +329,23 @@ mod tests {
let mut vector = vec![0; 1024];
let row_offsets =
{ write_batch_unchecked(&mut vector, 0, &batch, 0,
schema.clone()) };
- let output_batch = { read_as_batch(&mut vector, schema,
row_offsets)? };
+ let output_batch = { read_as_batch(&vector, schema,
row_offsets)? };
+ assert_eq!(batch, output_batch);
+ Ok(())
+ }
+
+ #[test]
+ #[allow(non_snake_case)]
+ #[cfg(feature = "jit")]
+ fn [<test_single_ $TYPE _jit>]() -> Result<()> {
+ let schema = Arc::new(Schema::new(vec![Field::new("a",
$TYPE, true)]));
+ let a = $ARRAY::from($VEC);
+ let batch = RecordBatch::try_new(schema.clone(),
vec![Arc::new(a)])?;
+ let mut vector = vec![0; 1024];
+ let assembler = Assembler::default();
+ let row_offsets =
+ { write_batch_unchecked_jit(&mut vector, 0, &batch, 0,
schema.clone(), &assembler)? };
+ let output_batch = { read_as_batch_jit(&vector, schema,
row_offsets, &assembler)? };
assert_eq!(batch, output_batch);
Ok(())
}
@@ -402,7 +447,33 @@ mod tests {
let mut vector = vec![0; 8192];
let row_offsets =
{ write_batch_unchecked(&mut vector, 0, &batch, 0, schema.clone())
};
- let output_batch = { read_as_batch(&mut vector, schema, row_offsets)?
};
+ let output_batch = { read_as_batch(&vector, schema, row_offsets)? };
+ assert_eq!(batch, output_batch);
+ Ok(())
+ }
+
+ #[test]
+ #[cfg(feature = "jit")]
+ fn test_single_binary_jit() -> Result<()> {
+ let schema = Arc::new(Schema::new(vec![Field::new("a", Binary,
true)]));
+ let values: Vec<Option<&[u8]>> =
+ vec![Some(b"one"), Some(b"two"), None, Some(b""), Some(b"three")];
+ let a = BinaryArray::from_opt_vec(values);
+ let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(a)])?;
+ let mut vector = vec![0; 8192];
+ let assembler = Assembler::default();
+ let row_offsets = {
+ write_batch_unchecked_jit(
+ &mut vector,
+ 0,
+ &batch,
+ 0,
+ schema.clone(),
+ &assembler,
+ )?
+ };
+ let output_batch =
+ { read_as_batch_jit(&vector, schema, row_offsets, &assembler)? };
assert_eq!(batch, output_batch);
Ok(())
}
@@ -421,7 +492,7 @@ mod tests {
let mut vector = vec![0; 20480];
let row_offsets =
{ write_batch_unchecked(&mut vector, 0, batch, 0, schema.clone())
};
- let output_batch = { read_as_batch(&mut vector, schema, row_offsets)?
};
+ let output_batch = { read_as_batch(&vector, schema, row_offsets)? };
assert_eq!(*batch, output_batch);
Ok(())
@@ -445,9 +516,9 @@ mod tests {
DataType::Decimal(5, 2),
false,
)]));
- let mut vector = vec![0; 1024];
+ let vector = vec![0; 1024];
let row_offsets = vec![0];
- read_as_batch(&mut vector, schema, row_offsets).unwrap();
+ read_as_batch(&vector, schema, row_offsets).unwrap();
}
async fn get_exec(
diff --git a/datafusion/src/row/reader.rs b/datafusion/src/row/reader.rs
index 779c099..213c34b 100644
--- a/datafusion/src/row/reader.rs
+++ b/datafusion/src/row/reader.rs
@@ -18,17 +18,27 @@
//! Accessing row from raw bytes
use crate::error::{DataFusionError, Result};
+#[cfg(feature = "jit")]
+use crate::reg_fn;
+#[cfg(feature = "jit")]
+use crate::row::fn_name;
use crate::row::{all_valid, get_offsets, supported, NullBitsFormatter};
use arrow::array::*;
use arrow::datatypes::{DataType, Schema};
use arrow::error::Result as ArrowResult;
use arrow::record_batch::RecordBatch;
use arrow::util::bit_util::{ceil, get_bit_raw};
+#[cfg(feature = "jit")]
+use datafusion_jit::api::Assembler;
+#[cfg(feature = "jit")]
+use datafusion_jit::api::GeneratedFunction;
+#[cfg(feature = "jit")]
+use datafusion_jit::ast::{I64, PTR};
use std::sync::Arc;
/// Read `data` of raw-bytes rows starting at `offsets` out to a record batch
pub fn read_as_batch(
- data: &mut [u8],
+ data: &[u8],
schema: Arc<Schema>,
offsets: Vec<usize>,
) -> Result<RecordBatch> {
@@ -44,6 +54,33 @@ pub fn read_as_batch(
output.output().map_err(DataFusionError::ArrowError)
}
+/// Read `data` of raw-bytes rows starting at `offsets` out to a record batch
+#[cfg(feature = "jit")]
+pub fn read_as_batch_jit(
+ data: &[u8],
+ schema: Arc<Schema>,
+ offsets: Vec<usize>,
+ assembler: &Assembler,
+) -> Result<RecordBatch> {
+ let row_num = offsets.len();
+ let mut output = MutableRecordBatch::new(row_num, schema.clone());
+ let mut row = RowReader::new(&schema, data);
+ register_read_functions(assembler)?;
+ let gen_func = gen_read_row(&schema, assembler)?;
+ let mut jit = assembler.create_jit();
+ let code_ptr = jit.compile(gen_func)?;
+ let code_fn = unsafe {
+ std::mem::transmute::<_, fn(&RowReader, &mut
MutableRecordBatch)>(code_ptr)
+ };
+
+ for offset in offsets.iter().take(row_num) {
+ row.point_to(*offset);
+ code_fn(&row, &mut output);
+ }
+
+ output.output().map_err(DataFusionError::ArrowError)
+}
+
macro_rules! get_idx {
($NATIVE: ident, $SELF: ident, $IDX: ident, $WIDTH: literal) => {{
$SELF.assert_index_valid($IDX);
@@ -260,6 +297,114 @@ fn read_row(row: &RowReader, batch: &mut
MutableRecordBatch, schema: &Arc<Schema
}
}
+#[cfg(feature = "jit")]
+fn get_array_mut(
+ batch: &mut MutableRecordBatch,
+ col_idx: usize,
+) -> &mut Box<dyn ArrayBuilder> {
+ let arrays: &mut [Box<dyn ArrayBuilder>] = batch.arrays.as_mut();
+ &mut arrays[col_idx]
+}
+
+#[cfg(feature = "jit")]
+fn register_read_functions(asm: &Assembler) -> Result<()> {
+ let reader_param = vec![PTR, I64, PTR];
+ reg_fn!(asm, get_array_mut, vec![PTR, I64], Some(PTR));
+ reg_fn!(asm, read_field_bool, reader_param.clone(), None);
+ reg_fn!(asm, read_field_u8, reader_param.clone(), None);
+ reg_fn!(asm, read_field_u16, reader_param.clone(), None);
+ reg_fn!(asm, read_field_u32, reader_param.clone(), None);
+ reg_fn!(asm, read_field_u64, reader_param.clone(), None);
+ reg_fn!(asm, read_field_i8, reader_param.clone(), None);
+ reg_fn!(asm, read_field_i16, reader_param.clone(), None);
+ reg_fn!(asm, read_field_i32, reader_param.clone(), None);
+ reg_fn!(asm, read_field_i64, reader_param.clone(), None);
+ reg_fn!(asm, read_field_f32, reader_param.clone(), None);
+ reg_fn!(asm, read_field_f64, reader_param.clone(), None);
+ reg_fn!(asm, read_field_date32, reader_param.clone(), None);
+ reg_fn!(asm, read_field_date64, reader_param.clone(), None);
+ reg_fn!(asm, read_field_utf8, reader_param.clone(), None);
+ reg_fn!(asm, read_field_binary, reader_param.clone(), None);
+ reg_fn!(asm, read_field_bool_nf, reader_param.clone(), None);
+ reg_fn!(asm, read_field_u8_nf, reader_param.clone(), None);
+ reg_fn!(asm, read_field_u16_nf, reader_param.clone(), None);
+ reg_fn!(asm, read_field_u32_nf, reader_param.clone(), None);
+ reg_fn!(asm, read_field_u64_nf, reader_param.clone(), None);
+ reg_fn!(asm, read_field_i8_nf, reader_param.clone(), None);
+ reg_fn!(asm, read_field_i16_nf, reader_param.clone(), None);
+ reg_fn!(asm, read_field_i32_nf, reader_param.clone(), None);
+ reg_fn!(asm, read_field_i64_nf, reader_param.clone(), None);
+ reg_fn!(asm, read_field_f32_nf, reader_param.clone(), None);
+ reg_fn!(asm, read_field_f64_nf, reader_param.clone(), None);
+ reg_fn!(asm, read_field_date32_nf, reader_param.clone(), None);
+ reg_fn!(asm, read_field_date64_nf, reader_param.clone(), None);
+ reg_fn!(asm, read_field_utf8_nf, reader_param.clone(), None);
+ reg_fn!(asm, read_field_binary_nf, reader_param, None);
+ Ok(())
+}
+
+#[cfg(feature = "jit")]
+fn gen_read_row(
+ schema: &Arc<Schema>,
+ assembler: &Assembler,
+) -> Result<GeneratedFunction> {
+ use DataType::*;
+ let mut builder = assembler
+ .new_func_builder("read_row")
+ .param("row", PTR)
+ .param("batch", PTR);
+ let mut b = builder.enter_block();
+ for (i, f) in schema.fields().iter().enumerate() {
+ let dt = f.data_type();
+ let arr = format!("a{}", i);
+ b.declare_as(
+ &arr,
+ b.call("get_array_mut", vec![b.id("batch")?, b.lit_i(i as i64)])?,
+ )?;
+ let params = vec![b.id(&arr)?, b.lit_i(i as i64), b.id("row")?];
+ if f.is_nullable() {
+ match dt {
+ Boolean => b.call_stmt("read_field_bool", params)?,
+ UInt8 => b.call_stmt("read_field_u8", params)?,
+ UInt16 => b.call_stmt("read_field_u16", params)?,
+ UInt32 => b.call_stmt("read_field_u32", params)?,
+ UInt64 => b.call_stmt("read_field_u64", params)?,
+ Int8 => b.call_stmt("read_field_i8", params)?,
+ Int16 => b.call_stmt("read_field_i16", params)?,
+ Int32 => b.call_stmt("read_field_i32", params)?,
+ Int64 => b.call_stmt("read_field_i64", params)?,
+ Float32 => b.call_stmt("read_field_f32", params)?,
+ Float64 => b.call_stmt("read_field_f64", params)?,
+ Date32 => b.call_stmt("read_field_date32", params)?,
+ Date64 => b.call_stmt("read_field_date64", params)?,
+ Utf8 => b.call_stmt("read_field_utf8", params)?,
+ Binary => b.call_stmt("read_field_binary", params)?,
+ _ => unimplemented!(),
+ }
+ } else {
+ match dt {
+ Boolean => b.call_stmt("read_field_bool_nf", params)?,
+ UInt8 => b.call_stmt("read_field_u8_nf", params)?,
+ UInt16 => b.call_stmt("read_field_u16_nf", params)?,
+ UInt32 => b.call_stmt("read_field_u32_nf", params)?,
+ UInt64 => b.call_stmt("read_field_u64_nf", params)?,
+ Int8 => b.call_stmt("read_field_i8_nf", params)?,
+ Int16 => b.call_stmt("read_field_i16_nf", params)?,
+ Int32 => b.call_stmt("read_field_i32_nf", params)?,
+ Int64 => b.call_stmt("read_field_i64_nf", params)?,
+ Float32 => b.call_stmt("read_field_f32_nf", params)?,
+ Float64 => b.call_stmt("read_field_f64_nf", params)?,
+ Date32 => b.call_stmt("read_field_date32_nf", params)?,
+ Date64 => b.call_stmt("read_field_date64_nf", params)?,
+ Utf8 => b.call_stmt("read_field_utf8_nf", params)?,
+ Binary => b.call_stmt("read_field_binary_nf", params)?,
+ _ => unimplemented!(),
+ }
+ }
+ }
+ Ok(b.build())
+}
+
macro_rules! fn_read_field {
($NATIVE: ident, $ARRAY: ident) => {
paste::item! {
diff --git a/datafusion/src/row/writer.rs b/datafusion/src/row/writer.rs
index 698f797..2206e35 100644
--- a/datafusion/src/row/writer.rs
+++ b/datafusion/src/row/writer.rs
@@ -17,11 +17,22 @@
//! Reusable row writer backed by Vec<u8> to stitch attributes together
+use crate::error::Result;
+#[cfg(feature = "jit")]
+use crate::reg_fn;
+#[cfg(feature = "jit")]
+use crate::row::fn_name;
use crate::row::{estimate_row_width, fixed_size, get_offsets, supported};
-use arrow::array::Array;
+use arrow::array::*;
use arrow::datatypes::{DataType, Schema};
use arrow::record_batch::RecordBatch;
use arrow::util::bit_util::{ceil, round_upto_power_of_2, set_bit_raw,
unset_bit_raw};
+use datafusion_jit::api::CodeBlock;
+#[cfg(feature = "jit")]
+use datafusion_jit::api::{Assembler, GeneratedFunction};
+use datafusion_jit::ast::Expr;
+#[cfg(feature = "jit")]
+use datafusion_jit::ast::{BOOL, I64, PTR};
use std::cmp::max;
use std::sync::Arc;
@@ -50,6 +61,103 @@ pub fn write_batch_unchecked(
offsets
}
+/// Append batch from `row_idx` to `output` buffer start from `offset`
+/// # Panics
+///
+/// This function will panic if the output buffer doesn't have enough space to
hold all the rows
+#[cfg(feature = "jit")]
+pub fn write_batch_unchecked_jit(
+ output: &mut [u8],
+ offset: usize,
+ batch: &RecordBatch,
+ row_idx: usize,
+ schema: Arc<Schema>,
+ assembler: &Assembler,
+) -> Result<Vec<usize>> {
+ let mut writer = RowWriter::new(&schema);
+ let mut current_offset = offset;
+ let mut offsets = vec![];
+ register_write_functions(assembler)?;
+ let gen_func = gen_write_row(&schema, assembler)?;
+ let mut jit = assembler.create_jit();
+ let code_ptr = jit.compile(gen_func)?;
+
+ let code_fn = unsafe {
+ std::mem::transmute::<_, fn(&mut RowWriter, usize,
&RecordBatch)>(code_ptr)
+ };
+
+ for cur_row in row_idx..batch.num_rows() {
+ offsets.push(current_offset);
+ code_fn(&mut writer, cur_row, batch);
+ writer.end_padding();
+ let row_width = writer.row_width;
+ output[current_offset..current_offset + row_width]
+ .copy_from_slice(writer.get_row());
+ current_offset += row_width;
+ writer.reset()
+ }
+ Ok(offsets)
+}
+
+#[cfg(feature = "jit")]
+/// bench interpreted version write
+pub fn bench_write_batch(
+ batches: &[Vec<RecordBatch>],
+ schema: Arc<Schema>,
+) -> Result<Vec<usize>> {
+ let mut writer = RowWriter::new(&schema);
+ let mut lengths = vec![];
+
+ for batch in batches.iter().flatten() {
+ for cur_row in 0..batch.num_rows() {
+ let row_width = write_row(&mut writer, cur_row, batch);
+ lengths.push(row_width);
+ writer.reset()
+ }
+ }
+
+ Ok(lengths)
+}
+
+#[cfg(feature = "jit")]
+/// bench jit version write
+pub fn bench_write_batch_jit(
+ batches: &[Vec<RecordBatch>],
+ schema: Arc<Schema>,
+) -> Result<Vec<usize>> {
+ let assembler = Assembler::default();
+ let mut writer = RowWriter::new(&schema);
+ let mut lengths = vec![];
+ register_write_functions(&assembler)?;
+ let gen_func = gen_write_row(&schema, &assembler)?;
+ let mut jit = assembler.create_jit();
+ let code_ptr = jit.compile(gen_func)?;
+ let code_fn = unsafe {
+ std::mem::transmute::<_, fn(&mut RowWriter, usize,
&RecordBatch)>(code_ptr)
+ };
+
+ for batch in batches.iter().flatten() {
+ for cur_row in 0..batch.num_rows() {
+ code_fn(&mut writer, cur_row, batch);
+ writer.end_padding();
+ lengths.push(writer.row_width);
+ writer.reset()
+ }
+ }
+ Ok(lengths)
+}
+
+#[cfg(feature = "jit")]
+/// bench code generation cost
+pub fn bench_write_batch_jit_dummy(schema: Arc<Schema>) -> Result<()> {
+ let assembler = Assembler::default();
+ register_write_functions(&assembler)?;
+ let gen_func = gen_write_row(&schema, &assembler)?;
+ let mut jit = assembler.create_jit();
+ let _: *const u8 = jit.compile(gen_func)?;
+ Ok(())
+}
+
macro_rules! set_idx {
($WIDTH: literal, $SELF: ident, $IDX: ident, $VALUE: ident) => {{
$SELF.assert_index_valid($IDX);
@@ -233,7 +341,6 @@ fn write_row(row: &mut RowWriter, row_idx: usize, batch:
&RecordBatch) -> usize
.zip(batch.columns().iter())
{
if !col.is_null(row_idx) {
- row.set_non_null_at(i);
write_field(i, row_idx, col, f.data_type(), row);
} else {
row.set_null_at(i);
@@ -244,6 +351,197 @@ fn write_row(row: &mut RowWriter, row_idx: usize, batch:
&RecordBatch) -> usize
row.row_width
}
+// we could remove this function wrapper once we find a way to call the trait
method directly.
+#[cfg(feature = "jit")]
+fn is_null(col: &Arc<dyn Array>, row_idx: usize) -> bool {
+ col.is_null(row_idx)
+}
+
+#[cfg(feature = "jit")]
+fn register_write_functions(asm: &Assembler) -> Result<()> {
+ let reader_param = vec![PTR, I64, PTR];
+ reg_fn!(asm, RecordBatch::column, vec![PTR, I64], Some(PTR));
+ reg_fn!(asm, RowWriter::set_null_at, vec![PTR, I64], None);
+ reg_fn!(asm, RowWriter::set_non_null_at, vec![PTR, I64], None);
+ reg_fn!(asm, is_null, vec![PTR, I64], Some(BOOL));
+ reg_fn!(asm, write_field_bool, reader_param.clone(), None);
+ reg_fn!(asm, write_field_u8, reader_param.clone(), None);
+ reg_fn!(asm, write_field_u16, reader_param.clone(), None);
+ reg_fn!(asm, write_field_u32, reader_param.clone(), None);
+ reg_fn!(asm, write_field_u64, reader_param.clone(), None);
+ reg_fn!(asm, write_field_i8, reader_param.clone(), None);
+ reg_fn!(asm, write_field_i16, reader_param.clone(), None);
+ reg_fn!(asm, write_field_i32, reader_param.clone(), None);
+ reg_fn!(asm, write_field_i64, reader_param.clone(), None);
+ reg_fn!(asm, write_field_f32, reader_param.clone(), None);
+ reg_fn!(asm, write_field_f64, reader_param.clone(), None);
+ reg_fn!(asm, write_field_date32, reader_param.clone(), None);
+ reg_fn!(asm, write_field_date64, reader_param.clone(), None);
+ reg_fn!(asm, write_field_utf8, reader_param.clone(), None);
+ reg_fn!(asm, write_field_binary, reader_param, None);
+ Ok(())
+}
+
+#[cfg(feature = "jit")]
+fn gen_write_row(
+ schema: &Arc<Schema>,
+ assembler: &Assembler,
+) -> Result<GeneratedFunction> {
+ let mut builder = assembler
+ .new_func_builder("write_row")
+ .param("row", PTR)
+ .param("row_idx", I64)
+ .param("batch", PTR);
+ let mut b = builder.enter_block();
+ for (i, f) in schema.fields().iter().enumerate() {
+ let dt = f.data_type();
+ let arr = format!("a{}", i);
+ b.declare_as(
+ &arr,
+ b.call("column", vec![b.id("batch")?, b.lit_i(i as i64)])?,
+ )?;
+ if f.is_nullable() {
+ b.if_block(
+ |c| c.call("is_null", vec![c.id(&arr)?, c.id("row_idx")?]),
+ |t| {
+ t.call_stmt("set_null_at", vec![t.id("row")?, t.lit_i(i as
i64)])?;
+ Ok(())
+ },
+ |e| {
+ e.call_stmt(
+ "set_non_null_at",
+ vec![e.id("row")?, e.lit_i(i as i64)],
+ )?;
+ let params = vec![
+ e.id("row")?,
+ e.id(&arr)?,
+ e.lit_i(i as i64),
+ e.id("row_idx")?,
+ ];
+ write_typed_field_stmt(dt, e, params)?;
+ Ok(())
+ },
+ )?;
+ } else {
+ b.call_stmt("set_non_null_at", vec![b.id("row")?, b.lit_i(i as
i64)])?;
+ let params = vec![
+ b.id("row")?,
+ b.id(&arr)?,
+ b.lit_i(i as i64),
+ b.id("row_idx")?,
+ ];
+ write_typed_field_stmt(dt, &mut b, params)?;
+ }
+ }
+ Ok(b.build())
+}
+
+#[cfg(feature = "jit")]
+fn write_typed_field_stmt<'a>(
+ dt: &DataType,
+ b: &mut CodeBlock<'a>,
+ params: Vec<Expr>,
+) -> Result<()> {
+ use DataType::*;
+ match dt {
+ Boolean => b.call_stmt("write_field_bool", params)?,
+ UInt8 => b.call_stmt("write_field_u8", params)?,
+ UInt16 => b.call_stmt("write_field_u16", params)?,
+ UInt32 => b.call_stmt("write_field_u32", params)?,
+ UInt64 => b.call_stmt("write_field_u64", params)?,
+ Int8 => b.call_stmt("write_field_i8", params)?,
+ Int16 => b.call_stmt("write_field_i16", params)?,
+ Int32 => b.call_stmt("write_field_i32", params)?,
+ Int64 => b.call_stmt("write_field_i64", params)?,
+ Float32 => b.call_stmt("write_field_f32", params)?,
+ Float64 => b.call_stmt("write_field_f64", params)?,
+ Date32 => b.call_stmt("write_field_date32", params)?,
+ Date64 => b.call_stmt("write_field_date64", params)?,
+ Utf8 => b.call_stmt("write_field_utf8", params)?,
+ Binary => b.call_stmt("write_field_binary", params)?,
+ _ => unimplemented!(),
+ }
+ Ok(())
+}
+
+macro_rules! fn_write_field {
+ ($NATIVE: ident, $ARRAY: ident) => {
+ paste::item! {
+ fn [<write_field_ $NATIVE>](to: &mut RowWriter, from: &Arc<dyn
Array>, col_idx: usize, row_idx: usize) {
+ let from = from
+ .as_any()
+ .downcast_ref::<$ARRAY>()
+ .unwrap();
+ to.[<set_ $NATIVE>](col_idx, from.value(row_idx));
+ }
+ }
+ };
+}
+
+fn_write_field!(bool, BooleanArray);
+fn_write_field!(u8, UInt8Array);
+fn_write_field!(u16, UInt16Array);
+fn_write_field!(u32, UInt32Array);
+fn_write_field!(u64, UInt64Array);
+fn_write_field!(i8, Int8Array);
+fn_write_field!(i16, Int16Array);
+fn_write_field!(i32, Int32Array);
+fn_write_field!(i64, Int64Array);
+fn_write_field!(f32, Float32Array);
+fn_write_field!(f64, Float64Array);
+
+fn write_field_date32(
+ to: &mut RowWriter,
+ from: &Arc<dyn Array>,
+ col_idx: usize,
+ row_idx: usize,
+) {
+ let from = from.as_any().downcast_ref::<Date32Array>().unwrap();
+ to.set_date32(col_idx, from.value(row_idx));
+}
+
+fn write_field_date64(
+ to: &mut RowWriter,
+ from: &Arc<dyn Array>,
+ col_idx: usize,
+ row_idx: usize,
+) {
+ let from = from.as_any().downcast_ref::<Date64Array>().unwrap();
+ to.set_date64(col_idx, from.value(row_idx));
+}
+
+fn write_field_utf8(
+ to: &mut RowWriter,
+ from: &Arc<dyn Array>,
+ col_idx: usize,
+ row_idx: usize,
+) {
+ let from = from.as_any().downcast_ref::<StringArray>().unwrap();
+ let s = from.value(row_idx);
+ let new_width = to.current_width() + s.as_bytes().len();
+ if new_width > to.data.capacity() {
+ // double the capacity to avoid repeated resize
+ to.data.resize(max(to.data.capacity() * 2, new_width), 0);
+ }
+ to.set_utf8(col_idx, s);
+}
+
+fn write_field_binary(
+ to: &mut RowWriter,
+ from: &Arc<dyn Array>,
+ col_idx: usize,
+ row_idx: usize,
+) {
+ let from = from.as_any().downcast_ref::<BinaryArray>().unwrap();
+ let s = from.value(row_idx);
+ let new_width = to.current_width() + s.len();
+ if new_width > to.data.capacity() {
+ // double the capacity to avoid repeated resize
+ to.data.resize(max(to.data.capacity() * 2, new_width), 0);
+ }
+ to.set_binary(col_idx, s);
+}
+
fn write_field(
col_idx: usize,
row_idx: usize,
@@ -251,82 +549,24 @@ fn write_field(
dt: &DataType,
row: &mut RowWriter,
) {
- // TODO: JIT compile this
- use arrow::array::*;
use DataType::*;
+ row.set_non_null_at(col_idx);
match dt {
- Boolean => {
- let c = col.as_any().downcast_ref::<BooleanArray>().unwrap();
- row.set_bool(col_idx, c.value(row_idx));
- }
- UInt8 => {
- let c = col.as_any().downcast_ref::<UInt8Array>().unwrap();
- row.set_u8(col_idx, c.value(row_idx));
- }
- UInt16 => {
- let c = col.as_any().downcast_ref::<UInt16Array>().unwrap();
- row.set_u16(col_idx, c.value(row_idx));
- }
- UInt32 => {
- let c = col.as_any().downcast_ref::<UInt32Array>().unwrap();
- row.set_u32(col_idx, c.value(row_idx));
- }
- UInt64 => {
- let c = col.as_any().downcast_ref::<UInt64Array>().unwrap();
- row.set_u64(col_idx, c.value(row_idx));
- }
- Int8 => {
- let c = col.as_any().downcast_ref::<Int8Array>().unwrap();
- row.set_i8(col_idx, c.value(row_idx));
- }
- Int16 => {
- let c = col.as_any().downcast_ref::<Int16Array>().unwrap();
- row.set_i16(col_idx, c.value(row_idx));
- }
- Int32 => {
- let c = col.as_any().downcast_ref::<Int32Array>().unwrap();
- row.set_i32(col_idx, c.value(row_idx));
- }
- Int64 => {
- let c = col.as_any().downcast_ref::<Int64Array>().unwrap();
- row.set_i64(col_idx, c.value(row_idx));
- }
- Float32 => {
- let c = col.as_any().downcast_ref::<Float32Array>().unwrap();
- row.set_f32(col_idx, c.value(row_idx));
- }
- Float64 => {
- let c = col.as_any().downcast_ref::<Float64Array>().unwrap();
- row.set_f64(col_idx, c.value(row_idx));
- }
- Date32 => {
- let c = col.as_any().downcast_ref::<Date32Array>().unwrap();
- row.set_date32(col_idx, c.value(row_idx));
- }
- Date64 => {
- let c = col.as_any().downcast_ref::<Date64Array>().unwrap();
- row.set_date64(col_idx, c.value(row_idx));
- }
- Utf8 => {
- let c = col.as_any().downcast_ref::<StringArray>().unwrap();
- let s = c.value(row_idx);
- let new_width = row.current_width() + s.as_bytes().len();
- if new_width > row.data.capacity() {
- // double the capacity to avoid repeated resize
- row.data.resize(max(row.data.capacity() * 2, new_width), 0);
- }
- row.set_utf8(col_idx, s);
- }
- Binary => {
- let c = col.as_any().downcast_ref::<BinaryArray>().unwrap();
- let binary = c.value(row_idx);
- let new_width = row.current_width() + binary.len();
- if new_width > row.data.capacity() {
- // double the capacity to avoid repeated resize
- row.data.resize(max(row.data.capacity() * 2, new_width), 0);
- }
- row.set_binary(col_idx, binary);
- }
+ Boolean => write_field_bool(row, col, col_idx, row_idx),
+ UInt8 => write_field_u8(row, col, col_idx, row_idx),
+ UInt16 => write_field_u16(row, col, col_idx, row_idx),
+ UInt32 => write_field_u32(row, col, col_idx, row_idx),
+ UInt64 => write_field_u64(row, col, col_idx, row_idx),
+ Int8 => write_field_i8(row, col, col_idx, row_idx),
+ Int16 => write_field_i16(row, col, col_idx, row_idx),
+ Int32 => write_field_i32(row, col, col_idx, row_idx),
+ Int64 => write_field_i64(row, col, col_idx, row_idx),
+ Float32 => write_field_f32(row, col, col_idx, row_idx),
+ Float64 => write_field_f64(row, col, col_idx, row_idx),
+ Date32 => write_field_date32(row, col, col_idx, row_idx),
+ Date64 => write_field_date64(row, col, col_idx, row_idx),
+ Utf8 => write_field_utf8(row, col, col_idx, row_idx),
+ Binary => write_field_binary(row, col, col_idx, row_idx),
_ => unimplemented!(),
}
}