You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2018/04/26 01:14:36 UTC
spark git commit: [SPARK-23849][SQL] Tests for samplingRatio of json
datasource
Repository: spark
Updated Branches:
refs/heads/master 95a651339 -> 3f1e999d3
[SPARK-23849][SQL] Tests for samplingRatio of json datasource
## What changes were proposed in this pull request?
Added the `samplingRatio` option to the `json()` method of PySpark DataFrame Reader. Improving existing tests for Scala API according to review of the PR: https://github.com/apache/spark/pull/20959
## How was this patch tested?
Added new test for PySpark, updated 2 existing tests according to reviews of https://github.com/apache/spark/pull/20959 and added new negative test
Author: Maxim Gekk <ma...@databricks.com>
Closes #21056 from MaxGekk/json-sampling.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/3f1e999d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/3f1e999d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/3f1e999d
Branch: refs/heads/master
Commit: 3f1e999d3d215bb3b867bcd83ec5c799448ec730
Parents: 95a6513
Author: Maxim Gekk <ma...@databricks.com>
Authored: Thu Apr 26 09:14:24 2018 +0800
Committer: hyukjinkwon <gu...@apache.org>
Committed: Thu Apr 26 09:14:24 2018 +0800
----------------------------------------------------------------------
python/pyspark/sql/readwriter.py | 7 ++-
python/pyspark/sql/tests.py | 8 +++
.../org/apache/spark/sql/DataFrameReader.scala | 2 +
.../execution/datasources/json/JsonSuite.scala | 63 +++++++++++---------
.../datasources/json/TestJsonData.scala | 12 ++++
5 files changed, 61 insertions(+), 31 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/3f1e999d/python/pyspark/sql/readwriter.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py
index 6bd79bc..df176c5 100644
--- a/python/pyspark/sql/readwriter.py
+++ b/python/pyspark/sql/readwriter.py
@@ -176,7 +176,7 @@ class DataFrameReader(OptionUtils):
allowComments=None, allowUnquotedFieldNames=None, allowSingleQuotes=None,
allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None,
mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None,
- multiLine=None, allowUnquotedControlChars=None, lineSep=None):
+ multiLine=None, allowUnquotedControlChars=None, lineSep=None, samplingRatio=None):
"""
Loads JSON files and returns the results as a :class:`DataFrame`.
@@ -239,6 +239,8 @@ class DataFrameReader(OptionUtils):
including tab and line feed characters) or not.
:param lineSep: defines the line separator that should be used for parsing. If None is
set, it covers all ``\\r``, ``\\r\\n`` and ``\\n``.
+ :param samplingRatio: defines fraction of input JSON objects used for schema inferring.
+ If None is set, it uses the default value, ``1.0``.
>>> df1 = spark.read.json('python/test_support/sql/people.json')
>>> df1.dtypes
@@ -256,7 +258,8 @@ class DataFrameReader(OptionUtils):
allowBackslashEscapingAnyCharacter=allowBackslashEscapingAnyCharacter,
mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat,
timestampFormat=timestampFormat, multiLine=multiLine,
- allowUnquotedControlChars=allowUnquotedControlChars, lineSep=lineSep)
+ allowUnquotedControlChars=allowUnquotedControlChars, lineSep=lineSep,
+ samplingRatio=samplingRatio)
if isinstance(path, basestring):
path = [path]
if type(path) == list:
http://git-wip-us.apache.org/repos/asf/spark/blob/3f1e999d/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 4e99c8e..98fa1b5 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -3018,6 +3018,14 @@ class SQLTests(ReusedSQLTestCase):
df.select(df.name).orderBy(functions.desc_nulls_last('name')).collect(),
[Row(name=u'Tom'), Row(name=u'Alice'), Row(name=None)])
+ def test_json_sampling_ratio(self):
+ rdd = self.spark.sparkContext.range(0, 100, 1, 1) \
+ .map(lambda x: '{"a":0.1}' if x == 1 else '{"a":%s}' % str(x))
+ schema = self.spark.read.option('inferSchema', True) \
+ .option('samplingRatio', 0.5) \
+ .json(rdd).schema
+ self.assertEquals(schema, StructType([StructField("a", LongType(), True)]))
+
class HiveSparkSubmitTests(SparkSubmitTests):
http://git-wip-us.apache.org/repos/asf/spark/blob/3f1e999d/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index d640fdc..b44552f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -374,6 +374,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* per file</li>
* <li>`lineSep` (default covers all `\r`, `\r\n` and `\n`): defines the line separator
* that should be used for parsing.</li>
+ * <li>`samplingRatio` (default is 1.0): defines fraction of input JSON objects used
+ * for schema inferring.</li>
* </ul>
*
* @since 2.0.0
http://git-wip-us.apache.org/repos/asf/spark/blob/3f1e999d/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
index 70aee56..a58dff8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
@@ -2128,38 +2128,43 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
}
}
- test("SPARK-23849: schema inferring touches less data if samplingRation < 1.0") {
- val predefinedSample = Set[Int](2, 8, 15, 27, 30, 34, 35, 37, 44, 46,
- 57, 62, 68, 72)
- withTempPath { path =>
- val writer = Files.newBufferedWriter(Paths.get(path.getAbsolutePath),
- StandardCharsets.UTF_8, StandardOpenOption.CREATE_NEW)
- for (i <- 0 until 100) {
- if (predefinedSample.contains(i)) {
- writer.write(s"""{"f1":${i.toString}}""" + "\n")
- } else {
- writer.write(s"""{"f1":${(i.toDouble + 0.1).toString}}""" + "\n")
- }
- }
- writer.close()
+ test("SPARK-23849: schema inferring touches less data if samplingRatio < 1.0") {
+ // Set default values for the DataSource parameters to make sure
+ // that whole test file is mapped to only one partition. This will guarantee
+ // reliable sampling of the input file.
+ withSQLConf(
+ "spark.sql.files.maxPartitionBytes" -> (128 * 1024 * 1024).toString,
+ "spark.sql.files.openCostInBytes" -> (4 * 1024 * 1024).toString
+ )(withTempPath { path =>
+ val ds = sampledTestData.coalesce(1)
+ ds.write.text(path.getAbsolutePath)
+ val readback = spark.read.option("samplingRatio", 0.1).json(path.getCanonicalPath)
+
+ assert(readback.schema == new StructType().add("f1", LongType))
+ })
+ }
- val ds = spark.read.option("samplingRatio", 0.1).json(path.getCanonicalPath)
- assert(ds.schema == new StructType().add("f1", LongType))
- }
+ test("SPARK-23849: usage of samplingRatio while parsing a dataset of strings") {
+ val ds = sampledTestData.coalesce(1)
+ val readback = spark.read.option("samplingRatio", 0.1).json(ds)
+
+ assert(readback.schema == new StructType().add("f1", LongType))
}
- test("SPARK-23849: usage of samplingRation while parsing of dataset of strings") {
- val dstr = spark.sparkContext.parallelize(0 until 100, 1).map { i =>
- val predefinedSample = Set[Int](2, 8, 15, 27, 30, 34, 35, 37, 44, 46,
- 57, 62, 68, 72)
- if (predefinedSample.contains(i)) {
- s"""{"f1":${i.toString}}""" + "\n"
- } else {
- s"""{"f1":${(i.toDouble + 0.1).toString}}""" + "\n"
- }
- }.toDS()
- val ds = spark.read.option("samplingRatio", 0.1).json(dstr)
+ test("SPARK-23849: samplingRatio is out of the range (0, 1.0]") {
+ val ds = spark.range(0, 100, 1, 1).map(_.toString)
+
+ val errorMsg0 = intercept[IllegalArgumentException] {
+ spark.read.option("samplingRatio", -1).json(ds)
+ }.getMessage
+ assert(errorMsg0.contains("samplingRatio (-1.0) should be greater than 0"))
+
+ val errorMsg1 = intercept[IllegalArgumentException] {
+ spark.read.option("samplingRatio", 0).json(ds)
+ }.getMessage
+ assert(errorMsg1.contains("samplingRatio (0.0) should be greater than 0"))
- assert(ds.schema == new StructType().add("f1", LongType))
+ val sampled = spark.read.option("samplingRatio", 1.0).json(ds)
+ assert(sampled.count() == ds.count())
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/3f1e999d/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala
index 13084ba..6e9559e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala
@@ -233,4 +233,16 @@ private[json] trait TestJsonData {
spark.createDataset(spark.sparkContext.parallelize("""{"a":123}""" :: Nil))(Encoders.STRING)
def empty: Dataset[String] = spark.emptyDataset(Encoders.STRING)
+
+ def sampledTestData: Dataset[String] = {
+ spark.range(0, 100, 1).map { index =>
+ val predefinedSample = Set[Long](2, 8, 15, 27, 30, 34, 35, 37, 44, 46,
+ 57, 62, 68, 72)
+ if (predefinedSample.contains(index)) {
+ s"""{"f1":${index.toString}}"""
+ } else {
+ s"""{"f1":${(index.toDouble + 0.1).toString}}"""
+ }
+ }(Encoders.STRING)
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org