You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2020/12/22 14:32:23 UTC

[spark] branch branch-3.1 updated: [SPARK-33876][SQL] Add length-check for reading char/varchar from tables w/ a external location

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.1
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.1 by this push:
     new 17fe038b [SPARK-33876][SQL] Add length-check for reading char/varchar from tables w/ a external location
17fe038b is described below

commit 17fe038b6f307a08d18cdaf9e79b1d4e974455c9
Author: Kent Yao <ya...@hotmail.com>
AuthorDate: Tue Dec 22 14:24:12 2020 +0000

    [SPARK-33876][SQL] Add length-check for reading char/varchar from tables w/ a external location
    
    ### What changes were proposed in this pull request?
    This PR adds the length check to the existing ApplyCharPadding rule. Tables will have external locations when users execute
    SET LOCATION or CREATE TABLE ... LOCATION. If the location contains over length values we should FAIL ON READ.
    
    ### Why are the changes needed?
    
    ```sql
    spark-sql> INSERT INTO t2 VALUES ('1', 'b12345');
    Time taken: 0.141 seconds
    spark-sql> alter table t set location '/tmp/hive_one/t2';
    Time taken: 0.095 seconds
    spark-sql> select * from t;
    1 b1234
    ```
    the above case should fail rather than implicitly applying truncation
    
    ### Does this PR introduce _any_ user-facing change?
    
    no
    
    ### How was this patch tested?
    
    new tests
    
    Closes #30882 from yaooqinn/SPARK-33876.
    
    Authored-by: Kent Yao <ya...@hotmail.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
    (cherry picked from commit 6da5cdf1dbfc35cee0ce32aa9e44c0b4187373d9)
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../spark/sql/catalyst/util/CharVarcharUtils.scala | 29 ++++++++----
 ...a => PaddingAndLengthCheckForCharVarchar.scala} | 20 ++++----
 .../sql/internal/BaseSessionStateBuilder.scala     |  2 +-
 .../apache/spark/sql/CharVarcharTestSuite.scala    | 55 ++++++++++++++++++++++
 .../spark/sql/hive/HiveSessionStateBuilder.scala   |  2 +-
 5 files changed, 89 insertions(+), 19 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala
index e42e384..cfdc50d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala
@@ -127,25 +127,36 @@ object CharVarcharUtils extends Logging {
   }
 
   /**
-   * Returns expressions to apply read-side char type padding for the given attributes. String
-   * values should be right-padded to N characters if it's from a CHAR(N) column/field.
+   * Returns expressions to apply read-side char type padding for the given attributes.
+   *
+   * For a CHAR(N) column/field and the length of string value is M
+   * If M > N, raise runtime error
+   * If M <= N, the value should be right-padded to N characters.
+   *
+   * For a VARCHAR(N) column/field and the length of string value is M
+   * If M > N, raise runtime error
+   * If M <= N, the value should be remained.
    */
-  def charTypePadding(output: Seq[AttributeReference]): Seq[NamedExpression] = {
+  def paddingWithLengthCheck(output: Seq[AttributeReference]): Seq[NamedExpression] = {
     output.map { attr =>
       getRawType(attr.metadata).filter { rawType =>
-        rawType.existsRecursively(_.isInstanceOf[CharType])
+        rawType.existsRecursively(dt => dt.isInstanceOf[CharType] || dt.isInstanceOf[VarcharType])
       }.map { rawType =>
-        Alias(charTypePadding(attr, rawType), attr.name)(explicitMetadata = Some(attr.metadata))
+        Alias(paddingWithLengthCheck(attr, rawType), attr.name)(
+          explicitMetadata = Some(attr.metadata))
       }.getOrElse(attr)
     }
   }
 
-  private def charTypePadding(expr: Expression, dt: DataType): Expression = dt match {
-    case CharType(length) => StringRPad(expr, Literal(length))
+  private def paddingWithLengthCheck(expr: Expression, dt: DataType): Expression = dt match {
+    case CharType(length) => StringRPad(stringLengthCheck(expr, dt), Literal(length))
+
+    case VarcharType(_) => stringLengthCheck(expr, dt)
 
     case StructType(fields) =>
       val struct = CreateNamedStruct(fields.zipWithIndex.flatMap { case (f, i) =>
-        Seq(Literal(f.name), charTypePadding(GetStructField(expr, i, Some(f.name)), f.dataType))
+        Seq(Literal(f.name),
+          paddingWithLengthCheck(GetStructField(expr, i, Some(f.name)), f.dataType))
       })
       if (expr.nullable) {
         If(IsNull(expr), Literal(null, struct.dataType), struct)
@@ -166,7 +177,7 @@ object CharVarcharUtils extends Logging {
   private def charTypePaddingInArray(
       arr: Expression, et: DataType, containsNull: Boolean): Expression = {
     val param = NamedLambdaVariable("x", replaceCharVarcharWithString(et), containsNull)
-    val func = LambdaFunction(charTypePadding(param, et), Seq(param))
+    val func = LambdaFunction(paddingWithLengthCheck(param, et), Seq(param))
     ArrayTransform(arr, func)
   }
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ApplyCharTypePadding.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PaddingAndLengthCheckForCharVarchar.scala
similarity index 86%
rename from sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ApplyCharTypePadding.scala
rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PaddingAndLengthCheckForCharVarchar.scala
index 35bb86f..f268d51 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ApplyCharTypePadding.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PaddingAndLengthCheckForCharVarchar.scala
@@ -27,17 +27,21 @@ import org.apache.spark.sql.types.{CharType, StringType}
 import org.apache.spark.unsafe.types.UTF8String
 
 /**
- * This rule applies char type padding in two places:
- *   1. When reading values from column/field of type CHAR(N), right-pad the values to length N.
- *   2. When comparing char type column/field with string literal or char type column/field,
- *      right-pad the shorter one to the longer length.
+ * This rule performs char type padding and length check for both char and varchar.
+ *
+ * When reading values from column/field of type CHAR(N) or VARCHAR(N), the underlying string value
+ * might be over length (e.g. tables w/ external locations), it will fail in this case.
+ * Otherwise, right-pad the values to length N for CHAR(N) and remain the same for VARCHAR(N).
+ *
+ * When comparing char type column/field with string literal or char type column/field,
+ * right-pad the shorter one to the longer length.
  */
-object ApplyCharTypePadding extends Rule[LogicalPlan] {
+object PaddingAndLengthCheckForCharVarchar extends Rule[LogicalPlan] {
 
   override def apply(plan: LogicalPlan): LogicalPlan = {
     val padded = plan.resolveOperatorsUpWithNewOutput {
       case r: LogicalRelation =>
-        val projectList = CharVarcharUtils.charTypePadding(r.output)
+        val projectList = CharVarcharUtils.paddingWithLengthCheck(r.output)
         if (projectList == r.output) {
           r -> Nil
         } else {
@@ -47,7 +51,7 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] {
         }
 
       case r: DataSourceV2Relation =>
-        val projectList = CharVarcharUtils.charTypePadding(r.output)
+        val projectList = CharVarcharUtils.paddingWithLengthCheck(r.output)
         if (projectList == r.output) {
           r -> Nil
         } else {
@@ -57,7 +61,7 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] {
         }
 
       case r: HiveTableRelation =>
-        val projectList = CharVarcharUtils.charTypePadding(r.output)
+        val projectList = CharVarcharUtils.paddingWithLengthCheck(r.output)
         if (projectList == r.output) {
           r -> Nil
         } else {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
index 75e2738..c71634f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
@@ -179,7 +179,7 @@ abstract class BaseSessionStateBuilder(
         PreprocessTableCreation(session) +:
         PreprocessTableInsertion +:
         DataSourceAnalysis +:
-        ApplyCharTypePadding +:
+        PaddingAndLengthCheckForCharVarchar +:
         customPostHocResolutionRules
 
     override val extendedCheckRules: Seq[LogicalPlan => Unit] =
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala
index b0f1198..d7b84a0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala
@@ -528,6 +528,61 @@ class FileSourceCharVarcharTestSuite extends CharVarcharTestSuite with SharedSpa
   override protected def sparkConf: SparkConf = {
     super.sparkConf.set(SQLConf.USE_V1_SOURCE_LIST, "parquet")
   }
+
+  test("create table w/ location and fit length values") {
+    Seq("char", "varchar").foreach { typ =>
+      withTempPath { dir =>
+        withTable("t") {
+          sql("SELECT '12' as col").write.format(format).save(dir.toString)
+          sql(s"CREATE TABLE t (col $typ(2)) using $format LOCATION '$dir'")
+          val df = sql("select * from t")
+          checkAnswer(sql("select * from t"), Row("12"))
+        }
+      }
+    }
+  }
+
+  test("create table w/ location and over length values") {
+    Seq("char", "varchar").foreach { typ =>
+      withTempPath { dir =>
+        withTable("t") {
+          sql("SELECT '123456' as col").write.format(format).save(dir.toString)
+          sql(s"CREATE TABLE t (col $typ(2)) using $format LOCATION '$dir'")
+          val e = intercept[SparkException] { sql("select * from t").collect() }
+          assert(e.getCause.getMessage.contains(
+            s"input string of length 6 exceeds $typ type length limitation: 2"))
+        }
+      }
+    }
+  }
+
+  test("alter table set location w/ fit length values") {
+    Seq("char", "varchar").foreach { typ =>
+      withTempPath { dir =>
+        withTable("t") {
+          sql("SELECT '12' as col").write.format(format).save(dir.toString)
+          sql(s"CREATE TABLE t (col $typ(2)) using $format")
+          sql(s"ALTER TABLE t SET LOCATION '$dir'")
+          checkAnswer(spark.table("t"), Row("12"))
+        }
+      }
+    }
+  }
+
+  test("alter table set location w/ over length values") {
+    Seq("char", "varchar").foreach { typ =>
+      withTempPath { dir =>
+        withTable("t") {
+          sql("SELECT '123456' as col").write.format(format).save(dir.toString)
+          sql(s"CREATE TABLE t (col $typ(2)) using $format")
+          sql(s"ALTER TABLE t SET LOCATION '$dir'")
+          val e = intercept[SparkException] { spark.table("t").collect() }
+          assert(e.getCause.getMessage.contains(
+            s"input string of length 6 exceeds $typ type length limitation: 2"))
+        }
+      }
+    }
+  }
 }
 
 class DSV2CharVarcharTestSuite extends CharVarcharTestSuite
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala
index da37b61..5963a71 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala
@@ -90,7 +90,7 @@ class HiveSessionStateBuilder(
         PreprocessTableCreation(session) +:
         PreprocessTableInsertion +:
         DataSourceAnalysis +:
-        ApplyCharTypePadding +:
+        PaddingAndLengthCheckForCharVarchar +:
         HiveAnalysis +:
         customPostHocResolutionRules
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org