You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by yh...@apache.org on 2015/06/22 22:51:30 UTC

spark git commit: [SPARK-8532] [SQL] In Python's DataFrameWriter, save/saveAsTable/json/parquet/jdbc always override mode

Repository: spark
Updated Branches:
  refs/heads/master da7bbb943 -> 5ab9fcfb0


[SPARK-8532] [SQL] In Python's DataFrameWriter, save/saveAsTable/json/parquet/jdbc always override mode

https://issues.apache.org/jira/browse/SPARK-8532

This PR has two changes. First, it fixes the bug that save actions (i.e. `save/saveAsTable/json/parquet/jdbc`) always override mode. Second, it adds input argument `partitionBy` to `save/saveAsTable/parquet`.

Author: Yin Huai <yh...@databricks.com>

Closes #6937 from yhuai/SPARK-8532 and squashes the following commits:

f972d5d [Yin Huai] davies's comment.
d37abd2 [Yin Huai] style.
d21290a [Yin Huai] Python doc.
889eb25 [Yin Huai] Minor refactoring and add partitionBy to save, saveAsTable, and parquet.
7fbc24b [Yin Huai] Use None instead of "error" as the default value of mode since JVM-side already uses "error" as the default value.
d696dff [Yin Huai] Python style.
88eb6c4 [Yin Huai] If mode is "error", do not call mode method.
c40c461 [Yin Huai] Regression test.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5ab9fcfb
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5ab9fcfb
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5ab9fcfb

Branch: refs/heads/master
Commit: 5ab9fcfb01a0ad2f6c103f67c1a785d3b49e33f0
Parents: da7bbb9
Author: Yin Huai <yh...@databricks.com>
Authored: Mon Jun 22 13:51:23 2015 -0700
Committer: Yin Huai <yh...@databricks.com>
Committed: Mon Jun 22 13:51:23 2015 -0700

----------------------------------------------------------------------
 python/pyspark/sql/readwriter.py | 30 +++++++++++++++++++-----------
 python/pyspark/sql/tests.py      | 32 ++++++++++++++++++++++++++++++++
 2 files changed, 51 insertions(+), 11 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/5ab9fcfb/python/pyspark/sql/readwriter.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py
index f036644..1b7bc0f 100644
--- a/python/pyspark/sql/readwriter.py
+++ b/python/pyspark/sql/readwriter.py
@@ -218,7 +218,10 @@ class DataFrameWriter(object):
 
         >>> df.write.mode('append').parquet(os.path.join(tempfile.mkdtemp(), 'data'))
         """
-        self._jwrite = self._jwrite.mode(saveMode)
+        # At the JVM side, the default value of mode is already set to "error".
+        # So, if the given saveMode is None, we will not call JVM-side's mode method.
+        if saveMode is not None:
+            self._jwrite = self._jwrite.mode(saveMode)
         return self
 
     @since(1.4)
@@ -253,11 +256,12 @@ class DataFrameWriter(object):
         """
         if len(cols) == 1 and isinstance(cols[0], (list, tuple)):
             cols = cols[0]
-        self._jwrite = self._jwrite.partitionBy(_to_seq(self._sqlContext._sc, cols))
+        if len(cols) > 0:
+            self._jwrite = self._jwrite.partitionBy(_to_seq(self._sqlContext._sc, cols))
         return self
 
     @since(1.4)
-    def save(self, path=None, format=None, mode="error", **options):
+    def save(self, path=None, format=None, mode=None, partitionBy=(), **options):
         """Saves the contents of the :class:`DataFrame` to a data source.
 
         The data source is specified by the ``format`` and a set of ``options``.
@@ -272,11 +276,12 @@ class DataFrameWriter(object):
             * ``overwrite``: Overwrite existing data.
             * ``ignore``: Silently ignore this operation if data already exists.
             * ``error`` (default case): Throw an exception if data already exists.
+        :param partitionBy: names of partitioning columns
         :param options: all other string options
 
         >>> df.write.mode('append').parquet(os.path.join(tempfile.mkdtemp(), 'data'))
         """
-        self.mode(mode).options(**options)
+        self.partitionBy(partitionBy).mode(mode).options(**options)
         if format is not None:
             self.format(format)
         if path is None:
@@ -296,7 +301,7 @@ class DataFrameWriter(object):
         self._jwrite.mode("overwrite" if overwrite else "append").insertInto(tableName)
 
     @since(1.4)
-    def saveAsTable(self, name, format=None, mode="error", **options):
+    def saveAsTable(self, name, format=None, mode=None, partitionBy=(), **options):
         """Saves the content of the :class:`DataFrame` as the specified table.
 
         In the case the table already exists, behavior of this function depends on the
@@ -312,15 +317,16 @@ class DataFrameWriter(object):
         :param name: the table name
         :param format: the format used to save
         :param mode: one of `append`, `overwrite`, `error`, `ignore` (default: error)
+        :param partitionBy: names of partitioning columns
         :param options: all other string options
         """
-        self.mode(mode).options(**options)
+        self.partitionBy(partitionBy).mode(mode).options(**options)
         if format is not None:
             self.format(format)
         self._jwrite.saveAsTable(name)
 
     @since(1.4)
-    def json(self, path, mode="error"):
+    def json(self, path, mode=None):
         """Saves the content of the :class:`DataFrame` in JSON format at the specified path.
 
         :param path: the path in any Hadoop supported file system
@@ -333,10 +339,10 @@ class DataFrameWriter(object):
 
         >>> df.write.json(os.path.join(tempfile.mkdtemp(), 'data'))
         """
-        self._jwrite.mode(mode).json(path)
+        self.mode(mode)._jwrite.json(path)
 
     @since(1.4)
-    def parquet(self, path, mode="error"):
+    def parquet(self, path, mode=None, partitionBy=()):
         """Saves the content of the :class:`DataFrame` in Parquet format at the specified path.
 
         :param path: the path in any Hadoop supported file system
@@ -346,13 +352,15 @@ class DataFrameWriter(object):
             * ``overwrite``: Overwrite existing data.
             * ``ignore``: Silently ignore this operation if data already exists.
             * ``error`` (default case): Throw an exception if data already exists.
+        :param partitionBy: names of partitioning columns
 
         >>> df.write.parquet(os.path.join(tempfile.mkdtemp(), 'data'))
         """
-        self._jwrite.mode(mode).parquet(path)
+        self.partitionBy(partitionBy).mode(mode)
+        self._jwrite.parquet(path)
 
     @since(1.4)
-    def jdbc(self, url, table, mode="error", properties={}):
+    def jdbc(self, url, table, mode=None, properties={}):
         """Saves the content of the :class:`DataFrame` to a external database table via JDBC.
 
         .. note:: Don't create too many partitions in parallel on a large cluster;\

http://git-wip-us.apache.org/repos/asf/spark/blob/5ab9fcfb/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index b5fbb7d..13f4556 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -539,6 +539,38 @@ class SQLTests(ReusedPySparkTestCase):
 
         shutil.rmtree(tmpPath)
 
+    def test_save_and_load_builder(self):
+        df = self.df
+        tmpPath = tempfile.mkdtemp()
+        shutil.rmtree(tmpPath)
+        df.write.json(tmpPath)
+        actual = self.sqlCtx.read.json(tmpPath)
+        self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
+
+        schema = StructType([StructField("value", StringType(), True)])
+        actual = self.sqlCtx.read.json(tmpPath, schema)
+        self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect()))
+
+        df.write.mode("overwrite").json(tmpPath)
+        actual = self.sqlCtx.read.json(tmpPath)
+        self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
+
+        df.write.mode("overwrite").options(noUse="this options will not be used in save.")\
+                .format("json").save(path=tmpPath)
+        actual =\
+            self.sqlCtx.read.format("json")\
+                            .load(path=tmpPath, noUse="this options will not be used in load.")
+        self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
+
+        defaultDataSourceName = self.sqlCtx.getConf("spark.sql.sources.default",
+                                                    "org.apache.spark.sql.parquet")
+        self.sqlCtx.sql("SET spark.sql.sources.default=org.apache.spark.sql.json")
+        actual = self.sqlCtx.load(path=tmpPath)
+        self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
+        self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName)
+
+        shutil.rmtree(tmpPath)
+
     def test_help_command(self):
         # Regression test for SPARK-5464
         rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org