You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2022/05/24 14:35:20 UTC

[arrow-datafusion] branch master updated: Evaluate JIT'd expression over arrays (#2587)

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new a9cc38a89 Evaluate JIT'd expression over arrays (#2587)
a9cc38a89 is described below

commit a9cc38a89ae8194caf0ce7dd54472a0428828e7c
Author: Ruihang Xia <wa...@gmail.com>
AuthorDate: Tue May 24 22:35:14 2022 +0800

    Evaluate JIT'd expression over arrays (#2587)
    
    * pipeline expr wrapper
    
    * clean up
    
    Signed-off-by: Ruihang Xia <wa...@gmail.com>
    
    * Apply suggestions from code review
    
    Add doc for `deref()` and `store()`
    
    Co-authored-by: Andrew Lamb <an...@nerdnetworks.org>
    
    * CR improvement: doc, naming and hardcode
    
    Signed-off-by: Ruihang Xia <wa...@gmail.com>
    
    Co-authored-by: Andrew Lamb <an...@nerdnetworks.org>
---
 datafusion/jit/src/api.rs     |  13 +++
 datafusion/jit/src/ast.rs     |  50 ++++++----
 datafusion/jit/src/compile.rs | 207 ++++++++++++++++++++++++++++++++++++++++++
 datafusion/jit/src/jit.rs     |  14 +++
 datafusion/jit/src/lib.rs     |   1 +
 datafusion/row/src/lib.rs     |   4 +
 6 files changed, 271 insertions(+), 18 deletions(-)

diff --git a/datafusion/jit/src/api.rs b/datafusion/jit/src/api.rs
index d95f9ccc7..7020985a7 100644
--- a/datafusion/jit/src/api.rs
+++ b/datafusion/jit/src/api.rs
@@ -153,6 +153,7 @@ impl FunctionBuilder {
     }
 
     /// Add one more parameter to the function.
+    #[must_use]
     pub fn param(mut self, name: impl Into<String>, ty: JITType) -> Self {
         let name = name.into();
         assert!(!self.fields.back().unwrap().contains_key(&name));
@@ -163,6 +164,7 @@ impl FunctionBuilder {
 
     /// Set return type for the function. Functions are of `void` type by default if
     /// you do not set the return type.
+    #[must_use]
     pub fn ret(mut self, name: impl Into<String>, ty: JITType) -> Self {
         let name = name.into();
         assert!(!self.fields.back().unwrap().contains_key(&name));
@@ -604,6 +606,17 @@ impl<'a> CodeBlock<'a> {
             internal_err!("No func with the name {} exist", fn_name)
         }
     }
+
+    /// Return the value pointed to by the ptr stored in `ptr`
+    pub fn load(&self, ptr: Expr, ty: JITType) -> Result<Expr> {
+        Ok(Expr::Load(Box::new(ptr), ty))
+    }
+
+    /// Store the value in `value` to the address in `ptr`
+    pub fn store(&mut self, value: Expr, ptr: Expr) -> Result<()> {
+        self.stmts.push(Stmt::Store(Box::new(value), Box::new(ptr)));
+        Ok(())
+    }
 }
 
 impl Display for GeneratedFunction {
diff --git a/datafusion/jit/src/ast.rs b/datafusion/jit/src/ast.rs
index fd10a909e..55731a650 100644
--- a/datafusion/jit/src/ast.rs
+++ b/datafusion/jit/src/ast.rs
@@ -15,6 +15,7 @@
 // specific language governing permissions and limitations
 // under the License.
 
+use arrow::datatypes::DataType;
 use cranelift::codegen::ir;
 use datafusion_common::{DFSchemaRef, DataFusionError, ScalarValue};
 use std::fmt::{Display, Formatter};
@@ -32,6 +33,8 @@ pub enum Stmt {
     Call(String, Vec<Expr>),
     /// declare a new variable of type
     Declare(String, JITType),
+    /// store value (the first expr) to an address (the second expr)
+    Store(Box<Expr>, Box<Expr>),
 }
 
 #[derive(Copy, Clone, Debug, PartialEq)]
@@ -54,6 +57,8 @@ pub enum Expr {
     Binary(BinaryExpr),
     /// call function expression
     Call(String, Vec<Expr>, JITType),
+    /// Load a value from pointer
+    Load(Box<Expr>, JITType),
 }
 
 impl Expr {
@@ -63,6 +68,7 @@ impl Expr {
             Expr::Identifier(_, ty) => *ty,
             Expr::Binary(bin) => bin.get_type(),
             Expr::Call(_, _, ty) => *ty,
+            Expr::Load(_, ty) => *ty,
         }
     }
 }
@@ -174,19 +180,7 @@ impl TryFrom<(datafusion_expr::Expr, DFSchemaRef)> for Expr {
                 let field = schema.field_from_column(col)?;
                 let ty = field.data_type();
 
-                let jit_type = match ty {
-                    arrow::datatypes::DataType::Int64 => I64,
-                    arrow::datatypes::DataType::Float32 => F32,
-                    arrow::datatypes::DataType::Float64 => F64,
-                    arrow::datatypes::DataType::Boolean => BOOL,
-
-                    _ => {
-                        return Err(DataFusionError::NotImplemented(format!(
-                        "Compiling Expression with type {} not yet supported in JIT mode",
-                        ty
-                    )))
-                    }
-                };
+                let jit_type = JITType::try_from(ty)?;
 
                 Ok(Expr::Identifier(field.qualified_name(), jit_type))
             }
@@ -272,12 +266,28 @@ pub const R64: JITType = JITType {
     native: ir::types::R64,
     code: 0x7f,
 };
+pub const PTR_SIZE: usize = std::mem::size_of::<usize>();
 /// The pointer type to use based on our currently target.
-pub const PTR: JITType = if std::mem::size_of::<usize>() == 8 {
-    R64
-} else {
-    R32
-};
+pub const PTR: JITType = if PTR_SIZE == 8 { R64 } else { R32 };
+
+impl TryFrom<&DataType> for JITType {
+    type Error = DataFusionError;
+
+    /// Try to convert DataFusion's [DataType] to [JITType]
+    fn try_from(df_type: &DataType) -> Result<Self, Self::Error> {
+        match df_type {
+            DataType::Int64 => Ok(I64),
+            DataType::Float32 => Ok(F32),
+            DataType::Float64 => Ok(F64),
+            DataType::Boolean => Ok(BOOL),
+
+            _ => Err(DataFusionError::NotImplemented(format!(
+                "Compiling Expression with type {} not yet supported in JIT mode",
+                df_type
+            ))),
+        }
+    }
+}
 
 impl Stmt {
     /// print the statement with indentation
@@ -323,6 +333,9 @@ impl Stmt {
             Stmt::Declare(name, ty) => {
                 writeln!(f, "{}let {}: {};", ident_str, name, ty)
             }
+            Stmt::Store(value, ptr) => {
+                writeln!(f, "{}*({}) = {}", ident_str, ptr, value)
+            }
         }
     }
 }
@@ -352,6 +365,7 @@ impl Display for Expr {
                         .join(", ")
                 )
             }
+            Expr::Load(ptr, _) => write!(f, "*({})", ptr,),
         }
     }
 }
diff --git a/datafusion/jit/src/compile.rs b/datafusion/jit/src/compile.rs
new file mode 100644
index 000000000..4e68b5210
--- /dev/null
+++ b/datafusion/jit/src/compile.rs
@@ -0,0 +1,207 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Compile DataFusion Expr to JIT'd function.
+
+use datafusion_common::Result;
+
+use crate::api::Assembler;
+use crate::ast::{JITType, I32};
+use crate::{
+    api::GeneratedFunction,
+    ast::{Expr as JITExpr, I64, PTR_SIZE},
+};
+
+/// Wrap JIT Expr to array compute function.
+pub fn build_calc_fn(
+    assembler: &Assembler,
+    jit_expr: JITExpr,
+    inputs: Vec<(String, JITType)>,
+    ret_type: JITType,
+) -> Result<GeneratedFunction> {
+    // Alias pointer type.
+    // The raw pointer `R64` or `R32` is not compatible with integers.
+    const PTR_TYPE: JITType = if PTR_SIZE == 8 { I64 } else { I32 };
+
+    let mut builder = assembler.new_func_builder("calc_fn");
+    // Declare in-param.
+    // Each input takes one position, following by a pointer to place result,
+    // and the last is the length of inputs/output arrays.
+    for (name, _) in &inputs {
+        builder = builder.param(format!("{}_array", name), PTR_TYPE);
+    }
+    let mut builder = builder.param("result", ret_type).param("len", I64);
+
+    // Start build function body.
+    // It's loop that calculates the result one by one.
+    let mut fn_body = builder.enter_block();
+    fn_body.declare_as("index", fn_body.lit_i(0))?;
+    fn_body.while_block(
+        |cond| cond.lt(cond.id("index")?, cond.id("len")?),
+        |w| {
+            w.declare_as("offset", w.mul(w.id("index")?, w.lit_i(PTR_SIZE as i64))?)?;
+            for (name, ty) in &inputs {
+                w.declare_as(
+                    format!("{}_ptr", name),
+                    w.add(w.id(format!("{}_array", name))?, w.id("offset")?)?,
+                )?;
+                w.declare_as(name, w.load(w.id(format!("{}_ptr", name))?, *ty)?)?;
+            }
+            w.declare_as("res_ptr", w.add(w.id("result")?, w.id("offset")?)?)?;
+            w.declare_as("res", jit_expr.clone())?;
+            w.store(w.id("res")?, w.id("res_ptr")?)?;
+
+            w.assign("index", w.add(w.id("index")?, w.lit_i(1))?)?;
+            Ok(())
+        },
+    )?;
+
+    let gen_func = fn_body.build();
+    Ok(gen_func)
+}
+
+#[cfg(test)]
+mod test {
+    use std::{collections::HashMap, sync::Arc};
+
+    use arrow::{
+        array::{Array, PrimitiveArray},
+        datatypes::{DataType, Int64Type},
+    };
+    use datafusion_common::{DFField, DFSchema, DataFusionError};
+    use datafusion_expr::Expr as DFExpr;
+
+    use crate::ast::BinaryExpr;
+
+    use super::*;
+
+    fn run_df_expr(
+        df_expr: DFExpr,
+        schema: Arc<DFSchema>,
+        lhs: PrimitiveArray<Int64Type>,
+        rhs: PrimitiveArray<Int64Type>,
+    ) -> Result<PrimitiveArray<Int64Type>> {
+        if lhs.null_count() != 0 || rhs.null_count() != 0 {
+            return Err(DataFusionError::NotImplemented(
+                "Computing on nullable array not yet supported".to_string(),
+            ));
+        }
+        if lhs.len() != rhs.len() {
+            return Err(DataFusionError::NotImplemented(
+                "Computing on different length arrays not yet supported".to_string(),
+            ));
+        }
+
+        // translate DF Expr to JIT Expr
+        let input_fields = schema
+            .fields()
+            .iter()
+            .map(|field| {
+                Ok((
+                    field.qualified_name(),
+                    JITType::try_from(field.data_type())?,
+                ))
+            })
+            .collect::<Result<Vec<_>>>()?;
+        let jit_expr: JITExpr = (df_expr, schema).try_into()?;
+
+        // allocate memory for calc result
+        let len = lhs.len();
+        let result = vec![0i64; len];
+
+        // compile and run JIT code
+        let assembler = Assembler::default();
+        let gen_func = build_calc_fn(&assembler, jit_expr, input_fields, I64)?;
+        let mut jit = assembler.create_jit();
+        let code_ptr = jit.compile(gen_func)?;
+        let code_fn = unsafe {
+            core::mem::transmute::<_, fn(*const i64, *const i64, *const i64, i64) -> ()>(
+                code_ptr,
+            )
+        };
+        code_fn(
+            lhs.values().as_ptr(),
+            rhs.values().as_ptr(),
+            result.as_ptr(),
+            len as i64,
+        );
+
+        let result_array = PrimitiveArray::<Int64Type>::from_iter(result);
+        Ok(result_array)
+    }
+
+    #[test]
+    fn array_add() {
+        let array_a: PrimitiveArray<Int64Type> =
+            PrimitiveArray::from_iter_values((0..10).map(|x| x + 1));
+        let array_b: PrimitiveArray<Int64Type> =
+            PrimitiveArray::from_iter_values((10..20).map(|x| x + 1));
+        let expected =
+            arrow::compute::kernels::arithmetic::add(&array_a, &array_b).unwrap();
+
+        let df_expr = datafusion_expr::col("a") + datafusion_expr::col("b");
+        let schema = Arc::new(
+            DFSchema::new_with_metadata(
+                vec![
+                    DFField::new(Some("table1"), "a", DataType::Int64, false),
+                    DFField::new(Some("table1"), "b", DataType::Int64, false),
+                ],
+                HashMap::new(),
+            )
+            .unwrap(),
+        );
+
+        let result = run_df_expr(df_expr, schema, array_a, array_b).unwrap();
+        assert_eq!(result, expected);
+    }
+
+    #[test]
+    fn calc_fn_builder() {
+        let expr = JITExpr::Binary(BinaryExpr::Add(
+            Box::new(JITExpr::Identifier("table1.a".to_string(), I64)),
+            Box::new(JITExpr::Identifier("table1.b".to_string(), I64)),
+        ));
+        let fields = vec![("table1.a".to_string(), I64), ("table1.b".to_string(), I64)];
+
+        let expected = r#"fn calc_fn_0(table1.a_array: i64, table1.b_array: i64, result: i64, len: i64) -> () {
+    let index: i64;
+    index = 0;
+    while index < len {
+        let offset: i64;
+        offset = index * 8;
+        let table1.a_ptr: i64;
+        table1.a_ptr = table1.a_array + offset;
+        let table1.a: i64;
+        table1.a = *(table1.a_ptr);
+        let table1.b_ptr: i64;
+        table1.b_ptr = table1.b_array + offset;
+        let table1.b: i64;
+        table1.b = *(table1.b_ptr);
+        let res_ptr: i64;
+        res_ptr = result + offset;
+        let res: i64;
+        res = table1.a + table1.b;
+        *(res_ptr) = res
+        index = index + 1;
+    }
+}"#;
+
+        let assembler = Assembler::default();
+        let gen_func = build_calc_fn(&assembler, expr, fields, I64).unwrap();
+        assert_eq!(format!("{}", &gen_func), expected);
+    }
+}
diff --git a/datafusion/jit/src/jit.rs b/datafusion/jit/src/jit.rs
index 0460cc805..21b0d44fb 100644
--- a/datafusion/jit/src/jit.rs
+++ b/datafusion/jit/src/jit.rs
@@ -263,6 +263,7 @@ impl<'a> FunctionTranslator<'a> {
                 Ok(())
             }
             Stmt::Declare(_, _) => Ok(()),
+            Stmt::Store(value, ptr) => self.translate_store(*ptr, *value),
         }
     }
 
@@ -289,6 +290,7 @@ impl<'a> FunctionTranslator<'a> {
             }
             Expr::Binary(b) => self.translate_binary_expr(b),
             Expr::Call(name, args, ret) => self.translate_call_expr(name, args, ret),
+            Expr::Load(ptr, ty) => self.translate_deref(*ptr, ty),
         }
     }
 
@@ -462,6 +464,18 @@ impl<'a> FunctionTranslator<'a> {
         Ok(())
     }
 
+    fn translate_deref(&mut self, ptr: Expr, ty: JITType) -> Result<Value> {
+        let ptr = self.translate_expr(ptr)?;
+        Ok(self.builder.ins().load(ty.native, MemFlags::new(), ptr, 0))
+    }
+
+    fn translate_store(&mut self, ptr: Expr, value: Expr) -> Result<()> {
+        let ptr = self.translate_expr(ptr)?;
+        let value = self.translate_expr(value)?;
+        self.builder.ins().store(MemFlags::new(), value, ptr, 0);
+        Ok(())
+    }
+
     fn translate_icmp(&mut self, cmp: IntCC, lhs: Expr, rhs: Expr) -> Result<Value> {
         let lhs = self.translate_expr(lhs)?;
         let rhs = self.translate_expr(rhs)?;
diff --git a/datafusion/jit/src/lib.rs b/datafusion/jit/src/lib.rs
index dff27da31..377d32d8a 100644
--- a/datafusion/jit/src/lib.rs
+++ b/datafusion/jit/src/lib.rs
@@ -19,6 +19,7 @@
 
 pub mod api;
 pub mod ast;
+pub mod compile;
 pub mod jit;
 
 #[cfg(test)]
diff --git a/datafusion/row/src/lib.rs b/datafusion/row/src/lib.rs
index c05cbcd0e..d77c37063 100644
--- a/datafusion/row/src/lib.rs
+++ b/datafusion/row/src/lib.rs
@@ -30,10 +30,12 @@
 //!       we append their actual content to the end of the var length region and
 //!       store their offset relative to row base and their length, packed into an 8-byte word.
 //!
+//! ```plaintext
 //! ┌────────────────┬──────────────────────────┬───────────────────────┐        ┌───────────────────────┬────────────┐
 //! │Validity Bitmask│    Fixed Width Field     │ Variable Width Field  │   ...  │     vardata area      │  padding   │
 //! │ (byte aligned) │   (native type width)    │(vardata offset + len) │        │   (variable length)   │   bytes    │
 //! └────────────────┴──────────────────────────┴───────────────────────┘        └───────────────────────┴────────────┘
+//! ```
 //!
 //!  For example, given the schema (Int8, Utf8, Float32, Utf8)
 //!
@@ -41,10 +43,12 @@
 //!
 //!  Requires 32 bytes (31 bytes payload and 1 byte padding to make each tuple 8-bytes aligned):
 //!
+//! ```plaintext
 //! ┌──────────┬──────────┬──────────────────────┬──────────────┬──────────────────────┬───────────────────────┬──────────┐
 //! │0b00001011│   0x01   │0x00000016  0x00000006│  0x00000000  │0x0000001C  0x00000003│       FooBarbaz       │   0x00   │
 //! └──────────┴──────────┴──────────────────────┴──────────────┴──────────────────────┴───────────────────────┴──────────┘
 //! 0          1          2                     10              14                     22                     31         32
+//! ```
 //!
 
 use arrow::array::{make_builder, ArrayBuilder, ArrayRef};