You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2022/05/24 14:35:20 UTC
[arrow-datafusion] branch master updated: Evaluate JIT'd expression over arrays (#2587)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new a9cc38a89 Evaluate JIT'd expression over arrays (#2587)
a9cc38a89 is described below
commit a9cc38a89ae8194caf0ce7dd54472a0428828e7c
Author: Ruihang Xia <wa...@gmail.com>
AuthorDate: Tue May 24 22:35:14 2022 +0800
Evaluate JIT'd expression over arrays (#2587)
* pipeline expr wrapper
* clean up
Signed-off-by: Ruihang Xia <wa...@gmail.com>
* Apply suggestions from code review
Add doc for `deref()` and `store()`
Co-authored-by: Andrew Lamb <an...@nerdnetworks.org>
* CR improvement: doc, naming and hardcode
Signed-off-by: Ruihang Xia <wa...@gmail.com>
Co-authored-by: Andrew Lamb <an...@nerdnetworks.org>
---
datafusion/jit/src/api.rs | 13 +++
datafusion/jit/src/ast.rs | 50 ++++++----
datafusion/jit/src/compile.rs | 207 ++++++++++++++++++++++++++++++++++++++++++
datafusion/jit/src/jit.rs | 14 +++
datafusion/jit/src/lib.rs | 1 +
datafusion/row/src/lib.rs | 4 +
6 files changed, 271 insertions(+), 18 deletions(-)
diff --git a/datafusion/jit/src/api.rs b/datafusion/jit/src/api.rs
index d95f9ccc7..7020985a7 100644
--- a/datafusion/jit/src/api.rs
+++ b/datafusion/jit/src/api.rs
@@ -153,6 +153,7 @@ impl FunctionBuilder {
}
/// Add one more parameter to the function.
+ #[must_use]
pub fn param(mut self, name: impl Into<String>, ty: JITType) -> Self {
let name = name.into();
assert!(!self.fields.back().unwrap().contains_key(&name));
@@ -163,6 +164,7 @@ impl FunctionBuilder {
/// Set return type for the function. Functions are of `void` type by default if
/// you do not set the return type.
+ #[must_use]
pub fn ret(mut self, name: impl Into<String>, ty: JITType) -> Self {
let name = name.into();
assert!(!self.fields.back().unwrap().contains_key(&name));
@@ -604,6 +606,17 @@ impl<'a> CodeBlock<'a> {
internal_err!("No func with the name {} exist", fn_name)
}
}
+
+ /// Return the value pointed to by the ptr stored in `ptr`
+ pub fn load(&self, ptr: Expr, ty: JITType) -> Result<Expr> {
+ Ok(Expr::Load(Box::new(ptr), ty))
+ }
+
+ /// Store the value in `value` to the address in `ptr`
+ pub fn store(&mut self, value: Expr, ptr: Expr) -> Result<()> {
+ self.stmts.push(Stmt::Store(Box::new(value), Box::new(ptr)));
+ Ok(())
+ }
}
impl Display for GeneratedFunction {
diff --git a/datafusion/jit/src/ast.rs b/datafusion/jit/src/ast.rs
index fd10a909e..55731a650 100644
--- a/datafusion/jit/src/ast.rs
+++ b/datafusion/jit/src/ast.rs
@@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.
+use arrow::datatypes::DataType;
use cranelift::codegen::ir;
use datafusion_common::{DFSchemaRef, DataFusionError, ScalarValue};
use std::fmt::{Display, Formatter};
@@ -32,6 +33,8 @@ pub enum Stmt {
Call(String, Vec<Expr>),
/// declare a new variable of type
Declare(String, JITType),
+ /// store value (the first expr) to an address (the second expr)
+ Store(Box<Expr>, Box<Expr>),
}
#[derive(Copy, Clone, Debug, PartialEq)]
@@ -54,6 +57,8 @@ pub enum Expr {
Binary(BinaryExpr),
/// call function expression
Call(String, Vec<Expr>, JITType),
+ /// Load a value from pointer
+ Load(Box<Expr>, JITType),
}
impl Expr {
@@ -63,6 +68,7 @@ impl Expr {
Expr::Identifier(_, ty) => *ty,
Expr::Binary(bin) => bin.get_type(),
Expr::Call(_, _, ty) => *ty,
+ Expr::Load(_, ty) => *ty,
}
}
}
@@ -174,19 +180,7 @@ impl TryFrom<(datafusion_expr::Expr, DFSchemaRef)> for Expr {
let field = schema.field_from_column(col)?;
let ty = field.data_type();
- let jit_type = match ty {
- arrow::datatypes::DataType::Int64 => I64,
- arrow::datatypes::DataType::Float32 => F32,
- arrow::datatypes::DataType::Float64 => F64,
- arrow::datatypes::DataType::Boolean => BOOL,
-
- _ => {
- return Err(DataFusionError::NotImplemented(format!(
- "Compiling Expression with type {} not yet supported in JIT mode",
- ty
- )))
- }
- };
+ let jit_type = JITType::try_from(ty)?;
Ok(Expr::Identifier(field.qualified_name(), jit_type))
}
@@ -272,12 +266,28 @@ pub const R64: JITType = JITType {
native: ir::types::R64,
code: 0x7f,
};
+pub const PTR_SIZE: usize = std::mem::size_of::<usize>();
/// The pointer type to use based on our currently target.
-pub const PTR: JITType = if std::mem::size_of::<usize>() == 8 {
- R64
-} else {
- R32
-};
+pub const PTR: JITType = if PTR_SIZE == 8 { R64 } else { R32 };
+
+impl TryFrom<&DataType> for JITType {
+ type Error = DataFusionError;
+
+ /// Try to convert DataFusion's [DataType] to [JITType]
+ fn try_from(df_type: &DataType) -> Result<Self, Self::Error> {
+ match df_type {
+ DataType::Int64 => Ok(I64),
+ DataType::Float32 => Ok(F32),
+ DataType::Float64 => Ok(F64),
+ DataType::Boolean => Ok(BOOL),
+
+ _ => Err(DataFusionError::NotImplemented(format!(
+ "Compiling Expression with type {} not yet supported in JIT mode",
+ df_type
+ ))),
+ }
+ }
+}
impl Stmt {
/// print the statement with indentation
@@ -323,6 +333,9 @@ impl Stmt {
Stmt::Declare(name, ty) => {
writeln!(f, "{}let {}: {};", ident_str, name, ty)
}
+ Stmt::Store(value, ptr) => {
+ writeln!(f, "{}*({}) = {}", ident_str, ptr, value)
+ }
}
}
}
@@ -352,6 +365,7 @@ impl Display for Expr {
.join(", ")
)
}
+ Expr::Load(ptr, _) => write!(f, "*({})", ptr,),
}
}
}
diff --git a/datafusion/jit/src/compile.rs b/datafusion/jit/src/compile.rs
new file mode 100644
index 000000000..4e68b5210
--- /dev/null
+++ b/datafusion/jit/src/compile.rs
@@ -0,0 +1,207 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Compile DataFusion Expr to JIT'd function.
+
+use datafusion_common::Result;
+
+use crate::api::Assembler;
+use crate::ast::{JITType, I32};
+use crate::{
+ api::GeneratedFunction,
+ ast::{Expr as JITExpr, I64, PTR_SIZE},
+};
+
+/// Wrap JIT Expr to array compute function.
+pub fn build_calc_fn(
+ assembler: &Assembler,
+ jit_expr: JITExpr,
+ inputs: Vec<(String, JITType)>,
+ ret_type: JITType,
+) -> Result<GeneratedFunction> {
+ // Alias pointer type.
+ // The raw pointer `R64` or `R32` is not compatible with integers.
+ const PTR_TYPE: JITType = if PTR_SIZE == 8 { I64 } else { I32 };
+
+ let mut builder = assembler.new_func_builder("calc_fn");
+ // Declare in-param.
+ // Each input takes one position, following by a pointer to place result,
+ // and the last is the length of inputs/output arrays.
+ for (name, _) in &inputs {
+ builder = builder.param(format!("{}_array", name), PTR_TYPE);
+ }
+ let mut builder = builder.param("result", ret_type).param("len", I64);
+
+ // Start build function body.
+ // It's loop that calculates the result one by one.
+ let mut fn_body = builder.enter_block();
+ fn_body.declare_as("index", fn_body.lit_i(0))?;
+ fn_body.while_block(
+ |cond| cond.lt(cond.id("index")?, cond.id("len")?),
+ |w| {
+ w.declare_as("offset", w.mul(w.id("index")?, w.lit_i(PTR_SIZE as i64))?)?;
+ for (name, ty) in &inputs {
+ w.declare_as(
+ format!("{}_ptr", name),
+ w.add(w.id(format!("{}_array", name))?, w.id("offset")?)?,
+ )?;
+ w.declare_as(name, w.load(w.id(format!("{}_ptr", name))?, *ty)?)?;
+ }
+ w.declare_as("res_ptr", w.add(w.id("result")?, w.id("offset")?)?)?;
+ w.declare_as("res", jit_expr.clone())?;
+ w.store(w.id("res")?, w.id("res_ptr")?)?;
+
+ w.assign("index", w.add(w.id("index")?, w.lit_i(1))?)?;
+ Ok(())
+ },
+ )?;
+
+ let gen_func = fn_body.build();
+ Ok(gen_func)
+}
+
+#[cfg(test)]
+mod test {
+ use std::{collections::HashMap, sync::Arc};
+
+ use arrow::{
+ array::{Array, PrimitiveArray},
+ datatypes::{DataType, Int64Type},
+ };
+ use datafusion_common::{DFField, DFSchema, DataFusionError};
+ use datafusion_expr::Expr as DFExpr;
+
+ use crate::ast::BinaryExpr;
+
+ use super::*;
+
+ fn run_df_expr(
+ df_expr: DFExpr,
+ schema: Arc<DFSchema>,
+ lhs: PrimitiveArray<Int64Type>,
+ rhs: PrimitiveArray<Int64Type>,
+ ) -> Result<PrimitiveArray<Int64Type>> {
+ if lhs.null_count() != 0 || rhs.null_count() != 0 {
+ return Err(DataFusionError::NotImplemented(
+ "Computing on nullable array not yet supported".to_string(),
+ ));
+ }
+ if lhs.len() != rhs.len() {
+ return Err(DataFusionError::NotImplemented(
+ "Computing on different length arrays not yet supported".to_string(),
+ ));
+ }
+
+ // translate DF Expr to JIT Expr
+ let input_fields = schema
+ .fields()
+ .iter()
+ .map(|field| {
+ Ok((
+ field.qualified_name(),
+ JITType::try_from(field.data_type())?,
+ ))
+ })
+ .collect::<Result<Vec<_>>>()?;
+ let jit_expr: JITExpr = (df_expr, schema).try_into()?;
+
+ // allocate memory for calc result
+ let len = lhs.len();
+ let result = vec![0i64; len];
+
+ // compile and run JIT code
+ let assembler = Assembler::default();
+ let gen_func = build_calc_fn(&assembler, jit_expr, input_fields, I64)?;
+ let mut jit = assembler.create_jit();
+ let code_ptr = jit.compile(gen_func)?;
+ let code_fn = unsafe {
+ core::mem::transmute::<_, fn(*const i64, *const i64, *const i64, i64) -> ()>(
+ code_ptr,
+ )
+ };
+ code_fn(
+ lhs.values().as_ptr(),
+ rhs.values().as_ptr(),
+ result.as_ptr(),
+ len as i64,
+ );
+
+ let result_array = PrimitiveArray::<Int64Type>::from_iter(result);
+ Ok(result_array)
+ }
+
+ #[test]
+ fn array_add() {
+ let array_a: PrimitiveArray<Int64Type> =
+ PrimitiveArray::from_iter_values((0..10).map(|x| x + 1));
+ let array_b: PrimitiveArray<Int64Type> =
+ PrimitiveArray::from_iter_values((10..20).map(|x| x + 1));
+ let expected =
+ arrow::compute::kernels::arithmetic::add(&array_a, &array_b).unwrap();
+
+ let df_expr = datafusion_expr::col("a") + datafusion_expr::col("b");
+ let schema = Arc::new(
+ DFSchema::new_with_metadata(
+ vec![
+ DFField::new(Some("table1"), "a", DataType::Int64, false),
+ DFField::new(Some("table1"), "b", DataType::Int64, false),
+ ],
+ HashMap::new(),
+ )
+ .unwrap(),
+ );
+
+ let result = run_df_expr(df_expr, schema, array_a, array_b).unwrap();
+ assert_eq!(result, expected);
+ }
+
+ #[test]
+ fn calc_fn_builder() {
+ let expr = JITExpr::Binary(BinaryExpr::Add(
+ Box::new(JITExpr::Identifier("table1.a".to_string(), I64)),
+ Box::new(JITExpr::Identifier("table1.b".to_string(), I64)),
+ ));
+ let fields = vec![("table1.a".to_string(), I64), ("table1.b".to_string(), I64)];
+
+ let expected = r#"fn calc_fn_0(table1.a_array: i64, table1.b_array: i64, result: i64, len: i64) -> () {
+ let index: i64;
+ index = 0;
+ while index < len {
+ let offset: i64;
+ offset = index * 8;
+ let table1.a_ptr: i64;
+ table1.a_ptr = table1.a_array + offset;
+ let table1.a: i64;
+ table1.a = *(table1.a_ptr);
+ let table1.b_ptr: i64;
+ table1.b_ptr = table1.b_array + offset;
+ let table1.b: i64;
+ table1.b = *(table1.b_ptr);
+ let res_ptr: i64;
+ res_ptr = result + offset;
+ let res: i64;
+ res = table1.a + table1.b;
+ *(res_ptr) = res
+ index = index + 1;
+ }
+}"#;
+
+ let assembler = Assembler::default();
+ let gen_func = build_calc_fn(&assembler, expr, fields, I64).unwrap();
+ assert_eq!(format!("{}", &gen_func), expected);
+ }
+}
diff --git a/datafusion/jit/src/jit.rs b/datafusion/jit/src/jit.rs
index 0460cc805..21b0d44fb 100644
--- a/datafusion/jit/src/jit.rs
+++ b/datafusion/jit/src/jit.rs
@@ -263,6 +263,7 @@ impl<'a> FunctionTranslator<'a> {
Ok(())
}
Stmt::Declare(_, _) => Ok(()),
+ Stmt::Store(value, ptr) => self.translate_store(*ptr, *value),
}
}
@@ -289,6 +290,7 @@ impl<'a> FunctionTranslator<'a> {
}
Expr::Binary(b) => self.translate_binary_expr(b),
Expr::Call(name, args, ret) => self.translate_call_expr(name, args, ret),
+ Expr::Load(ptr, ty) => self.translate_deref(*ptr, ty),
}
}
@@ -462,6 +464,18 @@ impl<'a> FunctionTranslator<'a> {
Ok(())
}
+ fn translate_deref(&mut self, ptr: Expr, ty: JITType) -> Result<Value> {
+ let ptr = self.translate_expr(ptr)?;
+ Ok(self.builder.ins().load(ty.native, MemFlags::new(), ptr, 0))
+ }
+
+ fn translate_store(&mut self, ptr: Expr, value: Expr) -> Result<()> {
+ let ptr = self.translate_expr(ptr)?;
+ let value = self.translate_expr(value)?;
+ self.builder.ins().store(MemFlags::new(), value, ptr, 0);
+ Ok(())
+ }
+
fn translate_icmp(&mut self, cmp: IntCC, lhs: Expr, rhs: Expr) -> Result<Value> {
let lhs = self.translate_expr(lhs)?;
let rhs = self.translate_expr(rhs)?;
diff --git a/datafusion/jit/src/lib.rs b/datafusion/jit/src/lib.rs
index dff27da31..377d32d8a 100644
--- a/datafusion/jit/src/lib.rs
+++ b/datafusion/jit/src/lib.rs
@@ -19,6 +19,7 @@
pub mod api;
pub mod ast;
+pub mod compile;
pub mod jit;
#[cfg(test)]
diff --git a/datafusion/row/src/lib.rs b/datafusion/row/src/lib.rs
index c05cbcd0e..d77c37063 100644
--- a/datafusion/row/src/lib.rs
+++ b/datafusion/row/src/lib.rs
@@ -30,10 +30,12 @@
//! we append their actual content to the end of the var length region and
//! store their offset relative to row base and their length, packed into an 8-byte word.
//!
+//! ```plaintext
//! ┌────────────────┬──────────────────────────┬───────────────────────┐ ┌───────────────────────┬────────────┐
//! │Validity Bitmask│ Fixed Width Field │ Variable Width Field │ ... │ vardata area │ padding │
//! │ (byte aligned) │ (native type width) │(vardata offset + len) │ │ (variable length) │ bytes │
//! └────────────────┴──────────────────────────┴───────────────────────┘ └───────────────────────┴────────────┘
+//! ```
//!
//! For example, given the schema (Int8, Utf8, Float32, Utf8)
//!
@@ -41,10 +43,12 @@
//!
//! Requires 32 bytes (31 bytes payload and 1 byte padding to make each tuple 8-bytes aligned):
//!
+//! ```plaintext
//! ┌──────────┬──────────┬──────────────────────┬──────────────┬──────────────────────┬───────────────────────┬──────────┐
//! │0b00001011│ 0x01 │0x00000016 0x00000006│ 0x00000000 │0x0000001C 0x00000003│ FooBarbaz │ 0x00 │
//! └──────────┴──────────┴──────────────────────┴──────────────┴──────────────────────┴───────────────────────┴──────────┘
//! 0 1 2 10 14 22 31 32
+//! ```
//!
use arrow::array::{make_builder, ArrayBuilder, ArrayRef};