You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2018/04/26 01:14:36 UTC

spark git commit: [SPARK-23849][SQL] Tests for samplingRatio of json datasource

Repository: spark
Updated Branches:
  refs/heads/master 95a651339 -> 3f1e999d3


[SPARK-23849][SQL] Tests for samplingRatio of json datasource

## What changes were proposed in this pull request?

Added the `samplingRatio` option to the `json()` method of PySpark DataFrame Reader. Improving existing tests for Scala API according to review of the PR: https://github.com/apache/spark/pull/20959

## How was this patch tested?

Added new test for PySpark, updated 2 existing tests according to reviews of https://github.com/apache/spark/pull/20959 and added new negative test

Author: Maxim Gekk <ma...@databricks.com>

Closes #21056 from MaxGekk/json-sampling.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/3f1e999d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/3f1e999d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/3f1e999d

Branch: refs/heads/master
Commit: 3f1e999d3d215bb3b867bcd83ec5c799448ec730
Parents: 95a6513
Author: Maxim Gekk <ma...@databricks.com>
Authored: Thu Apr 26 09:14:24 2018 +0800
Committer: hyukjinkwon <gu...@apache.org>
Committed: Thu Apr 26 09:14:24 2018 +0800

----------------------------------------------------------------------
 python/pyspark/sql/readwriter.py                |  7 ++-
 python/pyspark/sql/tests.py                     |  8 +++
 .../org/apache/spark/sql/DataFrameReader.scala  |  2 +
 .../execution/datasources/json/JsonSuite.scala  | 63 +++++++++++---------
 .../datasources/json/TestJsonData.scala         | 12 ++++
 5 files changed, 61 insertions(+), 31 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/3f1e999d/python/pyspark/sql/readwriter.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py
index 6bd79bc..df176c5 100644
--- a/python/pyspark/sql/readwriter.py
+++ b/python/pyspark/sql/readwriter.py
@@ -176,7 +176,7 @@ class DataFrameReader(OptionUtils):
              allowComments=None, allowUnquotedFieldNames=None, allowSingleQuotes=None,
              allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None,
              mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None,
-             multiLine=None, allowUnquotedControlChars=None, lineSep=None):
+             multiLine=None, allowUnquotedControlChars=None, lineSep=None, samplingRatio=None):
         """
         Loads JSON files and returns the results as a :class:`DataFrame`.
 
@@ -239,6 +239,8 @@ class DataFrameReader(OptionUtils):
                                           including tab and line feed characters) or not.
         :param lineSep: defines the line separator that should be used for parsing. If None is
                         set, it covers all ``\\r``, ``\\r\\n`` and ``\\n``.
+        :param samplingRatio: defines fraction of input JSON objects used for schema inferring.
+                              If None is set, it uses the default value, ``1.0``.
 
         >>> df1 = spark.read.json('python/test_support/sql/people.json')
         >>> df1.dtypes
@@ -256,7 +258,8 @@ class DataFrameReader(OptionUtils):
             allowBackslashEscapingAnyCharacter=allowBackslashEscapingAnyCharacter,
             mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat,
             timestampFormat=timestampFormat, multiLine=multiLine,
-            allowUnquotedControlChars=allowUnquotedControlChars, lineSep=lineSep)
+            allowUnquotedControlChars=allowUnquotedControlChars, lineSep=lineSep,
+            samplingRatio=samplingRatio)
         if isinstance(path, basestring):
             path = [path]
         if type(path) == list:

http://git-wip-us.apache.org/repos/asf/spark/blob/3f1e999d/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 4e99c8e..98fa1b5 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -3018,6 +3018,14 @@ class SQLTests(ReusedSQLTestCase):
             df.select(df.name).orderBy(functions.desc_nulls_last('name')).collect(),
             [Row(name=u'Tom'), Row(name=u'Alice'), Row(name=None)])
 
+    def test_json_sampling_ratio(self):
+        rdd = self.spark.sparkContext.range(0, 100, 1, 1) \
+            .map(lambda x: '{"a":0.1}' if x == 1 else '{"a":%s}' % str(x))
+        schema = self.spark.read.option('inferSchema', True) \
+            .option('samplingRatio', 0.5) \
+            .json(rdd).schema
+        self.assertEquals(schema, StructType([StructField("a", LongType(), True)]))
+
 
 class HiveSparkSubmitTests(SparkSubmitTests):
 

http://git-wip-us.apache.org/repos/asf/spark/blob/3f1e999d/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index d640fdc..b44552f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -374,6 +374,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
    * per file</li>
    * <li>`lineSep` (default covers all `\r`, `\r\n` and `\n`): defines the line separator
    * that should be used for parsing.</li>
+   * <li>`samplingRatio` (default is 1.0): defines fraction of input JSON objects used
+   * for schema inferring.</li>
    * </ul>
    *
    * @since 2.0.0

http://git-wip-us.apache.org/repos/asf/spark/blob/3f1e999d/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
index 70aee56..a58dff8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
@@ -2128,38 +2128,43 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
     }
   }
 
-  test("SPARK-23849: schema inferring touches less data if samplingRation < 1.0") {
-    val predefinedSample = Set[Int](2, 8, 15, 27, 30, 34, 35, 37, 44, 46,
-      57, 62, 68, 72)
-    withTempPath { path =>
-      val writer = Files.newBufferedWriter(Paths.get(path.getAbsolutePath),
-        StandardCharsets.UTF_8, StandardOpenOption.CREATE_NEW)
-      for (i <- 0 until 100) {
-        if (predefinedSample.contains(i)) {
-          writer.write(s"""{"f1":${i.toString}}""" + "\n")
-        } else {
-          writer.write(s"""{"f1":${(i.toDouble + 0.1).toString}}""" + "\n")
-        }
-      }
-      writer.close()
+  test("SPARK-23849: schema inferring touches less data if samplingRatio < 1.0") {
+    // Set default values for the DataSource parameters to make sure
+    // that whole test file is mapped to only one partition. This will guarantee
+    // reliable sampling of the input file.
+    withSQLConf(
+      "spark.sql.files.maxPartitionBytes" -> (128 * 1024 * 1024).toString,
+      "spark.sql.files.openCostInBytes" -> (4 * 1024 * 1024).toString
+    )(withTempPath { path =>
+      val ds = sampledTestData.coalesce(1)
+      ds.write.text(path.getAbsolutePath)
+      val readback = spark.read.option("samplingRatio", 0.1).json(path.getCanonicalPath)
+
+      assert(readback.schema == new StructType().add("f1", LongType))
+    })
+  }
 
-      val ds = spark.read.option("samplingRatio", 0.1).json(path.getCanonicalPath)
-      assert(ds.schema == new StructType().add("f1", LongType))
-    }
+  test("SPARK-23849: usage of samplingRatio while parsing a dataset of strings") {
+    val ds = sampledTestData.coalesce(1)
+    val readback = spark.read.option("samplingRatio", 0.1).json(ds)
+
+    assert(readback.schema == new StructType().add("f1", LongType))
   }
 
-  test("SPARK-23849: usage of samplingRation while parsing of dataset of strings") {
-    val dstr = spark.sparkContext.parallelize(0 until 100, 1).map { i =>
-      val predefinedSample = Set[Int](2, 8, 15, 27, 30, 34, 35, 37, 44, 46,
-        57, 62, 68, 72)
-      if (predefinedSample.contains(i)) {
-        s"""{"f1":${i.toString}}""" + "\n"
-      } else {
-        s"""{"f1":${(i.toDouble + 0.1).toString}}""" + "\n"
-      }
-    }.toDS()
-    val ds = spark.read.option("samplingRatio", 0.1).json(dstr)
+  test("SPARK-23849: samplingRatio is out of the range (0, 1.0]") {
+    val ds = spark.range(0, 100, 1, 1).map(_.toString)
+
+    val errorMsg0 = intercept[IllegalArgumentException] {
+      spark.read.option("samplingRatio", -1).json(ds)
+    }.getMessage
+    assert(errorMsg0.contains("samplingRatio (-1.0) should be greater than 0"))
+
+    val errorMsg1 = intercept[IllegalArgumentException] {
+      spark.read.option("samplingRatio", 0).json(ds)
+    }.getMessage
+    assert(errorMsg1.contains("samplingRatio (0.0) should be greater than 0"))
 
-    assert(ds.schema == new StructType().add("f1", LongType))
+    val sampled = spark.read.option("samplingRatio", 1.0).json(ds)
+    assert(sampled.count() == ds.count())
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/3f1e999d/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala
index 13084ba..6e9559e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala
@@ -233,4 +233,16 @@ private[json] trait TestJsonData {
     spark.createDataset(spark.sparkContext.parallelize("""{"a":123}""" :: Nil))(Encoders.STRING)
 
   def empty: Dataset[String] = spark.emptyDataset(Encoders.STRING)
+
+  def sampledTestData: Dataset[String] = {
+    spark.range(0, 100, 1).map { index =>
+      val predefinedSample = Set[Long](2, 8, 15, 27, 30, 34, 35, 37, 44, 46,
+        57, 62, 68, 72)
+      if (predefinedSample.contains(index)) {
+        s"""{"f1":${index.toString}}"""
+      } else {
+        s"""{"f1":${(index.toDouble + 0.1).toString}}"""
+      }
+    }(Encoders.STRING)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org