You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2024/02/26 13:39:23 UTC

(spark) branch master updated: [SPARK-47009][SQL] Enable create table support for collation

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 298134fd5e98 [SPARK-47009][SQL] Enable create table support for collation
298134fd5e98 is described below

commit 298134fd5e987a982f468a70454ae94f1b8565e3
Author: Stefan Kandic <st...@databricks.com>
AuthorDate: Mon Feb 26 21:38:56 2024 +0800

    [SPARK-47009][SQL] Enable create table support for collation
    
    ### What changes were proposed in this pull request?
    
    Adding support for create table with collated columns using parquet.
    
    We will map collated strings types to a regular parquet string type. This means that won't support cross-engine compatibility for now.
    
    I will add a PR soon to fix parquet filter pushdown. At first we will disable it completely for collated strings but we should look into using sort keys instead as min/max values to support pushdown later on.
    
    ### Why are the changes needed?
    
    In order to support basic DDL operations for collations
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, users are now able to create tables with collated columns
    
    ### How was this patch tested?
    
    With UTs
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #45105 from stefankandic/SPARK-47009-createTableCollation.
    
    Authored-by: Stefan Kandic <st...@databricks.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../spark/sql/catalyst/parser/SqlBaseParser.g4     |   8 +-
 .../sql/catalyst/parser/DataTypeAstBuilder.scala   |  21 +++-
 .../org/apache/spark/sql/types/DataType.scala      |   6 +-
 .../org/apache/spark/sql/types/StringType.scala    |   2 +-
 .../spark/sql/catalyst/expressions/Cast.scala      |  12 +--
 .../spark/sql/catalyst/expressions/hash.scala      |   2 +-
 .../spark/sql/catalyst/parser/AstBuilder.scala     |   4 +-
 .../parquet/ParquetVectorUpdaterFactory.java       |   2 +-
 .../sql/execution/aggregate/HashMapGenerator.scala |   2 +-
 .../datasources/parquet/ParquetRowConverter.scala  |   2 +-
 .../parquet/ParquetSchemaConverter.scala           |   2 +-
 .../datasources/parquet/ParquetWriteSupport.scala  |   2 +-
 .../org/apache/spark/sql/CollationSuite.scala      | 119 ++++++++++++++++++++-
 13 files changed, 161 insertions(+), 23 deletions(-)

diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
index 1109e4a7bdfc..ca01de4ffdc2 100644
--- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
+++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
@@ -989,7 +989,7 @@ primaryExpression
     | CASE whenClause+ (ELSE elseExpression=expression)? END                                   #searchedCase
     | CASE value=expression whenClause+ (ELSE elseExpression=expression)? END                  #simpleCase
     | name=(CAST | TRY_CAST) LEFT_PAREN expression AS dataType RIGHT_PAREN                     #cast
-    | primaryExpression COLLATE stringLit                                                      #collate
+    | primaryExpression collateClause                                                      #collate
     | primaryExpression DOUBLE_COLON dataType                                                  #castByColon
     | STRUCT LEFT_PAREN (argument+=namedExpression (COMMA argument+=namedExpression)*)? RIGHT_PAREN #struct
     | FIRST LEFT_PAREN expression (IGNORE NULLS)? RIGHT_PAREN                                  #first
@@ -1095,6 +1095,10 @@ colPosition
     : position=FIRST | position=AFTER afterCol=errorCapturingIdentifier
     ;
 
+collateClause
+    : COLLATE collationName=stringLit
+    ;
+
 type
     : BOOLEAN
     | TINYINT | BYTE
@@ -1105,7 +1109,7 @@ type
     | DOUBLE
     | DATE
     | TIMESTAMP | TIMESTAMP_NTZ | TIMESTAMP_LTZ
-    | STRING
+    | STRING collateClause?
     | CHARACTER | CHAR
     | VARCHAR
     | BINARY
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala
index 3a2e704ffe9f..0d2822e13efc 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala
@@ -24,6 +24,7 @@ import org.antlr.v4.runtime.Token
 import org.antlr.v4.runtime.tree.ParseTree
 
 import org.apache.spark.sql.catalyst.parser.SqlBaseParser._
+import org.apache.spark.sql.catalyst.util.CollationFactory
 import org.apache.spark.sql.catalyst.util.SparkParserUtils.{string, withOrigin}
 import org.apache.spark.sql.errors.QueryParsingErrors
 import org.apache.spark.sql.internal.SqlApiConf
@@ -58,8 +59,8 @@ class DataTypeAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] {
    * Resolve/create a primitive type.
    */
   override def visitPrimitiveDataType(ctx: PrimitiveDataTypeContext): DataType = withOrigin(ctx) {
-    val typeName = ctx.`type`.start.getType
-    (typeName, ctx.INTEGER_VALUE().asScala.toList) match {
+    val typeCtx = ctx.`type`
+    (typeCtx.start.getType, ctx.INTEGER_VALUE().asScala.toList) match {
       case (BOOLEAN, Nil) => BooleanType
       case (TINYINT | BYTE, Nil) => ByteType
       case (SMALLINT | SHORT, Nil) => ShortType
@@ -71,7 +72,14 @@ class DataTypeAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] {
       case (TIMESTAMP, Nil) => SqlApiConf.get.timestampType
       case (TIMESTAMP_NTZ, Nil) => TimestampNTZType
       case (TIMESTAMP_LTZ, Nil) => TimestampType
-      case (STRING, Nil) => StringType
+      case (STRING, Nil) =>
+        typeCtx.children.asScala.toSeq match {
+          case Seq(_) => StringType
+          case Seq(_, ctx: CollateClauseContext) =>
+            val collationName = visitCollateClause(ctx)
+            val collationId = CollationFactory.collationNameToId(collationName)
+            StringType(collationId)
+        }
       case (CHARACTER | CHAR, length :: Nil) => CharType(length.getText.toInt)
       case (VARCHAR, length :: Nil) => VarcharType(length.getText.toInt)
       case (BINARY, Nil) => BinaryType
@@ -205,4 +213,11 @@ class DataTypeAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] {
   override def visitCommentSpec(ctx: CommentSpecContext): String = withOrigin(ctx) {
     string(visitStringLit(ctx.stringLit))
   }
+
+  /**
+   * Returns a collation name.
+   */
+  override def visitCollateClause(ctx: CollateClauseContext): String = withOrigin(ctx) {
+    string(visitStringLit(ctx.stringLit))
+  }
 }
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala
index 319a5eccbb6d..fb0a8e586e0c 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala
@@ -31,8 +31,8 @@ import org.apache.spark.{SparkIllegalArgumentException, SparkThrowable}
 import org.apache.spark.annotation.Stable
 import org.apache.spark.sql.catalyst.analysis.SqlApiAnalysis
 import org.apache.spark.sql.catalyst.parser.DataTypeParser
+import org.apache.spark.sql.catalyst.util.{CollationFactory, StringConcat}
 import org.apache.spark.sql.catalyst.util.DataTypeJsonUtils.{DataTypeJsonDeserializer, DataTypeJsonSerializer}
-import org.apache.spark.sql.catalyst.util.StringConcat
 import org.apache.spark.sql.errors.DataTypeErrors
 import org.apache.spark.sql.internal.SqlApiConf
 import org.apache.spark.sql.types.DayTimeIntervalType._
@@ -117,6 +117,7 @@ object DataType {
   private val FIXED_DECIMAL = """decimal\(\s*(\d+)\s*,\s*(\-?\d+)\s*\)""".r
   private val CHAR_TYPE = """char\(\s*(\d+)\s*\)""".r
   private val VARCHAR_TYPE = """varchar\(\s*(\d+)\s*\)""".r
+  private val COLLATED_STRING_TYPE = """string\s+COLLATE\s+([\w_]+)""".r
 
   def fromDDL(ddl: String): DataType = {
     parseTypeWithFallback(
@@ -181,6 +182,9 @@ object DataType {
   /** Given the string representation of a type, return its DataType */
   private def nameToType(name: String): DataType = {
     name match {
+      case COLLATED_STRING_TYPE(collation) =>
+        val collationId = CollationFactory.collationNameToId(collation)
+        StringType(collationId)
       case "decimal" => DecimalType.USER_DEFAULT
       case FIXED_DECIMAL(precision, scale) => DecimalType(precision.toInt, scale.toInt)
       case CHAR_TYPE(length) => CharType(length.toInt)
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala
index 501d86433847..3fe0e1c9ce3f 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala
@@ -39,7 +39,7 @@ class StringType private(val collationId: Int) extends AtomicType with Serializa
    */
   override def typeName: String =
     if (isDefaultCollation) "string"
-    else s"string(${CollationFactory.fetchCollation(collationId).collationName})"
+    else s"string COLLATE ${CollationFactory.fetchCollation(collationId).collationName}"
 
   override def equals(obj: Any): Boolean =
     obj.isInstanceOf[StringType] && obj.asInstanceOf[StringType].collationId == collationId
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index eae112a6a398..66907dc6c353 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -93,7 +93,7 @@ object Cast extends QueryErrorsBase {
 
     case (NullType, _) => true
 
-    case (_, StringType) => true
+    case (_, _: StringType) => true
 
     case (StringType, _: BinaryType) => true
 
@@ -301,8 +301,8 @@ object Cast extends QueryErrorsBase {
     case _ if from == to => true
     case (NullType, _) => true
     case (_: NumericType, _: NumericType) => true
-    case (_: AtomicType, StringType) => true
-    case (_: CalendarIntervalType, StringType) => true
+    case (_: AtomicType, _: StringType) => true
+    case (_: CalendarIntervalType, _: StringType) => true
     case (_: DatetimeType, _: DatetimeType) => true
 
     case (ArrayType(fromType, fn), ArrayType(toType, tn)) =>
@@ -574,7 +574,7 @@ case class Cast(
 
   // BinaryConverter
   private[this] def castToBinary(from: DataType): Any => Any = from match {
-    case StringType => buildCast[UTF8String](_, _.getBytes)
+    case _: StringType => buildCast[UTF8String](_, _.getBytes)
     case ByteType => buildCast[Byte](_, NumberConverter.toBinary)
     case ShortType => buildCast[Short](_, NumberConverter.toBinary)
     case IntegerType => buildCast[Int](_, NumberConverter.toBinary)
@@ -1109,7 +1109,7 @@ case class Cast(
     } else {
       to match {
         case dt if dt == from => identity[Any]
-        case StringType => castToString(from)
+        case _: StringType => castToString(from)
         case BinaryType => castToBinary(from)
         case DateType => castToDate(from)
         case decimal: DecimalType => castToDecimal(from, decimal)
@@ -1198,7 +1198,7 @@ case class Cast(
 
     case _ if from == NullType => (c, evPrim, evNull) => code"$evNull = true;"
     case _ if to == from => (c, evPrim, evNull) => code"$evPrim = $c;"
-    case StringType => (c, evPrim, _) => castToStringCode(from, ctx).apply(c, evPrim)
+    case _: StringType => (c, evPrim, _) => castToStringCode(from, ctx).apply(c, evPrim)
     case BinaryType => castToBinaryCode(from)
     case DateType => castToDateCode(from, ctx)
     case decimal: DecimalType => castToDecimalCode(from, decimal, ctx)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
index 0f5d4707a164..5ad3011fb88b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
@@ -491,7 +491,7 @@ abstract class HashExpression[E] extends Expression {
     case _: DayTimeIntervalType => genHashLong(input, result)
     case _: YearMonthIntervalType => genHashInt(input, result)
     case BinaryType => genHashBytes(input, result)
-    case StringType => genHashString(input, result)
+    case _: StringType => genHashString(input, result)
     case ArrayType(et, containsNull) => genHashForArray(ctx, input, result, et, containsNull)
     case MapType(kt, vt, valueContainsNull) =>
       genHashForMap(ctx, input, result, kt, vt, valueContainsNull)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index 99486ae282a8..ea549cec12f6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -2186,8 +2186,8 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging {
    * Create a [[Collate]] expression.
    */
   override def visitCollate(ctx: CollateContext): Expression = withOrigin(ctx) {
-    val collation = string(visitStringLit(ctx.stringLit))
-    Collate(expression(ctx.primaryExpression), collation)
+    val collationName = visitCollateClause(ctx.collateClause())
+    Collate(expression(ctx.primaryExpression), collationName)
   }
 
   /**
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java
index f369688597b9..abb44915cbcd 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java
@@ -198,7 +198,7 @@ public class ParquetVectorUpdaterFactory {
         }
       }
       case BINARY -> {
-        if (sparkType == DataTypes.StringType || sparkType == DataTypes.BinaryType ||
+        if (sparkType instanceof  StringType || sparkType == DataTypes.BinaryType ||
           canReadAsBinaryDecimal(descriptor, sparkType)) {
           return new BinaryUpdater();
         } else if (canReadAsDecimal(descriptor, sparkType)) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala
index 8a88ad0a57e3..6154706231ed 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala
@@ -173,7 +173,7 @@ abstract class HashMapGenerator(
             ${hashBytes(bytes)}
           """
         }
-      case StringType => hashBytes(s"$input.getBytes()")
+      case _: StringType => hashBytes(s"$input.getBytes()")
       case CalendarIntervalType => hashInt(s"$input.hashCode()")
     }
   }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala
index b2222f4297e9..36c13e72993b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala
@@ -395,7 +395,7 @@ private[parquet] class ParquetRowConverter(
         throw QueryExecutionErrors.cannotCreateParquetConverterForDecimalTypeError(
           t, parquetType.toString)
 
-      case StringType =>
+      case _: StringType =>
         new ParquetStringConverter(updater)
 
       // As long as the parquet type is INT64 timestamp, whether logical annotation
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala
index 59c99cb998ca..963b1520b3c0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala
@@ -544,7 +544,7 @@ class SparkToParquetSchemaConverter(
       case DoubleType =>
         Types.primitive(DOUBLE, repetition).named(field.name)
 
-      case StringType =>
+      case _: StringType =>
         Types.primitive(BINARY, repetition)
           .as(LogicalTypeAnnotation.stringType()).named(field.name)
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala
index 7194033e603b..89a1cd5d4375 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala
@@ -199,7 +199,7 @@ class ParquetWriteSupport extends WriteSupport[InternalRow] with Logging {
         (row: SpecializedGetters, ordinal: Int) =>
           recordConsumer.addDouble(row.getDouble(ordinal))
 
-      case StringType =>
+      case _: StringType =>
         (row: SpecializedGetters, ordinal: Int) =>
           recordConsumer.addBinary(
             Binary.fromReusedByteArray(row.getUTF8String(ordinal).getBytes))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala
index 465afc7e2006..37bcdbbcd569 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala
@@ -17,13 +17,21 @@
 
 package org.apache.spark.sql
 
+import scala.collection.immutable.Seq
+import scala.jdk.CollectionConverters.MapHasAsJava
+
 import org.apache.spark.SparkException
 import org.apache.spark.sql.catalyst.ExtendedAnalysisException
 import org.apache.spark.sql.catalyst.util.CollationFactory
-import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.connector.{DatasourceV2SQLBase, FakeV2ProviderWithCustomSchema}
+import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTable}
+import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.CatalogHelper
+import org.apache.spark.sql.connector.catalog.CatalogV2Util.withDefaultOwnership
 import org.apache.spark.sql.types.StringType
 
-class CollationSuite extends QueryTest with SharedSparkSession {
+class CollationSuite extends DatasourceV2SQLBase {
+  protected val v2Source = classOf[FakeV2ProviderWithCustomSchema].getName
+
   test("collate returns proper type") {
     Seq("ucs_basic", "ucs_basic_lcase", "unicode", "unicode_ci").foreach { collationName =>
       checkAnswer(sql(s"select 'aaa' collate '$collationName'"), Row("aaa"))
@@ -174,4 +182,111 @@ class CollationSuite extends QueryTest with SharedSparkSession {
           Row(expected))
     }
   }
+
+  test("create table with collation") {
+    val tableName = "parquet_dummy_tbl"
+    val collationName = "UCS_BASIC_LCASE"
+    val collationId = CollationFactory.collationNameToId(collationName)
+
+    withTable(tableName) {
+      sql(
+        s"""
+           |CREATE TABLE $tableName (c1 STRING COLLATE '$collationName')
+           |USING PARQUET
+           |""".stripMargin)
+
+      sql(s"INSERT INTO $tableName VALUES ('aaa')")
+      sql(s"INSERT INTO $tableName VALUES ('AAA')")
+
+      checkAnswer(sql(s"SELECT DISTINCT COLLATION(c1) FROM $tableName"), Seq(Row(collationName)))
+      assert(sql(s"select c1 FROM $tableName").schema.head.dataType == StringType(collationId))
+    }
+  }
+
+  test("create table with collations inside a struct") {
+    val tableName = "struct_collation_tbl"
+    val collationName = "UCS_BASIC_LCASE"
+    val collationId = CollationFactory.collationNameToId(collationName)
+
+    withTable(tableName) {
+      sql(
+        s"""
+           |CREATE TABLE $tableName
+           |(c1 STRUCT<name: STRING COLLATE '$collationName', age: INT>)
+           |USING PARQUET
+           |""".stripMargin)
+
+      sql(s"INSERT INTO $tableName VALUES (named_struct('name', 'aaa', 'id', 1))")
+      sql(s"INSERT INTO $tableName VALUES (named_struct('name', 'AAA', 'id', 2))")
+
+      checkAnswer(sql(s"SELECT DISTINCT collation(c1.name) FROM $tableName"),
+        Seq(Row(collationName)))
+      assert(sql(s"SELECT c1.name FROM $tableName").schema.head.dataType == StringType(collationId))
+    }
+  }
+
+  test("add collated column with alter table") {
+    val tableName = "alter_column_tbl"
+    val defaultCollation = "UCS_BASIC"
+    val collationName = "UCS_BASIC_LCASE"
+    val collationId = CollationFactory.collationNameToId(collationName)
+
+    withTable(tableName) {
+      sql(
+        s"""
+           |CREATE TABLE $tableName (c1 STRING)
+           |USING PARQUET
+           |""".stripMargin)
+
+      sql(s"INSERT INTO $tableName VALUES ('aaa')")
+      sql(s"INSERT INTO $tableName VALUES ('AAA')")
+
+      checkAnswer(sql(s"SELECT DISTINCT COLLATION(c1) FROM $tableName"),
+          Seq(Row(defaultCollation)))
+
+      sql(
+      s"""
+         |ALTER TABLE $tableName
+         |ADD COLUMN c2 STRING COLLATE '$collationName'
+         |""".stripMargin)
+
+      sql(s"INSERT INTO $tableName VALUES ('aaa', 'aaa')")
+      sql(s"INSERT INTO $tableName VALUES ('AAA', 'AAA')")
+
+      checkAnswer(sql(s"SELECT DISTINCT COLLATION(c2) FROM $tableName"),
+        Seq(Row(collationName)))
+      assert(sql(s"select c2 FROM $tableName").schema.head.dataType == StringType(collationId))
+    }
+  }
+
+  test("create v2 table with collation column") {
+    val tableName = "testcat.table_name"
+    val collationName = "UCS_BASIC_LCASE"
+    val collationId = CollationFactory.collationNameToId(collationName)
+
+    withTable(tableName) {
+      sql(
+        s"""
+           |CREATE TABLE $tableName (c1 string COLLATE '$collationName')
+           |USING $v2Source
+           |""".stripMargin)
+
+      val testCatalog = catalog("testcat").asTableCatalog
+      val table = testCatalog.loadTable(Identifier.of(Array(), "table_name"))
+
+      assert(table.name == tableName)
+      assert(table.partitioning.isEmpty)
+      assert(table.properties == withDefaultOwnership(Map("provider" -> v2Source)).asJava)
+      assert(table.columns().head.dataType() == StringType(collationId))
+
+      val rdd = spark.sparkContext.parallelize(table.asInstanceOf[InMemoryTable].rows)
+      checkAnswer(spark.internalCreateDataFrame(rdd, table.schema), Seq.empty)
+
+      sql(s"INSERT INTO $tableName VALUES ('a'), ('A')")
+
+      checkAnswer(sql(s"SELECT DISTINCT COLLATION(c1) FROM $tableName"),
+        Seq(Row(collationName)))
+      assert(sql(s"select c1 FROM $tableName").schema.head.dataType == StringType(collationId))
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org