You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jo...@apache.org on 2017/01/09 23:18:02 UTC
spark git commit: [SPARK-18952][BACKPORT] Regex strings not properly
escaped in codegen for aggregations
Repository: spark
Updated Branches:
refs/heads/branch-2.1 80a3e13e5 -> 3b6ac323b
[SPARK-18952][BACKPORT] Regex strings not properly escaped in codegen for aggregations
## What changes were proposed in this pull request?
Backport for #16361 to 2.1 branch.
## How was this patch tested?
Unit tests
Author: Burak Yavuz <br...@gmail.com>
Closes #16518 from brkyvz/reg-break-2.1.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/3b6ac323
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/3b6ac323
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/3b6ac323
Branch: refs/heads/branch-2.1
Commit: 3b6ac323b16f8f6d79ee7bac6e7a57f841897d96
Parents: 80a3e13
Author: Burak Yavuz <br...@gmail.com>
Authored: Mon Jan 9 15:17:59 2017 -0800
Committer: Josh Rosen <jo...@databricks.com>
Committed: Mon Jan 9 15:17:59 2017 -0800
----------------------------------------------------------------------
.../execution/aggregate/RowBasedHashMapGenerator.scala | 12 +++++++-----
.../aggregate/VectorizedHashMapGenerator.scala | 12 +++++++-----
.../org/apache/spark/sql/DataFrameAggregateSuite.scala | 9 +++++++++
3 files changed, 23 insertions(+), 10 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/3b6ac323/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala
index a77e178..1b6e6d2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala
@@ -43,28 +43,30 @@ class RowBasedHashMapGenerator(
extends HashMapGenerator (ctx, aggregateExpressions, generatedClassName,
groupingKeySchema, bufferSchema) {
- protected def initializeAggregateHashMap(): String = {
+ override protected def initializeAggregateHashMap(): String = {
val generatedKeySchema: String =
s"new org.apache.spark.sql.types.StructType()" +
groupingKeySchema.map { key =>
+ val keyName = ctx.addReferenceObj(key.name)
key.dataType match {
case d: DecimalType =>
- s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.createDecimalType(
+ s""".add("$keyName", org.apache.spark.sql.types.DataTypes.createDecimalType(
|${d.precision}, ${d.scale}))""".stripMargin
case _ =>
- s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.${key.dataType})"""
+ s""".add("$keyName", org.apache.spark.sql.types.DataTypes.${key.dataType})"""
}
}.mkString("\n").concat(";")
val generatedValueSchema: String =
s"new org.apache.spark.sql.types.StructType()" +
bufferSchema.map { key =>
+ val keyName = ctx.addReferenceObj(key.name)
key.dataType match {
case d: DecimalType =>
- s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.createDecimalType(
+ s""".add("$keyName", org.apache.spark.sql.types.DataTypes.createDecimalType(
|${d.precision}, ${d.scale}))""".stripMargin
case _ =>
- s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.${key.dataType})"""
+ s""".add("$keyName", org.apache.spark.sql.types.DataTypes.${key.dataType})"""
}
}.mkString("\n").concat(";")
http://git-wip-us.apache.org/repos/asf/spark/blob/3b6ac323/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala
index 7418df9..586328a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala
@@ -48,28 +48,30 @@ class VectorizedHashMapGenerator(
extends HashMapGenerator (ctx, aggregateExpressions, generatedClassName,
groupingKeySchema, bufferSchema) {
- protected def initializeAggregateHashMap(): String = {
+ override protected def initializeAggregateHashMap(): String = {
val generatedSchema: String =
s"new org.apache.spark.sql.types.StructType()" +
(groupingKeySchema ++ bufferSchema).map { key =>
+ val keyName = ctx.addReferenceObj(key.name)
key.dataType match {
case d: DecimalType =>
- s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.createDecimalType(
+ s""".add("$keyName", org.apache.spark.sql.types.DataTypes.createDecimalType(
|${d.precision}, ${d.scale}))""".stripMargin
case _ =>
- s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.${key.dataType})"""
+ s""".add("$keyName", org.apache.spark.sql.types.DataTypes.${key.dataType})"""
}
}.mkString("\n").concat(";")
val generatedAggBufferSchema: String =
s"new org.apache.spark.sql.types.StructType()" +
bufferSchema.map { key =>
+ val keyName = ctx.addReferenceObj(key.name)
key.dataType match {
case d: DecimalType =>
- s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.createDecimalType(
+ s""".add("$keyName", org.apache.spark.sql.types.DataTypes.createDecimalType(
|${d.precision}, ${d.scale}))""".stripMargin
case _ =>
- s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.${key.dataType})"""
+ s""".add("$keyName", org.apache.spark.sql.types.DataTypes.${key.dataType})"""
}
}.mkString("\n").concat(";")
http://git-wip-us.apache.org/repos/asf/spark/blob/3b6ac323/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 6451759..7853b22fe 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -97,6 +97,15 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
)
}
+ test("SPARK-18952: regexes fail codegen when used as keys due to bad forward-slash escapes") {
+ val df = Seq(("some[thing]", "random-string")).toDF("key", "val")
+
+ checkAnswer(
+ df.groupBy(regexp_extract('key, "([a-z]+)\\[", 1)).count(),
+ Row("some", 1) :: Nil
+ )
+ }
+
test("rollup") {
checkAnswer(
courseSales.rollup("course", "year").sum("earnings"),
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org