You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by "chenhao-db (via GitHub)" <gi...@apache.org> on 2024/03/25 20:22:31 UTC

[PR] [SPARK-47551][SQL] Add variant_get expression. [spark]

chenhao-db opened a new pull request, #45708:
URL: https://github.com/apache/spark/pull/45708

   ### What changes were proposed in this pull request?
   
   This PR adds a new `VariantGet` expression that extracts variant value and casts it to a concrete type. It is exposed as two SQL expressions, `variant_get` and `try_variant_get`. If the extracted path doesn't exist in the source variant value, they should both return null. The difference is at the cast step: when the cast fails,`variant_get` should throw an exception, and `try_variant_get` should return null.
   
   The cast behavior is NOT affected by the `spark.sql.ansi.enabled` flag: `variant_get` always has the ANSI cast semantics, while `try_variant_get` always has the TRY cast semantics. An example is that casting a variant long to an int never silently overflows and produces the wrapped int value, while casting a long to an int may silently overflow in LEGACY mode.
   
   The current path extraction only supports array index access and case-sensitive object key access.
   
   Usage examples:
   
   ```
   
   > SELECT variant_get(parse_json('{"a": 1}'), '$.a', 'int');
    1
   > SELECT variant_get(parse_json('{"a": 1}'), '$.b', 'int');
    NULL
   > SELECT variant_get(parse_json('[1, "2"]'), '$[1]', 'string');
    2
   > SELECT variant_get(parse_json('[1, "2"]'), '$[2]', 'string');
    NULL
   > SELECT variant_get(parse_json('[1, "hello"]'), '$[1]'); -- when the target type is not specified, it returns variant by default (i.e., only extracts a sub-variant without cast)
    "hello"
   > SELECT try_variant_get(parse_json('[1, "hello"]'), '$[1]', 'int'); -- "hello" cannot be cast into int
    NULL
   ```
   
   ### How was this patch tested?
   
   Unit tests.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


Re: [PR] [SPARK-47551][SQL] Add variant_get expression. [spark]

Posted by "chenhao-db (via GitHub)" <gi...@apache.org>.
chenhao-db commented on code in PR #45708:
URL: https://github.com/apache/spark/pull/45708#discussion_r1544616525


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala:
##########
@@ -53,3 +66,311 @@ case class ParseJson(child: Expression)
   override protected def withNewChildInternal(newChild: Expression): ParseJson =
     copy(child = newChild)
 }
+
+object VariantPathParser extends RegexParsers {
+  // A path segment in the `VariantGet` expression represents either an object key access or an
+  // array index access.
+  type PathSegment = Either[String, Int]
+
+  private def root: Parser[Char] = '$'
+
+  // Parse index segment like `[123]`.
+  private def index: Parser[PathSegment] =
+    for {
+      index <- '[' ~> "\\d+".r <~ ']'
+    } yield {
+      scala.util.Right(index.toInt)
+    }
+
+  // Parse key segment like `.name`, `['name']`, or `["name"]`.
+  private def key: Parser[PathSegment] =
+    for {
+      key <- '.' ~> "[^\\.\\[]+".r | "['" ~> "[^\\'\\?]+".r <~ "']" |
+        "[\"" ~> "[^\\\"\\?]+".r <~ "\"]"
+    } yield {
+      scala.util.Left(key)
+    }
+
+  private val parser: Parser[List[PathSegment]] = phrase(root ~> rep(key | index))
+
+  def parse(str: String): Option[Array[PathSegment]] = {
+    this.parseAll(parser, str) match {
+      case Success(result, _) => Some(result.toArray)
+      case _ => None
+    }
+  }
+}
+
+/**
+ * The implementation for `variant_get` and `try_variant_get` expressions. Extracts a sub-variant
+ * value according to a path and cast it into a concrete data type.
+ * @param child The source variant value to extract from.
+ * @param path A literal path expression. It has the same format as the JSON path.
+ * @param targetType The target data type to cast into. Any non-nullable annotations are ignored.
+ * @param failOnError Controls whether the expression should throw an exception or return null if
+ *                    the cast fails.
+ * @param timeZoneId A string identifier of a time zone. It is required by timestamp-related casts.
+ */
+case class VariantGet(
+    child: Expression,
+    path: Expression,
+    targetType: DataType,
+    failOnError: Boolean,
+    timeZoneId: Option[String] = None)
+    extends BinaryExpression
+    with TimeZoneAwareExpression
+    with NullIntolerant
+    with ExpectsInputTypes
+    with CodegenFallback
+    with QueryErrorsBase {
+  override def checkInputDataTypes(): TypeCheckResult = {
+    val check = super.checkInputDataTypes()
+    if (check.isFailure) {
+      check
+    } else if (!path.foldable) {
+      DataTypeMismatch(
+        errorSubClass = "NON_FOLDABLE_INPUT",
+        messageParameters = Map(
+          "inputName" -> toSQLId("path"),
+          "inputType" -> toSQLType(path.dataType),
+          "inputExpr" -> toSQLExpr(path)))
+    } else if (!VariantGet.checkDataType(targetType)) {
+      DataTypeMismatch(
+        errorSubClass = "CAST_WITHOUT_SUGGESTION",
+        messageParameters =
+          Map("srcType" -> toSQLType(VariantType), "targetType" -> toSQLType(targetType)))
+    } else {
+      TypeCheckResult.TypeCheckSuccess
+    }
+  }
+
+  override lazy val dataType: DataType = targetType.asNullable
+
+  @transient private lazy val parsedPath = {
+    val pathValue = path.eval().toString
+    VariantPathParser.parse(pathValue).getOrElse {
+      throw QueryExecutionErrors.invalidVariantGetPath(pathValue, prettyName)
+    }
+  }
+
+  final override def nodePatternsInternal(): Seq[TreePattern] = Seq(VARIANT_GET)
+
+  override def inputTypes: Seq[AbstractDataType] = Seq(VariantType, StringType)
+
+  override def prettyName: String = if (failOnError) "variant_get" else "try_variant_get"
+
+  override def nullable: Boolean = true
+
+  protected override def nullSafeEval(input: Any, path: Any): Any = {
+    VariantGet.variantGet(
+      input.asInstanceOf[VariantVal],
+      parsedPath,
+      dataType,
+      failOnError,
+      timeZoneId)
+  }
+
+  override def left: Expression = child
+
+  override def right: Expression = path
+
+  override protected def withNewChildrenInternal(
+      newChild: Expression,
+      newPath: Expression): VariantGet = copy(child = newChild, path = newPath)
+
+  override def withTimeZone(timeZoneId: String): VariantGet = copy(timeZoneId = Option(timeZoneId))
+}
+
+case object VariantGet {
+  /**
+   * Returns whether a data type can be cast into/from variant. For scalar types, we allow a subset
+   * of them. For nested types, we reject map types with a non-string key type.
+   */
+  def checkDataType(dataType: DataType): Boolean = dataType match {
+    case _: NumericType | BooleanType | StringType | BinaryType | TimestampType | DateType |
+        VariantType =>
+      true
+    case ArrayType(elementType, _) => checkDataType(elementType)
+    case MapType(StringType, valueType, _) => checkDataType(valueType)
+    case StructType(fields) => fields.forall(f => checkDataType(f.dataType))
+    case _ => false
+  }
+
+  /** The actual implementation of the `VariantGet` expression. */
+  def variantGet(
+      input: VariantVal,
+      parsedPath: Array[VariantPathParser.PathSegment],
+      dataType: DataType,
+      failOnError: Boolean,
+      zoneId: Option[String]): Any = {
+    var v = new Variant(input.getValue, input.getMetadata)
+    for (path <- parsedPath) {
+      v = path match {
+        case scala.util.Left(key) if v.getType == Type.OBJECT => v.getFieldByKey(key)
+        case scala.util.Right(index) if v.getType == Type.ARRAY => v.getElementAtIndex(index)
+        case _ => null
+      }
+      if (v == null) return null
+    }
+    VariantGet.cast(v, dataType, failOnError, zoneId)
+  }
+
+  /**
+   * Cast a variant `v` into a target data type `dataType`. If the variant represents a variant
+   * null, the result is always a SQL NULL. The cast may fail due to an illegal type combination
+   * (e.g., cast a variant int to binary), or an invalid input valid (e.g, cast a variant string
+   * "hello" to int). If the cast fails, throw an exception when `failOnError` is true, or return a
+   * SQL NULL when it is false.
+   */
+  def cast(v: Variant, dataType: DataType, failOnError: Boolean, zoneId: Option[String]): Any = {
+    def invalidCast(): Any =
+      if (failOnError) throw QueryExecutionErrors.invalidVariantCast(v.toJson, dataType) else null
+
+    val variantType = v.getType
+    if (variantType == Type.NULL) return null
+    dataType match {
+      case VariantType => new VariantVal(v.getValue, v.getMetadata)
+      case _: AtomicType =>
+        variantType match {
+          case Type.OBJECT | Type.ARRAY =>
+            if (dataType == StringType) UTF8String.fromString(v.toJson) else invalidCast()
+          case _ =>
+            val input = variantType match {
+              case Type.BOOLEAN => v.getBoolean
+              case Type.LONG => v.getLong
+              case Type.STRING => UTF8String.fromString(v.getString)
+              case Type.DOUBLE => v.getDouble
+              case Type.DECIMAL => Decimal(v.getDecimal)
+              // We have handled other cases and should never reach here. This case is only intended
+              // to by pass the compiler exhaustiveness check.
+              case _ => throw QueryExecutionErrors.unreachableError()
+            }
+            // We mostly use the `Cast` expression to implement the cast. However, `Cast` silently
+            // ignores the overflow in the long/decimal -> timestamp cast, and we want to enforce
+            // strict overflow checks.
+            input match {
+              case l: Long if dataType == TimestampType =>
+                try Math.multiplyExact(l, MICROS_PER_SECOND)
+                catch {
+                  case _: ArithmeticException => invalidCast()
+                }
+              case d: Decimal if dataType == TimestampType =>
+                try {
+                  d.toJavaBigDecimal
+                    .multiply(new java.math.BigDecimal(MICROS_PER_SECOND))
+                    .toBigInteger
+                    .longValueExact()
+                } catch {
+                  case _: ArithmeticException => invalidCast()
+                }
+              case _ =>
+                val result = Cast(Literal(input), dataType, zoneId, EvalMode.TRY).eval()

Review Comment:
   Thanks for pointing out! Actually, int can be cast into boolean in ANSI mode: https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala#L101. But you point is valid, and there is another example: int cannot be cast into binary in ANSI mode, but calling `Call.eval` will succeed. I have added the check.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


Re: [PR] [SPARK-47551][SQL] Add variant_get expression. [spark]

Posted by "cloud-fan (via GitHub)" <gi...@apache.org>.
cloud-fan commented on code in PR #45708:
URL: https://github.com/apache/spark/pull/45708#discussion_r1546049361


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala:
##########
@@ -53,3 +66,320 @@ case class ParseJson(child: Expression)
   override protected def withNewChildInternal(newChild: Expression): ParseJson =
     copy(child = newChild)
 }
+
+object VariantPathParser extends RegexParsers {
+  // A path segment in the `VariantGet` expression represents either an object key access or an
+  // array index access.
+  type PathSegment = Either[String, Int]
+
+  private def root: Parser[Char] = '$'
+
+  // Parse index segment like `[123]`.
+  private def index: Parser[PathSegment] =
+    for {
+      index <- '[' ~> "\\d+".r <~ ']'
+    } yield {
+      scala.util.Right(index.toInt)
+    }
+
+  // Parse key segment like `.name`, `['name']`, or `["name"]`.
+  private def key: Parser[PathSegment] =
+    for {
+      key <- '.' ~> "[^\\.\\[]+".r | "['" ~> "[^\\'\\?]+".r <~ "']" |
+        "[\"" ~> "[^\\\"\\?]+".r <~ "\"]"
+    } yield {
+      scala.util.Left(key)
+    }
+
+  private val parser: Parser[List[PathSegment]] = phrase(root ~> rep(key | index))
+
+  def parse(str: String): Option[Array[PathSegment]] = {
+    this.parseAll(parser, str) match {
+      case Success(result, _) => Some(result.toArray)
+      case _ => None
+    }
+  }
+}
+
+/**
+ * The implementation for `variant_get` and `try_variant_get` expressions. Extracts a sub-variant
+ * value according to a path and cast it into a concrete data type.
+ * @param child The source variant value to extract from.
+ * @param path A literal path expression. It has the same format as the JSON path.
+ * @param targetType The target data type to cast into. Any non-nullable annotations are ignored.
+ * @param failOnError Controls whether the expression should throw an exception or return null if
+ *                    the cast fails.
+ * @param timeZoneId A string identifier of a time zone. It is required by timestamp-related casts.
+ */
+case class VariantGet(
+    child: Expression,
+    path: Expression,
+    targetType: DataType,
+    failOnError: Boolean,
+    timeZoneId: Option[String] = None)
+    extends BinaryExpression
+    with TimeZoneAwareExpression
+    with NullIntolerant
+    with ExpectsInputTypes
+    with CodegenFallback
+    with QueryErrorsBase {
+  override def checkInputDataTypes(): TypeCheckResult = {
+    val check = super.checkInputDataTypes()
+    if (check.isFailure) {
+      check
+    } else if (!path.foldable) {
+      DataTypeMismatch(
+        errorSubClass = "NON_FOLDABLE_INPUT",
+        messageParameters = Map(
+          "inputName" -> toSQLId("path"),
+          "inputType" -> toSQLType(path.dataType),
+          "inputExpr" -> toSQLExpr(path)))
+    } else if (!VariantGet.checkDataType(targetType)) {
+      DataTypeMismatch(
+        errorSubClass = "CAST_WITHOUT_SUGGESTION",
+        messageParameters =
+          Map("srcType" -> toSQLType(VariantType), "targetType" -> toSQLType(targetType)))
+    } else {
+      TypeCheckResult.TypeCheckSuccess
+    }
+  }
+
+  override lazy val dataType: DataType = targetType.asNullable
+
+  @transient private lazy val parsedPath = {
+    val pathValue = path.eval().toString
+    VariantPathParser.parse(pathValue).getOrElse {
+      throw QueryExecutionErrors.invalidVariantGetPath(pathValue, prettyName)
+    }
+  }
+
+  final override def nodePatternsInternal(): Seq[TreePattern] = Seq(VARIANT_GET)
+
+  override def inputTypes: Seq[AbstractDataType] = Seq(VariantType, StringType)
+
+  override def prettyName: String = if (failOnError) "variant_get" else "try_variant_get"
+
+  override def nullable: Boolean = true
+
+  protected override def nullSafeEval(input: Any, path: Any): Any = {
+    VariantGet.variantGet(
+      input.asInstanceOf[VariantVal],
+      parsedPath,
+      dataType,
+      failOnError,
+      timeZoneId)
+  }
+
+  override def left: Expression = child
+
+  override def right: Expression = path
+
+  override protected def withNewChildrenInternal(
+      newChild: Expression,
+      newPath: Expression): VariantGet = copy(child = newChild, path = newPath)
+
+  override def withTimeZone(timeZoneId: String): VariantGet = copy(timeZoneId = Option(timeZoneId))
+}
+
+case object VariantGet {
+  /**
+   * Returns whether a data type can be cast into/from variant. For scalar types, we allow a subset
+   * of them. For nested types, we reject map types with a non-string key type.
+   */
+  def checkDataType(dataType: DataType): Boolean = dataType match {
+    case _: NumericType | BooleanType | _: StringType | BinaryType | TimestampType | DateType |

Review Comment:
   shall we match `_: DatetimeType` so that TimestampNTZType is also included?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


Re: [PR] [SPARK-47551][SQL] Add variant_get expression. [spark]

Posted by "cloud-fan (via GitHub)" <gi...@apache.org>.
cloud-fan commented on code in PR #45708:
URL: https://github.com/apache/spark/pull/45708#discussion_r1540474925


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala:
##########
@@ -63,3 +70,300 @@ case class ParseJson(child: Expression) extends UnaryExpression
   override protected def withNewChildInternal(newChild: Expression): ParseJson =
     copy(child = newChild)
 }
+
+// A path segment in the `VariantGet` expression. It represents either an object key access (when
+// `key` is not null) or an array index access (when `key` is null).
+case class PathSegment(key: String, index: Int)
+
+object VariantPathParser extends RegexParsers {
+  private def root: Parser[Char] = '$'
+
+  // Parse index segment like `[123]`.
+  private def index: Parser[PathSegment] =
+    for {
+      index <- '[' ~> "\\d+".r <~ ']'
+    } yield {
+      PathSegment(null, index.toInt)
+    }
+
+  // Parse key segment like `.name`, `['name']`, or `["name"]`.
+  private def key: Parser[PathSegment] =
+    for {
+      key <- '.' ~> "[^\\.\\[]+".r | "['" ~> "[^\\'\\?]+".r <~ "']" |
+        "[\"" ~> "[^\\\"\\?]+".r <~ "\"]"
+    } yield {
+      PathSegment(key, 0)
+    }
+
+  private val parser: Parser[List[PathSegment]] = phrase(root ~> rep(key | index))
+
+  def parse(str: String): Option[Array[PathSegment]] = {
+    this.parseAll(parser, str) match {
+      case Success(result, _) => Some(result.toArray)
+      case _ => None
+    }
+  }
+}
+
+/**
+ * The implementation for `variant_get` and `try_variant_get` expressions. Extracts a sub-variant
+ * value according to a path and cast it into a concrete data type.
+ * @param child The source variant value to extract from.
+ * @param path A literal path expression. It has the same format as the JSON path.
+ * @param schema The target data type to cast into.
+ * @param failOnError Controls whether the expression should throw an exception or return null if
+ *                    the cast fails.
+ * @param timeZoneId A string identifier of a time zone. It is required by timestamp-related casts.
+ */
+case class VariantGet(
+    child: Expression,
+    path: Expression,
+    schema: DataType,
+    failOnError: Boolean,
+    timeZoneId: Option[String] = None)
+    extends BinaryExpression
+    with TimeZoneAwareExpression
+    with NullIntolerant
+    with ExpectsInputTypes
+    with CodegenFallback
+    with QueryErrorsBase {
+  override def checkInputDataTypes(): TypeCheckResult = {
+    val check = super.checkInputDataTypes()
+    if (check.isFailure) {
+      check
+    } else if (!path.foldable) {
+      DataTypeMismatch(
+        errorSubClass = "NON_FOLDABLE_INPUT",
+        messageParameters = Map(
+          "inputName" -> toSQLId("path"),
+          "inputType" -> toSQLType(path.dataType),
+          "inputExpr" -> toSQLExpr(path)
+        )
+      )
+    } else if (!VariantGet.checkDataType(schema)) {
+      DataTypeMismatch(
+        errorSubClass = "CAST_WITHOUT_SUGGESTION",
+        messageParameters = Map(
+          "srcType" -> toSQLType(VariantType),
+          "targetType" -> toSQLType(schema)
+        )
+      )
+    } else {
+      TypeCheckResult.TypeCheckSuccess
+    }
+  }
+
+  override lazy val dataType: DataType = schema.asNullable
+
+  @transient private lazy val parsedPath = {
+    val pathValue = path.eval().toString
+    VariantPathParser.parse(pathValue).getOrElse {
+      throw QueryExecutionErrors.invalidVariantGetPath(pathValue, prettyName)
+    }
+  }
+
+  final override def nodePatternsInternal(): Seq[TreePattern] = Seq(VARIANT_GET)
+
+  override def inputTypes: Seq[AbstractDataType] = Seq(VariantType, StringType)
+
+  override def prettyName: String = if (failOnError) "variant_get" else "try_variant_get"
+
+  override def nullable: Boolean = true
+
+  protected override def nullSafeEval(input: Any, path: Any): Any = {

Review Comment:
   Oh I see. I think we can still use `StaticInvoke`, but the java function should include the cast code as well. Let me think more about how we can reuse the `Cast` expression to do casting.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


Re: [PR] [SPARK-47551][SQL] Add variant_get expression. [spark]

Posted by "chenhao-db (via GitHub)" <gi...@apache.org>.
chenhao-db commented on code in PR #45708:
URL: https://github.com/apache/spark/pull/45708#discussion_r1541767305


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala:
##########
@@ -63,3 +70,300 @@ case class ParseJson(child: Expression) extends UnaryExpression
   override protected def withNewChildInternal(newChild: Expression): ParseJson =
     copy(child = newChild)
 }
+
+// A path segment in the `VariantGet` expression. It represents either an object key access (when
+// `key` is not null) or an array index access (when `key` is null).
+case class PathSegment(key: String, index: Int)
+
+object VariantPathParser extends RegexParsers {
+  private def root: Parser[Char] = '$'
+
+  // Parse index segment like `[123]`.
+  private def index: Parser[PathSegment] =
+    for {
+      index <- '[' ~> "\\d+".r <~ ']'
+    } yield {
+      PathSegment(null, index.toInt)
+    }
+
+  // Parse key segment like `.name`, `['name']`, or `["name"]`.
+  private def key: Parser[PathSegment] =
+    for {
+      key <- '.' ~> "[^\\.\\[]+".r | "['" ~> "[^\\'\\?]+".r <~ "']" |
+        "[\"" ~> "[^\\\"\\?]+".r <~ "\"]"
+    } yield {
+      PathSegment(key, 0)
+    }
+
+  private val parser: Parser[List[PathSegment]] = phrase(root ~> rep(key | index))
+
+  def parse(str: String): Option[Array[PathSegment]] = {
+    this.parseAll(parser, str) match {
+      case Success(result, _) => Some(result.toArray)
+      case _ => None
+    }
+  }
+}
+
+/**
+ * The implementation for `variant_get` and `try_variant_get` expressions. Extracts a sub-variant
+ * value according to a path and cast it into a concrete data type.
+ * @param child The source variant value to extract from.
+ * @param path A literal path expression. It has the same format as the JSON path.
+ * @param schema The target data type to cast into.
+ * @param failOnError Controls whether the expression should throw an exception or return null if
+ *                    the cast fails.
+ * @param timeZoneId A string identifier of a time zone. It is required by timestamp-related casts.
+ */
+case class VariantGet(
+    child: Expression,
+    path: Expression,
+    schema: DataType,
+    failOnError: Boolean,
+    timeZoneId: Option[String] = None)
+    extends BinaryExpression
+    with TimeZoneAwareExpression
+    with NullIntolerant
+    with ExpectsInputTypes
+    with CodegenFallback
+    with QueryErrorsBase {
+  override def checkInputDataTypes(): TypeCheckResult = {
+    val check = super.checkInputDataTypes()
+    if (check.isFailure) {
+      check
+    } else if (!path.foldable) {
+      DataTypeMismatch(
+        errorSubClass = "NON_FOLDABLE_INPUT",
+        messageParameters = Map(
+          "inputName" -> toSQLId("path"),
+          "inputType" -> toSQLType(path.dataType),
+          "inputExpr" -> toSQLExpr(path)
+        )
+      )
+    } else if (!VariantGet.checkDataType(schema)) {
+      DataTypeMismatch(
+        errorSubClass = "CAST_WITHOUT_SUGGESTION",
+        messageParameters = Map(
+          "srcType" -> toSQLType(VariantType),
+          "targetType" -> toSQLType(schema)
+        )
+      )
+    } else {
+      TypeCheckResult.TypeCheckSuccess
+    }
+  }
+
+  override lazy val dataType: DataType = schema.asNullable
+
+  @transient private lazy val parsedPath = {
+    val pathValue = path.eval().toString
+    VariantPathParser.parse(pathValue).getOrElse {
+      throw QueryExecutionErrors.invalidVariantGetPath(pathValue, prettyName)
+    }
+  }
+
+  final override def nodePatternsInternal(): Seq[TreePattern] = Seq(VARIANT_GET)
+
+  override def inputTypes: Seq[AbstractDataType] = Seq(VariantType, StringType)
+
+  override def prettyName: String = if (failOnError) "variant_get" else "try_variant_get"
+
+  override def nullable: Boolean = true
+
+  protected override def nullSafeEval(input: Any, path: Any): Any = {

Review Comment:
   I have done some experiments with the `StaticInvoke` approach. Suppose I have encapsulated the `VariantGet` implementation into the following function:
   ```
   case object VariantGetCodegen {
     def variantGet(input: VariantVal, parsedPath: Array[PathSegment],
                    dataType: DataType, failOnError: Boolean, zoneId: Option[String]): Any = {...}
   }
   ```
   and make `VariantGet` a `RuntimeReplaceable` expression with a replacement of `StaticInvoke` that invokes `VariantGetCodegen.variantGet`. It still won't directly work because the codegen logic of `StaticInvoke` assumes the return type of the method directly matches the return type, but the return type of `VariantGetCodegen.variantGet` is `Any`.
   
   In order to make it work, I have to create a wrapper for each return type, like:
   
   ```
   case object VariantGetCodegen {
     def variantGetByte(input: VariantVal, parsedPath: Array[PathSegment],
                    dataType: DataType, failOnError: Boolean, zoneId: Option[String]): Byte =
       variantGet(input, parsedPath, dataType, failOnError, zoneId).asInstanceOf[Byte]
     def variantGetShort(input: VariantVal, parsedPath: Array[PathSegment],
                    dataType: DataType, failOnError: Boolean, zoneId: Option[String]): Short =
       variantGet(input, parsedPath, dataType, failOnError, zoneId).asInstanceOf[Short]
     def variantGetStruct(input: VariantVal, parsedPath: Array[PathSegment],
                    dataType: DataType, failOnError: Boolean, zoneId: Option[String]): InternalRow =
       variantGet(input, parsedPath, dataType, failOnError, zoneId).asInstanceOf[InternalRow]
     ...
   }
   ```
   
   and pick one method according to the return type. It is very cumbersome and doesn't really avoid any boxing/unboxing costs.
   
   On the other hand, if we have this `VariantGetCodegen.variantGet` method, it is reasonably easy to write the codegen by hand. I just need to cast the return value of this method into the desired type. The whole point of using `StaticInvoke` is to simplify the implementation, but I think it actually makes the implementation much more complex.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


Re: [PR] [SPARK-47551][SQL] Add variant_get expression. [spark]

Posted by "chenhao-db (via GitHub)" <gi...@apache.org>.
chenhao-db commented on code in PR #45708:
URL: https://github.com/apache/spark/pull/45708#discussion_r1544573846


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala:
##########
@@ -53,3 +66,311 @@ case class ParseJson(child: Expression)
   override protected def withNewChildInternal(newChild: Expression): ParseJson =
     copy(child = newChild)
 }
+
+object VariantPathParser extends RegexParsers {
+  // A path segment in the `VariantGet` expression represents either an object key access or an
+  // array index access.
+  type PathSegment = Either[String, Int]
+
+  private def root: Parser[Char] = '$'
+
+  // Parse index segment like `[123]`.
+  private def index: Parser[PathSegment] =
+    for {
+      index <- '[' ~> "\\d+".r <~ ']'
+    } yield {
+      scala.util.Right(index.toInt)
+    }
+
+  // Parse key segment like `.name`, `['name']`, or `["name"]`.
+  private def key: Parser[PathSegment] =
+    for {
+      key <- '.' ~> "[^\\.\\[]+".r | "['" ~> "[^\\'\\?]+".r <~ "']" |
+        "[\"" ~> "[^\\\"\\?]+".r <~ "\"]"
+    } yield {
+      scala.util.Left(key)
+    }
+
+  private val parser: Parser[List[PathSegment]] = phrase(root ~> rep(key | index))
+
+  def parse(str: String): Option[Array[PathSegment]] = {
+    this.parseAll(parser, str) match {
+      case Success(result, _) => Some(result.toArray)
+      case _ => None
+    }
+  }
+}
+
+/**
+ * The implementation for `variant_get` and `try_variant_get` expressions. Extracts a sub-variant
+ * value according to a path and cast it into a concrete data type.
+ * @param child The source variant value to extract from.
+ * @param path A literal path expression. It has the same format as the JSON path.
+ * @param targetType The target data type to cast into. Any non-nullable annotations are ignored.
+ * @param failOnError Controls whether the expression should throw an exception or return null if
+ *                    the cast fails.
+ * @param timeZoneId A string identifier of a time zone. It is required by timestamp-related casts.
+ */
+case class VariantGet(
+    child: Expression,
+    path: Expression,
+    targetType: DataType,
+    failOnError: Boolean,
+    timeZoneId: Option[String] = None)
+    extends BinaryExpression
+    with TimeZoneAwareExpression
+    with NullIntolerant
+    with ExpectsInputTypes
+    with CodegenFallback
+    with QueryErrorsBase {
+  override def checkInputDataTypes(): TypeCheckResult = {
+    val check = super.checkInputDataTypes()
+    if (check.isFailure) {
+      check
+    } else if (!path.foldable) {
+      DataTypeMismatch(
+        errorSubClass = "NON_FOLDABLE_INPUT",
+        messageParameters = Map(
+          "inputName" -> toSQLId("path"),
+          "inputType" -> toSQLType(path.dataType),
+          "inputExpr" -> toSQLExpr(path)))
+    } else if (!VariantGet.checkDataType(targetType)) {
+      DataTypeMismatch(
+        errorSubClass = "CAST_WITHOUT_SUGGESTION",
+        messageParameters =
+          Map("srcType" -> toSQLType(VariantType), "targetType" -> toSQLType(targetType)))
+    } else {
+      TypeCheckResult.TypeCheckSuccess
+    }
+  }
+
+  override lazy val dataType: DataType = targetType.asNullable
+
+  @transient private lazy val parsedPath = {
+    val pathValue = path.eval().toString
+    VariantPathParser.parse(pathValue).getOrElse {
+      throw QueryExecutionErrors.invalidVariantGetPath(pathValue, prettyName)
+    }
+  }
+
+  final override def nodePatternsInternal(): Seq[TreePattern] = Seq(VARIANT_GET)
+
+  override def inputTypes: Seq[AbstractDataType] = Seq(VariantType, StringType)
+
+  override def prettyName: String = if (failOnError) "variant_get" else "try_variant_get"
+
+  override def nullable: Boolean = true
+
+  protected override def nullSafeEval(input: Any, path: Any): Any = {
+    VariantGet.variantGet(
+      input.asInstanceOf[VariantVal],
+      parsedPath,
+      dataType,
+      failOnError,
+      timeZoneId)
+  }
+
+  override def left: Expression = child
+
+  override def right: Expression = path
+
+  override protected def withNewChildrenInternal(
+      newChild: Expression,
+      newPath: Expression): VariantGet = copy(child = newChild, path = newPath)
+
+  override def withTimeZone(timeZoneId: String): VariantGet = copy(timeZoneId = Option(timeZoneId))
+}
+
+case object VariantGet {
+  /**
+   * Returns whether a data type can be cast into/from variant. For scalar types, we allow a subset
+   * of them. For nested types, we reject map types with a non-string key type.
+   */
+  def checkDataType(dataType: DataType): Boolean = dataType match {
+    case _: NumericType | BooleanType | StringType | BinaryType | TimestampType | DateType |
+        VariantType =>
+      true
+    case ArrayType(elementType, _) => checkDataType(elementType)
+    case MapType(StringType, valueType, _) => checkDataType(valueType)
+    case StructType(fields) => fields.forall(f => checkDataType(f.dataType))
+    case _ => false
+  }
+
+  /** The actual implementation of the `VariantGet` expression. */
+  def variantGet(
+      input: VariantVal,
+      parsedPath: Array[VariantPathParser.PathSegment],
+      dataType: DataType,
+      failOnError: Boolean,
+      zoneId: Option[String]): Any = {
+    var v = new Variant(input.getValue, input.getMetadata)
+    for (path <- parsedPath) {
+      v = path match {
+        case scala.util.Left(key) if v.getType == Type.OBJECT => v.getFieldByKey(key)
+        case scala.util.Right(index) if v.getType == Type.ARRAY => v.getElementAtIndex(index)
+        case _ => null
+      }
+      if (v == null) return null
+    }
+    VariantGet.cast(v, dataType, failOnError, zoneId)
+  }
+
+  /**
+   * Cast a variant `v` into a target data type `dataType`. If the variant represents a variant
+   * null, the result is always a SQL NULL. The cast may fail due to an illegal type combination
+   * (e.g., cast a variant int to binary), or an invalid input valid (e.g, cast a variant string
+   * "hello" to int). If the cast fails, throw an exception when `failOnError` is true, or return a
+   * SQL NULL when it is false.
+   */
+  def cast(v: Variant, dataType: DataType, failOnError: Boolean, zoneId: Option[String]): Any = {
+    def invalidCast(): Any =
+      if (failOnError) throw QueryExecutionErrors.invalidVariantCast(v.toJson, dataType) else null
+
+    val variantType = v.getType
+    if (variantType == Type.NULL) return null
+    dataType match {
+      case VariantType => new VariantVal(v.getValue, v.getMetadata)
+      case _: AtomicType =>
+        variantType match {
+          case Type.OBJECT | Type.ARRAY =>
+            if (dataType == StringType) UTF8String.fromString(v.toJson) else invalidCast()
+          case _ =>
+            val input = variantType match {
+              case Type.BOOLEAN => v.getBoolean
+              case Type.LONG => v.getLong
+              case Type.STRING => UTF8String.fromString(v.getString)
+              case Type.DOUBLE => v.getDouble
+              case Type.DECIMAL => Decimal(v.getDecimal)
+              // We have handled other cases and should never reach here. This case is only intended
+              // to by pass the compiler exhaustiveness check.
+              case _ => throw QueryExecutionErrors.unreachableError()
+            }
+            // We mostly use the `Cast` expression to implement the cast. However, `Cast` silently
+            // ignores the overflow in the long/decimal -> timestamp cast, and we want to enforce
+            // strict overflow checks.
+            input match {
+              case l: Long if dataType == TimestampType =>
+                try Math.multiplyExact(l, MICROS_PER_SECOND)
+                catch {
+                  case _: ArithmeticException => invalidCast()
+                }
+              case d: Decimal if dataType == TimestampType =>
+                try {
+                  d.toJavaBigDecimal
+                    .multiply(new java.math.BigDecimal(MICROS_PER_SECOND))
+                    .toBigInteger
+                    .longValueExact()
+                } catch {
+                  case _: ArithmeticException => invalidCast()
+                }
+              case _ =>
+                val result = Cast(Literal(input), dataType, zoneId, EvalMode.TRY).eval()

Review Comment:
   I cannot find an efficient way to do so because of decimal types. Each decimal type with a distinct precision and scale is different to the `Cast` expression. There are ~740 decimal types, so a type mapping should contain at least  `740 * 740 = 5.5e5` entries. And it needs to be a per-expression mapping, because `VarintGet` expressions can have different `timeZoneId`s.
   
   I feel that the type check is robust enough. We have validated `dataType` can only be a subset of atomic types. It is also easy to see `Literal(input)` can only be a a subset of atomic types, so the cast should be legal.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


Re: [PR] [SPARK-47551][SQL] Add variant_get expression. [spark]

Posted by "cloud-fan (via GitHub)" <gi...@apache.org>.
cloud-fan commented on code in PR #45708:
URL: https://github.com/apache/spark/pull/45708#discussion_r1538615582


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala:
##########
@@ -63,3 +70,300 @@ case class ParseJson(child: Expression) extends UnaryExpression
   override protected def withNewChildInternal(newChild: Expression): ParseJson =
     copy(child = newChild)
 }
+
+// A path segment in the `VariantGet` expression. It represents either an object key access (when
+// `key` is not null) or an array index access (when `key` is null).
+case class PathSegment(key: String, index: Int)
+
+object VariantPathParser extends RegexParsers {
+  private def root: Parser[Char] = '$'
+
+  // Parse index segment like `[123]`.
+  private def index: Parser[PathSegment] =
+    for {
+      index <- '[' ~> "\\d+".r <~ ']'
+    } yield {
+      PathSegment(null, index.toInt)
+    }
+
+  // Parse key segment like `.name`, `['name']`, or `["name"]`.
+  private def key: Parser[PathSegment] =
+    for {
+      key <- '.' ~> "[^\\.\\[]+".r | "['" ~> "[^\\'\\?]+".r <~ "']" |
+        "[\"" ~> "[^\\\"\\?]+".r <~ "\"]"
+    } yield {
+      PathSegment(key, 0)
+    }
+
+  private val parser: Parser[List[PathSegment]] = phrase(root ~> rep(key | index))
+
+  def parse(str: String): Option[Array[PathSegment]] = {
+    this.parseAll(parser, str) match {
+      case Success(result, _) => Some(result.toArray)
+      case _ => None
+    }
+  }
+}
+
+/**
+ * The implementation for `variant_get` and `try_variant_get` expressions. Extracts a sub-variant
+ * value according to a path and cast it into a concrete data type.
+ * @param child The source variant value to extract from.
+ * @param path A literal path expression. It has the same format as the JSON path.
+ * @param schema The target data type to cast into.
+ * @param failOnError Controls whether the expression should throw an exception or return null if
+ *                    the cast fails.
+ * @param timeZoneId A string identifier of a time zone. It is required by timestamp-related casts.
+ */
+case class VariantGet(
+    child: Expression,
+    path: Expression,
+    schema: DataType,

Review Comment:
   ```suggestion
       targetType: DataType,
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


Re: [PR] [SPARK-47551][SQL] Add variant_get expression. [spark]

Posted by "cloud-fan (via GitHub)" <gi...@apache.org>.
cloud-fan commented on code in PR #45708:
URL: https://github.com/apache/spark/pull/45708#discussion_r1542993030


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala:
##########
@@ -63,3 +70,300 @@ case class ParseJson(child: Expression) extends UnaryExpression
   override protected def withNewChildInternal(newChild: Expression): ParseJson =
     copy(child = newChild)
 }
+
+// A path segment in the `VariantGet` expression. It represents either an object key access (when
+// `key` is not null) or an array index access (when `key` is null).
+case class PathSegment(key: String, index: Int)
+
+object VariantPathParser extends RegexParsers {
+  private def root: Parser[Char] = '$'
+
+  // Parse index segment like `[123]`.
+  private def index: Parser[PathSegment] =
+    for {
+      index <- '[' ~> "\\d+".r <~ ']'
+    } yield {
+      PathSegment(null, index.toInt)
+    }
+
+  // Parse key segment like `.name`, `['name']`, or `["name"]`.
+  private def key: Parser[PathSegment] =
+    for {
+      key <- '.' ~> "[^\\.\\[]+".r | "['" ~> "[^\\'\\?]+".r <~ "']" |
+        "[\"" ~> "[^\\\"\\?]+".r <~ "\"]"
+    } yield {
+      PathSegment(key, 0)
+    }
+
+  private val parser: Parser[List[PathSegment]] = phrase(root ~> rep(key | index))
+
+  def parse(str: String): Option[Array[PathSegment]] = {
+    this.parseAll(parser, str) match {
+      case Success(result, _) => Some(result.toArray)
+      case _ => None
+    }
+  }
+}
+
+/**
+ * The implementation for `variant_get` and `try_variant_get` expressions. Extracts a sub-variant
+ * value according to a path and cast it into a concrete data type.
+ * @param child The source variant value to extract from.
+ * @param path A literal path expression. It has the same format as the JSON path.
+ * @param schema The target data type to cast into.
+ * @param failOnError Controls whether the expression should throw an exception or return null if
+ *                    the cast fails.
+ * @param timeZoneId A string identifier of a time zone. It is required by timestamp-related casts.
+ */
+case class VariantGet(
+    child: Expression,
+    path: Expression,
+    schema: DataType,
+    failOnError: Boolean,
+    timeZoneId: Option[String] = None)
+    extends BinaryExpression
+    with TimeZoneAwareExpression
+    with NullIntolerant
+    with ExpectsInputTypes
+    with CodegenFallback
+    with QueryErrorsBase {
+  override def checkInputDataTypes(): TypeCheckResult = {
+    val check = super.checkInputDataTypes()
+    if (check.isFailure) {
+      check
+    } else if (!path.foldable) {
+      DataTypeMismatch(
+        errorSubClass = "NON_FOLDABLE_INPUT",
+        messageParameters = Map(
+          "inputName" -> toSQLId("path"),
+          "inputType" -> toSQLType(path.dataType),
+          "inputExpr" -> toSQLExpr(path)
+        )
+      )
+    } else if (!VariantGet.checkDataType(schema)) {
+      DataTypeMismatch(
+        errorSubClass = "CAST_WITHOUT_SUGGESTION",
+        messageParameters = Map(
+          "srcType" -> toSQLType(VariantType),
+          "targetType" -> toSQLType(schema)
+        )
+      )
+    } else {
+      TypeCheckResult.TypeCheckSuccess
+    }
+  }
+
+  override lazy val dataType: DataType = schema.asNullable
+
+  @transient private lazy val parsedPath = {
+    val pathValue = path.eval().toString
+    VariantPathParser.parse(pathValue).getOrElse {
+      throw QueryExecutionErrors.invalidVariantGetPath(pathValue, prettyName)
+    }
+  }
+
+  final override def nodePatternsInternal(): Seq[TreePattern] = Seq(VARIANT_GET)
+
+  override def inputTypes: Seq[AbstractDataType] = Seq(VariantType, StringType)
+
+  override def prettyName: String = if (failOnError) "variant_get" else "try_variant_get"
+
+  override def nullable: Boolean = true
+
+  protected override def nullSafeEval(input: Any, path: Any): Any = {

Review Comment:
   oh dynamic return type is indeed an issue. Your proposal LGTM: we can still put a large chunk of implementation code in a static method and then call it in both the interpreted and codegen versions.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


Re: [PR] [SPARK-47551][SQL] Add variant_get expression. [spark]

Posted by "chenhao-db (via GitHub)" <gi...@apache.org>.
chenhao-db commented on PR #45708:
URL: https://github.com/apache/spark/pull/45708#issuecomment-2025604435

   @cloud-fan I have resolved conflicts. Please take another look, thanks!


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


Re: [PR] [SPARK-47551][SQL] Add variant_get expression. [spark]

Posted by "chenhao-db (via GitHub)" <gi...@apache.org>.
chenhao-db commented on code in PR #45708:
URL: https://github.com/apache/spark/pull/45708#discussion_r1540169393


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala:
##########
@@ -63,3 +70,300 @@ case class ParseJson(child: Expression) extends UnaryExpression
   override protected def withNewChildInternal(newChild: Expression): ParseJson =
     copy(child = newChild)
 }
+
+// A path segment in the `VariantGet` expression. It represents either an object key access (when
+// `key` is not null) or an array index access (when `key` is null).
+case class PathSegment(key: String, index: Int)
+
+object VariantPathParser extends RegexParsers {
+  private def root: Parser[Char] = '$'
+
+  // Parse index segment like `[123]`.
+  private def index: Parser[PathSegment] =
+    for {
+      index <- '[' ~> "\\d+".r <~ ']'
+    } yield {
+      PathSegment(null, index.toInt)
+    }
+
+  // Parse key segment like `.name`, `['name']`, or `["name"]`.
+  private def key: Parser[PathSegment] =
+    for {
+      key <- '.' ~> "[^\\.\\[]+".r | "['" ~> "[^\\'\\?]+".r <~ "']" |
+        "[\"" ~> "[^\\\"\\?]+".r <~ "\"]"
+    } yield {
+      PathSegment(key, 0)
+    }
+
+  private val parser: Parser[List[PathSegment]] = phrase(root ~> rep(key | index))
+
+  def parse(str: String): Option[Array[PathSegment]] = {
+    this.parseAll(parser, str) match {
+      case Success(result, _) => Some(result.toArray)
+      case _ => None
+    }
+  }
+}
+
+/**
+ * The implementation for `variant_get` and `try_variant_get` expressions. Extracts a sub-variant
+ * value according to a path and cast it into a concrete data type.
+ * @param child The source variant value to extract from.
+ * @param path A literal path expression. It has the same format as the JSON path.
+ * @param schema The target data type to cast into.
+ * @param failOnError Controls whether the expression should throw an exception or return null if
+ *                    the cast fails.
+ * @param timeZoneId A string identifier of a time zone. It is required by timestamp-related casts.
+ */
+case class VariantGet(
+    child: Expression,
+    path: Expression,
+    schema: DataType,
+    failOnError: Boolean,
+    timeZoneId: Option[String] = None)
+    extends BinaryExpression
+    with TimeZoneAwareExpression
+    with NullIntolerant
+    with ExpectsInputTypes
+    with CodegenFallback
+    with QueryErrorsBase {
+  override def checkInputDataTypes(): TypeCheckResult = {
+    val check = super.checkInputDataTypes()
+    if (check.isFailure) {
+      check
+    } else if (!path.foldable) {
+      DataTypeMismatch(
+        errorSubClass = "NON_FOLDABLE_INPUT",
+        messageParameters = Map(
+          "inputName" -> toSQLId("path"),
+          "inputType" -> toSQLType(path.dataType),
+          "inputExpr" -> toSQLExpr(path)
+        )
+      )
+    } else if (!VariantGet.checkDataType(schema)) {
+      DataTypeMismatch(
+        errorSubClass = "CAST_WITHOUT_SUGGESTION",
+        messageParameters = Map(
+          "srcType" -> toSQLType(VariantType),
+          "targetType" -> toSQLType(schema)
+        )
+      )
+    } else {
+      TypeCheckResult.TypeCheckSuccess
+    }
+  }
+
+  override lazy val dataType: DataType = schema.asNullable
+
+  @transient private lazy val parsedPath = {
+    val pathValue = path.eval().toString
+    VariantPathParser.parse(pathValue).getOrElse {
+      throw QueryExecutionErrors.invalidVariantGetPath(pathValue, prettyName)
+    }
+  }
+
+  final override def nodePatternsInternal(): Seq[TreePattern] = Seq(VARIANT_GET)
+
+  override def inputTypes: Seq[AbstractDataType] = Seq(VariantType, StringType)
+
+  override def prettyName: String = if (failOnError) "variant_get" else "try_variant_get"
+
+  override def nullable: Boolean = true
+
+  protected override def nullSafeEval(input: Any, path: Any): Any = {

Review Comment:
   There is an issue with this approach: we cannot know the return type of `StaticInvoke(...)` in advance. It is possible that there are two variant rows with different content types and they can both be cast into the same type. E.g., both `variant_get(parse_json('1'), '$', 'int')` and `variant_get(parse_json('"1"'), '$', 'int')` should return `1`. If the `StaticInvoke(...)` must have a return type, it can only be the variant type. We still need the current `VariantGet.cast` function as the internal implementation of `Cast(Variant, targetType)`, so it doesn't really simplify anything.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


Re: [PR] [SPARK-47551][SQL] Add variant_get expression. [spark]

Posted by "chenhao-db (via GitHub)" <gi...@apache.org>.
chenhao-db commented on code in PR #45708:
URL: https://github.com/apache/spark/pull/45708#discussion_r1546630168


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala:
##########
@@ -53,3 +66,320 @@ case class ParseJson(child: Expression)
   override protected def withNewChildInternal(newChild: Expression): ParseJson =
     copy(child = newChild)
 }
+
+object VariantPathParser extends RegexParsers {
+  // A path segment in the `VariantGet` expression represents either an object key access or an
+  // array index access.
+  type PathSegment = Either[String, Int]
+
+  private def root: Parser[Char] = '$'
+
+  // Parse index segment like `[123]`.
+  private def index: Parser[PathSegment] =
+    for {
+      index <- '[' ~> "\\d+".r <~ ']'
+    } yield {
+      scala.util.Right(index.toInt)
+    }
+
+  // Parse key segment like `.name`, `['name']`, or `["name"]`.
+  private def key: Parser[PathSegment] =
+    for {
+      key <- '.' ~> "[^\\.\\[]+".r | "['" ~> "[^\\'\\?]+".r <~ "']" |
+        "[\"" ~> "[^\\\"\\?]+".r <~ "\"]"
+    } yield {
+      scala.util.Left(key)
+    }
+
+  private val parser: Parser[List[PathSegment]] = phrase(root ~> rep(key | index))
+
+  def parse(str: String): Option[Array[PathSegment]] = {
+    this.parseAll(parser, str) match {
+      case Success(result, _) => Some(result.toArray)
+      case _ => None
+    }
+  }
+}
+
+/**
+ * The implementation for `variant_get` and `try_variant_get` expressions. Extracts a sub-variant
+ * value according to a path and cast it into a concrete data type.
+ * @param child The source variant value to extract from.
+ * @param path A literal path expression. It has the same format as the JSON path.
+ * @param targetType The target data type to cast into. Any non-nullable annotations are ignored.
+ * @param failOnError Controls whether the expression should throw an exception or return null if
+ *                    the cast fails.
+ * @param timeZoneId A string identifier of a time zone. It is required by timestamp-related casts.
+ */
+case class VariantGet(
+    child: Expression,
+    path: Expression,
+    targetType: DataType,
+    failOnError: Boolean,
+    timeZoneId: Option[String] = None)
+    extends BinaryExpression
+    with TimeZoneAwareExpression
+    with NullIntolerant
+    with ExpectsInputTypes
+    with CodegenFallback
+    with QueryErrorsBase {
+  override def checkInputDataTypes(): TypeCheckResult = {
+    val check = super.checkInputDataTypes()
+    if (check.isFailure) {
+      check
+    } else if (!path.foldable) {
+      DataTypeMismatch(
+        errorSubClass = "NON_FOLDABLE_INPUT",
+        messageParameters = Map(
+          "inputName" -> toSQLId("path"),
+          "inputType" -> toSQLType(path.dataType),
+          "inputExpr" -> toSQLExpr(path)))
+    } else if (!VariantGet.checkDataType(targetType)) {
+      DataTypeMismatch(
+        errorSubClass = "CAST_WITHOUT_SUGGESTION",
+        messageParameters =
+          Map("srcType" -> toSQLType(VariantType), "targetType" -> toSQLType(targetType)))
+    } else {
+      TypeCheckResult.TypeCheckSuccess
+    }
+  }
+
+  override lazy val dataType: DataType = targetType.asNullable
+
+  @transient private lazy val parsedPath = {
+    val pathValue = path.eval().toString
+    VariantPathParser.parse(pathValue).getOrElse {
+      throw QueryExecutionErrors.invalidVariantGetPath(pathValue, prettyName)
+    }
+  }
+
+  final override def nodePatternsInternal(): Seq[TreePattern] = Seq(VARIANT_GET)
+
+  override def inputTypes: Seq[AbstractDataType] = Seq(VariantType, StringType)
+
+  override def prettyName: String = if (failOnError) "variant_get" else "try_variant_get"
+
+  override def nullable: Boolean = true
+
+  protected override def nullSafeEval(input: Any, path: Any): Any = {
+    VariantGet.variantGet(
+      input.asInstanceOf[VariantVal],
+      parsedPath,
+      dataType,
+      failOnError,
+      timeZoneId)
+  }
+
+  override def left: Expression = child
+
+  override def right: Expression = path
+
+  override protected def withNewChildrenInternal(
+      newChild: Expression,
+      newPath: Expression): VariantGet = copy(child = newChild, path = newPath)
+
+  override def withTimeZone(timeZoneId: String): VariantGet = copy(timeZoneId = Option(timeZoneId))
+}
+
+case object VariantGet {
+  /**
+   * Returns whether a data type can be cast into/from variant. For scalar types, we allow a subset
+   * of them. For nested types, we reject map types with a non-string key type.
+   */
+  def checkDataType(dataType: DataType): Boolean = dataType match {
+    case _: NumericType | BooleanType | _: StringType | BinaryType | TimestampType | DateType |
+        VariantType =>
+      true
+    case ArrayType(elementType, _) => checkDataType(elementType)
+    case MapType(StringType, valueType, _) => checkDataType(valueType)
+    case StructType(fields) => fields.forall(f => checkDataType(f.dataType))
+    case _ => false
+  }
+
+  /** The actual implementation of the `VariantGet` expression. */
+  def variantGet(
+      input: VariantVal,
+      parsedPath: Array[VariantPathParser.PathSegment],
+      dataType: DataType,
+      failOnError: Boolean,
+      zoneId: Option[String]): Any = {
+    var v = new Variant(input.getValue, input.getMetadata)
+    for (path <- parsedPath) {
+      v = path match {
+        case scala.util.Left(key) if v.getType == Type.OBJECT => v.getFieldByKey(key)
+        case scala.util.Right(index) if v.getType == Type.ARRAY => v.getElementAtIndex(index)
+        case _ => null

Review Comment:
   It is not really exhausted due to `if v.getType == Type.OBJECT/ARRAY`. For example, if the variant type is a `LONG`, we should return null. Or if the variant type is `ARRAY` and the `path` is an object field access, we should also return null.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


Re: [PR] [SPARK-47551][SQL] Add variant_get expression. [spark]

Posted by "cloud-fan (via GitHub)" <gi...@apache.org>.
cloud-fan commented on code in PR #45708:
URL: https://github.com/apache/spark/pull/45708#discussion_r1546054662


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala:
##########
@@ -53,3 +66,320 @@ case class ParseJson(child: Expression)
   override protected def withNewChildInternal(newChild: Expression): ParseJson =
     copy(child = newChild)
 }
+
+object VariantPathParser extends RegexParsers {
+  // A path segment in the `VariantGet` expression represents either an object key access or an
+  // array index access.
+  type PathSegment = Either[String, Int]
+
+  private def root: Parser[Char] = '$'
+
+  // Parse index segment like `[123]`.
+  private def index: Parser[PathSegment] =
+    for {
+      index <- '[' ~> "\\d+".r <~ ']'
+    } yield {
+      scala.util.Right(index.toInt)
+    }
+
+  // Parse key segment like `.name`, `['name']`, or `["name"]`.
+  private def key: Parser[PathSegment] =
+    for {
+      key <- '.' ~> "[^\\.\\[]+".r | "['" ~> "[^\\'\\?]+".r <~ "']" |
+        "[\"" ~> "[^\\\"\\?]+".r <~ "\"]"
+    } yield {
+      scala.util.Left(key)
+    }
+
+  private val parser: Parser[List[PathSegment]] = phrase(root ~> rep(key | index))
+
+  def parse(str: String): Option[Array[PathSegment]] = {
+    this.parseAll(parser, str) match {
+      case Success(result, _) => Some(result.toArray)
+      case _ => None
+    }
+  }
+}
+
+/**
+ * The implementation for `variant_get` and `try_variant_get` expressions. Extracts a sub-variant
+ * value according to a path and cast it into a concrete data type.
+ * @param child The source variant value to extract from.
+ * @param path A literal path expression. It has the same format as the JSON path.
+ * @param targetType The target data type to cast into. Any non-nullable annotations are ignored.
+ * @param failOnError Controls whether the expression should throw an exception or return null if
+ *                    the cast fails.
+ * @param timeZoneId A string identifier of a time zone. It is required by timestamp-related casts.
+ */
+case class VariantGet(
+    child: Expression,
+    path: Expression,
+    targetType: DataType,
+    failOnError: Boolean,
+    timeZoneId: Option[String] = None)
+    extends BinaryExpression
+    with TimeZoneAwareExpression
+    with NullIntolerant
+    with ExpectsInputTypes
+    with CodegenFallback
+    with QueryErrorsBase {
+  override def checkInputDataTypes(): TypeCheckResult = {
+    val check = super.checkInputDataTypes()
+    if (check.isFailure) {
+      check
+    } else if (!path.foldable) {
+      DataTypeMismatch(
+        errorSubClass = "NON_FOLDABLE_INPUT",
+        messageParameters = Map(
+          "inputName" -> toSQLId("path"),
+          "inputType" -> toSQLType(path.dataType),
+          "inputExpr" -> toSQLExpr(path)))
+    } else if (!VariantGet.checkDataType(targetType)) {
+      DataTypeMismatch(
+        errorSubClass = "CAST_WITHOUT_SUGGESTION",
+        messageParameters =
+          Map("srcType" -> toSQLType(VariantType), "targetType" -> toSQLType(targetType)))
+    } else {
+      TypeCheckResult.TypeCheckSuccess
+    }
+  }
+
+  override lazy val dataType: DataType = targetType.asNullable
+
+  @transient private lazy val parsedPath = {
+    val pathValue = path.eval().toString
+    VariantPathParser.parse(pathValue).getOrElse {
+      throw QueryExecutionErrors.invalidVariantGetPath(pathValue, prettyName)
+    }
+  }
+
+  final override def nodePatternsInternal(): Seq[TreePattern] = Seq(VARIANT_GET)
+
+  override def inputTypes: Seq[AbstractDataType] = Seq(VariantType, StringType)
+
+  override def prettyName: String = if (failOnError) "variant_get" else "try_variant_get"
+
+  override def nullable: Boolean = true
+
+  protected override def nullSafeEval(input: Any, path: Any): Any = {
+    VariantGet.variantGet(
+      input.asInstanceOf[VariantVal],
+      parsedPath,
+      dataType,
+      failOnError,
+      timeZoneId)
+  }
+
+  override def left: Expression = child
+
+  override def right: Expression = path
+
+  override protected def withNewChildrenInternal(
+      newChild: Expression,
+      newPath: Expression): VariantGet = copy(child = newChild, path = newPath)
+
+  override def withTimeZone(timeZoneId: String): VariantGet = copy(timeZoneId = Option(timeZoneId))
+}
+
+case object VariantGet {
+  /**
+   * Returns whether a data type can be cast into/from variant. For scalar types, we allow a subset
+   * of them. For nested types, we reject map types with a non-string key type.
+   */
+  def checkDataType(dataType: DataType): Boolean = dataType match {
+    case _: NumericType | BooleanType | _: StringType | BinaryType | TimestampType | DateType |
+        VariantType =>
+      true
+    case ArrayType(elementType, _) => checkDataType(elementType)
+    case MapType(StringType, valueType, _) => checkDataType(valueType)
+    case StructType(fields) => fields.forall(f => checkDataType(f.dataType))
+    case _ => false
+  }
+
+  /** The actual implementation of the `VariantGet` expression. */
+  def variantGet(
+      input: VariantVal,
+      parsedPath: Array[VariantPathParser.PathSegment],
+      dataType: DataType,
+      failOnError: Boolean,
+      zoneId: Option[String]): Any = {
+    var v = new Variant(input.getValue, input.getMetadata)
+    for (path <- parsedPath) {
+      v = path match {
+        case scala.util.Left(key) if v.getType == Type.OBJECT => v.getFieldByKey(key)
+        case scala.util.Right(index) if v.getType == Type.ARRAY => v.getElementAtIndex(index)
+        case _ => null
+      }
+      if (v == null) return null
+    }
+    VariantGet.cast(v, dataType, failOnError, zoneId)
+  }
+
+  /**
+   * Cast a variant `v` into a target data type `dataType`. If the variant represents a variant
+   * null, the result is always a SQL NULL. The cast may fail due to an illegal type combination
+   * (e.g., cast a variant int to binary), or an invalid input valid (e.g, cast a variant string
+   * "hello" to int). If the cast fails, throw an exception when `failOnError` is true, or return a
+   * SQL NULL when it is false.
+   */
+  def cast(v: Variant, dataType: DataType, failOnError: Boolean, zoneId: Option[String]): Any = {
+    def invalidCast(): Any =
+      if (failOnError) throw QueryExecutionErrors.invalidVariantCast(v.toJson, dataType) else null
+
+    val variantType = v.getType
+    if (variantType == Type.NULL) return null
+    dataType match {
+      case VariantType => new VariantVal(v.getValue, v.getMetadata)
+      case _: AtomicType =>
+        variantType match {
+          case Type.OBJECT | Type.ARRAY =>
+            if (dataType.isInstanceOf[StringType]) {
+              UTF8String.fromString(v.toJson)
+            } else {
+              invalidCast()
+            }
+          case _ =>
+            val input = variantType match {
+              case Type.BOOLEAN => v.getBoolean
+              case Type.LONG => v.getLong
+              case Type.STRING => UTF8String.fromString(v.getString)
+              case Type.DOUBLE => v.getDouble
+              case Type.DECIMAL => Decimal(v.getDecimal)
+              // We have handled other cases and should never reach here. This case is only intended
+              // to by pass the compiler exhaustiveness check.
+              case _ => throw QueryExecutionErrors.unreachableError()
+            }
+            // We mostly use the `Cast` expression to implement the cast. However, `Cast` silently
+            // ignores the overflow in the long/decimal -> timestamp cast, and we want to enforce

Review Comment:
   This looks like a bug in ansi mode `Cast`. cc @srielau 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


Re: [PR] [SPARK-47551][SQL] Add variant_get expression. [spark]

Posted by "cloud-fan (via GitHub)" <gi...@apache.org>.
cloud-fan commented on code in PR #45708:
URL: https://github.com/apache/spark/pull/45708#discussion_r1546049712


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala:
##########
@@ -53,3 +66,320 @@ case class ParseJson(child: Expression)
   override protected def withNewChildInternal(newChild: Expression): ParseJson =
     copy(child = newChild)
 }
+
+object VariantPathParser extends RegexParsers {
+  // A path segment in the `VariantGet` expression represents either an object key access or an
+  // array index access.
+  type PathSegment = Either[String, Int]
+
+  private def root: Parser[Char] = '$'
+
+  // Parse index segment like `[123]`.
+  private def index: Parser[PathSegment] =
+    for {
+      index <- '[' ~> "\\d+".r <~ ']'
+    } yield {
+      scala.util.Right(index.toInt)
+    }
+
+  // Parse key segment like `.name`, `['name']`, or `["name"]`.
+  private def key: Parser[PathSegment] =
+    for {
+      key <- '.' ~> "[^\\.\\[]+".r | "['" ~> "[^\\'\\?]+".r <~ "']" |
+        "[\"" ~> "[^\\\"\\?]+".r <~ "\"]"
+    } yield {
+      scala.util.Left(key)
+    }
+
+  private val parser: Parser[List[PathSegment]] = phrase(root ~> rep(key | index))
+
+  def parse(str: String): Option[Array[PathSegment]] = {
+    this.parseAll(parser, str) match {
+      case Success(result, _) => Some(result.toArray)
+      case _ => None
+    }
+  }
+}
+
+/**
+ * The implementation for `variant_get` and `try_variant_get` expressions. Extracts a sub-variant
+ * value according to a path and cast it into a concrete data type.
+ * @param child The source variant value to extract from.
+ * @param path A literal path expression. It has the same format as the JSON path.
+ * @param targetType The target data type to cast into. Any non-nullable annotations are ignored.
+ * @param failOnError Controls whether the expression should throw an exception or return null if
+ *                    the cast fails.
+ * @param timeZoneId A string identifier of a time zone. It is required by timestamp-related casts.
+ */
+case class VariantGet(
+    child: Expression,
+    path: Expression,
+    targetType: DataType,
+    failOnError: Boolean,
+    timeZoneId: Option[String] = None)
+    extends BinaryExpression
+    with TimeZoneAwareExpression
+    with NullIntolerant
+    with ExpectsInputTypes
+    with CodegenFallback

Review Comment:
   it's very easy to add codegen now. Can we do it in this PR?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


Re: [PR] [SPARK-47551][SQL] Add variant_get expression. [spark]

Posted by "chenhao-db (via GitHub)" <gi...@apache.org>.
chenhao-db commented on code in PR #45708:
URL: https://github.com/apache/spark/pull/45708#discussion_r1541443869


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala:
##########
@@ -63,3 +70,300 @@ case class ParseJson(child: Expression) extends UnaryExpression
   override protected def withNewChildInternal(newChild: Expression): ParseJson =
     copy(child = newChild)
 }
+
+// A path segment in the `VariantGet` expression. It represents either an object key access (when
+// `key` is not null) or an array index access (when `key` is null).
+case class PathSegment(key: String, index: Int)
+
+object VariantPathParser extends RegexParsers {
+  private def root: Parser[Char] = '$'
+
+  // Parse index segment like `[123]`.
+  private def index: Parser[PathSegment] =
+    for {
+      index <- '[' ~> "\\d+".r <~ ']'
+    } yield {
+      PathSegment(null, index.toInt)
+    }
+
+  // Parse key segment like `.name`, `['name']`, or `["name"]`.
+  private def key: Parser[PathSegment] =
+    for {
+      key <- '.' ~> "[^\\.\\[]+".r | "['" ~> "[^\\'\\?]+".r <~ "']" |
+        "[\"" ~> "[^\\\"\\?]+".r <~ "\"]"
+    } yield {
+      PathSegment(key, 0)
+    }
+
+  private val parser: Parser[List[PathSegment]] = phrase(root ~> rep(key | index))
+
+  def parse(str: String): Option[Array[PathSegment]] = {
+    this.parseAll(parser, str) match {
+      case Success(result, _) => Some(result.toArray)
+      case _ => None
+    }
+  }
+}
+
+/**
+ * The implementation for `variant_get` and `try_variant_get` expressions. Extracts a sub-variant
+ * value according to a path and cast it into a concrete data type.
+ * @param child The source variant value to extract from.
+ * @param path A literal path expression. It has the same format as the JSON path.
+ * @param schema The target data type to cast into.
+ * @param failOnError Controls whether the expression should throw an exception or return null if
+ *                    the cast fails.
+ * @param timeZoneId A string identifier of a time zone. It is required by timestamp-related casts.
+ */
+case class VariantGet(
+    child: Expression,
+    path: Expression,
+    schema: DataType,
+    failOnError: Boolean,
+    timeZoneId: Option[String] = None)
+    extends BinaryExpression
+    with TimeZoneAwareExpression
+    with NullIntolerant
+    with ExpectsInputTypes
+    with CodegenFallback
+    with QueryErrorsBase {
+  override def checkInputDataTypes(): TypeCheckResult = {
+    val check = super.checkInputDataTypes()
+    if (check.isFailure) {
+      check
+    } else if (!path.foldable) {
+      DataTypeMismatch(
+        errorSubClass = "NON_FOLDABLE_INPUT",
+        messageParameters = Map(
+          "inputName" -> toSQLId("path"),
+          "inputType" -> toSQLType(path.dataType),
+          "inputExpr" -> toSQLExpr(path)
+        )
+      )
+    } else if (!VariantGet.checkDataType(schema)) {
+      DataTypeMismatch(
+        errorSubClass = "CAST_WITHOUT_SUGGESTION",
+        messageParameters = Map(
+          "srcType" -> toSQLType(VariantType),
+          "targetType" -> toSQLType(schema)
+        )
+      )
+    } else {
+      TypeCheckResult.TypeCheckSuccess
+    }
+  }
+
+  override lazy val dataType: DataType = schema.asNullable
+
+  @transient private lazy val parsedPath = {
+    val pathValue = path.eval().toString
+    VariantPathParser.parse(pathValue).getOrElse {
+      throw QueryExecutionErrors.invalidVariantGetPath(pathValue, prettyName)
+    }
+  }
+
+  final override def nodePatternsInternal(): Seq[TreePattern] = Seq(VARIANT_GET)
+
+  override def inputTypes: Seq[AbstractDataType] = Seq(VariantType, StringType)
+
+  override def prettyName: String = if (failOnError) "variant_get" else "try_variant_get"
+
+  override def nullable: Boolean = true
+
+  protected override def nullSafeEval(input: Any, path: Any): Any = {

Review Comment:
   There is yet another reason against using `StaticInvoke`. The path parameter must be an literal, and I can make use of this requirement to avoid repeated path parsing. However, I cannot find how to do a similar caching in `StaticInvoke`.
   
   Using `StaticInvoke` won't simplify the current implementation. It can indeed simplify the implementation if we want to support native codegen rather than depending on `CodegenFallback`. I think that is an optional optimization we can do in the future, when we can manually write the codegen for `VariantGet`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


Re: [PR] [SPARK-47551][SQL] Add variant_get expression. [spark]

Posted by "chenhao-db (via GitHub)" <gi...@apache.org>.
chenhao-db commented on PR #45708:
URL: https://github.com/apache/spark/pull/45708#issuecomment-2019117063

   @cloud-fan I haven't changed https://github.com/apache/spark/pull/45708/files#diff-9e7a4d9777eb424f4453b1ece9618eb916ea4b1e312d5e300b1b29b657ced562R305. There are two reasons:
   - Although we should never hit this error inside Spark, the variant sub-libaray is intended to be used outside of Spark too, and other users may hit it (e.g., call `getLong` without ensuring `getType` is `LONG`). So it shouldn't be an internal error.
   - Throwing a `SparkException` requires changing all the Java method signatures that use this function, which is a non-trivial change.
   
   Please let me know whether it makes sense. Thanks!


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


Re: [PR] [SPARK-47551][SQL] Add variant_get expression. [spark]

Posted by "cloud-fan (via GitHub)" <gi...@apache.org>.
cloud-fan commented on code in PR #45708:
URL: https://github.com/apache/spark/pull/45708#discussion_r1544181368


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala:
##########
@@ -53,3 +66,311 @@ case class ParseJson(child: Expression)
   override protected def withNewChildInternal(newChild: Expression): ParseJson =
     copy(child = newChild)
 }
+
+object VariantPathParser extends RegexParsers {
+  // A path segment in the `VariantGet` expression represents either an object key access or an
+  // array index access.
+  type PathSegment = Either[String, Int]
+
+  private def root: Parser[Char] = '$'
+
+  // Parse index segment like `[123]`.
+  private def index: Parser[PathSegment] =
+    for {
+      index <- '[' ~> "\\d+".r <~ ']'
+    } yield {
+      scala.util.Right(index.toInt)
+    }
+
+  // Parse key segment like `.name`, `['name']`, or `["name"]`.
+  private def key: Parser[PathSegment] =
+    for {
+      key <- '.' ~> "[^\\.\\[]+".r | "['" ~> "[^\\'\\?]+".r <~ "']" |
+        "[\"" ~> "[^\\\"\\?]+".r <~ "\"]"
+    } yield {
+      scala.util.Left(key)
+    }
+
+  private val parser: Parser[List[PathSegment]] = phrase(root ~> rep(key | index))
+
+  def parse(str: String): Option[Array[PathSegment]] = {
+    this.parseAll(parser, str) match {
+      case Success(result, _) => Some(result.toArray)
+      case _ => None
+    }
+  }
+}
+
+/**
+ * The implementation for `variant_get` and `try_variant_get` expressions. Extracts a sub-variant
+ * value according to a path and cast it into a concrete data type.
+ * @param child The source variant value to extract from.
+ * @param path A literal path expression. It has the same format as the JSON path.
+ * @param targetType The target data type to cast into. Any non-nullable annotations are ignored.
+ * @param failOnError Controls whether the expression should throw an exception or return null if
+ *                    the cast fails.
+ * @param timeZoneId A string identifier of a time zone. It is required by timestamp-related casts.
+ */
+case class VariantGet(
+    child: Expression,
+    path: Expression,
+    targetType: DataType,
+    failOnError: Boolean,
+    timeZoneId: Option[String] = None)
+    extends BinaryExpression
+    with TimeZoneAwareExpression
+    with NullIntolerant
+    with ExpectsInputTypes
+    with CodegenFallback
+    with QueryErrorsBase {
+  override def checkInputDataTypes(): TypeCheckResult = {
+    val check = super.checkInputDataTypes()
+    if (check.isFailure) {
+      check
+    } else if (!path.foldable) {
+      DataTypeMismatch(
+        errorSubClass = "NON_FOLDABLE_INPUT",
+        messageParameters = Map(
+          "inputName" -> toSQLId("path"),
+          "inputType" -> toSQLType(path.dataType),
+          "inputExpr" -> toSQLExpr(path)))
+    } else if (!VariantGet.checkDataType(targetType)) {
+      DataTypeMismatch(
+        errorSubClass = "CAST_WITHOUT_SUGGESTION",
+        messageParameters =
+          Map("srcType" -> toSQLType(VariantType), "targetType" -> toSQLType(targetType)))
+    } else {
+      TypeCheckResult.TypeCheckSuccess
+    }
+  }
+
+  override lazy val dataType: DataType = targetType.asNullable
+
+  @transient private lazy val parsedPath = {
+    val pathValue = path.eval().toString
+    VariantPathParser.parse(pathValue).getOrElse {
+      throw QueryExecutionErrors.invalidVariantGetPath(pathValue, prettyName)
+    }
+  }
+
+  final override def nodePatternsInternal(): Seq[TreePattern] = Seq(VARIANT_GET)
+
+  override def inputTypes: Seq[AbstractDataType] = Seq(VariantType, StringType)
+
+  override def prettyName: String = if (failOnError) "variant_get" else "try_variant_get"
+
+  override def nullable: Boolean = true
+
+  protected override def nullSafeEval(input: Any, path: Any): Any = {
+    VariantGet.variantGet(
+      input.asInstanceOf[VariantVal],
+      parsedPath,
+      dataType,
+      failOnError,
+      timeZoneId)
+  }
+
+  override def left: Expression = child
+
+  override def right: Expression = path
+
+  override protected def withNewChildrenInternal(
+      newChild: Expression,
+      newPath: Expression): VariantGet = copy(child = newChild, path = newPath)
+
+  override def withTimeZone(timeZoneId: String): VariantGet = copy(timeZoneId = Option(timeZoneId))
+}
+
+case object VariantGet {
+  /**
+   * Returns whether a data type can be cast into/from variant. For scalar types, we allow a subset
+   * of them. For nested types, we reject map types with a non-string key type.
+   */
+  def checkDataType(dataType: DataType): Boolean = dataType match {
+    case _: NumericType | BooleanType | StringType | BinaryType | TimestampType | DateType |

Review Comment:
   ```suggestion
       case _: NumericType | BooleanType | _: StringType | BinaryType | TimestampType | DateType |
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


Re: [PR] [SPARK-47551][SQL] Add variant_get expression. [spark]

Posted by "cloud-fan (via GitHub)" <gi...@apache.org>.
cloud-fan closed pull request #45708: [SPARK-47551][SQL] Add variant_get expression.
URL: https://github.com/apache/spark/pull/45708


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


Re: [PR] [SPARK-47551][SQL] Add variant_get expression. [spark]

Posted by "chenhao-db (via GitHub)" <gi...@apache.org>.
chenhao-db commented on code in PR #45708:
URL: https://github.com/apache/spark/pull/45708#discussion_r1546622028


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala:
##########
@@ -53,3 +66,320 @@ case class ParseJson(child: Expression)
   override protected def withNewChildInternal(newChild: Expression): ParseJson =
     copy(child = newChild)
 }
+
+object VariantPathParser extends RegexParsers {
+  // A path segment in the `VariantGet` expression represents either an object key access or an
+  // array index access.
+  type PathSegment = Either[String, Int]
+
+  private def root: Parser[Char] = '$'
+
+  // Parse index segment like `[123]`.
+  private def index: Parser[PathSegment] =
+    for {
+      index <- '[' ~> "\\d+".r <~ ']'
+    } yield {
+      scala.util.Right(index.toInt)
+    }
+
+  // Parse key segment like `.name`, `['name']`, or `["name"]`.
+  private def key: Parser[PathSegment] =
+    for {
+      key <- '.' ~> "[^\\.\\[]+".r | "['" ~> "[^\\'\\?]+".r <~ "']" |
+        "[\"" ~> "[^\\\"\\?]+".r <~ "\"]"
+    } yield {
+      scala.util.Left(key)
+    }
+
+  private val parser: Parser[List[PathSegment]] = phrase(root ~> rep(key | index))
+
+  def parse(str: String): Option[Array[PathSegment]] = {
+    this.parseAll(parser, str) match {
+      case Success(result, _) => Some(result.toArray)
+      case _ => None
+    }
+  }
+}
+
+/**
+ * The implementation for `variant_get` and `try_variant_get` expressions. Extracts a sub-variant
+ * value according to a path and cast it into a concrete data type.
+ * @param child The source variant value to extract from.
+ * @param path A literal path expression. It has the same format as the JSON path.
+ * @param targetType The target data type to cast into. Any non-nullable annotations are ignored.
+ * @param failOnError Controls whether the expression should throw an exception or return null if
+ *                    the cast fails.
+ * @param timeZoneId A string identifier of a time zone. It is required by timestamp-related casts.
+ */
+case class VariantGet(
+    child: Expression,
+    path: Expression,
+    targetType: DataType,
+    failOnError: Boolean,
+    timeZoneId: Option[String] = None)
+    extends BinaryExpression
+    with TimeZoneAwareExpression
+    with NullIntolerant
+    with ExpectsInputTypes
+    with CodegenFallback
+    with QueryErrorsBase {
+  override def checkInputDataTypes(): TypeCheckResult = {
+    val check = super.checkInputDataTypes()
+    if (check.isFailure) {
+      check
+    } else if (!path.foldable) {
+      DataTypeMismatch(
+        errorSubClass = "NON_FOLDABLE_INPUT",
+        messageParameters = Map(
+          "inputName" -> toSQLId("path"),
+          "inputType" -> toSQLType(path.dataType),
+          "inputExpr" -> toSQLExpr(path)))
+    } else if (!VariantGet.checkDataType(targetType)) {
+      DataTypeMismatch(
+        errorSubClass = "CAST_WITHOUT_SUGGESTION",
+        messageParameters =
+          Map("srcType" -> toSQLType(VariantType), "targetType" -> toSQLType(targetType)))
+    } else {
+      TypeCheckResult.TypeCheckSuccess
+    }
+  }
+
+  override lazy val dataType: DataType = targetType.asNullable
+
+  @transient private lazy val parsedPath = {
+    val pathValue = path.eval().toString
+    VariantPathParser.parse(pathValue).getOrElse {
+      throw QueryExecutionErrors.invalidVariantGetPath(pathValue, prettyName)
+    }
+  }
+
+  final override def nodePatternsInternal(): Seq[TreePattern] = Seq(VARIANT_GET)
+
+  override def inputTypes: Seq[AbstractDataType] = Seq(VariantType, StringType)
+
+  override def prettyName: String = if (failOnError) "variant_get" else "try_variant_get"
+
+  override def nullable: Boolean = true
+
+  protected override def nullSafeEval(input: Any, path: Any): Any = {
+    VariantGet.variantGet(
+      input.asInstanceOf[VariantVal],
+      parsedPath,
+      dataType,
+      failOnError,
+      timeZoneId)
+  }
+
+  override def left: Expression = child
+
+  override def right: Expression = path
+
+  override protected def withNewChildrenInternal(
+      newChild: Expression,
+      newPath: Expression): VariantGet = copy(child = newChild, path = newPath)
+
+  override def withTimeZone(timeZoneId: String): VariantGet = copy(timeZoneId = Option(timeZoneId))
+}
+
+case object VariantGet {
+  /**
+   * Returns whether a data type can be cast into/from variant. For scalar types, we allow a subset
+   * of them. For nested types, we reject map types with a non-string key type.
+   */
+  def checkDataType(dataType: DataType): Boolean = dataType match {
+    case _: NumericType | BooleanType | _: StringType | BinaryType | TimestampType | DateType |

Review Comment:
   Done. As a side note, `TimestampNTZType` doesn't have a similar overflowing issue like `TimestampType` because it is not allowed to cast numeric types into `TimestampNTZType`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


Re: [PR] [SPARK-47551][SQL] Add variant_get expression. [spark]

Posted by "chenhao-db (via GitHub)" <gi...@apache.org>.
chenhao-db commented on code in PR #45708:
URL: https://github.com/apache/spark/pull/45708#discussion_r1546633520


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala:
##########
@@ -53,3 +66,320 @@ case class ParseJson(child: Expression)
   override protected def withNewChildInternal(newChild: Expression): ParseJson =
     copy(child = newChild)
 }
+
+object VariantPathParser extends RegexParsers {
+  // A path segment in the `VariantGet` expression represents either an object key access or an
+  // array index access.
+  type PathSegment = Either[String, Int]
+
+  private def root: Parser[Char] = '$'
+
+  // Parse index segment like `[123]`.
+  private def index: Parser[PathSegment] =
+    for {
+      index <- '[' ~> "\\d+".r <~ ']'
+    } yield {
+      scala.util.Right(index.toInt)
+    }
+
+  // Parse key segment like `.name`, `['name']`, or `["name"]`.
+  private def key: Parser[PathSegment] =
+    for {
+      key <- '.' ~> "[^\\.\\[]+".r | "['" ~> "[^\\'\\?]+".r <~ "']" |
+        "[\"" ~> "[^\\\"\\?]+".r <~ "\"]"
+    } yield {
+      scala.util.Left(key)
+    }
+
+  private val parser: Parser[List[PathSegment]] = phrase(root ~> rep(key | index))
+
+  def parse(str: String): Option[Array[PathSegment]] = {
+    this.parseAll(parser, str) match {
+      case Success(result, _) => Some(result.toArray)
+      case _ => None
+    }
+  }
+}
+
+/**
+ * The implementation for `variant_get` and `try_variant_get` expressions. Extracts a sub-variant
+ * value according to a path and cast it into a concrete data type.
+ * @param child The source variant value to extract from.
+ * @param path A literal path expression. It has the same format as the JSON path.
+ * @param targetType The target data type to cast into. Any non-nullable annotations are ignored.
+ * @param failOnError Controls whether the expression should throw an exception or return null if
+ *                    the cast fails.
+ * @param timeZoneId A string identifier of a time zone. It is required by timestamp-related casts.
+ */
+case class VariantGet(
+    child: Expression,
+    path: Expression,
+    targetType: DataType,
+    failOnError: Boolean,
+    timeZoneId: Option[String] = None)
+    extends BinaryExpression
+    with TimeZoneAwareExpression
+    with NullIntolerant
+    with ExpectsInputTypes
+    with CodegenFallback
+    with QueryErrorsBase {
+  override def checkInputDataTypes(): TypeCheckResult = {
+    val check = super.checkInputDataTypes()
+    if (check.isFailure) {
+      check
+    } else if (!path.foldable) {
+      DataTypeMismatch(
+        errorSubClass = "NON_FOLDABLE_INPUT",
+        messageParameters = Map(
+          "inputName" -> toSQLId("path"),
+          "inputType" -> toSQLType(path.dataType),
+          "inputExpr" -> toSQLExpr(path)))
+    } else if (!VariantGet.checkDataType(targetType)) {
+      DataTypeMismatch(
+        errorSubClass = "CAST_WITHOUT_SUGGESTION",
+        messageParameters =
+          Map("srcType" -> toSQLType(VariantType), "targetType" -> toSQLType(targetType)))
+    } else {
+      TypeCheckResult.TypeCheckSuccess
+    }
+  }
+
+  override lazy val dataType: DataType = targetType.asNullable
+
+  @transient private lazy val parsedPath = {
+    val pathValue = path.eval().toString
+    VariantPathParser.parse(pathValue).getOrElse {
+      throw QueryExecutionErrors.invalidVariantGetPath(pathValue, prettyName)
+    }
+  }
+
+  final override def nodePatternsInternal(): Seq[TreePattern] = Seq(VARIANT_GET)
+
+  override def inputTypes: Seq[AbstractDataType] = Seq(VariantType, StringType)
+
+  override def prettyName: String = if (failOnError) "variant_get" else "try_variant_get"
+
+  override def nullable: Boolean = true
+
+  protected override def nullSafeEval(input: Any, path: Any): Any = {
+    VariantGet.variantGet(
+      input.asInstanceOf[VariantVal],
+      parsedPath,
+      dataType,
+      failOnError,
+      timeZoneId)
+  }
+
+  override def left: Expression = child
+
+  override def right: Expression = path
+
+  override protected def withNewChildrenInternal(
+      newChild: Expression,
+      newPath: Expression): VariantGet = copy(child = newChild, path = newPath)
+
+  override def withTimeZone(timeZoneId: String): VariantGet = copy(timeZoneId = Option(timeZoneId))
+}
+
+case object VariantGet {
+  /**
+   * Returns whether a data type can be cast into/from variant. For scalar types, we allow a subset
+   * of them. For nested types, we reject map types with a non-string key type.
+   */
+  def checkDataType(dataType: DataType): Boolean = dataType match {
+    case _: NumericType | BooleanType | _: StringType | BinaryType | TimestampType | DateType |
+        VariantType =>
+      true
+    case ArrayType(elementType, _) => checkDataType(elementType)
+    case MapType(StringType, valueType, _) => checkDataType(valueType)
+    case StructType(fields) => fields.forall(f => checkDataType(f.dataType))
+    case _ => false
+  }
+
+  /** The actual implementation of the `VariantGet` expression. */
+  def variantGet(
+      input: VariantVal,
+      parsedPath: Array[VariantPathParser.PathSegment],
+      dataType: DataType,
+      failOnError: Boolean,
+      zoneId: Option[String]): Any = {
+    var v = new Variant(input.getValue, input.getMetadata)
+    for (path <- parsedPath) {
+      v = path match {
+        case scala.util.Left(key) if v.getType == Type.OBJECT => v.getFieldByKey(key)
+        case scala.util.Right(index) if v.getType == Type.ARRAY => v.getElementAtIndex(index)
+        case _ => null
+      }
+      if (v == null) return null
+    }
+    VariantGet.cast(v, dataType, failOnError, zoneId)
+  }
+
+  /**
+   * Cast a variant `v` into a target data type `dataType`. If the variant represents a variant
+   * null, the result is always a SQL NULL. The cast may fail due to an illegal type combination
+   * (e.g., cast a variant int to binary), or an invalid input valid (e.g, cast a variant string
+   * "hello" to int). If the cast fails, throw an exception when `failOnError` is true, or return a
+   * SQL NULL when it is false.
+   */
+  def cast(v: Variant, dataType: DataType, failOnError: Boolean, zoneId: Option[String]): Any = {
+    def invalidCast(): Any =
+      if (failOnError) throw QueryExecutionErrors.invalidVariantCast(v.toJson, dataType) else null
+
+    val variantType = v.getType
+    if (variantType == Type.NULL) return null
+    dataType match {
+      case VariantType => new VariantVal(v.getValue, v.getMetadata)
+      case _: AtomicType =>

Review Comment:
   I think there will be a lot of duplication. The new code would be something like:
   ```
   case _: StringType =>
     variantType match {
       case Type.OBJECT | Type.ARRAY => ...
       case Type.BOOLEAN => ...
       case Type.LONG => ...
       ...
     }
   case _: AtomicType =>
     variantType match {
       case Type.BOOLEAN => ...
       case Type.LONG => ...
       ...
     }
   ```
   
   However, I have another way to reduce the indentation: we can have an early return in the `case Type.OBJECT | Type.ARRAY` case.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


Re: [PR] [SPARK-47551][SQL] Add variant_get expression. [spark]

Posted by "cloud-fan (via GitHub)" <gi...@apache.org>.
cloud-fan commented on code in PR #45708:
URL: https://github.com/apache/spark/pull/45708#discussion_r1544604853


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala:
##########
@@ -53,3 +66,311 @@ case class ParseJson(child: Expression)
   override protected def withNewChildInternal(newChild: Expression): ParseJson =
     copy(child = newChild)
 }
+
+object VariantPathParser extends RegexParsers {
+  // A path segment in the `VariantGet` expression represents either an object key access or an
+  // array index access.
+  type PathSegment = Either[String, Int]
+
+  private def root: Parser[Char] = '$'
+
+  // Parse index segment like `[123]`.
+  private def index: Parser[PathSegment] =
+    for {
+      index <- '[' ~> "\\d+".r <~ ']'
+    } yield {
+      scala.util.Right(index.toInt)
+    }
+
+  // Parse key segment like `.name`, `['name']`, or `["name"]`.
+  private def key: Parser[PathSegment] =
+    for {
+      key <- '.' ~> "[^\\.\\[]+".r | "['" ~> "[^\\'\\?]+".r <~ "']" |
+        "[\"" ~> "[^\\\"\\?]+".r <~ "\"]"
+    } yield {
+      scala.util.Left(key)
+    }
+
+  private val parser: Parser[List[PathSegment]] = phrase(root ~> rep(key | index))
+
+  def parse(str: String): Option[Array[PathSegment]] = {
+    this.parseAll(parser, str) match {
+      case Success(result, _) => Some(result.toArray)
+      case _ => None
+    }
+  }
+}
+
+/**
+ * The implementation for `variant_get` and `try_variant_get` expressions. Extracts a sub-variant
+ * value according to a path and cast it into a concrete data type.
+ * @param child The source variant value to extract from.
+ * @param path A literal path expression. It has the same format as the JSON path.
+ * @param targetType The target data type to cast into. Any non-nullable annotations are ignored.
+ * @param failOnError Controls whether the expression should throw an exception or return null if
+ *                    the cast fails.
+ * @param timeZoneId A string identifier of a time zone. It is required by timestamp-related casts.
+ */
+case class VariantGet(
+    child: Expression,
+    path: Expression,
+    targetType: DataType,
+    failOnError: Boolean,
+    timeZoneId: Option[String] = None)
+    extends BinaryExpression
+    with TimeZoneAwareExpression
+    with NullIntolerant
+    with ExpectsInputTypes
+    with CodegenFallback
+    with QueryErrorsBase {
+  override def checkInputDataTypes(): TypeCheckResult = {
+    val check = super.checkInputDataTypes()
+    if (check.isFailure) {
+      check
+    } else if (!path.foldable) {
+      DataTypeMismatch(
+        errorSubClass = "NON_FOLDABLE_INPUT",
+        messageParameters = Map(
+          "inputName" -> toSQLId("path"),
+          "inputType" -> toSQLType(path.dataType),
+          "inputExpr" -> toSQLExpr(path)))
+    } else if (!VariantGet.checkDataType(targetType)) {
+      DataTypeMismatch(
+        errorSubClass = "CAST_WITHOUT_SUGGESTION",
+        messageParameters =
+          Map("srcType" -> toSQLType(VariantType), "targetType" -> toSQLType(targetType)))
+    } else {
+      TypeCheckResult.TypeCheckSuccess
+    }
+  }
+
+  override lazy val dataType: DataType = targetType.asNullable
+
+  @transient private lazy val parsedPath = {
+    val pathValue = path.eval().toString
+    VariantPathParser.parse(pathValue).getOrElse {
+      throw QueryExecutionErrors.invalidVariantGetPath(pathValue, prettyName)
+    }
+  }
+
+  final override def nodePatternsInternal(): Seq[TreePattern] = Seq(VARIANT_GET)
+
+  override def inputTypes: Seq[AbstractDataType] = Seq(VariantType, StringType)
+
+  override def prettyName: String = if (failOnError) "variant_get" else "try_variant_get"
+
+  override def nullable: Boolean = true
+
+  protected override def nullSafeEval(input: Any, path: Any): Any = {
+    VariantGet.variantGet(
+      input.asInstanceOf[VariantVal],
+      parsedPath,
+      dataType,
+      failOnError,
+      timeZoneId)
+  }
+
+  override def left: Expression = child
+
+  override def right: Expression = path
+
+  override protected def withNewChildrenInternal(
+      newChild: Expression,
+      newPath: Expression): VariantGet = copy(child = newChild, path = newPath)
+
+  override def withTimeZone(timeZoneId: String): VariantGet = copy(timeZoneId = Option(timeZoneId))
+}
+
+case object VariantGet {
+  /**
+   * Returns whether a data type can be cast into/from variant. For scalar types, we allow a subset
+   * of them. For nested types, we reject map types with a non-string key type.
+   */
+  def checkDataType(dataType: DataType): Boolean = dataType match {
+    case _: NumericType | BooleanType | StringType | BinaryType | TimestampType | DateType |
+        VariantType =>
+      true
+    case ArrayType(elementType, _) => checkDataType(elementType)
+    case MapType(StringType, valueType, _) => checkDataType(valueType)
+    case StructType(fields) => fields.forall(f => checkDataType(f.dataType))
+    case _ => false
+  }
+
+  /** The actual implementation of the `VariantGet` expression. */
+  def variantGet(
+      input: VariantVal,
+      parsedPath: Array[VariantPathParser.PathSegment],
+      dataType: DataType,
+      failOnError: Boolean,
+      zoneId: Option[String]): Any = {
+    var v = new Variant(input.getValue, input.getMetadata)
+    for (path <- parsedPath) {
+      v = path match {
+        case scala.util.Left(key) if v.getType == Type.OBJECT => v.getFieldByKey(key)
+        case scala.util.Right(index) if v.getType == Type.ARRAY => v.getElementAtIndex(index)
+        case _ => null
+      }
+      if (v == null) return null
+    }
+    VariantGet.cast(v, dataType, failOnError, zoneId)
+  }
+
+  /**
+   * Cast a variant `v` into a target data type `dataType`. If the variant represents a variant
+   * null, the result is always a SQL NULL. The cast may fail due to an illegal type combination
+   * (e.g., cast a variant int to binary), or an invalid input valid (e.g, cast a variant string
+   * "hello" to int). If the cast fails, throw an exception when `failOnError` is true, or return a
+   * SQL NULL when it is false.
+   */
+  def cast(v: Variant, dataType: DataType, failOnError: Boolean, zoneId: Option[String]): Any = {
+    def invalidCast(): Any =
+      if (failOnError) throw QueryExecutionErrors.invalidVariantCast(v.toJson, dataType) else null
+
+    val variantType = v.getType
+    if (variantType == Type.NULL) return null
+    dataType match {
+      case VariantType => new VariantVal(v.getValue, v.getMetadata)
+      case _: AtomicType =>
+        variantType match {
+          case Type.OBJECT | Type.ARRAY =>
+            if (dataType == StringType) UTF8String.fromString(v.toJson) else invalidCast()
+          case _ =>
+            val input = variantType match {
+              case Type.BOOLEAN => v.getBoolean
+              case Type.LONG => v.getLong
+              case Type.STRING => UTF8String.fromString(v.getString)
+              case Type.DOUBLE => v.getDouble
+              case Type.DECIMAL => Decimal(v.getDecimal)
+              // We have handled other cases and should never reach here. This case is only intended
+              // to by pass the compiler exhaustiveness check.
+              case _ => throw QueryExecutionErrors.unreachableError()
+            }
+            // We mostly use the `Cast` expression to implement the cast. However, `Cast` silently
+            // ignores the overflow in the long/decimal -> timestamp cast, and we want to enforce
+            // strict overflow checks.
+            input match {
+              case l: Long if dataType == TimestampType =>
+                try Math.multiplyExact(l, MICROS_PER_SECOND)
+                catch {
+                  case _: ArithmeticException => invalidCast()
+                }
+              case d: Decimal if dataType == TimestampType =>
+                try {
+                  d.toJavaBigDecimal
+                    .multiply(new java.math.BigDecimal(MICROS_PER_SECOND))
+                    .toBigInteger
+                    .longValueExact()
+                } catch {
+                  case _: ArithmeticException => invalidCast()
+                }
+              case _ =>
+                val result = Cast(Literal(input), dataType, zoneId, EvalMode.TRY).eval()

Review Comment:
   I don't think so. The target type can be boolean and the variant value can be int type, which isn't a valid cast in the ANSI mode, but we won't see any failure if we evaluate the Cast expression directly.
   
   I suggest we check `Cast.canAnsiCast` and fail earlier.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


Re: [PR] [SPARK-47551][SQL] Add variant_get expression. [spark]

Posted by "cloud-fan (via GitHub)" <gi...@apache.org>.
cloud-fan commented on code in PR #45708:
URL: https://github.com/apache/spark/pull/45708#discussion_r1544183331


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala:
##########
@@ -53,3 +66,311 @@ case class ParseJson(child: Expression)
   override protected def withNewChildInternal(newChild: Expression): ParseJson =
     copy(child = newChild)
 }
+
+object VariantPathParser extends RegexParsers {
+  // A path segment in the `VariantGet` expression represents either an object key access or an
+  // array index access.
+  type PathSegment = Either[String, Int]
+
+  private def root: Parser[Char] = '$'
+
+  // Parse index segment like `[123]`.
+  private def index: Parser[PathSegment] =
+    for {
+      index <- '[' ~> "\\d+".r <~ ']'
+    } yield {
+      scala.util.Right(index.toInt)
+    }
+
+  // Parse key segment like `.name`, `['name']`, or `["name"]`.
+  private def key: Parser[PathSegment] =
+    for {
+      key <- '.' ~> "[^\\.\\[]+".r | "['" ~> "[^\\'\\?]+".r <~ "']" |
+        "[\"" ~> "[^\\\"\\?]+".r <~ "\"]"
+    } yield {
+      scala.util.Left(key)
+    }
+
+  private val parser: Parser[List[PathSegment]] = phrase(root ~> rep(key | index))
+
+  def parse(str: String): Option[Array[PathSegment]] = {
+    this.parseAll(parser, str) match {
+      case Success(result, _) => Some(result.toArray)
+      case _ => None
+    }
+  }
+}
+
+/**
+ * The implementation for `variant_get` and `try_variant_get` expressions. Extracts a sub-variant
+ * value according to a path and cast it into a concrete data type.
+ * @param child The source variant value to extract from.
+ * @param path A literal path expression. It has the same format as the JSON path.
+ * @param targetType The target data type to cast into. Any non-nullable annotations are ignored.
+ * @param failOnError Controls whether the expression should throw an exception or return null if
+ *                    the cast fails.
+ * @param timeZoneId A string identifier of a time zone. It is required by timestamp-related casts.
+ */
+case class VariantGet(
+    child: Expression,
+    path: Expression,
+    targetType: DataType,
+    failOnError: Boolean,
+    timeZoneId: Option[String] = None)
+    extends BinaryExpression
+    with TimeZoneAwareExpression
+    with NullIntolerant
+    with ExpectsInputTypes
+    with CodegenFallback
+    with QueryErrorsBase {
+  override def checkInputDataTypes(): TypeCheckResult = {
+    val check = super.checkInputDataTypes()
+    if (check.isFailure) {
+      check
+    } else if (!path.foldable) {
+      DataTypeMismatch(
+        errorSubClass = "NON_FOLDABLE_INPUT",
+        messageParameters = Map(
+          "inputName" -> toSQLId("path"),
+          "inputType" -> toSQLType(path.dataType),
+          "inputExpr" -> toSQLExpr(path)))
+    } else if (!VariantGet.checkDataType(targetType)) {
+      DataTypeMismatch(
+        errorSubClass = "CAST_WITHOUT_SUGGESTION",
+        messageParameters =
+          Map("srcType" -> toSQLType(VariantType), "targetType" -> toSQLType(targetType)))
+    } else {
+      TypeCheckResult.TypeCheckSuccess
+    }
+  }
+
+  override lazy val dataType: DataType = targetType.asNullable
+
+  @transient private lazy val parsedPath = {
+    val pathValue = path.eval().toString
+    VariantPathParser.parse(pathValue).getOrElse {
+      throw QueryExecutionErrors.invalidVariantGetPath(pathValue, prettyName)
+    }
+  }
+
+  final override def nodePatternsInternal(): Seq[TreePattern] = Seq(VARIANT_GET)
+
+  override def inputTypes: Seq[AbstractDataType] = Seq(VariantType, StringType)
+
+  override def prettyName: String = if (failOnError) "variant_get" else "try_variant_get"
+
+  override def nullable: Boolean = true
+
+  protected override def nullSafeEval(input: Any, path: Any): Any = {
+    VariantGet.variantGet(
+      input.asInstanceOf[VariantVal],
+      parsedPath,
+      dataType,
+      failOnError,
+      timeZoneId)
+  }
+
+  override def left: Expression = child
+
+  override def right: Expression = path
+
+  override protected def withNewChildrenInternal(
+      newChild: Expression,
+      newPath: Expression): VariantGet = copy(child = newChild, path = newPath)
+
+  override def withTimeZone(timeZoneId: String): VariantGet = copy(timeZoneId = Option(timeZoneId))
+}
+
+case object VariantGet {
+  /**
+   * Returns whether a data type can be cast into/from variant. For scalar types, we allow a subset
+   * of them. For nested types, we reject map types with a non-string key type.
+   */
+  def checkDataType(dataType: DataType): Boolean = dataType match {
+    case _: NumericType | BooleanType | StringType | BinaryType | TimestampType | DateType |
+        VariantType =>
+      true
+    case ArrayType(elementType, _) => checkDataType(elementType)
+    case MapType(StringType, valueType, _) => checkDataType(valueType)
+    case StructType(fields) => fields.forall(f => checkDataType(f.dataType))
+    case _ => false
+  }
+
+  /** The actual implementation of the `VariantGet` expression. */
+  def variantGet(
+      input: VariantVal,
+      parsedPath: Array[VariantPathParser.PathSegment],
+      dataType: DataType,
+      failOnError: Boolean,
+      zoneId: Option[String]): Any = {
+    var v = new Variant(input.getValue, input.getMetadata)
+    for (path <- parsedPath) {
+      v = path match {
+        case scala.util.Left(key) if v.getType == Type.OBJECT => v.getFieldByKey(key)
+        case scala.util.Right(index) if v.getType == Type.ARRAY => v.getElementAtIndex(index)
+        case _ => null
+      }
+      if (v == null) return null
+    }
+    VariantGet.cast(v, dataType, failOnError, zoneId)
+  }
+
+  /**
+   * Cast a variant `v` into a target data type `dataType`. If the variant represents a variant
+   * null, the result is always a SQL NULL. The cast may fail due to an illegal type combination
+   * (e.g., cast a variant int to binary), or an invalid input valid (e.g, cast a variant string
+   * "hello" to int). If the cast fails, throw an exception when `failOnError` is true, or return a
+   * SQL NULL when it is false.
+   */
+  def cast(v: Variant, dataType: DataType, failOnError: Boolean, zoneId: Option[String]): Any = {
+    def invalidCast(): Any =
+      if (failOnError) throw QueryExecutionErrors.invalidVariantCast(v.toJson, dataType) else null
+
+    val variantType = v.getType
+    if (variantType == Type.NULL) return null
+    dataType match {
+      case VariantType => new VariantVal(v.getValue, v.getMetadata)
+      case _: AtomicType =>
+        variantType match {
+          case Type.OBJECT | Type.ARRAY =>
+            if (dataType == StringType) UTF8String.fromString(v.toJson) else invalidCast()
+          case _ =>
+            val input = variantType match {
+              case Type.BOOLEAN => v.getBoolean
+              case Type.LONG => v.getLong
+              case Type.STRING => UTF8String.fromString(v.getString)
+              case Type.DOUBLE => v.getDouble
+              case Type.DECIMAL => Decimal(v.getDecimal)
+              // We have handled other cases and should never reach here. This case is only intended
+              // to by pass the compiler exhaustiveness check.
+              case _ => throw QueryExecutionErrors.unreachableError()
+            }
+            // We mostly use the `Cast` expression to implement the cast. However, `Cast` silently
+            // ignores the overflow in the long/decimal -> timestamp cast, and we want to enforce
+            // strict overflow checks.
+            input match {
+              case l: Long if dataType == TimestampType =>
+                try Math.multiplyExact(l, MICROS_PER_SECOND)
+                catch {
+                  case _: ArithmeticException => invalidCast()
+                }
+              case d: Decimal if dataType == TimestampType =>
+                try {
+                  d.toJavaBigDecimal
+                    .multiply(new java.math.BigDecimal(MICROS_PER_SECOND))
+                    .toBigInteger
+                    .longValueExact()
+                } catch {
+                  case _: ArithmeticException => invalidCast()
+                }
+              case _ =>
+                val result = Cast(Literal(input), dataType, zoneId, EvalMode.TRY).eval()

Review Comment:
   It's risky to evaluate Cast on the fly, as we do not apply any analysis checks. Can we define the allowed type mapping before creating `Cast`.



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala:
##########
@@ -53,3 +66,311 @@ case class ParseJson(child: Expression)
   override protected def withNewChildInternal(newChild: Expression): ParseJson =
     copy(child = newChild)
 }
+
+object VariantPathParser extends RegexParsers {
+  // A path segment in the `VariantGet` expression represents either an object key access or an
+  // array index access.
+  type PathSegment = Either[String, Int]
+
+  private def root: Parser[Char] = '$'
+
+  // Parse index segment like `[123]`.
+  private def index: Parser[PathSegment] =
+    for {
+      index <- '[' ~> "\\d+".r <~ ']'
+    } yield {
+      scala.util.Right(index.toInt)
+    }
+
+  // Parse key segment like `.name`, `['name']`, or `["name"]`.
+  private def key: Parser[PathSegment] =
+    for {
+      key <- '.' ~> "[^\\.\\[]+".r | "['" ~> "[^\\'\\?]+".r <~ "']" |
+        "[\"" ~> "[^\\\"\\?]+".r <~ "\"]"
+    } yield {
+      scala.util.Left(key)
+    }
+
+  private val parser: Parser[List[PathSegment]] = phrase(root ~> rep(key | index))
+
+  def parse(str: String): Option[Array[PathSegment]] = {
+    this.parseAll(parser, str) match {
+      case Success(result, _) => Some(result.toArray)
+      case _ => None
+    }
+  }
+}
+
+/**
+ * The implementation for `variant_get` and `try_variant_get` expressions. Extracts a sub-variant
+ * value according to a path and cast it into a concrete data type.
+ * @param child The source variant value to extract from.
+ * @param path A literal path expression. It has the same format as the JSON path.
+ * @param targetType The target data type to cast into. Any non-nullable annotations are ignored.
+ * @param failOnError Controls whether the expression should throw an exception or return null if
+ *                    the cast fails.
+ * @param timeZoneId A string identifier of a time zone. It is required by timestamp-related casts.
+ */
+case class VariantGet(
+    child: Expression,
+    path: Expression,
+    targetType: DataType,
+    failOnError: Boolean,
+    timeZoneId: Option[String] = None)
+    extends BinaryExpression
+    with TimeZoneAwareExpression
+    with NullIntolerant
+    with ExpectsInputTypes
+    with CodegenFallback
+    with QueryErrorsBase {
+  override def checkInputDataTypes(): TypeCheckResult = {
+    val check = super.checkInputDataTypes()
+    if (check.isFailure) {
+      check
+    } else if (!path.foldable) {
+      DataTypeMismatch(
+        errorSubClass = "NON_FOLDABLE_INPUT",
+        messageParameters = Map(
+          "inputName" -> toSQLId("path"),
+          "inputType" -> toSQLType(path.dataType),
+          "inputExpr" -> toSQLExpr(path)))
+    } else if (!VariantGet.checkDataType(targetType)) {
+      DataTypeMismatch(
+        errorSubClass = "CAST_WITHOUT_SUGGESTION",
+        messageParameters =
+          Map("srcType" -> toSQLType(VariantType), "targetType" -> toSQLType(targetType)))
+    } else {
+      TypeCheckResult.TypeCheckSuccess
+    }
+  }
+
+  override lazy val dataType: DataType = targetType.asNullable
+
+  @transient private lazy val parsedPath = {
+    val pathValue = path.eval().toString
+    VariantPathParser.parse(pathValue).getOrElse {
+      throw QueryExecutionErrors.invalidVariantGetPath(pathValue, prettyName)
+    }
+  }
+
+  final override def nodePatternsInternal(): Seq[TreePattern] = Seq(VARIANT_GET)
+
+  override def inputTypes: Seq[AbstractDataType] = Seq(VariantType, StringType)
+
+  override def prettyName: String = if (failOnError) "variant_get" else "try_variant_get"
+
+  override def nullable: Boolean = true
+
+  protected override def nullSafeEval(input: Any, path: Any): Any = {
+    VariantGet.variantGet(
+      input.asInstanceOf[VariantVal],
+      parsedPath,
+      dataType,
+      failOnError,
+      timeZoneId)
+  }
+
+  override def left: Expression = child
+
+  override def right: Expression = path
+
+  override protected def withNewChildrenInternal(
+      newChild: Expression,
+      newPath: Expression): VariantGet = copy(child = newChild, path = newPath)
+
+  override def withTimeZone(timeZoneId: String): VariantGet = copy(timeZoneId = Option(timeZoneId))
+}
+
+case object VariantGet {
+  /**
+   * Returns whether a data type can be cast into/from variant. For scalar types, we allow a subset
+   * of them. For nested types, we reject map types with a non-string key type.
+   */
+  def checkDataType(dataType: DataType): Boolean = dataType match {
+    case _: NumericType | BooleanType | StringType | BinaryType | TimestampType | DateType |
+        VariantType =>
+      true
+    case ArrayType(elementType, _) => checkDataType(elementType)
+    case MapType(StringType, valueType, _) => checkDataType(valueType)
+    case StructType(fields) => fields.forall(f => checkDataType(f.dataType))
+    case _ => false
+  }
+
+  /** The actual implementation of the `VariantGet` expression. */
+  def variantGet(
+      input: VariantVal,
+      parsedPath: Array[VariantPathParser.PathSegment],
+      dataType: DataType,
+      failOnError: Boolean,
+      zoneId: Option[String]): Any = {
+    var v = new Variant(input.getValue, input.getMetadata)
+    for (path <- parsedPath) {
+      v = path match {
+        case scala.util.Left(key) if v.getType == Type.OBJECT => v.getFieldByKey(key)
+        case scala.util.Right(index) if v.getType == Type.ARRAY => v.getElementAtIndex(index)
+        case _ => null
+      }
+      if (v == null) return null
+    }
+    VariantGet.cast(v, dataType, failOnError, zoneId)
+  }
+
+  /**
+   * Cast a variant `v` into a target data type `dataType`. If the variant represents a variant
+   * null, the result is always a SQL NULL. The cast may fail due to an illegal type combination
+   * (e.g., cast a variant int to binary), or an invalid input valid (e.g, cast a variant string
+   * "hello" to int). If the cast fails, throw an exception when `failOnError` is true, or return a
+   * SQL NULL when it is false.
+   */
+  def cast(v: Variant, dataType: DataType, failOnError: Boolean, zoneId: Option[String]): Any = {
+    def invalidCast(): Any =
+      if (failOnError) throw QueryExecutionErrors.invalidVariantCast(v.toJson, dataType) else null
+
+    val variantType = v.getType
+    if (variantType == Type.NULL) return null
+    dataType match {
+      case VariantType => new VariantVal(v.getValue, v.getMetadata)
+      case _: AtomicType =>
+        variantType match {
+          case Type.OBJECT | Type.ARRAY =>
+            if (dataType == StringType) UTF8String.fromString(v.toJson) else invalidCast()
+          case _ =>
+            val input = variantType match {
+              case Type.BOOLEAN => v.getBoolean
+              case Type.LONG => v.getLong
+              case Type.STRING => UTF8String.fromString(v.getString)
+              case Type.DOUBLE => v.getDouble
+              case Type.DECIMAL => Decimal(v.getDecimal)
+              // We have handled other cases and should never reach here. This case is only intended
+              // to by pass the compiler exhaustiveness check.
+              case _ => throw QueryExecutionErrors.unreachableError()
+            }
+            // We mostly use the `Cast` expression to implement the cast. However, `Cast` silently
+            // ignores the overflow in the long/decimal -> timestamp cast, and we want to enforce
+            // strict overflow checks.
+            input match {
+              case l: Long if dataType == TimestampType =>
+                try Math.multiplyExact(l, MICROS_PER_SECOND)
+                catch {
+                  case _: ArithmeticException => invalidCast()
+                }
+              case d: Decimal if dataType == TimestampType =>
+                try {
+                  d.toJavaBigDecimal
+                    .multiply(new java.math.BigDecimal(MICROS_PER_SECOND))
+                    .toBigInteger
+                    .longValueExact()
+                } catch {
+                  case _: ArithmeticException => invalidCast()
+                }
+              case _ =>
+                val result = Cast(Literal(input), dataType, zoneId, EvalMode.TRY).eval()

Review Comment:
   It's risky to evaluate Cast on the fly, as we do not apply any analysis checks for this Cast. Can we define the allowed type mapping before creating `Cast`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


Re: [PR] [SPARK-47551][SQL] Add variant_get expression. [spark]

Posted by "cloud-fan (via GitHub)" <gi...@apache.org>.
cloud-fan commented on code in PR #45708:
URL: https://github.com/apache/spark/pull/45708#discussion_r1544183331


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala:
##########
@@ -53,3 +66,311 @@ case class ParseJson(child: Expression)
   override protected def withNewChildInternal(newChild: Expression): ParseJson =
     copy(child = newChild)
 }
+
+object VariantPathParser extends RegexParsers {
+  // A path segment in the `VariantGet` expression represents either an object key access or an
+  // array index access.
+  type PathSegment = Either[String, Int]
+
+  private def root: Parser[Char] = '$'
+
+  // Parse index segment like `[123]`.
+  private def index: Parser[PathSegment] =
+    for {
+      index <- '[' ~> "\\d+".r <~ ']'
+    } yield {
+      scala.util.Right(index.toInt)
+    }
+
+  // Parse key segment like `.name`, `['name']`, or `["name"]`.
+  private def key: Parser[PathSegment] =
+    for {
+      key <- '.' ~> "[^\\.\\[]+".r | "['" ~> "[^\\'\\?]+".r <~ "']" |
+        "[\"" ~> "[^\\\"\\?]+".r <~ "\"]"
+    } yield {
+      scala.util.Left(key)
+    }
+
+  private val parser: Parser[List[PathSegment]] = phrase(root ~> rep(key | index))
+
+  def parse(str: String): Option[Array[PathSegment]] = {
+    this.parseAll(parser, str) match {
+      case Success(result, _) => Some(result.toArray)
+      case _ => None
+    }
+  }
+}
+
+/**
+ * The implementation for `variant_get` and `try_variant_get` expressions. Extracts a sub-variant
+ * value according to a path and cast it into a concrete data type.
+ * @param child The source variant value to extract from.
+ * @param path A literal path expression. It has the same format as the JSON path.
+ * @param targetType The target data type to cast into. Any non-nullable annotations are ignored.
+ * @param failOnError Controls whether the expression should throw an exception or return null if
+ *                    the cast fails.
+ * @param timeZoneId A string identifier of a time zone. It is required by timestamp-related casts.
+ */
+case class VariantGet(
+    child: Expression,
+    path: Expression,
+    targetType: DataType,
+    failOnError: Boolean,
+    timeZoneId: Option[String] = None)
+    extends BinaryExpression
+    with TimeZoneAwareExpression
+    with NullIntolerant
+    with ExpectsInputTypes
+    with CodegenFallback
+    with QueryErrorsBase {
+  override def checkInputDataTypes(): TypeCheckResult = {
+    val check = super.checkInputDataTypes()
+    if (check.isFailure) {
+      check
+    } else if (!path.foldable) {
+      DataTypeMismatch(
+        errorSubClass = "NON_FOLDABLE_INPUT",
+        messageParameters = Map(
+          "inputName" -> toSQLId("path"),
+          "inputType" -> toSQLType(path.dataType),
+          "inputExpr" -> toSQLExpr(path)))
+    } else if (!VariantGet.checkDataType(targetType)) {
+      DataTypeMismatch(
+        errorSubClass = "CAST_WITHOUT_SUGGESTION",
+        messageParameters =
+          Map("srcType" -> toSQLType(VariantType), "targetType" -> toSQLType(targetType)))
+    } else {
+      TypeCheckResult.TypeCheckSuccess
+    }
+  }
+
+  override lazy val dataType: DataType = targetType.asNullable
+
+  @transient private lazy val parsedPath = {
+    val pathValue = path.eval().toString
+    VariantPathParser.parse(pathValue).getOrElse {
+      throw QueryExecutionErrors.invalidVariantGetPath(pathValue, prettyName)
+    }
+  }
+
+  final override def nodePatternsInternal(): Seq[TreePattern] = Seq(VARIANT_GET)
+
+  override def inputTypes: Seq[AbstractDataType] = Seq(VariantType, StringType)
+
+  override def prettyName: String = if (failOnError) "variant_get" else "try_variant_get"
+
+  override def nullable: Boolean = true
+
+  protected override def nullSafeEval(input: Any, path: Any): Any = {
+    VariantGet.variantGet(
+      input.asInstanceOf[VariantVal],
+      parsedPath,
+      dataType,
+      failOnError,
+      timeZoneId)
+  }
+
+  override def left: Expression = child
+
+  override def right: Expression = path
+
+  override protected def withNewChildrenInternal(
+      newChild: Expression,
+      newPath: Expression): VariantGet = copy(child = newChild, path = newPath)
+
+  override def withTimeZone(timeZoneId: String): VariantGet = copy(timeZoneId = Option(timeZoneId))
+}
+
+case object VariantGet {
+  /**
+   * Returns whether a data type can be cast into/from variant. For scalar types, we allow a subset
+   * of them. For nested types, we reject map types with a non-string key type.
+   */
+  def checkDataType(dataType: DataType): Boolean = dataType match {
+    case _: NumericType | BooleanType | StringType | BinaryType | TimestampType | DateType |
+        VariantType =>
+      true
+    case ArrayType(elementType, _) => checkDataType(elementType)
+    case MapType(StringType, valueType, _) => checkDataType(valueType)
+    case StructType(fields) => fields.forall(f => checkDataType(f.dataType))
+    case _ => false
+  }
+
+  /** The actual implementation of the `VariantGet` expression. */
+  def variantGet(
+      input: VariantVal,
+      parsedPath: Array[VariantPathParser.PathSegment],
+      dataType: DataType,
+      failOnError: Boolean,
+      zoneId: Option[String]): Any = {
+    var v = new Variant(input.getValue, input.getMetadata)
+    for (path <- parsedPath) {
+      v = path match {
+        case scala.util.Left(key) if v.getType == Type.OBJECT => v.getFieldByKey(key)
+        case scala.util.Right(index) if v.getType == Type.ARRAY => v.getElementAtIndex(index)
+        case _ => null
+      }
+      if (v == null) return null
+    }
+    VariantGet.cast(v, dataType, failOnError, zoneId)
+  }
+
+  /**
+   * Cast a variant `v` into a target data type `dataType`. If the variant represents a variant
+   * null, the result is always a SQL NULL. The cast may fail due to an illegal type combination
+   * (e.g., cast a variant int to binary), or an invalid input valid (e.g, cast a variant string
+   * "hello" to int). If the cast fails, throw an exception when `failOnError` is true, or return a
+   * SQL NULL when it is false.
+   */
+  def cast(v: Variant, dataType: DataType, failOnError: Boolean, zoneId: Option[String]): Any = {
+    def invalidCast(): Any =
+      if (failOnError) throw QueryExecutionErrors.invalidVariantCast(v.toJson, dataType) else null
+
+    val variantType = v.getType
+    if (variantType == Type.NULL) return null
+    dataType match {
+      case VariantType => new VariantVal(v.getValue, v.getMetadata)
+      case _: AtomicType =>
+        variantType match {
+          case Type.OBJECT | Type.ARRAY =>
+            if (dataType == StringType) UTF8String.fromString(v.toJson) else invalidCast()
+          case _ =>
+            val input = variantType match {
+              case Type.BOOLEAN => v.getBoolean
+              case Type.LONG => v.getLong
+              case Type.STRING => UTF8String.fromString(v.getString)
+              case Type.DOUBLE => v.getDouble
+              case Type.DECIMAL => Decimal(v.getDecimal)
+              // We have handled other cases and should never reach here. This case is only intended
+              // to by pass the compiler exhaustiveness check.
+              case _ => throw QueryExecutionErrors.unreachableError()
+            }
+            // We mostly use the `Cast` expression to implement the cast. However, `Cast` silently
+            // ignores the overflow in the long/decimal -> timestamp cast, and we want to enforce
+            // strict overflow checks.
+            input match {
+              case l: Long if dataType == TimestampType =>
+                try Math.multiplyExact(l, MICROS_PER_SECOND)
+                catch {
+                  case _: ArithmeticException => invalidCast()
+                }
+              case d: Decimal if dataType == TimestampType =>
+                try {
+                  d.toJavaBigDecimal
+                    .multiply(new java.math.BigDecimal(MICROS_PER_SECOND))
+                    .toBigInteger
+                    .longValueExact()
+                } catch {
+                  case _: ArithmeticException => invalidCast()
+                }
+              case _ =>
+                val result = Cast(Literal(input), dataType, zoneId, EvalMode.TRY).eval()

Review Comment:
   It's risky to evaluate Cast on the fly, as we do not apply any analysis checks for this Cast. Can we define the allowed type mapping before creating `Cast`?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


Re: [PR] [SPARK-47551][SQL] Add variant_get expression. [spark]

Posted by "cloud-fan (via GitHub)" <gi...@apache.org>.
cloud-fan commented on code in PR #45708:
URL: https://github.com/apache/spark/pull/45708#discussion_r1538614839


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala:
##########
@@ -63,3 +70,300 @@ case class ParseJson(child: Expression) extends UnaryExpression
   override protected def withNewChildInternal(newChild: Expression): ParseJson =
     copy(child = newChild)
 }
+
+// A path segment in the `VariantGet` expression. It represents either an object key access (when
+// `key` is not null) or an array index access (when `key` is null).

Review Comment:
   This doesn't like a good API design, shall we use scala `Either`?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


Re: [PR] [SPARK-47551][SQL] Add variant_get expression. [spark]

Posted by "chenhao-db (via GitHub)" <gi...@apache.org>.
chenhao-db commented on code in PR #45708:
URL: https://github.com/apache/spark/pull/45708#discussion_r1541963979


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala:
##########
@@ -63,3 +70,300 @@ case class ParseJson(child: Expression) extends UnaryExpression
   override protected def withNewChildInternal(newChild: Expression): ParseJson =
     copy(child = newChild)
 }
+
+// A path segment in the `VariantGet` expression. It represents either an object key access (when
+// `key` is not null) or an array index access (when `key` is null).
+case class PathSegment(key: String, index: Int)
+
+object VariantPathParser extends RegexParsers {
+  private def root: Parser[Char] = '$'
+
+  // Parse index segment like `[123]`.
+  private def index: Parser[PathSegment] =
+    for {
+      index <- '[' ~> "\\d+".r <~ ']'
+    } yield {
+      PathSegment(null, index.toInt)
+    }
+
+  // Parse key segment like `.name`, `['name']`, or `["name"]`.
+  private def key: Parser[PathSegment] =
+    for {
+      key <- '.' ~> "[^\\.\\[]+".r | "['" ~> "[^\\'\\?]+".r <~ "']" |
+        "[\"" ~> "[^\\\"\\?]+".r <~ "\"]"
+    } yield {
+      PathSegment(key, 0)
+    }
+
+  private val parser: Parser[List[PathSegment]] = phrase(root ~> rep(key | index))
+
+  def parse(str: String): Option[Array[PathSegment]] = {
+    this.parseAll(parser, str) match {
+      case Success(result, _) => Some(result.toArray)
+      case _ => None
+    }
+  }
+}
+
+/**
+ * The implementation for `variant_get` and `try_variant_get` expressions. Extracts a sub-variant
+ * value according to a path and cast it into a concrete data type.
+ * @param child The source variant value to extract from.
+ * @param path A literal path expression. It has the same format as the JSON path.
+ * @param schema The target data type to cast into.
+ * @param failOnError Controls whether the expression should throw an exception or return null if
+ *                    the cast fails.
+ * @param timeZoneId A string identifier of a time zone. It is required by timestamp-related casts.
+ */
+case class VariantGet(
+    child: Expression,
+    path: Expression,
+    schema: DataType,
+    failOnError: Boolean,
+    timeZoneId: Option[String] = None)
+    extends BinaryExpression
+    with TimeZoneAwareExpression
+    with NullIntolerant
+    with ExpectsInputTypes
+    with CodegenFallback
+    with QueryErrorsBase {
+  override def checkInputDataTypes(): TypeCheckResult = {
+    val check = super.checkInputDataTypes()
+    if (check.isFailure) {
+      check
+    } else if (!path.foldable) {
+      DataTypeMismatch(
+        errorSubClass = "NON_FOLDABLE_INPUT",
+        messageParameters = Map(
+          "inputName" -> toSQLId("path"),
+          "inputType" -> toSQLType(path.dataType),
+          "inputExpr" -> toSQLExpr(path)
+        )
+      )
+    } else if (!VariantGet.checkDataType(schema)) {
+      DataTypeMismatch(
+        errorSubClass = "CAST_WITHOUT_SUGGESTION",
+        messageParameters = Map(
+          "srcType" -> toSQLType(VariantType),
+          "targetType" -> toSQLType(schema)
+        )
+      )
+    } else {
+      TypeCheckResult.TypeCheckSuccess
+    }
+  }
+
+  override lazy val dataType: DataType = schema.asNullable
+
+  @transient private lazy val parsedPath = {
+    val pathValue = path.eval().toString
+    VariantPathParser.parse(pathValue).getOrElse {
+      throw QueryExecutionErrors.invalidVariantGetPath(pathValue, prettyName)
+    }
+  }
+
+  final override def nodePatternsInternal(): Seq[TreePattern] = Seq(VARIANT_GET)
+
+  override def inputTypes: Seq[AbstractDataType] = Seq(VariantType, StringType)
+
+  override def prettyName: String = if (failOnError) "variant_get" else "try_variant_get"
+
+  override def nullable: Boolean = true
+
+  protected override def nullSafeEval(input: Any, path: Any): Any = {

Review Comment:
   In case you are interested, I have a draft for the manual codegn version. I think I can add it in a follow-up PR. Personally, I don't feel the code has any red flag, and it is much better than the `StaticInvoke` approach.
   
   ```
     protected override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
       val childCode = child.genCode(ctx)
       val tmp = ctx.freshVariable("tmp", classOf[Object])
       val parsedPathArg = ctx.addReferenceObj("parsedPath", parsedPath)
       val dataTypeArg = ctx.addReferenceObj("dataType", dataType)
       val zoneIdArg = ctx.addReferenceObj("zoneId", timeZoneId)
       val code = code"""
         ${childCode.code}
         boolean ${ev.isNull} = ${childCode.isNull};
         ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
         if (!${ev.isNull}) {
           Object $tmp = org.apache.spark.sql.catalyst.expressions.variant.VariantGet.variantGet(
             ${childCode.value}, $parsedPathArg, $dataTypeArg, $failOnError, $zoneIdArg);
           if ($tmp == null) {
             ${ev.isNull} = true;
           } else {
             ${ev.value} = (${CodeGenerator.boxedType(dataType)})$tmp;
           }
         }
       """
       ev.copy(code = code)
     }
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


Re: [PR] [SPARK-47551][SQL] Add variant_get expression. [spark]

Posted by "chenhao-db (via GitHub)" <gi...@apache.org>.
chenhao-db commented on code in PR #45708:
URL: https://github.com/apache/spark/pull/45708#discussion_r1541513308


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala:
##########
@@ -63,3 +70,300 @@ case class ParseJson(child: Expression) extends UnaryExpression
   override protected def withNewChildInternal(newChild: Expression): ParseJson =
     copy(child = newChild)
 }
+
+// A path segment in the `VariantGet` expression. It represents either an object key access (when
+// `key` is not null) or an array index access (when `key` is null).
+case class PathSegment(key: String, index: Int)
+
+object VariantPathParser extends RegexParsers {
+  private def root: Parser[Char] = '$'
+
+  // Parse index segment like `[123]`.
+  private def index: Parser[PathSegment] =
+    for {
+      index <- '[' ~> "\\d+".r <~ ']'
+    } yield {
+      PathSegment(null, index.toInt)
+    }
+
+  // Parse key segment like `.name`, `['name']`, or `["name"]`.
+  private def key: Parser[PathSegment] =
+    for {
+      key <- '.' ~> "[^\\.\\[]+".r | "['" ~> "[^\\'\\?]+".r <~ "']" |
+        "[\"" ~> "[^\\\"\\?]+".r <~ "\"]"
+    } yield {
+      PathSegment(key, 0)
+    }
+
+  private val parser: Parser[List[PathSegment]] = phrase(root ~> rep(key | index))
+
+  def parse(str: String): Option[Array[PathSegment]] = {
+    this.parseAll(parser, str) match {
+      case Success(result, _) => Some(result.toArray)
+      case _ => None
+    }
+  }
+}
+
+/**
+ * The implementation for `variant_get` and `try_variant_get` expressions. Extracts a sub-variant
+ * value according to a path and cast it into a concrete data type.
+ * @param child The source variant value to extract from.
+ * @param path A literal path expression. It has the same format as the JSON path.
+ * @param schema The target data type to cast into.
+ * @param failOnError Controls whether the expression should throw an exception or return null if
+ *                    the cast fails.
+ * @param timeZoneId A string identifier of a time zone. It is required by timestamp-related casts.
+ */
+case class VariantGet(
+    child: Expression,
+    path: Expression,
+    schema: DataType,
+    failOnError: Boolean,
+    timeZoneId: Option[String] = None)
+    extends BinaryExpression
+    with TimeZoneAwareExpression
+    with NullIntolerant
+    with ExpectsInputTypes
+    with CodegenFallback
+    with QueryErrorsBase {
+  override def checkInputDataTypes(): TypeCheckResult = {
+    val check = super.checkInputDataTypes()
+    if (check.isFailure) {
+      check
+    } else if (!path.foldable) {
+      DataTypeMismatch(
+        errorSubClass = "NON_FOLDABLE_INPUT",
+        messageParameters = Map(
+          "inputName" -> toSQLId("path"),
+          "inputType" -> toSQLType(path.dataType),
+          "inputExpr" -> toSQLExpr(path)
+        )
+      )
+    } else if (!VariantGet.checkDataType(schema)) {
+      DataTypeMismatch(
+        errorSubClass = "CAST_WITHOUT_SUGGESTION",
+        messageParameters = Map(
+          "srcType" -> toSQLType(VariantType),
+          "targetType" -> toSQLType(schema)
+        )
+      )
+    } else {
+      TypeCheckResult.TypeCheckSuccess
+    }
+  }
+
+  override lazy val dataType: DataType = schema.asNullable
+
+  @transient private lazy val parsedPath = {
+    val pathValue = path.eval().toString
+    VariantPathParser.parse(pathValue).getOrElse {
+      throw QueryExecutionErrors.invalidVariantGetPath(pathValue, prettyName)
+    }
+  }
+
+  final override def nodePatternsInternal(): Seq[TreePattern] = Seq(VARIANT_GET)
+
+  override def inputTypes: Seq[AbstractDataType] = Seq(VariantType, StringType)
+
+  override def prettyName: String = if (failOnError) "variant_get" else "try_variant_get"
+
+  override def nullable: Boolean = true
+
+  protected override def nullSafeEval(input: Any, path: Any): Any = {

Review Comment:
   I didn't mean writing everything by hand. Essentially, we create a method that implements `VariantGet`, and the class only needs some boilerplate code to call this method (similar to the code in `StaticInvoke` itself).
   
   There is still another reason why I don't like `StaticInvoke`. In the future, I will write some optimizer rules on `VariantGet` (e.g., to push it down a join). This is why I added a new `TreePattern` ``VARIANT_GET` in this PR. The optimizer rule will run after `RuntimeReplaceable` expression is replaced, so it will become `StaticInvoke` and no longer has this tree pattern, and the optimizer rule can no longer prune expressions. Plus, matching against `StaticInvoke` is also more complex than matching against `VariantGet`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


Re: [PR] [SPARK-47551][SQL] Add variant_get expression. [spark]

Posted by "chenhao-db (via GitHub)" <gi...@apache.org>.
chenhao-db commented on code in PR #45708:
URL: https://github.com/apache/spark/pull/45708#discussion_r1541767305


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala:
##########
@@ -63,3 +70,300 @@ case class ParseJson(child: Expression) extends UnaryExpression
   override protected def withNewChildInternal(newChild: Expression): ParseJson =
     copy(child = newChild)
 }
+
+// A path segment in the `VariantGet` expression. It represents either an object key access (when
+// `key` is not null) or an array index access (when `key` is null).
+case class PathSegment(key: String, index: Int)
+
+object VariantPathParser extends RegexParsers {
+  private def root: Parser[Char] = '$'
+
+  // Parse index segment like `[123]`.
+  private def index: Parser[PathSegment] =
+    for {
+      index <- '[' ~> "\\d+".r <~ ']'
+    } yield {
+      PathSegment(null, index.toInt)
+    }
+
+  // Parse key segment like `.name`, `['name']`, or `["name"]`.
+  private def key: Parser[PathSegment] =
+    for {
+      key <- '.' ~> "[^\\.\\[]+".r | "['" ~> "[^\\'\\?]+".r <~ "']" |
+        "[\"" ~> "[^\\\"\\?]+".r <~ "\"]"
+    } yield {
+      PathSegment(key, 0)
+    }
+
+  private val parser: Parser[List[PathSegment]] = phrase(root ~> rep(key | index))
+
+  def parse(str: String): Option[Array[PathSegment]] = {
+    this.parseAll(parser, str) match {
+      case Success(result, _) => Some(result.toArray)
+      case _ => None
+    }
+  }
+}
+
+/**
+ * The implementation for `variant_get` and `try_variant_get` expressions. Extracts a sub-variant
+ * value according to a path and cast it into a concrete data type.
+ * @param child The source variant value to extract from.
+ * @param path A literal path expression. It has the same format as the JSON path.
+ * @param schema The target data type to cast into.
+ * @param failOnError Controls whether the expression should throw an exception or return null if
+ *                    the cast fails.
+ * @param timeZoneId A string identifier of a time zone. It is required by timestamp-related casts.
+ */
+case class VariantGet(
+    child: Expression,
+    path: Expression,
+    schema: DataType,
+    failOnError: Boolean,
+    timeZoneId: Option[String] = None)
+    extends BinaryExpression
+    with TimeZoneAwareExpression
+    with NullIntolerant
+    with ExpectsInputTypes
+    with CodegenFallback
+    with QueryErrorsBase {
+  override def checkInputDataTypes(): TypeCheckResult = {
+    val check = super.checkInputDataTypes()
+    if (check.isFailure) {
+      check
+    } else if (!path.foldable) {
+      DataTypeMismatch(
+        errorSubClass = "NON_FOLDABLE_INPUT",
+        messageParameters = Map(
+          "inputName" -> toSQLId("path"),
+          "inputType" -> toSQLType(path.dataType),
+          "inputExpr" -> toSQLExpr(path)
+        )
+      )
+    } else if (!VariantGet.checkDataType(schema)) {
+      DataTypeMismatch(
+        errorSubClass = "CAST_WITHOUT_SUGGESTION",
+        messageParameters = Map(
+          "srcType" -> toSQLType(VariantType),
+          "targetType" -> toSQLType(schema)
+        )
+      )
+    } else {
+      TypeCheckResult.TypeCheckSuccess
+    }
+  }
+
+  override lazy val dataType: DataType = schema.asNullable
+
+  @transient private lazy val parsedPath = {
+    val pathValue = path.eval().toString
+    VariantPathParser.parse(pathValue).getOrElse {
+      throw QueryExecutionErrors.invalidVariantGetPath(pathValue, prettyName)
+    }
+  }
+
+  final override def nodePatternsInternal(): Seq[TreePattern] = Seq(VARIANT_GET)
+
+  override def inputTypes: Seq[AbstractDataType] = Seq(VariantType, StringType)
+
+  override def prettyName: String = if (failOnError) "variant_get" else "try_variant_get"
+
+  override def nullable: Boolean = true
+
+  protected override def nullSafeEval(input: Any, path: Any): Any = {

Review Comment:
   I have done some experiments with the `StaticInvoke` approach. Suppose I have encapsulated the `VariantGet` implementation into the following function:
   ```
   case object VariantGetCodegen {
     def variantGet(input: VariantVal, parsedPath: Array[PathSegment],
                    dataType: DataType, failOnError: Boolean, zoneId: Option[String]): Any = {...}
   }
   ```
   and make `VariantGet` a `RuntimeReplaceable` expression with a replacement of `StaticInvoke` that invokes `VariantGetCodegen.variantGet`. It still won't directly work because the codegen logic of `StaticInvoke` assumes the return type of the method directly matches the return type, but the return type of `VariantGetCodegen.variantGet` is `Any`.
   
   In order to make it work, I have to create a wrapper for each return type, like:
   
   ```
   case object VariantGetCodegen {
     def variantGetByte(input: VariantVal, parsedPath: Array[PathSegment],
                    dataType: DataType, failOnError: Boolean, zoneId: Option[String]): Byte =
       variantGet(input, parsedPath. dataType, failOnError, zoneId).asInstanceOf[Byte]
     def variantGetShort(input: VariantVal, parsedPath: Array[PathSegment],
                    dataType: DataType, failOnError: Boolean, zoneId: Option[String]): Short =
       variantGet(input, parsedPath. dataType, failOnError, zoneId).asInstanceOf[Short]
     def variantGetStruct(input: VariantVal, parsedPath: Array[PathSegment],
                    dataType: DataType, failOnError: Boolean, zoneId: Option[String]): InternalRow =
       variantGet(input, parsedPath. dataType, failOnError, zoneId).asInstanceOf[InternalRow]
     ...
   }
   ```
   
   and pick one method according to the return type. It is very cumbersome and doesn't really avoid any boxing/unboxing costs.
   
   On the other hand, if we have this `VariantGetCodegen.variantGet` method, it is reasonably easy to write the codegen by hand. I just need to cast the return value of this method into the desired type. The whole point of using `StaticInvoke` is to simplify the implementation, but I think it actually makes the implementation much more complex.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


Re: [PR] [SPARK-47551][SQL] Add variant_get expression. [spark]

Posted by "cloud-fan (via GitHub)" <gi...@apache.org>.
cloud-fan commented on code in PR #45708:
URL: https://github.com/apache/spark/pull/45708#discussion_r1544181632


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala:
##########
@@ -53,3 +66,311 @@ case class ParseJson(child: Expression)
   override protected def withNewChildInternal(newChild: Expression): ParseJson =
     copy(child = newChild)
 }
+
+object VariantPathParser extends RegexParsers {
+  // A path segment in the `VariantGet` expression represents either an object key access or an
+  // array index access.
+  type PathSegment = Either[String, Int]
+
+  private def root: Parser[Char] = '$'
+
+  // Parse index segment like `[123]`.
+  private def index: Parser[PathSegment] =
+    for {
+      index <- '[' ~> "\\d+".r <~ ']'
+    } yield {
+      scala.util.Right(index.toInt)
+    }
+
+  // Parse key segment like `.name`, `['name']`, or `["name"]`.
+  private def key: Parser[PathSegment] =
+    for {
+      key <- '.' ~> "[^\\.\\[]+".r | "['" ~> "[^\\'\\?]+".r <~ "']" |
+        "[\"" ~> "[^\\\"\\?]+".r <~ "\"]"
+    } yield {
+      scala.util.Left(key)
+    }
+
+  private val parser: Parser[List[PathSegment]] = phrase(root ~> rep(key | index))
+
+  def parse(str: String): Option[Array[PathSegment]] = {
+    this.parseAll(parser, str) match {
+      case Success(result, _) => Some(result.toArray)
+      case _ => None
+    }
+  }
+}
+
+/**
+ * The implementation for `variant_get` and `try_variant_get` expressions. Extracts a sub-variant
+ * value according to a path and cast it into a concrete data type.
+ * @param child The source variant value to extract from.
+ * @param path A literal path expression. It has the same format as the JSON path.
+ * @param targetType The target data type to cast into. Any non-nullable annotations are ignored.
+ * @param failOnError Controls whether the expression should throw an exception or return null if
+ *                    the cast fails.
+ * @param timeZoneId A string identifier of a time zone. It is required by timestamp-related casts.
+ */
+case class VariantGet(
+    child: Expression,
+    path: Expression,
+    targetType: DataType,
+    failOnError: Boolean,
+    timeZoneId: Option[String] = None)
+    extends BinaryExpression
+    with TimeZoneAwareExpression
+    with NullIntolerant
+    with ExpectsInputTypes
+    with CodegenFallback
+    with QueryErrorsBase {
+  override def checkInputDataTypes(): TypeCheckResult = {
+    val check = super.checkInputDataTypes()
+    if (check.isFailure) {
+      check
+    } else if (!path.foldable) {
+      DataTypeMismatch(
+        errorSubClass = "NON_FOLDABLE_INPUT",
+        messageParameters = Map(
+          "inputName" -> toSQLId("path"),
+          "inputType" -> toSQLType(path.dataType),
+          "inputExpr" -> toSQLExpr(path)))
+    } else if (!VariantGet.checkDataType(targetType)) {
+      DataTypeMismatch(
+        errorSubClass = "CAST_WITHOUT_SUGGESTION",
+        messageParameters =
+          Map("srcType" -> toSQLType(VariantType), "targetType" -> toSQLType(targetType)))
+    } else {
+      TypeCheckResult.TypeCheckSuccess
+    }
+  }
+
+  override lazy val dataType: DataType = targetType.asNullable
+
+  @transient private lazy val parsedPath = {
+    val pathValue = path.eval().toString
+    VariantPathParser.parse(pathValue).getOrElse {
+      throw QueryExecutionErrors.invalidVariantGetPath(pathValue, prettyName)
+    }
+  }
+
+  final override def nodePatternsInternal(): Seq[TreePattern] = Seq(VARIANT_GET)
+
+  override def inputTypes: Seq[AbstractDataType] = Seq(VariantType, StringType)
+
+  override def prettyName: String = if (failOnError) "variant_get" else "try_variant_get"
+
+  override def nullable: Boolean = true
+
+  protected override def nullSafeEval(input: Any, path: Any): Any = {
+    VariantGet.variantGet(
+      input.asInstanceOf[VariantVal],
+      parsedPath,
+      dataType,
+      failOnError,
+      timeZoneId)
+  }
+
+  override def left: Expression = child
+
+  override def right: Expression = path
+
+  override protected def withNewChildrenInternal(
+      newChild: Expression,
+      newPath: Expression): VariantGet = copy(child = newChild, path = newPath)
+
+  override def withTimeZone(timeZoneId: String): VariantGet = copy(timeZoneId = Option(timeZoneId))
+}
+
+case object VariantGet {
+  /**
+   * Returns whether a data type can be cast into/from variant. For scalar types, we allow a subset
+   * of them. For nested types, we reject map types with a non-string key type.
+   */
+  def checkDataType(dataType: DataType): Boolean = dataType match {
+    case _: NumericType | BooleanType | StringType | BinaryType | TimestampType | DateType |
+        VariantType =>
+      true
+    case ArrayType(elementType, _) => checkDataType(elementType)
+    case MapType(StringType, valueType, _) => checkDataType(valueType)
+    case StructType(fields) => fields.forall(f => checkDataType(f.dataType))
+    case _ => false
+  }
+
+  /** The actual implementation of the `VariantGet` expression. */
+  def variantGet(
+      input: VariantVal,
+      parsedPath: Array[VariantPathParser.PathSegment],
+      dataType: DataType,
+      failOnError: Boolean,
+      zoneId: Option[String]): Any = {
+    var v = new Variant(input.getValue, input.getMetadata)
+    for (path <- parsedPath) {
+      v = path match {
+        case scala.util.Left(key) if v.getType == Type.OBJECT => v.getFieldByKey(key)
+        case scala.util.Right(index) if v.getType == Type.ARRAY => v.getElementAtIndex(index)
+        case _ => null
+      }
+      if (v == null) return null
+    }
+    VariantGet.cast(v, dataType, failOnError, zoneId)
+  }
+
+  /**
+   * Cast a variant `v` into a target data type `dataType`. If the variant represents a variant
+   * null, the result is always a SQL NULL. The cast may fail due to an illegal type combination
+   * (e.g., cast a variant int to binary), or an invalid input valid (e.g, cast a variant string
+   * "hello" to int). If the cast fails, throw an exception when `failOnError` is true, or return a
+   * SQL NULL when it is false.
+   */
+  def cast(v: Variant, dataType: DataType, failOnError: Boolean, zoneId: Option[String]): Any = {
+    def invalidCast(): Any =
+      if (failOnError) throw QueryExecutionErrors.invalidVariantCast(v.toJson, dataType) else null
+
+    val variantType = v.getType
+    if (variantType == Type.NULL) return null
+    dataType match {
+      case VariantType => new VariantVal(v.getValue, v.getMetadata)
+      case _: AtomicType =>
+        variantType match {
+          case Type.OBJECT | Type.ARRAY =>
+            if (dataType == StringType) UTF8String.fromString(v.toJson) else invalidCast()

Review Comment:
   ```suggestion
               if (dataType.isInstanceOf[StringType]) UTF8String.fromString(v.toJson) else invalidCast()
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


Re: [PR] [SPARK-47551][SQL] Add variant_get expression. [spark]

Posted by "chenhao-db (via GitHub)" <gi...@apache.org>.
chenhao-db commented on code in PR #45708:
URL: https://github.com/apache/spark/pull/45708#discussion_r1541443869


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala:
##########
@@ -63,3 +70,300 @@ case class ParseJson(child: Expression) extends UnaryExpression
   override protected def withNewChildInternal(newChild: Expression): ParseJson =
     copy(child = newChild)
 }
+
+// A path segment in the `VariantGet` expression. It represents either an object key access (when
+// `key` is not null) or an array index access (when `key` is null).
+case class PathSegment(key: String, index: Int)
+
+object VariantPathParser extends RegexParsers {
+  private def root: Parser[Char] = '$'
+
+  // Parse index segment like `[123]`.
+  private def index: Parser[PathSegment] =
+    for {
+      index <- '[' ~> "\\d+".r <~ ']'
+    } yield {
+      PathSegment(null, index.toInt)
+    }
+
+  // Parse key segment like `.name`, `['name']`, or `["name"]`.
+  private def key: Parser[PathSegment] =
+    for {
+      key <- '.' ~> "[^\\.\\[]+".r | "['" ~> "[^\\'\\?]+".r <~ "']" |
+        "[\"" ~> "[^\\\"\\?]+".r <~ "\"]"
+    } yield {
+      PathSegment(key, 0)
+    }
+
+  private val parser: Parser[List[PathSegment]] = phrase(root ~> rep(key | index))
+
+  def parse(str: String): Option[Array[PathSegment]] = {
+    this.parseAll(parser, str) match {
+      case Success(result, _) => Some(result.toArray)
+      case _ => None
+    }
+  }
+}
+
+/**
+ * The implementation for `variant_get` and `try_variant_get` expressions. Extracts a sub-variant
+ * value according to a path and cast it into a concrete data type.
+ * @param child The source variant value to extract from.
+ * @param path A literal path expression. It has the same format as the JSON path.
+ * @param schema The target data type to cast into.
+ * @param failOnError Controls whether the expression should throw an exception or return null if
+ *                    the cast fails.
+ * @param timeZoneId A string identifier of a time zone. It is required by timestamp-related casts.
+ */
+case class VariantGet(
+    child: Expression,
+    path: Expression,
+    schema: DataType,
+    failOnError: Boolean,
+    timeZoneId: Option[String] = None)
+    extends BinaryExpression
+    with TimeZoneAwareExpression
+    with NullIntolerant
+    with ExpectsInputTypes
+    with CodegenFallback
+    with QueryErrorsBase {
+  override def checkInputDataTypes(): TypeCheckResult = {
+    val check = super.checkInputDataTypes()
+    if (check.isFailure) {
+      check
+    } else if (!path.foldable) {
+      DataTypeMismatch(
+        errorSubClass = "NON_FOLDABLE_INPUT",
+        messageParameters = Map(
+          "inputName" -> toSQLId("path"),
+          "inputType" -> toSQLType(path.dataType),
+          "inputExpr" -> toSQLExpr(path)
+        )
+      )
+    } else if (!VariantGet.checkDataType(schema)) {
+      DataTypeMismatch(
+        errorSubClass = "CAST_WITHOUT_SUGGESTION",
+        messageParameters = Map(
+          "srcType" -> toSQLType(VariantType),
+          "targetType" -> toSQLType(schema)
+        )
+      )
+    } else {
+      TypeCheckResult.TypeCheckSuccess
+    }
+  }
+
+  override lazy val dataType: DataType = schema.asNullable
+
+  @transient private lazy val parsedPath = {
+    val pathValue = path.eval().toString
+    VariantPathParser.parse(pathValue).getOrElse {
+      throw QueryExecutionErrors.invalidVariantGetPath(pathValue, prettyName)
+    }
+  }
+
+  final override def nodePatternsInternal(): Seq[TreePattern] = Seq(VARIANT_GET)
+
+  override def inputTypes: Seq[AbstractDataType] = Seq(VariantType, StringType)
+
+  override def prettyName: String = if (failOnError) "variant_get" else "try_variant_get"
+
+  override def nullable: Boolean = true
+
+  protected override def nullSafeEval(input: Any, path: Any): Any = {

Review Comment:
   There is yet another reason against using `StaticInvoke`. The path parameter must be an literal, and I can make use of this requirement to avoid repeated path parsing. However, I cannot find how to do a similar caching in `StaticInvoke`.
   
   Using `StaticInvoke` won't simplify the current implementation. It can indeed simplify the implementation if we want to support native codegen rather than depending on `CodegenFallback`. I think that is an optional optimization we can do in the future, but I would prefer manually writing the codegen for `VariantGet`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


Re: [PR] [SPARK-47551][SQL] Add variant_get expression. [spark]

Posted by "cloud-fan (via GitHub)" <gi...@apache.org>.
cloud-fan commented on code in PR #45708:
URL: https://github.com/apache/spark/pull/45708#discussion_r1541456694


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala:
##########
@@ -63,3 +70,300 @@ case class ParseJson(child: Expression) extends UnaryExpression
   override protected def withNewChildInternal(newChild: Expression): ParseJson =
     copy(child = newChild)
 }
+
+// A path segment in the `VariantGet` expression. It represents either an object key access (when
+// `key` is not null) or an array index access (when `key` is null).
+case class PathSegment(key: String, index: Int)
+
+object VariantPathParser extends RegexParsers {
+  private def root: Parser[Char] = '$'
+
+  // Parse index segment like `[123]`.
+  private def index: Parser[PathSegment] =
+    for {
+      index <- '[' ~> "\\d+".r <~ ']'
+    } yield {
+      PathSegment(null, index.toInt)
+    }
+
+  // Parse key segment like `.name`, `['name']`, or `["name"]`.
+  private def key: Parser[PathSegment] =
+    for {
+      key <- '.' ~> "[^\\.\\[]+".r | "['" ~> "[^\\'\\?]+".r <~ "']" |
+        "[\"" ~> "[^\\\"\\?]+".r <~ "\"]"
+    } yield {
+      PathSegment(key, 0)
+    }
+
+  private val parser: Parser[List[PathSegment]] = phrase(root ~> rep(key | index))
+
+  def parse(str: String): Option[Array[PathSegment]] = {
+    this.parseAll(parser, str) match {
+      case Success(result, _) => Some(result.toArray)
+      case _ => None
+    }
+  }
+}
+
+/**
+ * The implementation for `variant_get` and `try_variant_get` expressions. Extracts a sub-variant
+ * value according to a path and cast it into a concrete data type.
+ * @param child The source variant value to extract from.
+ * @param path A literal path expression. It has the same format as the JSON path.
+ * @param schema The target data type to cast into.
+ * @param failOnError Controls whether the expression should throw an exception or return null if
+ *                    the cast fails.
+ * @param timeZoneId A string identifier of a time zone. It is required by timestamp-related casts.
+ */
+case class VariantGet(
+    child: Expression,
+    path: Expression,
+    schema: DataType,
+    failOnError: Boolean,
+    timeZoneId: Option[String] = None)
+    extends BinaryExpression
+    with TimeZoneAwareExpression
+    with NullIntolerant
+    with ExpectsInputTypes
+    with CodegenFallback
+    with QueryErrorsBase {
+  override def checkInputDataTypes(): TypeCheckResult = {
+    val check = super.checkInputDataTypes()
+    if (check.isFailure) {
+      check
+    } else if (!path.foldable) {
+      DataTypeMismatch(
+        errorSubClass = "NON_FOLDABLE_INPUT",
+        messageParameters = Map(
+          "inputName" -> toSQLId("path"),
+          "inputType" -> toSQLType(path.dataType),
+          "inputExpr" -> toSQLExpr(path)
+        )
+      )
+    } else if (!VariantGet.checkDataType(schema)) {
+      DataTypeMismatch(
+        errorSubClass = "CAST_WITHOUT_SUGGESTION",
+        messageParameters = Map(
+          "srcType" -> toSQLType(VariantType),
+          "targetType" -> toSQLType(schema)
+        )
+      )
+    } else {
+      TypeCheckResult.TypeCheckSuccess
+    }
+  }
+
+  override lazy val dataType: DataType = schema.asNullable
+
+  @transient private lazy val parsedPath = {
+    val pathValue = path.eval().toString
+    VariantPathParser.parse(pathValue).getOrElse {
+      throw QueryExecutionErrors.invalidVariantGetPath(pathValue, prettyName)
+    }
+  }
+
+  final override def nodePatternsInternal(): Seq[TreePattern] = Seq(VARIANT_GET)
+
+  override def inputTypes: Seq[AbstractDataType] = Seq(VariantType, StringType)
+
+  override def prettyName: String = if (failOnError) "variant_get" else "try_variant_get"
+
+  override def nullable: Boolean = true
+
+  protected override def nullSafeEval(input: Any, path: Any): Any = {

Review Comment:
   You can pass anything to `StaticInvoke`, including arbitrary java object, using `Literal` with `ObjectType`. I'm against writing codegen by hand, as it's hard to debug, and error-prone (maybe inconsistent with the interpreted implementation).



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


Re: [PR] [SPARK-47551][SQL] Add variant_get expression. [spark]

Posted by "chenhao-db (via GitHub)" <gi...@apache.org>.
chenhao-db commented on PR #45708:
URL: https://github.com/apache/spark/pull/45708#issuecomment-2030808085

   @cloud-fan Could you help merge it? Thanks!


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


Re: [PR] [SPARK-47551][SQL] Add variant_get expression. [spark]

Posted by "cloud-fan (via GitHub)" <gi...@apache.org>.
cloud-fan commented on code in PR #45708:
URL: https://github.com/apache/spark/pull/45708#discussion_r1546050591


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala:
##########
@@ -53,3 +66,320 @@ case class ParseJson(child: Expression)
   override protected def withNewChildInternal(newChild: Expression): ParseJson =
     copy(child = newChild)
 }
+
+object VariantPathParser extends RegexParsers {
+  // A path segment in the `VariantGet` expression represents either an object key access or an
+  // array index access.
+  type PathSegment = Either[String, Int]
+
+  private def root: Parser[Char] = '$'
+
+  // Parse index segment like `[123]`.
+  private def index: Parser[PathSegment] =
+    for {
+      index <- '[' ~> "\\d+".r <~ ']'
+    } yield {
+      scala.util.Right(index.toInt)
+    }
+
+  // Parse key segment like `.name`, `['name']`, or `["name"]`.
+  private def key: Parser[PathSegment] =
+    for {
+      key <- '.' ~> "[^\\.\\[]+".r | "['" ~> "[^\\'\\?]+".r <~ "']" |
+        "[\"" ~> "[^\\\"\\?]+".r <~ "\"]"
+    } yield {
+      scala.util.Left(key)
+    }
+
+  private val parser: Parser[List[PathSegment]] = phrase(root ~> rep(key | index))
+
+  def parse(str: String): Option[Array[PathSegment]] = {
+    this.parseAll(parser, str) match {
+      case Success(result, _) => Some(result.toArray)
+      case _ => None
+    }
+  }
+}
+
+/**
+ * The implementation for `variant_get` and `try_variant_get` expressions. Extracts a sub-variant
+ * value according to a path and cast it into a concrete data type.
+ * @param child The source variant value to extract from.
+ * @param path A literal path expression. It has the same format as the JSON path.
+ * @param targetType The target data type to cast into. Any non-nullable annotations are ignored.
+ * @param failOnError Controls whether the expression should throw an exception or return null if
+ *                    the cast fails.
+ * @param timeZoneId A string identifier of a time zone. It is required by timestamp-related casts.
+ */
+case class VariantGet(
+    child: Expression,
+    path: Expression,
+    targetType: DataType,
+    failOnError: Boolean,
+    timeZoneId: Option[String] = None)
+    extends BinaryExpression
+    with TimeZoneAwareExpression
+    with NullIntolerant
+    with ExpectsInputTypes
+    with CodegenFallback
+    with QueryErrorsBase {
+  override def checkInputDataTypes(): TypeCheckResult = {
+    val check = super.checkInputDataTypes()
+    if (check.isFailure) {
+      check
+    } else if (!path.foldable) {
+      DataTypeMismatch(
+        errorSubClass = "NON_FOLDABLE_INPUT",
+        messageParameters = Map(
+          "inputName" -> toSQLId("path"),
+          "inputType" -> toSQLType(path.dataType),
+          "inputExpr" -> toSQLExpr(path)))
+    } else if (!VariantGet.checkDataType(targetType)) {
+      DataTypeMismatch(
+        errorSubClass = "CAST_WITHOUT_SUGGESTION",
+        messageParameters =
+          Map("srcType" -> toSQLType(VariantType), "targetType" -> toSQLType(targetType)))
+    } else {
+      TypeCheckResult.TypeCheckSuccess
+    }
+  }
+
+  override lazy val dataType: DataType = targetType.asNullable
+
+  @transient private lazy val parsedPath = {
+    val pathValue = path.eval().toString
+    VariantPathParser.parse(pathValue).getOrElse {
+      throw QueryExecutionErrors.invalidVariantGetPath(pathValue, prettyName)
+    }
+  }
+
+  final override def nodePatternsInternal(): Seq[TreePattern] = Seq(VARIANT_GET)
+
+  override def inputTypes: Seq[AbstractDataType] = Seq(VariantType, StringType)
+
+  override def prettyName: String = if (failOnError) "variant_get" else "try_variant_get"
+
+  override def nullable: Boolean = true
+
+  protected override def nullSafeEval(input: Any, path: Any): Any = {
+    VariantGet.variantGet(
+      input.asInstanceOf[VariantVal],
+      parsedPath,
+      dataType,
+      failOnError,
+      timeZoneId)
+  }
+
+  override def left: Expression = child
+
+  override def right: Expression = path
+
+  override protected def withNewChildrenInternal(
+      newChild: Expression,
+      newPath: Expression): VariantGet = copy(child = newChild, path = newPath)
+
+  override def withTimeZone(timeZoneId: String): VariantGet = copy(timeZoneId = Option(timeZoneId))
+}
+
+case object VariantGet {
+  /**
+   * Returns whether a data type can be cast into/from variant. For scalar types, we allow a subset
+   * of them. For nested types, we reject map types with a non-string key type.
+   */
+  def checkDataType(dataType: DataType): Boolean = dataType match {
+    case _: NumericType | BooleanType | _: StringType | BinaryType | TimestampType | DateType |
+        VariantType =>
+      true
+    case ArrayType(elementType, _) => checkDataType(elementType)
+    case MapType(StringType, valueType, _) => checkDataType(valueType)
+    case StructType(fields) => fields.forall(f => checkDataType(f.dataType))
+    case _ => false
+  }
+
+  /** The actual implementation of the `VariantGet` expression. */
+  def variantGet(
+      input: VariantVal,
+      parsedPath: Array[VariantPathParser.PathSegment],
+      dataType: DataType,
+      failOnError: Boolean,
+      zoneId: Option[String]): Any = {
+    var v = new Variant(input.getValue, input.getMetadata)
+    for (path <- parsedPath) {
+      v = path match {
+        case scala.util.Left(key) if v.getType == Type.OBJECT => v.getFieldByKey(key)
+        case scala.util.Right(index) if v.getType == Type.ARRAY => v.getElementAtIndex(index)
+        case _ => null

Review Comment:
   We can remove this case match. This pattern match is exhausted.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


Re: [PR] [SPARK-47551][SQL] Add variant_get expression. [spark]

Posted by "cloud-fan (via GitHub)" <gi...@apache.org>.
cloud-fan commented on code in PR #45708:
URL: https://github.com/apache/spark/pull/45708#discussion_r1538622950


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala:
##########
@@ -63,3 +70,300 @@ case class ParseJson(child: Expression) extends UnaryExpression
   override protected def withNewChildInternal(newChild: Expression): ParseJson =
     copy(child = newChild)
 }
+
+// A path segment in the `VariantGet` expression. It represents either an object key access (when
+// `key` is not null) or an array index access (when `key` is null).
+case class PathSegment(key: String, index: Int)
+
+object VariantPathParser extends RegexParsers {
+  private def root: Parser[Char] = '$'
+
+  // Parse index segment like `[123]`.
+  private def index: Parser[PathSegment] =
+    for {
+      index <- '[' ~> "\\d+".r <~ ']'
+    } yield {
+      PathSegment(null, index.toInt)
+    }
+
+  // Parse key segment like `.name`, `['name']`, or `["name"]`.
+  private def key: Parser[PathSegment] =
+    for {
+      key <- '.' ~> "[^\\.\\[]+".r | "['" ~> "[^\\'\\?]+".r <~ "']" |
+        "[\"" ~> "[^\\\"\\?]+".r <~ "\"]"
+    } yield {
+      PathSegment(key, 0)
+    }
+
+  private val parser: Parser[List[PathSegment]] = phrase(root ~> rep(key | index))
+
+  def parse(str: String): Option[Array[PathSegment]] = {
+    this.parseAll(parser, str) match {
+      case Success(result, _) => Some(result.toArray)
+      case _ => None
+    }
+  }
+}
+
+/**
+ * The implementation for `variant_get` and `try_variant_get` expressions. Extracts a sub-variant
+ * value according to a path and cast it into a concrete data type.
+ * @param child The source variant value to extract from.
+ * @param path A literal path expression. It has the same format as the JSON path.
+ * @param schema The target data type to cast into.
+ * @param failOnError Controls whether the expression should throw an exception or return null if
+ *                    the cast fails.
+ * @param timeZoneId A string identifier of a time zone. It is required by timestamp-related casts.
+ */
+case class VariantGet(
+    child: Expression,
+    path: Expression,
+    schema: DataType,
+    failOnError: Boolean,
+    timeZoneId: Option[String] = None)
+    extends BinaryExpression
+    with TimeZoneAwareExpression
+    with NullIntolerant
+    with ExpectsInputTypes
+    with CodegenFallback
+    with QueryErrorsBase {
+  override def checkInputDataTypes(): TypeCheckResult = {
+    val check = super.checkInputDataTypes()
+    if (check.isFailure) {
+      check
+    } else if (!path.foldable) {
+      DataTypeMismatch(
+        errorSubClass = "NON_FOLDABLE_INPUT",
+        messageParameters = Map(
+          "inputName" -> toSQLId("path"),
+          "inputType" -> toSQLType(path.dataType),
+          "inputExpr" -> toSQLExpr(path)
+        )
+      )
+    } else if (!VariantGet.checkDataType(schema)) {
+      DataTypeMismatch(
+        errorSubClass = "CAST_WITHOUT_SUGGESTION",
+        messageParameters = Map(
+          "srcType" -> toSQLType(VariantType),
+          "targetType" -> toSQLType(schema)
+        )
+      )
+    } else {
+      TypeCheckResult.TypeCheckSuccess
+    }
+  }
+
+  override lazy val dataType: DataType = schema.asNullable
+
+  @transient private lazy val parsedPath = {
+    val pathValue = path.eval().toString
+    VariantPathParser.parse(pathValue).getOrElse {
+      throw QueryExecutionErrors.invalidVariantGetPath(pathValue, prettyName)
+    }
+  }
+
+  final override def nodePatternsInternal(): Seq[TreePattern] = Seq(VARIANT_GET)
+
+  override def inputTypes: Seq[AbstractDataType] = Seq(VariantType, StringType)
+
+  override def prettyName: String = if (failOnError) "variant_get" else "try_variant_get"
+
+  override def nullable: Boolean = true
+
+  protected override def nullSafeEval(input: Any, path: Any): Any = {

Review Comment:
   Some ideas to simplify the implementation:
   
   Add java util methods for the implementation, e.g.
   ```
   public Object getValueFromVariant(VariantVal variant, String path) {
     ...
   }
   ```
   
   We can add more overloads to avoid boxing
   ```
   public boolean getBooleanFromVariant(VariantVal variant, String path) {
     ...
   }
   ```
   
   Then, `VariantGet` can extend `RuntimeReplaceable` and use `StaticInvoke` to call the java util methods, and also `Cast` to do the cast work
   ```
   case class VariantGet ...
     lazy val replacement: Expression = Cast(StaticInvoke(...), targetType, ansiEnabled = failOnError)
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


Re: [PR] [SPARK-47551][SQL] Add variant_get expression. [spark]

Posted by "chenhao-db (via GitHub)" <gi...@apache.org>.
chenhao-db commented on code in PR #45708:
URL: https://github.com/apache/spark/pull/45708#discussion_r1540176300


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala:
##########
@@ -63,3 +70,300 @@ case class ParseJson(child: Expression) extends UnaryExpression
   override protected def withNewChildInternal(newChild: Expression): ParseJson =
     copy(child = newChild)
 }
+
+// A path segment in the `VariantGet` expression. It represents either an object key access (when
+// `key` is not null) or an array index access (when `key` is null).

Review Comment:
   Good point, changed to `Either` instead.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


Re: [PR] [SPARK-47551][SQL] Add variant_get expression. [spark]

Posted by "cloud-fan (via GitHub)" <gi...@apache.org>.
cloud-fan commented on code in PR #45708:
URL: https://github.com/apache/spark/pull/45708#discussion_r1541232922


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala:
##########
@@ -63,3 +70,300 @@ case class ParseJson(child: Expression) extends UnaryExpression
   override protected def withNewChildInternal(newChild: Expression): ParseJson =
     copy(child = newChild)
 }
+
+// A path segment in the `VariantGet` expression. It represents either an object key access (when
+// `key` is not null) or an array index access (when `key` is null).
+case class PathSegment(key: String, index: Int)
+
+object VariantPathParser extends RegexParsers {
+  private def root: Parser[Char] = '$'
+
+  // Parse index segment like `[123]`.
+  private def index: Parser[PathSegment] =
+    for {
+      index <- '[' ~> "\\d+".r <~ ']'
+    } yield {
+      PathSegment(null, index.toInt)
+    }
+
+  // Parse key segment like `.name`, `['name']`, or `["name"]`.
+  private def key: Parser[PathSegment] =
+    for {
+      key <- '.' ~> "[^\\.\\[]+".r | "['" ~> "[^\\'\\?]+".r <~ "']" |
+        "[\"" ~> "[^\\\"\\?]+".r <~ "\"]"
+    } yield {
+      PathSegment(key, 0)
+    }
+
+  private val parser: Parser[List[PathSegment]] = phrase(root ~> rep(key | index))
+
+  def parse(str: String): Option[Array[PathSegment]] = {
+    this.parseAll(parser, str) match {
+      case Success(result, _) => Some(result.toArray)
+      case _ => None
+    }
+  }
+}
+
+/**
+ * The implementation for `variant_get` and `try_variant_get` expressions. Extracts a sub-variant
+ * value according to a path and cast it into a concrete data type.
+ * @param child The source variant value to extract from.
+ * @param path A literal path expression. It has the same format as the JSON path.
+ * @param schema The target data type to cast into.
+ * @param failOnError Controls whether the expression should throw an exception or return null if
+ *                    the cast fails.
+ * @param timeZoneId A string identifier of a time zone. It is required by timestamp-related casts.
+ */
+case class VariantGet(
+    child: Expression,
+    path: Expression,
+    schema: DataType,
+    failOnError: Boolean,
+    timeZoneId: Option[String] = None)
+    extends BinaryExpression
+    with TimeZoneAwareExpression
+    with NullIntolerant
+    with ExpectsInputTypes
+    with CodegenFallback
+    with QueryErrorsBase {
+  override def checkInputDataTypes(): TypeCheckResult = {
+    val check = super.checkInputDataTypes()
+    if (check.isFailure) {
+      check
+    } else if (!path.foldable) {
+      DataTypeMismatch(
+        errorSubClass = "NON_FOLDABLE_INPUT",
+        messageParameters = Map(
+          "inputName" -> toSQLId("path"),
+          "inputType" -> toSQLType(path.dataType),
+          "inputExpr" -> toSQLExpr(path)
+        )
+      )
+    } else if (!VariantGet.checkDataType(schema)) {
+      DataTypeMismatch(
+        errorSubClass = "CAST_WITHOUT_SUGGESTION",
+        messageParameters = Map(
+          "srcType" -> toSQLType(VariantType),
+          "targetType" -> toSQLType(schema)
+        )
+      )
+    } else {
+      TypeCheckResult.TypeCheckSuccess
+    }
+  }
+
+  override lazy val dataType: DataType = schema.asNullable
+
+  @transient private lazy val parsedPath = {
+    val pathValue = path.eval().toString
+    VariantPathParser.parse(pathValue).getOrElse {
+      throw QueryExecutionErrors.invalidVariantGetPath(pathValue, prettyName)
+    }
+  }
+
+  final override def nodePatternsInternal(): Seq[TreePattern] = Seq(VARIANT_GET)
+
+  override def inputTypes: Seq[AbstractDataType] = Seq(VariantType, StringType)
+
+  override def prettyName: String = if (failOnError) "variant_get" else "try_variant_get"
+
+  override def nullable: Boolean = true
+
+  protected override def nullSafeEval(input: Any, path: Any): Any = {

Review Comment:
   BTW, we can probably wait for https://github.com/apache/spark/pull/45714 to add the codegen utils for variant type.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


Re: [PR] [SPARK-47551][SQL] Add variant_get expression. [spark]

Posted by "chenhao-db (via GitHub)" <gi...@apache.org>.
chenhao-db commented on code in PR #45708:
URL: https://github.com/apache/spark/pull/45708#discussion_r1541513308


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala:
##########
@@ -63,3 +70,300 @@ case class ParseJson(child: Expression) extends UnaryExpression
   override protected def withNewChildInternal(newChild: Expression): ParseJson =
     copy(child = newChild)
 }
+
+// A path segment in the `VariantGet` expression. It represents either an object key access (when
+// `key` is not null) or an array index access (when `key` is null).
+case class PathSegment(key: String, index: Int)
+
+object VariantPathParser extends RegexParsers {
+  private def root: Parser[Char] = '$'
+
+  // Parse index segment like `[123]`.
+  private def index: Parser[PathSegment] =
+    for {
+      index <- '[' ~> "\\d+".r <~ ']'
+    } yield {
+      PathSegment(null, index.toInt)
+    }
+
+  // Parse key segment like `.name`, `['name']`, or `["name"]`.
+  private def key: Parser[PathSegment] =
+    for {
+      key <- '.' ~> "[^\\.\\[]+".r | "['" ~> "[^\\'\\?]+".r <~ "']" |
+        "[\"" ~> "[^\\\"\\?]+".r <~ "\"]"
+    } yield {
+      PathSegment(key, 0)
+    }
+
+  private val parser: Parser[List[PathSegment]] = phrase(root ~> rep(key | index))
+
+  def parse(str: String): Option[Array[PathSegment]] = {
+    this.parseAll(parser, str) match {
+      case Success(result, _) => Some(result.toArray)
+      case _ => None
+    }
+  }
+}
+
+/**
+ * The implementation for `variant_get` and `try_variant_get` expressions. Extracts a sub-variant
+ * value according to a path and cast it into a concrete data type.
+ * @param child The source variant value to extract from.
+ * @param path A literal path expression. It has the same format as the JSON path.
+ * @param schema The target data type to cast into.
+ * @param failOnError Controls whether the expression should throw an exception or return null if
+ *                    the cast fails.
+ * @param timeZoneId A string identifier of a time zone. It is required by timestamp-related casts.
+ */
+case class VariantGet(
+    child: Expression,
+    path: Expression,
+    schema: DataType,
+    failOnError: Boolean,
+    timeZoneId: Option[String] = None)
+    extends BinaryExpression
+    with TimeZoneAwareExpression
+    with NullIntolerant
+    with ExpectsInputTypes
+    with CodegenFallback
+    with QueryErrorsBase {
+  override def checkInputDataTypes(): TypeCheckResult = {
+    val check = super.checkInputDataTypes()
+    if (check.isFailure) {
+      check
+    } else if (!path.foldable) {
+      DataTypeMismatch(
+        errorSubClass = "NON_FOLDABLE_INPUT",
+        messageParameters = Map(
+          "inputName" -> toSQLId("path"),
+          "inputType" -> toSQLType(path.dataType),
+          "inputExpr" -> toSQLExpr(path)
+        )
+      )
+    } else if (!VariantGet.checkDataType(schema)) {
+      DataTypeMismatch(
+        errorSubClass = "CAST_WITHOUT_SUGGESTION",
+        messageParameters = Map(
+          "srcType" -> toSQLType(VariantType),
+          "targetType" -> toSQLType(schema)
+        )
+      )
+    } else {
+      TypeCheckResult.TypeCheckSuccess
+    }
+  }
+
+  override lazy val dataType: DataType = schema.asNullable
+
+  @transient private lazy val parsedPath = {
+    val pathValue = path.eval().toString
+    VariantPathParser.parse(pathValue).getOrElse {
+      throw QueryExecutionErrors.invalidVariantGetPath(pathValue, prettyName)
+    }
+  }
+
+  final override def nodePatternsInternal(): Seq[TreePattern] = Seq(VARIANT_GET)
+
+  override def inputTypes: Seq[AbstractDataType] = Seq(VariantType, StringType)
+
+  override def prettyName: String = if (failOnError) "variant_get" else "try_variant_get"
+
+  override def nullable: Boolean = true
+
+  protected override def nullSafeEval(input: Any, path: Any): Any = {

Review Comment:
   I didn't mean writing everything by hand. Essentially, we create a method that implements `VariantGet`, and the class only needs some boilerplate code to call this method (similar to the code in `StaticInvoke` itself).
   
   There is still another reason why I don't like `StaticInvoke`. In the future, I will write some optimizer rules on `VariantGet` (e.g., to push it down a join). This is why I added a new `TreePattern` `VARIANT_GET` in this PR. The optimizer rule will run after `RuntimeReplaceable` expression is replaced, so it will become `StaticInvoke` and no longer has this tree pattern, and the optimizer rule can no longer prune expressions. Plus, matching against `StaticInvoke` is also more complex than matching against `VariantGet`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


Re: [PR] [SPARK-47551][SQL] Add variant_get expression. [spark]

Posted by "chenhao-db (via GitHub)" <gi...@apache.org>.
chenhao-db commented on code in PR #45708:
URL: https://github.com/apache/spark/pull/45708#discussion_r1546623105


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala:
##########
@@ -53,3 +66,320 @@ case class ParseJson(child: Expression)
   override protected def withNewChildInternal(newChild: Expression): ParseJson =
     copy(child = newChild)
 }
+
+object VariantPathParser extends RegexParsers {
+  // A path segment in the `VariantGet` expression represents either an object key access or an
+  // array index access.
+  type PathSegment = Either[String, Int]
+
+  private def root: Parser[Char] = '$'
+
+  // Parse index segment like `[123]`.
+  private def index: Parser[PathSegment] =
+    for {
+      index <- '[' ~> "\\d+".r <~ ']'
+    } yield {
+      scala.util.Right(index.toInt)
+    }
+
+  // Parse key segment like `.name`, `['name']`, or `["name"]`.
+  private def key: Parser[PathSegment] =
+    for {
+      key <- '.' ~> "[^\\.\\[]+".r | "['" ~> "[^\\'\\?]+".r <~ "']" |
+        "[\"" ~> "[^\\\"\\?]+".r <~ "\"]"
+    } yield {
+      scala.util.Left(key)
+    }
+
+  private val parser: Parser[List[PathSegment]] = phrase(root ~> rep(key | index))
+
+  def parse(str: String): Option[Array[PathSegment]] = {
+    this.parseAll(parser, str) match {
+      case Success(result, _) => Some(result.toArray)
+      case _ => None
+    }
+  }
+}
+
+/**
+ * The implementation for `variant_get` and `try_variant_get` expressions. Extracts a sub-variant
+ * value according to a path and cast it into a concrete data type.
+ * @param child The source variant value to extract from.
+ * @param path A literal path expression. It has the same format as the JSON path.
+ * @param targetType The target data type to cast into. Any non-nullable annotations are ignored.
+ * @param failOnError Controls whether the expression should throw an exception or return null if
+ *                    the cast fails.
+ * @param timeZoneId A string identifier of a time zone. It is required by timestamp-related casts.
+ */
+case class VariantGet(
+    child: Expression,
+    path: Expression,
+    targetType: DataType,
+    failOnError: Boolean,
+    timeZoneId: Option[String] = None)
+    extends BinaryExpression
+    with TimeZoneAwareExpression
+    with NullIntolerant
+    with ExpectsInputTypes
+    with CodegenFallback

Review Comment:
   Done. The current tests should be enough because `checkEvaluation` checks the codegen path.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


Re: [PR] [SPARK-47551][SQL] Add variant_get expression. [spark]

Posted by "cloud-fan (via GitHub)" <gi...@apache.org>.
cloud-fan commented on PR #45708:
URL: https://github.com/apache/spark/pull/45708#issuecomment-2030916612

   thanks, merging to master!


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


Re: [PR] [SPARK-47551][SQL] Add variant_get expression. [spark]

Posted by "cloud-fan (via GitHub)" <gi...@apache.org>.
cloud-fan commented on code in PR #45708:
URL: https://github.com/apache/spark/pull/45708#discussion_r1546051987


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala:
##########
@@ -53,3 +66,320 @@ case class ParseJson(child: Expression)
   override protected def withNewChildInternal(newChild: Expression): ParseJson =
     copy(child = newChild)
 }
+
+object VariantPathParser extends RegexParsers {
+  // A path segment in the `VariantGet` expression represents either an object key access or an
+  // array index access.
+  type PathSegment = Either[String, Int]
+
+  private def root: Parser[Char] = '$'
+
+  // Parse index segment like `[123]`.
+  private def index: Parser[PathSegment] =
+    for {
+      index <- '[' ~> "\\d+".r <~ ']'
+    } yield {
+      scala.util.Right(index.toInt)
+    }
+
+  // Parse key segment like `.name`, `['name']`, or `["name"]`.
+  private def key: Parser[PathSegment] =
+    for {
+      key <- '.' ~> "[^\\.\\[]+".r | "['" ~> "[^\\'\\?]+".r <~ "']" |
+        "[\"" ~> "[^\\\"\\?]+".r <~ "\"]"
+    } yield {
+      scala.util.Left(key)
+    }
+
+  private val parser: Parser[List[PathSegment]] = phrase(root ~> rep(key | index))
+
+  def parse(str: String): Option[Array[PathSegment]] = {
+    this.parseAll(parser, str) match {
+      case Success(result, _) => Some(result.toArray)
+      case _ => None
+    }
+  }
+}
+
+/**
+ * The implementation for `variant_get` and `try_variant_get` expressions. Extracts a sub-variant
+ * value according to a path and cast it into a concrete data type.
+ * @param child The source variant value to extract from.
+ * @param path A literal path expression. It has the same format as the JSON path.
+ * @param targetType The target data type to cast into. Any non-nullable annotations are ignored.
+ * @param failOnError Controls whether the expression should throw an exception or return null if
+ *                    the cast fails.
+ * @param timeZoneId A string identifier of a time zone. It is required by timestamp-related casts.
+ */
+case class VariantGet(
+    child: Expression,
+    path: Expression,
+    targetType: DataType,
+    failOnError: Boolean,
+    timeZoneId: Option[String] = None)
+    extends BinaryExpression
+    with TimeZoneAwareExpression
+    with NullIntolerant
+    with ExpectsInputTypes
+    with CodegenFallback
+    with QueryErrorsBase {
+  override def checkInputDataTypes(): TypeCheckResult = {
+    val check = super.checkInputDataTypes()
+    if (check.isFailure) {
+      check
+    } else if (!path.foldable) {
+      DataTypeMismatch(
+        errorSubClass = "NON_FOLDABLE_INPUT",
+        messageParameters = Map(
+          "inputName" -> toSQLId("path"),
+          "inputType" -> toSQLType(path.dataType),
+          "inputExpr" -> toSQLExpr(path)))
+    } else if (!VariantGet.checkDataType(targetType)) {
+      DataTypeMismatch(
+        errorSubClass = "CAST_WITHOUT_SUGGESTION",
+        messageParameters =
+          Map("srcType" -> toSQLType(VariantType), "targetType" -> toSQLType(targetType)))
+    } else {
+      TypeCheckResult.TypeCheckSuccess
+    }
+  }
+
+  override lazy val dataType: DataType = targetType.asNullable
+
+  @transient private lazy val parsedPath = {
+    val pathValue = path.eval().toString
+    VariantPathParser.parse(pathValue).getOrElse {
+      throw QueryExecutionErrors.invalidVariantGetPath(pathValue, prettyName)
+    }
+  }
+
+  final override def nodePatternsInternal(): Seq[TreePattern] = Seq(VARIANT_GET)
+
+  override def inputTypes: Seq[AbstractDataType] = Seq(VariantType, StringType)
+
+  override def prettyName: String = if (failOnError) "variant_get" else "try_variant_get"
+
+  override def nullable: Boolean = true
+
+  protected override def nullSafeEval(input: Any, path: Any): Any = {
+    VariantGet.variantGet(
+      input.asInstanceOf[VariantVal],
+      parsedPath,
+      dataType,
+      failOnError,
+      timeZoneId)
+  }
+
+  override def left: Expression = child
+
+  override def right: Expression = path
+
+  override protected def withNewChildrenInternal(
+      newChild: Expression,
+      newPath: Expression): VariantGet = copy(child = newChild, path = newPath)
+
+  override def withTimeZone(timeZoneId: String): VariantGet = copy(timeZoneId = Option(timeZoneId))
+}
+
+case object VariantGet {
+  /**
+   * Returns whether a data type can be cast into/from variant. For scalar types, we allow a subset
+   * of them. For nested types, we reject map types with a non-string key type.
+   */
+  def checkDataType(dataType: DataType): Boolean = dataType match {
+    case _: NumericType | BooleanType | _: StringType | BinaryType | TimestampType | DateType |
+        VariantType =>
+      true
+    case ArrayType(elementType, _) => checkDataType(elementType)
+    case MapType(StringType, valueType, _) => checkDataType(valueType)
+    case StructType(fields) => fields.forall(f => checkDataType(f.dataType))
+    case _ => false
+  }
+
+  /** The actual implementation of the `VariantGet` expression. */
+  def variantGet(
+      input: VariantVal,
+      parsedPath: Array[VariantPathParser.PathSegment],
+      dataType: DataType,
+      failOnError: Boolean,
+      zoneId: Option[String]): Any = {
+    var v = new Variant(input.getValue, input.getMetadata)
+    for (path <- parsedPath) {
+      v = path match {
+        case scala.util.Left(key) if v.getType == Type.OBJECT => v.getFieldByKey(key)
+        case scala.util.Right(index) if v.getType == Type.ARRAY => v.getElementAtIndex(index)
+        case _ => null
+      }
+      if (v == null) return null
+    }
+    VariantGet.cast(v, dataType, failOnError, zoneId)
+  }
+
+  /**
+   * Cast a variant `v` into a target data type `dataType`. If the variant represents a variant
+   * null, the result is always a SQL NULL. The cast may fail due to an illegal type combination
+   * (e.g., cast a variant int to binary), or an invalid input valid (e.g, cast a variant string
+   * "hello" to int). If the cast fails, throw an exception when `failOnError` is true, or return a
+   * SQL NULL when it is false.
+   */
+  def cast(v: Variant, dataType: DataType, failOnError: Boolean, zoneId: Option[String]): Any = {
+    def invalidCast(): Any =
+      if (failOnError) throw QueryExecutionErrors.invalidVariantCast(v.toJson, dataType) else null
+
+    val variantType = v.getType
+    if (variantType == Type.NULL) return null
+    dataType match {
+      case VariantType => new VariantVal(v.getValue, v.getMetadata)
+      case _: AtomicType =>

Review Comment:
   nit: we can split it into two case matches to reduce the indentation.
   ```
   case _: StringType =>
     variantType match {
     ...
     }
   case _: AtomicType =>
     ...
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org