You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2014/05/16 22:41:45 UTC

git commit: SPARK-1487 [SQL] Support record filtering via predicate pushdown in Parquet

Repository: spark
Updated Branches:
  refs/heads/master 032d6632a -> 40d6acd6b


SPARK-1487 [SQL] Support record filtering via predicate pushdown in Parquet

Simple filter predicates such as LessThan, GreaterThan, etc., where one side is a literal and the other one a NamedExpression are now pushed down to the underlying ParquetTableScan. Here are some results for a microbenchmark with a simple schema of six fields of different types where most records failed the test:

             | Uncompressed  | Compressed
-------------| ------------- | -------------
File size  |     10 GB  | 2 GB
Speedup |      2         | 1.8

Since mileage may vary I added a new option to SparkConf:

`org.apache.spark.sql.parquet.filter.pushdown`

Default value would be `true` and setting it to `false` disables the pushdown. When most rows are expected to pass the filter or when there are few fields performance can be better when pushdown is disabled. The default should fit situations with a reasonable number of (possibly nested) fields where not too many records on average pass the filter.

Because of an issue with Parquet ([see here](https://github.com/Parquet/parquet-mr/issues/371])) currently only predicates on non-nullable attributes are pushed down. If one would know that for a given table no optional fields have missing values one could also allow overriding this.

Author: Andre Schumacher <an...@iki.fi>

Closes #511 from AndreSchumacher/parquet_filter and squashes the following commits:

16bfe83 [Andre Schumacher] Removing leftovers from merge during rebase
7b304ca [Andre Schumacher] Fixing formatting
c36d5cb [Andre Schumacher] Scalastyle
3da98db [Andre Schumacher] Second round of review feedback
7a78265 [Andre Schumacher] Fixing broken formatting in ParquetFilter
a86553b [Andre Schumacher] First round of code review feedback
b0f7806 [Andre Schumacher] Optimizing imports in ParquetTestData
85fea2d [Andre Schumacher] Adding SparkConf setting to disable filter predicate pushdown
f0ad3cf [Andre Schumacher] Undoing changes not needed for this PR
210e9cb [Andre Schumacher] Adding disjunctive filter predicates
a93a588 [Andre Schumacher] Adding unit test for filtering
6d22666 [Andre Schumacher] Extending ParquetFilters
93e8192 [Andre Schumacher] First commit Parquet record filtering


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/40d6acd6
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/40d6acd6
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/40d6acd6

Branch: refs/heads/master
Commit: 40d6acd6ba2feccc600301f5c47d4f90157138b1
Parents: 032d663
Author: Andre Schumacher <an...@iki.fi>
Authored: Fri May 16 13:41:41 2014 -0700
Committer: Reynold Xin <rx...@apache.org>
Committed: Fri May 16 13:41:41 2014 -0700

----------------------------------------------------------------------
 .../spark/sql/execution/SparkStrategies.scala   |  31 +-
 .../spark/sql/parquet/ParquetFilters.scala      | 436 +++++++++++++++++++
 .../sql/parquet/ParquetTableOperations.scala    |  90 +++-
 .../spark/sql/parquet/ParquetTestData.scala     |  90 +++-
 .../spark/sql/parquet/ParquetQuerySuite.scala   | 135 +++++-
 5 files changed, 731 insertions(+), 51 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/40d6acd6/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index f763106..394a597 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -140,12 +140,35 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
         InsertIntoParquetTable(relation, planLater(child), overwrite=true)(sparkContext) :: Nil
       case logical.InsertIntoTable(table: ParquetRelation, partition, child, overwrite) =>
         InsertIntoParquetTable(table, planLater(child), overwrite)(sparkContext) :: Nil
-      case PhysicalOperation(projectList, filters, relation: ParquetRelation) =>
-        // TODO: Should be pushing down filters as well.
+      case PhysicalOperation(projectList, filters: Seq[Expression], relation: ParquetRelation) => {
+        val remainingFilters =
+          if (sparkContext.conf.getBoolean(ParquetFilters.PARQUET_FILTER_PUSHDOWN_ENABLED, true)) {
+            filters.filter {
+              // Note: filters cannot be pushed down to Parquet if they contain more complex
+              // expressions than simple "Attribute cmp Literal" comparisons. Here we remove
+              // all filters that have been pushed down. Note that a predicate such as
+              // "(A AND B) OR C" can result in "A OR C" being pushed down.
+              filter =>
+                val recordFilter = ParquetFilters.createFilter(filter)
+                if (!recordFilter.isDefined) {
+                  // First case: the pushdown did not result in any record filter.
+                  true
+                } else {
+                  // Second case: a record filter was created; here we are conservative in
+                  // the sense that even if "A" was pushed and we check for "A AND B" we
+                  // still want to keep "A AND B" in the higher-level filter, not just "B".
+                  !ParquetFilters.findExpression(recordFilter.get, filter).isDefined
+                }
+            }
+          } else {
+            filters
+          }
         pruneFilterProject(
           projectList,
-          filters,
-          ParquetTableScan(_, relation, None)(sparkContext)) :: Nil
+          remainingFilters,
+          ParquetTableScan(_, relation, filters)(sparkContext)) :: Nil
+      }
+
       case _ => Nil
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/40d6acd6/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala
new file mode 100644
index 0000000..052b0a9
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala
@@ -0,0 +1,436 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.parquet
+
+import org.apache.hadoop.conf.Configuration
+
+import parquet.filter._
+import parquet.filter.ColumnPredicates._
+import parquet.column.ColumnReader
+
+import com.google.common.io.BaseEncoding
+
+import org.apache.spark.sql.catalyst.types._
+import org.apache.spark.sql.catalyst.expressions.{Predicate => CatalystPredicate}
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.execution.SparkSqlSerializer
+
+object ParquetFilters {
+  val PARQUET_FILTER_DATA = "org.apache.spark.sql.parquet.row.filter"
+  // set this to false if pushdown should be disabled
+  val PARQUET_FILTER_PUSHDOWN_ENABLED = "spark.sql.hints.parquetFilterPushdown"
+
+  def createRecordFilter(filterExpressions: Seq[Expression]): UnboundRecordFilter = {
+    val filters: Seq[CatalystFilter] = filterExpressions.collect {
+      case (expression: Expression) if createFilter(expression).isDefined =>
+        createFilter(expression).get
+    }
+    if (filters.length > 0) filters.reduce(AndRecordFilter.and) else null
+  }
+
+  def createFilter(expression: Expression): Option[CatalystFilter] = {
+    def createEqualityFilter(
+        name: String,
+        literal: Literal,
+        predicate: CatalystPredicate) = literal.dataType match {
+      case BooleanType =>
+        ComparisonFilter.createBooleanFilter(name, literal.value.asInstanceOf[Boolean], predicate)
+      case IntegerType =>
+        ComparisonFilter.createIntFilter(
+          name,
+          (x: Int) => x == literal.value.asInstanceOf[Int],
+          predicate)
+      case LongType =>
+        ComparisonFilter.createLongFilter(
+          name,
+          (x: Long) => x == literal.value.asInstanceOf[Long],
+          predicate)
+      case DoubleType =>
+        ComparisonFilter.createDoubleFilter(
+          name,
+          (x: Double) => x == literal.value.asInstanceOf[Double],
+          predicate)
+      case FloatType =>
+        ComparisonFilter.createFloatFilter(
+          name,
+          (x: Float) => x == literal.value.asInstanceOf[Float],
+          predicate)
+      case StringType =>
+        ComparisonFilter.createStringFilter(name, literal.value.asInstanceOf[String], predicate)
+    }
+    def createLessThanFilter(
+        name: String,
+        literal: Literal,
+        predicate: CatalystPredicate) = literal.dataType match {
+      case IntegerType =>
+        ComparisonFilter.createIntFilter(
+          name,
+          (x: Int) => x < literal.value.asInstanceOf[Int],
+          predicate)
+      case LongType =>
+        ComparisonFilter.createLongFilter(
+          name,
+          (x: Long) => x < literal.value.asInstanceOf[Long],
+          predicate)
+      case DoubleType =>
+        ComparisonFilter.createDoubleFilter(
+          name,
+          (x: Double) => x < literal.value.asInstanceOf[Double],
+          predicate)
+      case FloatType =>
+        ComparisonFilter.createFloatFilter(
+          name,
+          (x: Float) => x < literal.value.asInstanceOf[Float],
+          predicate)
+    }
+    def createLessThanOrEqualFilter(
+        name: String,
+        literal: Literal,
+        predicate: CatalystPredicate) = literal.dataType match {
+      case IntegerType =>
+        ComparisonFilter.createIntFilter(
+          name,
+          (x: Int) => x <= literal.value.asInstanceOf[Int],
+          predicate)
+      case LongType =>
+        ComparisonFilter.createLongFilter(
+          name,
+          (x: Long) => x <= literal.value.asInstanceOf[Long],
+          predicate)
+      case DoubleType =>
+        ComparisonFilter.createDoubleFilter(
+          name,
+          (x: Double) => x <= literal.value.asInstanceOf[Double],
+          predicate)
+      case FloatType =>
+        ComparisonFilter.createFloatFilter(
+          name,
+          (x: Float) => x <= literal.value.asInstanceOf[Float],
+          predicate)
+    }
+    // TODO: combine these two types somehow?
+    def createGreaterThanFilter(
+        name: String,
+        literal: Literal,
+        predicate: CatalystPredicate) = literal.dataType match {
+      case IntegerType =>
+        ComparisonFilter.createIntFilter(
+          name,
+          (x: Int) => x > literal.value.asInstanceOf[Int],
+          predicate)
+      case LongType =>
+        ComparisonFilter.createLongFilter(
+          name,
+          (x: Long) => x > literal.value.asInstanceOf[Long],
+          predicate)
+      case DoubleType =>
+        ComparisonFilter.createDoubleFilter(
+          name,
+          (x: Double) => x > literal.value.asInstanceOf[Double],
+          predicate)
+      case FloatType =>
+        ComparisonFilter.createFloatFilter(
+          name,
+          (x: Float) => x > literal.value.asInstanceOf[Float],
+          predicate)
+    }
+    def createGreaterThanOrEqualFilter(
+        name: String,
+        literal: Literal,
+        predicate: CatalystPredicate) = literal.dataType match {
+      case IntegerType =>
+        ComparisonFilter.createIntFilter(
+          name, (x: Int) => x >= literal.value.asInstanceOf[Int],
+          predicate)
+      case LongType =>
+        ComparisonFilter.createLongFilter(
+          name,
+          (x: Long) => x >= literal.value.asInstanceOf[Long],
+          predicate)
+      case DoubleType =>
+        ComparisonFilter.createDoubleFilter(
+          name,
+          (x: Double) => x >= literal.value.asInstanceOf[Double],
+          predicate)
+      case FloatType =>
+        ComparisonFilter.createFloatFilter(
+          name,
+          (x: Float) => x >= literal.value.asInstanceOf[Float],
+          predicate)
+    }
+
+    /**
+     * TODO: we currently only filter on non-nullable (Parquet REQUIRED) attributes until
+     * https://github.com/Parquet/parquet-mr/issues/371
+     * has been resolved.
+     */
+    expression match {
+      case p @ Or(left: Expression, right: Expression)
+          if createFilter(left).isDefined && createFilter(right).isDefined => {
+        // If either side of this Or-predicate is empty then this means
+        // it contains a more complex comparison than between attribute and literal
+        // (e.g., it contained a CAST). The only safe thing to do is then to disregard
+        // this disjunction, which could be contained in a conjunction. If it stands
+        // alone then it is also safe to drop it, since a Null return value of this
+        // function is interpreted as having no filters at all.
+        val leftFilter = createFilter(left).get
+        val rightFilter = createFilter(right).get
+        Some(new OrFilter(leftFilter, rightFilter))
+      }
+      case p @ And(left: Expression, right: Expression) => {
+        // This treats nested conjunctions; since either side of the conjunction
+        // may contain more complex filter expressions we may actually generate
+        // strictly weaker filter predicates in the process.
+        val leftFilter = createFilter(left)
+        val rightFilter = createFilter(right)
+        (leftFilter, rightFilter) match {
+          case (None, Some(filter)) => Some(filter)
+          case (Some(filter), None) => Some(filter)
+          case (_, _) =>
+            Some(new AndFilter(leftFilter.get, rightFilter.get))
+        }
+      }
+      case p @ Equals(left: Literal, right: NamedExpression) if !right.nullable =>
+        Some(createEqualityFilter(right.name, left, p))
+      case p @ Equals(left: NamedExpression, right: Literal) if !left.nullable =>
+        Some(createEqualityFilter(left.name, right, p))
+      case p @ LessThan(left: Literal, right: NamedExpression) if !right.nullable =>
+        Some(createLessThanFilter(right.name, left, p))
+      case p @ LessThan(left: NamedExpression, right: Literal) if !left.nullable =>
+        Some(createLessThanFilter(left.name, right, p))
+      case p @ LessThanOrEqual(left: Literal, right: NamedExpression) if !right.nullable =>
+        Some(createLessThanOrEqualFilter(right.name, left, p))
+      case p @ LessThanOrEqual(left: NamedExpression, right: Literal) if !left.nullable =>
+        Some(createLessThanOrEqualFilter(left.name, right, p))
+      case p @ GreaterThan(left: Literal, right: NamedExpression) if !right.nullable =>
+        Some(createGreaterThanFilter(right.name, left, p))
+      case p @ GreaterThan(left: NamedExpression, right: Literal) if !left.nullable =>
+        Some(createGreaterThanFilter(left.name, right, p))
+      case p @ GreaterThanOrEqual(left: Literal, right: NamedExpression) if !right.nullable =>
+        Some(createGreaterThanOrEqualFilter(right.name, left, p))
+      case p @ GreaterThanOrEqual(left: NamedExpression, right: Literal) if !left.nullable =>
+        Some(createGreaterThanOrEqualFilter(left.name, right, p))
+      case _ => None
+    }
+  }
+
+  /**
+   * Note: Inside the Hadoop API we only have access to `Configuration`, not to
+   * [[org.apache.spark.SparkContext]], so we cannot use broadcasts to convey
+   * the actual filter predicate.
+   */
+  def serializeFilterExpressions(filters: Seq[Expression], conf: Configuration): Unit = {
+    if (filters.length > 0) {
+      val serialized: Array[Byte] = SparkSqlSerializer.serialize(filters)
+      val encoded: String = BaseEncoding.base64().encode(serialized)
+      conf.set(PARQUET_FILTER_DATA, encoded)
+    }
+  }
+
+  /**
+   * Note: Inside the Hadoop API we only have access to `Configuration`, not to
+   * [[org.apache.spark.SparkContext]], so we cannot use broadcasts to convey
+   * the actual filter predicate.
+   */
+  def deserializeFilterExpressions(conf: Configuration): Seq[Expression] = {
+    val data = conf.get(PARQUET_FILTER_DATA)
+    if (data != null) {
+      val decoded: Array[Byte] = BaseEncoding.base64().decode(data)
+      SparkSqlSerializer.deserialize(decoded)
+    } else {
+      Seq()
+    }
+  }
+
+  /**
+   * Try to find the given expression in the tree of filters in order to
+   * determine whether it is safe to remove it from the higher level filters. Note
+   * that strictly speaking we could stop the search whenever an expression is found
+   * that contains this expression as subexpression (e.g., when searching for "a"
+   * and "(a or c)" is found) but we don't care about optimizations here since the
+   * filter tree is assumed to be small.
+   *
+   * @param filter The [[org.apache.spark.sql.parquet.CatalystFilter]] to expand
+   *               and search
+   * @param expression The expression to look for
+   * @return An optional [[org.apache.spark.sql.parquet.CatalystFilter]] that
+   *         contains the expression.
+   */
+  def findExpression(
+      filter: CatalystFilter,
+      expression: Expression): Option[CatalystFilter] = filter match {
+    case f @ OrFilter(_, leftFilter, rightFilter, _) =>
+      if (f.predicate == expression) {
+        Some(f)
+      } else {
+        val left = findExpression(leftFilter, expression)
+        if (left.isDefined) left else findExpression(rightFilter, expression)
+      }
+    case f @ AndFilter(_, leftFilter, rightFilter, _) =>
+      if (f.predicate == expression) {
+        Some(f)
+      } else {
+        val left = findExpression(leftFilter, expression)
+        if (left.isDefined) left else findExpression(rightFilter, expression)
+      }
+    case f @ ComparisonFilter(_, _, predicate) =>
+      if (predicate == expression) Some(f) else None
+    case _ => None
+  }
+}
+
+abstract private[parquet] class CatalystFilter(
+    @transient val predicate: CatalystPredicate) extends UnboundRecordFilter
+
+private[parquet] case class ComparisonFilter(
+    val columnName: String,
+    private var filter: UnboundRecordFilter,
+    @transient override val predicate: CatalystPredicate)
+  extends CatalystFilter(predicate) {
+  override def bind(readers: java.lang.Iterable[ColumnReader]): RecordFilter = {
+    filter.bind(readers)
+  }
+}
+
+private[parquet] case class OrFilter(
+    private var filter: UnboundRecordFilter,
+    @transient val left: CatalystFilter,
+    @transient val right: CatalystFilter,
+    @transient override val predicate: Or)
+  extends CatalystFilter(predicate) {
+  def this(l: CatalystFilter, r: CatalystFilter) =
+    this(
+      OrRecordFilter.or(l, r),
+      l,
+      r,
+      Or(l.predicate, r.predicate))
+
+  override def bind(readers: java.lang.Iterable[ColumnReader]): RecordFilter = {
+    filter.bind(readers)
+  }
+}
+
+private[parquet] case class AndFilter(
+    private var filter: UnboundRecordFilter,
+    @transient val left: CatalystFilter,
+    @transient val right: CatalystFilter,
+    @transient override val predicate: And)
+  extends CatalystFilter(predicate) {
+  def this(l: CatalystFilter, r: CatalystFilter) =
+    this(
+      AndRecordFilter.and(l, r),
+      l,
+      r,
+      And(l.predicate, r.predicate))
+
+  override def bind(readers: java.lang.Iterable[ColumnReader]): RecordFilter = {
+    filter.bind(readers)
+  }
+}
+
+private[parquet] object ComparisonFilter {
+  def createBooleanFilter(
+      columnName: String,
+      value: Boolean,
+      predicate: CatalystPredicate): CatalystFilter =
+    new ComparisonFilter(
+      columnName,
+      ColumnRecordFilter.column(
+        columnName,
+        ColumnPredicates.applyFunctionToBoolean(
+          new BooleanPredicateFunction {
+            def functionToApply(input: Boolean): Boolean = input == value
+          }
+      )),
+      predicate)
+
+  def createStringFilter(
+      columnName: String,
+      value: String,
+      predicate: CatalystPredicate): CatalystFilter =
+    new ComparisonFilter(
+      columnName,
+      ColumnRecordFilter.column(
+        columnName,
+        ColumnPredicates.applyFunctionToString (
+          new ColumnPredicates.PredicateFunction[String]  {
+            def functionToApply(input: String): Boolean = input == value
+          }
+      )),
+      predicate)
+
+  def createIntFilter(
+      columnName: String,
+      func: Int => Boolean,
+      predicate: CatalystPredicate): CatalystFilter =
+    new ComparisonFilter(
+      columnName,
+      ColumnRecordFilter.column(
+        columnName,
+        ColumnPredicates.applyFunctionToInteger(
+          new IntegerPredicateFunction {
+            def functionToApply(input: Int) = func(input)
+          }
+      )),
+      predicate)
+
+  def createLongFilter(
+      columnName: String,
+      func: Long => Boolean,
+      predicate: CatalystPredicate): CatalystFilter =
+    new ComparisonFilter(
+      columnName,
+      ColumnRecordFilter.column(
+        columnName,
+        ColumnPredicates.applyFunctionToLong(
+          new LongPredicateFunction {
+            def functionToApply(input: Long) = func(input)
+          }
+      )),
+      predicate)
+
+  def createDoubleFilter(
+      columnName: String,
+      func: Double => Boolean,
+      predicate: CatalystPredicate): CatalystFilter =
+    new ComparisonFilter(
+      columnName,
+      ColumnRecordFilter.column(
+        columnName,
+        ColumnPredicates.applyFunctionToDouble(
+          new DoublePredicateFunction {
+            def functionToApply(input: Double) = func(input)
+          }
+      )),
+      predicate)
+
+  def createFloatFilter(
+      columnName: String,
+      func: Float => Boolean,
+      predicate: CatalystPredicate): CatalystFilter =
+    new ComparisonFilter(
+      columnName,
+      ColumnRecordFilter.column(
+        columnName,
+        ColumnPredicates.applyFunctionToFloat(
+          new FloatPredicateFunction {
+            def functionToApply(input: Float) = func(input)
+          }
+      )),
+      predicate)
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/40d6acd6/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
index f825ca3..65ba124 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
@@ -27,26 +27,27 @@ import org.apache.hadoop.mapreduce._
 import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat}
 import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat => NewFileOutputFormat, FileOutputCommitter}
 
-import parquet.hadoop.{ParquetInputFormat, ParquetOutputFormat}
+import parquet.hadoop.{ParquetRecordReader, ParquetInputFormat, ParquetOutputFormat}
+import parquet.hadoop.api.ReadSupport
 import parquet.hadoop.util.ContextUtil
 import parquet.io.InvalidRecordException
 import parquet.schema.MessageType
 
-import org.apache.spark.{SerializableWritable, SparkContext, TaskContext}
+import org.apache.spark.{Logging, SerializableWritable, SparkContext, TaskContext}
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Row}
 import org.apache.spark.sql.execution.{LeafNode, SparkPlan, UnaryNode}
 
 /**
  * Parquet table scan operator. Imports the file that backs the given
- * [[ParquetRelation]] as a RDD[Row].
+ * [[org.apache.spark.sql.parquet.ParquetRelation]] as a ``RDD[Row]``.
  */
 case class ParquetTableScan(
     // note: output cannot be transient, see
     // https://issues.apache.org/jira/browse/SPARK-1367
     output: Seq[Attribute],
     relation: ParquetRelation,
-    columnPruningPred: Option[Expression])(
+    columnPruningPred: Seq[Expression])(
     @transient val sc: SparkContext)
   extends LeafNode {
 
@@ -62,18 +63,30 @@ case class ParquetTableScan(
     for (path <- fileList if !path.getName.startsWith("_")) {
       NewFileInputFormat.addInputPath(job, path)
     }
+
+    // Store Parquet schema in `Configuration`
     conf.set(
         RowReadSupport.PARQUET_ROW_REQUESTED_SCHEMA,
         ParquetTypesConverter.convertFromAttributes(output).toString)
-    // TODO: think about adding record filters
-    /* Comments regarding record filters: it would be nice to push down as much filtering
-      to Parquet as possible. However, currently it seems we cannot pass enough information
-      to materialize an (arbitrary) Catalyst [[Predicate]] inside Parquet's
-      ``FilteredRecordReader`` (via Configuration, for example). Simple
-      filter-rows-by-column-values however should be supported.
-    */
-    sc.newAPIHadoopRDD(conf, classOf[ParquetInputFormat[Row]], classOf[Void], classOf[Row])
-    .map(_._2)
+
+    // Store record filtering predicate in `Configuration`
+    // Note 1: the input format ignores all predicates that cannot be expressed
+    // as simple column predicate filters in Parquet. Here we just record
+    // the whole pruning predicate.
+    // Note 2: you can disable filter predicate pushdown by setting
+    // "spark.sql.hints.parquetFilterPushdown" to false inside SparkConf.
+    if (columnPruningPred.length > 0 &&
+      sc.conf.getBoolean(ParquetFilters.PARQUET_FILTER_PUSHDOWN_ENABLED, true)) {
+      ParquetFilters.serializeFilterExpressions(columnPruningPred, conf)
+    }
+
+    sc.newAPIHadoopRDD(
+      conf,
+      classOf[org.apache.spark.sql.parquet.FilteringParquetRowInputFormat],
+      classOf[Void],
+      classOf[Row])
+      .map(_._2)
+      .filter(_ != null) // Parquet's record filters may produce null values
   }
 
   override def otherCopyArgs = sc :: Nil
@@ -184,10 +197,19 @@ case class InsertIntoParquetTable(
 
   override def otherCopyArgs = sc :: Nil
 
-  // based on ``saveAsNewAPIHadoopFile`` in [[PairRDDFunctions]]
-  // TODO: Maybe PairRDDFunctions should use Product2 instead of Tuple2?
-  // .. then we could use the default one and could use [[MutablePair]]
-  // instead of ``Tuple2``
+  /**
+   * Stores the given Row RDD as a Hadoop file.
+   *
+   * Note: We cannot use ``saveAsNewAPIHadoopFile`` from [[org.apache.spark.rdd.PairRDDFunctions]]
+   * together with [[org.apache.spark.util.MutablePair]] because ``PairRDDFunctions`` uses
+   * ``Tuple2`` and not ``Product2``. Also, we want to allow appending files to an existing
+   * directory and need to determine which was the largest written file index before starting to
+   * write.
+   *
+   * @param rdd The [[org.apache.spark.rdd.RDD]] to writer
+   * @param path The directory to write to.
+   * @param conf A [[org.apache.hadoop.conf.Configuration]].
+   */
   private def saveAsHadoopFile(
       rdd: RDD[Row],
       path: String,
@@ -244,8 +266,10 @@ case class InsertIntoParquetTable(
   }
 }
 
-// TODO: this will be able to append to directories it created itself, not necessarily
-// to imported ones
+/**
+ * TODO: this will be able to append to directories it created itself, not necessarily
+ * to imported ones.
+ */
 private[parquet] class AppendingParquetOutputFormat(offset: Int)
   extends parquet.hadoop.ParquetOutputFormat[Row] {
   // override to accept existing directories as valid output directory
@@ -262,6 +286,30 @@ private[parquet] class AppendingParquetOutputFormat(offset: Int)
   }
 }
 
+/**
+ * We extend ParquetInputFormat in order to have more control over which
+ * RecordFilter we want to use.
+ */
+private[parquet] class FilteringParquetRowInputFormat
+  extends parquet.hadoop.ParquetInputFormat[Row] with Logging {
+  override def createRecordReader(
+      inputSplit: InputSplit,
+      taskAttemptContext: TaskAttemptContext): RecordReader[Void, Row] = {
+    val readSupport: ReadSupport[Row] = new RowReadSupport()
+
+    val filterExpressions =
+      ParquetFilters.deserializeFilterExpressions(ContextUtil.getConfiguration(taskAttemptContext))
+    if (filterExpressions.length > 0) {
+      logInfo(s"Pushing down predicates for RecordFilter: ${filterExpressions.mkString(", ")}")
+      new ParquetRecordReader[Row](
+        readSupport,
+        ParquetFilters.createRecordFilter(filterExpressions))
+    } else {
+      new ParquetRecordReader[Row](readSupport)
+    }
+  }
+}
+
 private[parquet] object FileSystemHelper {
   def listFiles(pathStr: String, conf: Configuration): Seq[Path] = {
     val origPath = new Path(pathStr)
@@ -278,7 +326,9 @@ private[parquet] object FileSystemHelper {
     fs.listStatus(path).map(_.getPath)
   }
 
-  // finds the maximum taskid in the output file names at the given path
+    /**
+     * Finds the maximum taskid in the output file names at the given path.
+     */
   def findMaxTaskId(pathStr: String, conf: Configuration): Int = {
     val files = FileSystemHelper.listFiles(pathStr, conf)
     // filename pattern is part-r-<int>.parquet

http://git-wip-us.apache.org/repos/asf/spark/blob/40d6acd6/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala
index f37976f..46c7172 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala
@@ -19,15 +19,34 @@ package org.apache.spark.sql.parquet
 
 import org.apache.hadoop.conf.Configuration
 import org.apache.hadoop.fs.Path
-import org.apache.hadoop.mapreduce.Job
 
+import parquet.example.data.{GroupWriter, Group}
+import parquet.example.data.simple.SimpleGroup
 import parquet.hadoop.ParquetWriter
-import parquet.hadoop.util.ContextUtil
+import parquet.hadoop.api.WriteSupport
+import parquet.hadoop.api.WriteSupport.WriteContext
+import parquet.io.api.RecordConsumer
 import parquet.schema.{MessageType, MessageTypeParser}
 
-import org.apache.spark.sql.catalyst.expressions.GenericRow
 import org.apache.spark.util.Utils
 
+// Write support class for nested groups: ParquetWriter initializes GroupWriteSupport
+// with an empty configuration (it is after all not intended to be used in this way?)
+// and members are private so we need to make our own in order to pass the schema
+// to the writer.
+private class TestGroupWriteSupport(schema: MessageType) extends WriteSupport[Group] {
+  var groupWriter: GroupWriter = null
+  override def prepareForWrite(recordConsumer: RecordConsumer): Unit = {
+    groupWriter = new GroupWriter(recordConsumer, schema)
+  }
+  override def init(configuration: Configuration): WriteContext = {
+    new WriteContext(schema, new java.util.HashMap[String, String]())
+  }
+  override def write(record: Group) {
+    groupWriter.write(record)
+  }
+}
+
 private[sql] object ParquetTestData {
 
   val testSchema =
@@ -43,7 +62,7 @@ private[sql] object ParquetTestData {
   // field names for test assertion error messages
   val testSchemaFieldNames = Seq(
     "myboolean:Boolean",
-    "mtint:Int",
+    "myint:Int",
     "mystring:String",
     "mylong:Long",
     "myfloat:Float",
@@ -58,6 +77,18 @@ private[sql] object ParquetTestData {
       |}
     """.stripMargin
 
+  val testFilterSchema =
+    """
+      |message myrecord {
+      |required boolean myboolean;
+      |required int32 myint;
+      |required binary mystring;
+      |required int64 mylong;
+      |required float myfloat;
+      |required double mydouble;
+      |}
+    """.stripMargin
+
   // field names for test assertion error messages
   val subTestSchemaFieldNames = Seq(
     "myboolean:Boolean",
@@ -65,36 +96,57 @@ private[sql] object ParquetTestData {
   )
 
   val testDir = Utils.createTempDir()
+  val testFilterDir = Utils.createTempDir()
 
   lazy val testData = new ParquetRelation(testDir.toURI.toString)
 
   def writeFile() = {
     testDir.delete
     val path: Path = new Path(new Path(testDir.toURI), new Path("part-r-0.parquet"))
-    val job = new Job()
-    val configuration: Configuration = ContextUtil.getConfiguration(job)
     val schema: MessageType = MessageTypeParser.parseMessageType(testSchema)
+    val writeSupport = new TestGroupWriteSupport(schema)
+    val writer = new ParquetWriter[Group](path, writeSupport)
 
-    val writeSupport = new RowWriteSupport()
-    writeSupport.setSchema(schema, configuration)
-    val writer = new ParquetWriter(path, writeSupport)
     for(i <- 0 until 15) {
-      val data = new Array[Any](6)
+      val record = new SimpleGroup(schema)
       if (i % 3 == 0) {
-        data.update(0, true)
+        record.add(0, true)
       } else {
-        data.update(0, false)
+        record.add(0, false)
       }
       if (i % 5 == 0) {
-        data.update(1, 5)
+        record.add(1, 5)
+      }
+      record.add(2, "abc")
+      record.add(3, i.toLong << 33)
+      record.add(4, 2.5F)
+      record.add(5, 4.5D)
+      writer.write(record)
+    }
+    writer.close()
+  }
+
+  def writeFilterFile(records: Int = 200) = {
+    // for microbenchmark use: records = 300000000
+    testFilterDir.delete
+    val path: Path = new Path(new Path(testFilterDir.toURI), new Path("part-r-0.parquet"))
+    val schema: MessageType = MessageTypeParser.parseMessageType(testFilterSchema)
+    val writeSupport = new TestGroupWriteSupport(schema)
+    val writer = new ParquetWriter[Group](path, writeSupport)
+
+    for(i <- 0 to records) {
+      val record = new SimpleGroup(schema)
+      if (i % 4 == 0) {
+        record.add(0, true)
       } else {
-        data.update(1, null) // optional
+        record.add(0, false)
       }
-      data.update(2, "abc")
-      data.update(3, i.toLong << 33)
-      data.update(4, 2.5F)
-      data.update(5, 4.5D)
-      writer.write(new GenericRow(data.toArray))
+      record.add(1, i)
+      record.add(2, i.toString)
+      record.add(3, i.toLong)
+      record.add(4, i.toFloat + 0.5f)
+      record.add(5, i.toDouble + 0.5d)
+      writer.write(record)
     }
     writer.close()
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/40d6acd6/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
index ff1677e..65f4c17 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
@@ -17,25 +17,25 @@
 
 package org.apache.spark.sql.parquet
 
-import java.io.File
-
 import org.scalatest.{BeforeAndAfterAll, FunSuite}
 
 import org.apache.hadoop.fs.{Path, FileSystem}
 import org.apache.hadoop.mapreduce.Job
 
 import parquet.hadoop.ParquetFileWriter
-import parquet.schema.MessageTypeParser
 import parquet.hadoop.util.ContextUtil
+import parquet.schema.MessageTypeParser
 
 import org.apache.spark.sql._
 import org.apache.spark.sql.catalyst.util.getTempFilePath
-import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Row}
+import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.test.TestSQLContext
 import org.apache.spark.sql.TestData
+import org.apache.spark.sql.SchemaRDD
+import org.apache.spark.sql.catalyst.expressions.Row
+import org.apache.spark.sql.catalyst.expressions.Equals
+import org.apache.spark.sql.catalyst.types.IntegerType
 import org.apache.spark.util.Utils
-import org.apache.spark.sql.catalyst.types.{StringType, IntegerType, DataType}
-import org.apache.spark.sql.{parquet, SchemaRDD}
 
 // Implicits
 import org.apache.spark.sql.test.TestSQLContext._
@@ -64,12 +64,16 @@ class ParquetQuerySuite extends QueryTest with FunSuite with BeforeAndAfterAll {
 
   override def beforeAll() {
     ParquetTestData.writeFile()
+    ParquetTestData.writeFilterFile()
     testRDD = parquetFile(ParquetTestData.testDir.toString)
     testRDD.registerAsTable("testsource")
+    parquetFile(ParquetTestData.testFilterDir.toString)
+      .registerAsTable("testfiltersource")
   }
 
   override def afterAll() {
     Utils.deleteRecursively(ParquetTestData.testDir)
+    Utils.deleteRecursively(ParquetTestData.testFilterDir)
     // here we should also unregister the table??
   }
 
@@ -120,7 +124,7 @@ class ParquetQuerySuite extends QueryTest with FunSuite with BeforeAndAfterAll {
     val scanner = new ParquetTableScan(
       ParquetTestData.testData.output,
       ParquetTestData.testData,
-      None)(TestSQLContext.sparkContext)
+      Seq())(TestSQLContext.sparkContext)
     val projected = scanner.pruneColumns(ParquetTypesConverter
       .convertToAttributes(MessageTypeParser
       .parseMessageType(ParquetTestData.subTestSchema)))
@@ -196,7 +200,6 @@ class ParquetQuerySuite extends QueryTest with FunSuite with BeforeAndAfterAll {
     assert(true)
   }
 
-
   test("insert (appending) to same table via Scala API") {
     sql("INSERT INTO testsource SELECT * FROM testsource").collect()
     val double_rdd = sql("SELECT * FROM testsource").collect()
@@ -239,5 +242,121 @@ class ParquetQuerySuite extends QueryTest with FunSuite with BeforeAndAfterAll {
     Utils.deleteRecursively(file)
     assert(true)
   }
+
+  test("create RecordFilter for simple predicates") {
+    val attribute1 = new AttributeReference("first", IntegerType, false)()
+    val predicate1 = new Equals(attribute1, new Literal(1, IntegerType))
+    val filter1 = ParquetFilters.createFilter(predicate1)
+    assert(filter1.isDefined)
+    assert(filter1.get.predicate == predicate1, "predicates do not match")
+    assert(filter1.get.isInstanceOf[ComparisonFilter])
+    val cmpFilter1 = filter1.get.asInstanceOf[ComparisonFilter]
+    assert(cmpFilter1.columnName == "first", "column name incorrect")
+
+    val predicate2 = new LessThan(attribute1, new Literal(4, IntegerType))
+    val filter2 = ParquetFilters.createFilter(predicate2)
+    assert(filter2.isDefined)
+    assert(filter2.get.predicate == predicate2, "predicates do not match")
+    assert(filter2.get.isInstanceOf[ComparisonFilter])
+    val cmpFilter2 = filter2.get.asInstanceOf[ComparisonFilter]
+    assert(cmpFilter2.columnName == "first", "column name incorrect")
+
+    val predicate3 = new And(predicate1, predicate2)
+    val filter3 = ParquetFilters.createFilter(predicate3)
+    assert(filter3.isDefined)
+    assert(filter3.get.predicate == predicate3, "predicates do not match")
+    assert(filter3.get.isInstanceOf[AndFilter])
+
+    val predicate4 = new Or(predicate1, predicate2)
+    val filter4 = ParquetFilters.createFilter(predicate4)
+    assert(filter4.isDefined)
+    assert(filter4.get.predicate == predicate4, "predicates do not match")
+    assert(filter4.get.isInstanceOf[OrFilter])
+
+    val attribute2 = new AttributeReference("second", IntegerType, false)()
+    val predicate5 = new GreaterThan(attribute1, attribute2)
+    val badfilter = ParquetFilters.createFilter(predicate5)
+    assert(badfilter.isDefined === false)
+  }
+
+  test("test filter by predicate pushdown") {
+    for(myval <- Seq("myint", "mylong", "mydouble", "myfloat")) {
+      println(s"testing field $myval")
+      val query1 = sql(s"SELECT * FROM testfiltersource WHERE $myval < 150 AND $myval >= 100")
+      assert(
+        query1.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan],
+        "Top operator should be ParquetTableScan after pushdown")
+      val result1 = query1.collect()
+      assert(result1.size === 50)
+      assert(result1(0)(1) === 100)
+      assert(result1(49)(1) === 149)
+      val query2 = sql(s"SELECT * FROM testfiltersource WHERE $myval > 150 AND $myval <= 200")
+      assert(
+        query2.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan],
+        "Top operator should be ParquetTableScan after pushdown")
+      val result2 = query2.collect()
+      assert(result2.size === 50)
+      if (myval == "myint" || myval == "mylong") {
+        assert(result2(0)(1) === 151)
+        assert(result2(49)(1) === 200)
+      } else {
+        assert(result2(0)(1) === 150)
+        assert(result2(49)(1) === 199)
+      }
+    }
+    for(myval <- Seq("myint", "mylong")) {
+      val query3 = sql(s"SELECT * FROM testfiltersource WHERE $myval > 190 OR $myval < 10")
+      assert(
+        query3.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan],
+        "Top operator should be ParquetTableScan after pushdown")
+      val result3 = query3.collect()
+      assert(result3.size === 20)
+      assert(result3(0)(1) === 0)
+      assert(result3(9)(1) === 9)
+      assert(result3(10)(1) === 191)
+      assert(result3(19)(1) === 200)
+    }
+    for(myval <- Seq("mydouble", "myfloat")) {
+      val result4 =
+        if (myval == "mydouble") {
+          val query4 = sql(s"SELECT * FROM testfiltersource WHERE $myval > 190.5 OR $myval < 10.0")
+          assert(
+            query4.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan],
+            "Top operator should be ParquetTableScan after pushdown")
+          query4.collect()
+        } else {
+          // CASTs are problematic. Here myfloat will be casted to a double and it seems there is
+          // currently no way to specify float constants in SqlParser?
+          sql(s"SELECT * FROM testfiltersource WHERE $myval > 190.5 OR $myval < 10").collect()
+        }
+      assert(result4.size === 20)
+      assert(result4(0)(1) === 0)
+      assert(result4(9)(1) === 9)
+      assert(result4(10)(1) === 191)
+      assert(result4(19)(1) === 200)
+    }
+    val query5 = sql(s"SELECT * FROM testfiltersource WHERE myboolean = true AND myint < 40")
+    assert(
+      query5.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan],
+      "Top operator should be ParquetTableScan after pushdown")
+    val booleanResult = query5.collect()
+    assert(booleanResult.size === 10)
+    for(i <- 0 until 10) {
+      if (!booleanResult(i).getBoolean(0)) {
+        fail(s"Boolean value in result row $i not true")
+      }
+      if (booleanResult(i).getInt(1) != i * 4) {
+        fail(s"Int value in result row $i should be ${4*i}")
+      }
+    }
+    val query6 = sql("SELECT * FROM testfiltersource WHERE mystring = \"100\"")
+    assert(
+      query6.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan],
+      "Top operator should be ParquetTableScan after pushdown")
+    val stringResult = query6.collect()
+    assert(stringResult.size === 1)
+    assert(stringResult(0).getString(2) == "100", "stringvalue incorrect")
+    assert(stringResult(0).getInt(1) === 100)
+  }
 }