You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@carbondata.apache.org by qi...@apache.org on 2020/07/21 02:32:28 UTC

[carbondata] branch master updated: [CARBONDATA-3849] Push down array_contains filter for array column

This is an automated email from the ASF dual-hosted git repository.

qiangcai pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/carbondata.git


The following commit(s) were added to refs/heads/master by this push:
     new d8a5b7b  [CARBONDATA-3849] Push down array_contains filter for array<primitive type> column
d8a5b7b is described below

commit d8a5b7b6e27330d6efe345d1001b55b1cd6b5a45
Author: ajantha-bhat <aj...@gmail.com>
AuthorDate: Thu May 21 22:39:12 2020 +0530

    [CARBONDATA-3849] Push down array_contains filter for array<primitive type> column
    
    Why is this PR needed?
    Currently array_contains() UDF is not pushed down to carbon. So, carbon has to scan all the rows for query having this UDF. Scanning all the rows reduces the query performance. Hence need to pushdown.
    
    What changes were proposed in this PR?
    Push down array_contains() for all the array of primitive type as equalsTo filter. Having as equals to filter, we can break the scanning of elements in array once found.
    
    Does this PR introduce any user interface change?
    No
    
    Is any new testcase added?
    Yes
    
    This closes  #3771
---
 .../core/datastore/block/SegmentProperties.java    |   3 +
 .../core/scan/complextypes/ArrayQueryType.java     |  27 +++
 .../core/scan/complextypes/ComplexQueryType.java   |   3 +-
 .../executer/RowLevelFilterExecutorImpl.java       | 132 +++++++---
 .../sql/execution/CastExpressionOptimization.scala |  10 +
 .../strategy/CarbonLateDecodeStrategy.scala        |  34 ++-
 .../apache/spark/sql/optimizer/CarbonFilters.scala |  16 +-
 .../scala/org/apache/spark/util/SparkUtil.scala    |  18 ++
 .../complexType/TestArrayContainsPushDown.scala    | 267 +++++++++++++++++++++
 9 files changed, 464 insertions(+), 46 deletions(-)

diff --git a/core/src/main/java/org/apache/carbondata/core/datastore/block/SegmentProperties.java b/core/src/main/java/org/apache/carbondata/core/datastore/block/SegmentProperties.java
index 1d291d2..fe28a37 100644
--- a/core/src/main/java/org/apache/carbondata/core/datastore/block/SegmentProperties.java
+++ b/core/src/main/java/org/apache/carbondata/core/datastore/block/SegmentProperties.java
@@ -466,6 +466,9 @@ public class SegmentProperties {
    * @return
    */
   public CarbonDimension getDimensionFromCurrentBlock(CarbonDimension queryDimension) {
+    if (queryDimension.isComplex()) {
+      return CarbonUtil.getDimensionFromCurrentBlock(this.complexDimensions, queryDimension);
+    }
     return CarbonUtil.getDimensionFromCurrentBlock(this.dimensions, queryDimension);
   }
 
diff --git a/core/src/main/java/org/apache/carbondata/core/scan/complextypes/ArrayQueryType.java b/core/src/main/java/org/apache/carbondata/core/scan/complextypes/ArrayQueryType.java
index 71d3b1e..8a41384 100644
--- a/core/src/main/java/org/apache/carbondata/core/scan/complextypes/ArrayQueryType.java
+++ b/core/src/main/java/org/apache/carbondata/core/scan/complextypes/ArrayQueryType.java
@@ -128,4 +128,31 @@ public class ArrayQueryType extends ComplexQueryType implements GenericQueryType
     throw new UnsupportedOperationException("Operation Unsupported for ArrayType");
   }
 
+  public int[][] getNumberOfChild(DimensionRawColumnChunk[] rawColumnChunks,
+      DimensionColumnPage[][] dimensionColumnPages, int numberOfRows, int pageNumber) {
+    DimensionColumnPage page =
+        getDecodedDimensionPage(dimensionColumnPages, rawColumnChunks[columnIndex], pageNumber);
+    int[][] numberOfChild = new int[numberOfRows][2];
+    for (int i = 0; i < numberOfRows; i++) {
+      byte[] input = page.getChunkData(i);
+      ByteBuffer wrap = ByteBuffer.wrap(input);
+      int[] metadata = new int[2];
+      metadata[0] = wrap.getInt();
+      if (metadata[0] > 0) {
+        metadata[1] = wrap.getInt();
+      }
+      numberOfChild[i] = metadata;
+    }
+    return numberOfChild;
+  }
+
+  public DimensionColumnPage parseBlockAndReturnChildData(DimensionRawColumnChunk[] rawColumnChunks,
+      DimensionColumnPage[][] dimensionColumnPages, int pageNumber) {
+    PrimitiveQueryType queryType = (PrimitiveQueryType) children;
+    return queryType.getDecodedDimensionPage(
+        dimensionColumnPages,
+        rawColumnChunks[queryType.columnIndex],
+        pageNumber);
+  }
+
 }
diff --git a/core/src/main/java/org/apache/carbondata/core/scan/complextypes/ComplexQueryType.java b/core/src/main/java/org/apache/carbondata/core/scan/complextypes/ComplexQueryType.java
index d0cc8ae..2288ae2 100644
--- a/core/src/main/java/org/apache/carbondata/core/scan/complextypes/ComplexQueryType.java
+++ b/core/src/main/java/org/apache/carbondata/core/scan/complextypes/ComplexQueryType.java
@@ -60,11 +60,12 @@ public class ComplexQueryType {
     }
   }
 
-  private DimensionColumnPage getDecodedDimensionPage(DimensionColumnPage[][] dimensionColumnPages,
+  public DimensionColumnPage getDecodedDimensionPage(DimensionColumnPage[][] dimensionColumnPages,
       DimensionRawColumnChunk dimensionRawColumnChunk, int pageNumber) {
     if (dimensionColumnPages == null || null == dimensionColumnPages[columnIndex]) {
       return dimensionRawColumnChunk.decodeColumnPage(pageNumber);
     }
     return dimensionColumnPages[columnIndex][pageNumber];
   }
+
 }
diff --git a/core/src/main/java/org/apache/carbondata/core/scan/filter/executer/RowLevelFilterExecutorImpl.java b/core/src/main/java/org/apache/carbondata/core/scan/filter/executer/RowLevelFilterExecutorImpl.java
index 8a8a841..a5b3a42 100644
--- a/core/src/main/java/org/apache/carbondata/core/scan/filter/executer/RowLevelFilterExecutorImpl.java
+++ b/core/src/main/java/org/apache/carbondata/core/scan/filter/executer/RowLevelFilterExecutorImpl.java
@@ -36,14 +36,19 @@ import org.apache.carbondata.core.datastore.chunk.DimensionColumnPage;
 import org.apache.carbondata.core.datastore.chunk.impl.VariableLengthDimensionColumnPage;
 import org.apache.carbondata.core.datastore.chunk.store.ColumnPageWrapper;
 import org.apache.carbondata.core.datastore.page.ColumnPage;
+import org.apache.carbondata.core.keygenerator.directdictionary.timestamp.DateDirectDictionaryGenerator;
+import org.apache.carbondata.core.keygenerator.directdictionary.timestamp.TimeStampGranularityTypeValue;
 import org.apache.carbondata.core.metadata.AbsoluteTableIdentifier;
 import org.apache.carbondata.core.metadata.datatype.DataType;
 import org.apache.carbondata.core.metadata.datatype.DataTypes;
 import org.apache.carbondata.core.metadata.schema.table.column.CarbonDimension;
 import org.apache.carbondata.core.metadata.schema.table.column.CarbonMeasure;
+import org.apache.carbondata.core.scan.complextypes.ArrayQueryType;
 import org.apache.carbondata.core.scan.executor.util.RestructureUtil;
 import org.apache.carbondata.core.scan.expression.Expression;
+import org.apache.carbondata.core.scan.expression.LiteralExpression;
 import org.apache.carbondata.core.scan.expression.MatchExpression;
+import org.apache.carbondata.core.scan.expression.conditional.EqualToExpression;
 import org.apache.carbondata.core.scan.expression.exception.FilterIllegalMemberException;
 import org.apache.carbondata.core.scan.expression.exception.FilterUnsupportedException;
 import org.apache.carbondata.core.scan.filter.FilterUtil;
@@ -54,6 +59,7 @@ import org.apache.carbondata.core.scan.filter.resolver.resolverinfo.DimColumnRes
 import org.apache.carbondata.core.scan.filter.resolver.resolverinfo.MeasureColumnResolvedFilterInfo;
 import org.apache.carbondata.core.scan.processor.RawBlockletColumnChunks;
 import org.apache.carbondata.core.util.BitSetGroup;
+import org.apache.carbondata.core.util.ByteUtil;
 import org.apache.carbondata.core.util.DataTypeUtil;
 
 import org.apache.log4j.Logger;
@@ -222,49 +228,103 @@ public class RowLevelFilterExecutorImpl implements FilterExecutor {
       }
     }
     BitSetGroup bitSetGroup = new BitSetGroup(pageNumbers);
-    for (int i = 0; i < pageNumbers; i++) {
-      BitSet set = new BitSet(numberOfRows[i]);
-      RowIntf row = new RowImpl();
-      BitSet prvBitset = null;
-      // if bitset pipe line is enabled then use row id from previous bitset
-      // otherwise use older flow
-      if (!useBitsetPipeLine ||
-          null == rawBlockletColumnChunks.getBitSetGroup() ||
-          null == bitSetGroup.getBitSet(i) ||
-          rawBlockletColumnChunks.getBitSetGroup().getBitSet(i).isEmpty()) {
+    if (isDimensionPresentInCurrentBlock.length == 1 && isDimensionPresentInCurrentBlock[0]
+        && dimColEvaluatorInfoList.get(0).getDimension().getDataType().isComplexType()
+        && exp instanceof EqualToExpression) {
+      LiteralExpression literalExp = (LiteralExpression) (((EqualToExpression) exp).getRight());
+      // convert filter value to byte[] to compare with byte[] data from columnPage
+      Object literalExpValue = literalExp.getLiteralExpValue();
+      DataType literalExpDataType = literalExp.getLiteralExpDataType();
+      if (literalExpDataType == DataTypes.TIMESTAMP) {
+        if ((long) literalExpValue == 0) {
+          literalExpValue = null;
+        } else {
+          literalExpValue =
+              (long) literalExpValue / TimeStampGranularityTypeValue.MILLIS_SECONDS.getValue();
+        }
+      } else if (literalExpDataType == DataTypes.DATE) {
+        // change data type to int to get the byte[] filter value as it is direct dictionary
+        literalExpDataType = DataTypes.INT;
+        if (literalExpValue == null) {
+          literalExpValue = CarbonCommonConstants.DIRECT_DICT_VALUE_NULL;
+        } else {
+          literalExpValue =
+              (int) literalExpValue + DateDirectDictionaryGenerator.cutOffDate;
+        }
+      }
+      byte[] filterValueInBytes = DataTypeUtil.getBytesDataDataTypeForNoDictionaryColumn(
+          literalExpValue,
+          literalExpDataType);
+      ArrayQueryType complexType =
+          (ArrayQueryType) complexDimensionInfoMap.get(dimensionChunkIndex[0]);
+      // check all the pages
+      for (int i = 0; i < pageNumbers; i++) {
+        BitSet set = new BitSet(numberOfRows[i]);
+        int[][] numberOfChild = complexType
+            .getNumberOfChild(rawBlockletColumnChunks.getDimensionRawColumnChunks(), null,
+                numberOfRows[i], i);
+        DimensionColumnPage page = complexType
+            .parseBlockAndReturnChildData(rawBlockletColumnChunks.getDimensionRawColumnChunks(),
+                null, i);
+        // check every row
         for (int index = 0; index < numberOfRows[i]; index++) {
-          createRow(rawBlockletColumnChunks, row, i, index);
-          Boolean result = false;
-          try {
-            result = exp.evaluate(row).getBoolean();
-          }
-          // Any invalid member while evaluation shall be ignored, system will log the
-          // error only once since all rows the evaluation happens so inorder to avoid
-          // too much log inforation only once the log will be printed.
-          catch (FilterIllegalMemberException e) {
-            FilterUtil.logError(e, false);
-          }
-          if (null != result && result) {
-            set.set(index);
+          int dataOffset = numberOfChild[index][1];
+          // loop the children
+          for (int j = 0; j < numberOfChild[index][0]; j++) {
+            byte[] obj = page.getChunkData(dataOffset++);
+            if (ByteUtil.UnsafeComparer.INSTANCE.compareTo(obj, filterValueInBytes) == 0) {
+              set.set(index);
+              break;
+            }
           }
         }
-      } else {
-        prvBitset = rawBlockletColumnChunks.getBitSetGroup().getBitSet(i);
-        for (int index = prvBitset.nextSetBit(0);
-             index >= 0; index = prvBitset.nextSetBit(index + 1)) {
-          createRow(rawBlockletColumnChunks, row, i, index);
-          Boolean result = false;
-          try {
-            result = exp.evaluate(row).getBoolean();
-          } catch (FilterIllegalMemberException e) {
-            FilterUtil.logError(e, false);
+        bitSetGroup.setBitSet(set, i);
+      }
+    } else {
+      for (int i = 0; i < pageNumbers; i++) {
+        BitSet set = new BitSet(numberOfRows[i]);
+        RowIntf row = new RowImpl();
+        BitSet prvBitset = null;
+        // if bitset pipe line is enabled then use row id from previous bitset
+        // otherwise use older flow
+        if (!useBitsetPipeLine ||
+            null == rawBlockletColumnChunks.getBitSetGroup() ||
+            null == bitSetGroup.getBitSet(i) ||
+            rawBlockletColumnChunks.getBitSetGroup().getBitSet(i).isEmpty()) {
+          for (int index = 0; index < numberOfRows[i]; index++) {
+            createRow(rawBlockletColumnChunks, row, i, index);
+            Boolean result = false;
+            try {
+              result = exp.evaluate(row).getBoolean();
+            }
+            // Any invalid member while evaluation shall be ignored, system will log the
+            // error only once since all rows the evaluation happens so inorder to avoid
+            // too much log inforation only once the log will be printed.
+            catch (FilterIllegalMemberException e) {
+              FilterUtil.logError(e, false);
+            }
+            if (null != result && result) {
+              set.set(index);
+            }
           }
-          if (null != result && result) {
-            set.set(index);
+        } else {
+          prvBitset = rawBlockletColumnChunks.getBitSetGroup().getBitSet(i);
+          for (int index = prvBitset.nextSetBit(0);
+               index >= 0; index = prvBitset.nextSetBit(index + 1)) {
+            createRow(rawBlockletColumnChunks, row, i, index);
+            Boolean rslt = false;
+            try {
+              rslt = exp.evaluate(row).getBoolean();
+            } catch (FilterIllegalMemberException e) {
+              FilterUtil.logError(e, false);
+            }
+            if (null != rslt && rslt) {
+              set.set(index);
+            }
           }
         }
+        bitSetGroup.setBitSet(set, i);
       }
-      bitSetGroup.setBitSet(set, i);
     }
     return bitSetGroup;
   }
diff --git a/integration/spark/src/main/scala/org/apache/spark/sql/execution/CastExpressionOptimization.scala b/integration/spark/src/main/scala/org/apache/spark/sql/execution/CastExpressionOptimization.scala
index 57fb3f0..90558ac 100644
--- a/integration/spark/src/main/scala/org/apache/spark/sql/execution/CastExpressionOptimization.scala
+++ b/integration/spark/src/main/scala/org/apache/spark/sql/execution/CastExpressionOptimization.scala
@@ -162,6 +162,16 @@ object CastExpressionOptimization {
             updateFilterForInt(v, c)
           case s: ShortType if t.sameType(IntegerType) =>
             updateFilterForShort(v, c)
+          case arr: ArrayType =>
+            arr.elementType match {
+              case ts@(_: DateType | _: TimestampType) if t.sameType(StringType) =>
+                updateFilterForTimeStamp(v, c, ts)
+              case i: IntegerType if t.sameType(DoubleType) =>
+                updateFilterForInt(v, c)
+              case s: ShortType if t.sameType(IntegerType) =>
+                updateFilterForShort(v, c)
+              case _ => Some(CastExpr(c))
+            }
           case _ => Some(CastExpr(c))
         }
       case c@EqualTo(Literal(v, t), Cast(a: Attribute, _)) =>
diff --git a/integration/spark/src/main/scala/org/apache/spark/sql/execution/strategy/CarbonLateDecodeStrategy.scala b/integration/spark/src/main/scala/org/apache/spark/sql/execution/strategy/CarbonLateDecodeStrategy.scala
index d05d8f9..56547ef 100644
--- a/integration/spark/src/main/scala/org/apache/spark/sql/execution/strategy/CarbonLateDecodeStrategy.scala
+++ b/integration/spark/src/main/scala/org/apache/spark/sql/execution/strategy/CarbonLateDecodeStrategy.scala
@@ -42,7 +42,7 @@ import org.apache.spark.sql.optimizer.CarbonFilters
 import org.apache.spark.sql.secondaryindex.joins.BroadCastSIFilterPushJoin
 import org.apache.spark.sql.sources.{BaseRelation, Filter}
 import org.apache.spark.sql.types._
-import org.apache.spark.util.CarbonReflectionUtils
+import org.apache.spark.util.{CarbonReflectionUtils, SparkUtil}
 
 import org.apache.carbondata.common.exceptions.sql.MalformedCarbonCommandException
 import org.apache.carbondata.common.logging.LogServiceFactory
@@ -517,7 +517,8 @@ private[sql] class CarbonLateDecodeStrategy extends SparkStrategy {
       val supportBatch =
         supportBatchedDataSource(relation.relation.sqlContext,
           updateRequestedColumns) && extraRdd.getOrElse((null, true))._2
-      if (!vectorPushRowFilters && !supportBatch && !implicitExisted) {
+      if (!vectorPushRowFilters && !supportBatch && !implicitExisted && filterSet.nonEmpty &&
+          !filterSet.baseSet.exists(_.a.dataType.isInstanceOf[ArrayType])) {
         // revert for row scan
         updateRequestedColumns = requestedColumns
       }
@@ -679,10 +680,13 @@ private[sql] class CarbonLateDecodeStrategy extends SparkStrategy {
     // In case of ComplexType dataTypes no filters should be pushed down. IsNotNull is being
     // explicitly added by spark and pushed. That also has to be handled and pushed back to
     // Spark for handling.
-    val predicatesWithoutComplex = predicates.filter(predicate =>
+    // allow array_contains() push down
+    val filteredPredicates = predicates.filter { predicate =>
+      predicate.isInstanceOf[ArrayContains] ||
       predicate.collect {
-      case a: Attribute if isComplexAttribute(a) => a
-    }.size == 0 )
+        case a: Attribute if isComplexAttribute(a) => a
+      }.isEmpty
+    }
 
     // For conciseness, all Catalyst filter expressions of type `expressions.Expression` below are
     // called `predicate`s, while all data source filters of type `sources.Filter` are simply called
@@ -690,7 +694,7 @@ private[sql] class CarbonLateDecodeStrategy extends SparkStrategy {
     // Todo: handle when lucene and normal query filter is supported
 
     var count = 0
-    val translated: Seq[(Expression, Filter)] = predicatesWithoutComplex.flatMap {
+    val translated: Seq[(Expression, Filter)] = filteredPredicates.flatMap {
       predicate =>
         if (predicate.isInstanceOf[ScalaUDF]) {
           predicate match {
@@ -865,7 +869,23 @@ private[sql] class CarbonLateDecodeStrategy extends SparkStrategy {
         Some(CarbonContainsWith(c))
       case c@Literal(v, t) if (v == null) =>
         Some(FalseExpr())
-      case others => None
+      case c@ArrayContains(a: Attribute, Literal(v, t)) =>
+        a.dataType match {
+          case arrayType: ArrayType =>
+            if (SparkUtil.isPrimitiveType(arrayType.elementType)) {
+              Some(sources.EqualTo(a.name, v))
+            } else {
+              None
+            }
+          case _ =>
+            None
+        }
+      case c@ArrayContains(Cast(a: Attribute, _), Literal(v, t)) =>
+        CastExpressionOptimization.checkIfCastCanBeRemove(
+          EqualTo(
+            predicate.asInstanceOf[ArrayContains].left,
+            predicate.asInstanceOf[ArrayContains].right))
+      case _ => None
     }
   }
 
diff --git a/integration/spark/src/main/scala/org/apache/spark/sql/optimizer/CarbonFilters.scala b/integration/spark/src/main/scala/org/apache/spark/sql/optimizer/CarbonFilters.scala
index 4929a4c..ff3a583 100644
--- a/integration/spark/src/main/scala/org/apache/spark/sql/optimizer/CarbonFilters.scala
+++ b/integration/spark/src/main/scala/org/apache/spark/sql/optimizer/CarbonFilters.scala
@@ -152,13 +152,25 @@ object CarbonFilters {
     }
 
     def getCarbonExpression(name: String) = {
+      var sparkDatatype = dataTypeOf(name)
+      sparkDatatype match {
+        case arrayType: ArrayType =>
+          sparkDatatype = arrayType.elementType
+        case _ =>
+      }
       new CarbonColumnExpression(name,
-        CarbonSparkDataSourceUtil.convertSparkToCarbonDataType(dataTypeOf(name)))
+        CarbonSparkDataSourceUtil.convertSparkToCarbonDataType(sparkDatatype))
     }
 
     def getCarbonLiteralExpression(name: String, value: Any): CarbonExpression = {
+      var sparkDatatype = dataTypeOf(name)
+      sparkDatatype match {
+        case arrayType: ArrayType =>
+          sparkDatatype = arrayType.elementType
+        case _ =>
+      }
       val dataTypeOfAttribute =
-        CarbonSparkDataSourceUtil.convertSparkToCarbonDataType(dataTypeOf(name))
+        CarbonSparkDataSourceUtil.convertSparkToCarbonDataType(sparkDatatype)
       val dataType = if (Option(value).isDefined
                          && dataTypeOfAttribute == CarbonDataTypes.STRING
                          && value.isInstanceOf[Double]) {
diff --git a/integration/spark/src/main/scala/org/apache/spark/util/SparkUtil.scala b/integration/spark/src/main/scala/org/apache/spark/util/SparkUtil.scala
index ba782e5..1111287 100644
--- a/integration/spark/src/main/scala/org/apache/spark/util/SparkUtil.scala
+++ b/integration/spark/src/main/scala/org/apache/spark/util/SparkUtil.scala
@@ -20,6 +20,7 @@ package org.apache.spark.util
 import org.apache.spark.{SPARK_VERSION, TaskContext}
 import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.execution.SQLExecution.EXECUTION_ID_KEY
+import org.apache.spark.sql.types.{BinaryType, BooleanType, ByteType, DataType, DateType, DecimalType, DoubleType, FloatType, IntegerType, LongType, ShortType, StringType, TimestampType}
 
 /*
  * this object use to handle file splits
@@ -67,4 +68,21 @@ object SparkUtil {
     }
   }
 
+  def isPrimitiveType(datatype : DataType): Boolean = {
+    datatype match {
+      case StringType => true
+      case ByteType => true
+      case ShortType => true
+      case IntegerType => true
+      case LongType => true
+      case FloatType => true
+      case DoubleType => true
+      case BinaryType => true
+      case BooleanType => true
+      case DateType => true
+      case TimestampType => true
+      case DecimalType() => true
+      case _ => false
+    }
+  }
 }
diff --git a/integration/spark/src/test/scala/org/apache/carbondata/integration/spark/testsuite/complexType/TestArrayContainsPushDown.scala b/integration/spark/src/test/scala/org/apache/carbondata/integration/spark/testsuite/complexType/TestArrayContainsPushDown.scala
new file mode 100644
index 0000000..ae2a062
--- /dev/null
+++ b/integration/spark/src/test/scala/org/apache/carbondata/integration/spark/testsuite/complexType/TestArrayContainsPushDown.scala
@@ -0,0 +1,267 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.carbondata.integration.spark.testsuite.complexType
+
+import java.sql.{Date, Timestamp}
+
+import scala.collection.mutable
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.test.util.QueryTest
+import org.scalatest.BeforeAndAfterAll
+
+import org.apache.carbondata.core.constants.CarbonCommonConstants
+import org.apache.carbondata.core.util.CarbonProperties
+
+class TestArrayContainsPushDown extends QueryTest with BeforeAndAfterAll {
+
+  override protected def afterAll(): Unit = {
+    CarbonProperties.getInstance()
+      .addProperty(CarbonCommonConstants.CARBON_TIMESTAMP_FORMAT,
+        CarbonCommonConstants.CARBON_TIMESTAMP_DEFAULT_FORMAT)
+    sql("DROP TABLE IF EXISTS compactComplex")
+  }
+
+  test("test array contains pushdown for array of string") {
+    sql("drop table if exists complex1")
+    sql("create table complex1 (arr array<String>) stored as carbondata")
+    sql("insert into complex1 select array('as') union all " +
+        "select array('sd','df','gh') union all " +
+        "select array('rt','ew','rtyu','jk',null) union all " +
+        "select array('ghsf','dbv','','ty') union all " +
+        "select array('hjsd','fggb','nhj','sd','asd')")
+
+    checkExistence(sql(" explain select * from complex1 where array_contains(arr,'sd')"),
+      true,
+      "PushedFilters: [*EqualTo(arr,sd)]")
+
+    checkExistence(sql(" explain select count(*) from complex1 where array_contains(arr,'sd')"),
+      true,
+      "PushedFilters: [*EqualTo(arr,sd)]")
+
+    checkAnswer(sql(" select * from complex1 where array_contains(arr,'sd')"),
+      Seq(Row(mutable.WrappedArray.make(Array("sd", "df", "gh"))),
+        Row(mutable.WrappedArray.make(Array("hjsd", "fggb", "nhj", "sd", "asd")))))
+
+    checkAnswer(sql(" select count(*) from complex1 where array_contains(arr,'sd')"), Seq(Row(2)))
+    // test for empty
+    checkAnswer(sql(" select * from complex1 where array_contains(arr,'')"),
+      Seq(Row(mutable.WrappedArray.make(Array("ghsf", "dbv","","ty")))))
+
+    sql("drop table complex1")
+  }
+
+  test("test array contains pushdown for array of boolean") {
+    sql("drop table if exists complex1")
+    sql("create table complex1 (arr array<boolean>) stored as carbondata")
+    sql("insert into complex1 select array(true) union all " +
+        "select array(false, null, false) union all " +
+        "select array(false, false, true, false) union all " +
+        "select array(true, true, true, false) union all " +
+        "select array(false)")
+
+    checkExistence(sql(" explain select * from complex1 where array_contains(arr,true)"),
+      true,
+      "PushedFilters: [*EqualTo(arr,true)]")
+
+    checkExistence(sql(" explain select count(*) from complex1 where array_contains(arr,true)"),
+      true,
+      "PushedFilters: [*EqualTo(arr,true)]")
+
+    checkAnswer(sql(" select * from complex1 where array_contains(arr,true)"),
+      Seq(Row(mutable.WrappedArray.make(Array(true))),
+        Row(mutable.WrappedArray.make(Array(false, false, true, false))),
+          Row(mutable.WrappedArray.make(Array(true, true, true, false)))))
+
+    checkAnswer(sql(" select count(*) from complex1 where array_contains(arr,true)"), Seq(Row(3)))
+    sql("drop table complex1")
+  }
+
+  test("test array contains pushdown for array of short") {
+    sql("drop table if exists complex1")
+    sql("create table complex1 (arr array<short>) stored as carbondata")
+    sql("insert into complex1 select array(12) union all " +
+        "select array(20, 30, 31000) union all " +
+        "select array(11, 12, 13, 14, 15, 16, -31000) union all " +
+        "select array(20, 31000, 60, 80) union all " +
+        "select array(41, 41, -41, -42)")
+
+    checkExistence(sql(" explain select * from complex1 where array_contains(arr,31000)"),
+      true,
+      "PushedFilters: [*EqualTo(arr,31000)]")
+
+    checkExistence(sql(" explain select count(*) from complex1 where array_contains(arr,31000)"),
+      true,
+      "PushedFilters: [*EqualTo(arr,31000)]")
+
+    checkAnswer(sql(" select * from complex1 where array_contains(arr,31000)"),
+      Seq(Row(mutable.WrappedArray.make(Array(20, 30, 31000))),
+        Row(mutable.WrappedArray.make(Array(20, 31000, 60, 80)))))
+
+    checkAnswer(sql(" select count(*) from complex1 where array_contains(arr,31000)"), Seq(Row(2)))
+
+    sql("drop table complex1")
+  }
+
+  test("test array contains pushdown for array of int") {
+    sql("drop table if exists complex1")
+    sql("create table complex1 (arr array<int>) stored as carbondata")
+    sql("insert into complex1 select array(12) union all " +
+        "select array(20, 30, 33000) union all " +
+        "select array(11, 12, 13, 14, 15, 16, -33000) union all " +
+        "select array(20, 33000, 60, 80) union all " +
+        "select array(41, 41, -41, -42)")
+
+    checkExistence(sql(" explain select * from complex1 where array_contains(arr,33000)"),
+      true,
+      "PushedFilters: [*EqualTo(arr,33000)]")
+
+    checkExistence(sql(" explain select count(*) from complex1 where array_contains(arr,33000)"),
+      true,
+      "PushedFilters: [*EqualTo(arr,33000)]")
+
+    checkAnswer(sql(" select * from complex1 where array_contains(arr,33000)"),
+      Seq(Row(mutable.WrappedArray.make(Array(20, 30, 33000))),
+        Row(mutable.WrappedArray.make(Array(20, 33000, 60, 80)))))
+
+    checkAnswer(sql(" select count(*) from complex1 where array_contains(arr,33000)"), Seq(Row(2)))
+
+    sql("drop table complex1")
+  }
+
+  test("test array contains pushdown for array of long") {
+    sql("drop table if exists complex1")
+    sql("create table complex1 (arr array<long>) stored as carbondata")
+    sql("insert into complex1 select array(12) union all " +
+        "select array(20, 30, 33000) union all " +
+        "select array(11, 12, 13, 14, 15, 16, -33000) union all " +
+        "select array(20, 33000, 60, 80) union all " +
+        "select array(41, 41, -41, -42)")
+
+    checkExistence(sql(" explain select * from complex1 where array_contains(arr,33000)"),
+      true,
+      "PushedFilters: [*EqualTo(arr,33000)]")
+
+    checkExistence(sql(" explain select count(*) from complex1 where array_contains(arr,33000)"),
+      true,
+      "PushedFilters: [*EqualTo(arr,33000)]")
+
+    checkAnswer(sql(" select * from complex1 where array_contains(arr,33000)"),
+      Seq(Row(mutable.WrappedArray.make(Array(20, 30, 33000))),
+        Row(mutable.WrappedArray.make(Array(20, 33000, 60, 80)))))
+
+    checkAnswer(sql(" select count(*) from complex1 where array_contains(arr,33000)"), Seq(Row(2)))
+    sql("drop table complex1")
+  }
+
+  test("test array contains pushdown for array of double") {
+    sql("drop table if exists complex1")
+    sql("create table complex1 (arr array<double>) stored as carbondata")
+    sql("insert into complex1 select array(2.2) union all " +
+        "select array(3.3, 4.4) union all " +
+        "select array(3.3, 4.4, 2.2) union all " +
+        "select array(-2.2, 3.3, 4.4)")
+
+    checkExistence(sql(" explain select * from complex1 where array_contains(arr,cast(2.2 as double))"),
+      true,
+      "PushedFilters: [*EqualTo(arr,2.2)]")
+
+    checkExistence(sql(" explain select count(*) from complex1 where array_contains(arr,cast(2.2 as double))"),
+      true,
+      "PushedFilters: [*EqualTo(arr,2.2)]")
+
+    checkAnswer(sql(" select * from complex1 where array_contains(arr,cast(2.2 as double))"),
+      Seq(Row(mutable.WrappedArray.make(Array(2.2))),
+        Row(mutable.WrappedArray.make(Array(3.3, 4.4, 2.2)))))
+
+    checkAnswer(sql(" select count(*) from complex1 where array_contains(arr,cast(2.2 as double))"), Seq(Row(2)))
+    sql("drop table complex1")
+  }
+
+  test("test array contains pushdown for array of decimal") {
+    sql("drop table if exists complex1")
+    sql("create table complex1 (arr array<decimal(5,2)>) stored as carbondata")
+    sql("insert into complex1 select array(2.2) union all " +
+        "select array(3.3, 4.4) union all " +
+        "select array(3.3, 4.4, 2.2) union all " +
+        "select array(-2.2, 3.3, 4.4)")
+
+    checkExistence(sql(" explain select * from complex1 where array_contains(arr,cast(2.2 as decimal(5,2)))"),
+      true,
+      "PushedFilters: [*EqualTo(arr,2.20)]")
+
+    checkExistence(sql(" explain select count(*) from complex1 where array_contains(arr,cast(2.2 as decimal(5,2)))"),
+      true,
+      "PushedFilters: [*EqualTo(arr,2.20)]")
+
+    checkAnswer(sql(" select * from complex1 where array_contains(arr,cast(2.2 as decimal(5,2)))"),
+      Seq(Row(mutable.WrappedArray.make(Array(java.math.BigDecimal.valueOf(2.20).setScale(2)))),
+        Row(mutable.WrappedArray.make(Array(
+          java.math.BigDecimal.valueOf(3.30).setScale(2),
+          java.math.BigDecimal.valueOf(4.40).setScale(2),
+          java.math.BigDecimal.valueOf(2.20).setScale(2))))))
+
+    checkAnswer(sql(" select count(*) from complex1 where array_contains(arr,cast(2.2 as decimal(5,2)))"), Seq(Row(2)))
+    sql("drop table complex1")
+  }
+
+  test("test array contains pushdown for array of timestamp") {
+    sql("drop table if exists complex1")
+    sql("create table complex1 (arr array<timestamp>) stored as carbondata")
+    sql("insert into complex1 select array('2017-01-01 00:00:00','2018-01-01 00:00:00') union all " +
+        "select array('2019-01-01 00:00:00') union all " +
+        "select array('2018-01-01 00:00:00') ")
+    checkExistence(sql(" explain select * from complex1 where array_contains(arr,cast('2018-01-01 00:00:00' as timestamp))"),
+      true,
+      "PushedFilters: [*EqualTo(arr,1514793600000000)]")
+
+    checkExistence(sql(" explain select count(*) from complex1 where array_contains(arr,cast('2018-01-01 00:00:00' as timestamp))"),
+      true,
+      "PushedFilters: [*EqualTo(arr,1514793600000000)]")
+
+    checkAnswer(sql(" select * from complex1 where array_contains(arr,cast('2018-01-01 00:00:00' as timestamp))"),
+      Seq(Row(mutable.WrappedArray.make(Array(Timestamp.valueOf("2017-01-01 00:00:00.0"),Timestamp.valueOf("2018-01-01 00:00:00.0")))),
+        Row(mutable.WrappedArray.make(Array(Timestamp.valueOf("2018-01-01 00:00:00.0"))))))
+
+    checkAnswer(sql(" select count(*) from complex1 where array_contains(arr,cast('2018-01-01 00:00:00' as timestamp))"), Seq(Row(2)))
+    sql("drop table complex1")
+  }
+
+  test("test array contains pushdown for array of date") {
+    sql("drop table if exists complex1")
+    sql("create table complex1 (arr array<date>) stored as carbondata")
+    sql("insert into complex1 select array('2017-01-01','2018-01-01') union all " +
+        "select array('2019-01-01') union all " +
+        "select array('2018-01-01') ")
+
+    checkExistence(sql(" explain select * from complex1 where array_contains(arr,cast('2018-01-01' as date))"),
+      true,
+      "PushedFilters: [*EqualTo(arr,17532)]")
+
+    checkExistence(sql(" explain select count(*) from complex1 where array_contains(arr,cast('2018-01-01' as date))"),
+      true,
+      "PushedFilters: [*EqualTo(arr,17532)]")
+
+    checkAnswer(sql(" select * from complex1 where array_contains(arr,cast('2018-01-01' as date))"),
+      Seq(Row(mutable.WrappedArray.make(Array(Date.valueOf("2017-01-01"),Date.valueOf("2018-01-01")))),
+        Row(mutable.WrappedArray.make(Array(Date.valueOf("2018-01-01"))))))
+
+    checkAnswer(sql(" select count(*) from complex1 where array_contains(arr,cast('2018-01-01' as date))"), Seq(Row(2)))
+    sql("drop table complex1")
+  }
+}