You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by li...@apache.org on 2016/04/11 11:14:29 UTC

spark git commit: [SPARK-14372][SQL] Dataset.randomSplit() needs a Java version

Repository: spark
Updated Branches:
  refs/heads/master 1a0cca1fc -> e82d95bf6


[SPARK-14372][SQL] Dataset.randomSplit() needs a Java version

## What changes were proposed in this pull request?

1.Added method randomSplitAsList() in Dataset for java
for https://issues.apache.org/jira/browse/SPARK-14372

## How was this patch tested?

TestSuite

Author: Rekha Joshi <re...@gmail.com>
Author: Joshi <re...@gmail.com>

Closes #12184 from rekhajoshm/SPARK-14372.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e82d95bf
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e82d95bf
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e82d95bf

Branch: refs/heads/master
Commit: e82d95bf63f57cefa02dc545ceb451ecdeedce28
Parents: 1a0cca1
Author: Rekha Joshi <re...@gmail.com>
Authored: Mon Apr 11 17:13:30 2016 +0800
Committer: Cheng Lian <li...@databricks.com>
Committed: Mon Apr 11 17:13:30 2016 +0800

----------------------------------------------------------------------
 .../main/scala/org/apache/spark/sql/Dataset.scala  | 17 ++++++++++++++++-
 .../org/apache/spark/sql/JavaDatasetSuite.java     | 10 ++++++++++
 2 files changed, 26 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/e82d95bf/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 2f6d8d1..e216945 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -20,7 +20,6 @@ package org.apache.spark.sql
 import java.io.CharArrayWriter
 
 import scala.collection.JavaConverters._
-import scala.collection.mutable.ArrayBuffer
 import scala.language.implicitConversions
 import scala.reflect.runtime.universe.TypeTag
 import scala.util.control.NonFatal
@@ -1493,6 +1492,8 @@ class Dataset[T] private[sql](
    * @param weights weights for splits, will be normalized if they don't sum to 1.
    * @param seed Seed for sampling.
    *
+   * For Java API, use [[randomSplitAsList]].
+   *
    * @group typedrel
    * @since 2.0.0
    */
@@ -1511,6 +1512,20 @@ class Dataset[T] private[sql](
   }
 
   /**
+   * Returns a Java list that contains randomly split [[Dataset]] with the provided weights.
+   *
+   * @param weights weights for splits, will be normalized if they don't sum to 1.
+   * @param seed Seed for sampling.
+   *
+   * @group typedrel
+   * @since 2.0.0
+   */
+  def randomSplitAsList(weights: Array[Double], seed: Long): java.util.List[Dataset[T]] = {
+    val values = randomSplit(weights, seed)
+    java.util.Arrays.asList(values : _*)
+  }
+
+  /**
    * Randomly splits this [[Dataset]] with the provided weights.
    *
    * @param weights weights for splits, will be normalized if they don't sum to 1.

http://git-wip-us.apache.org/repos/asf/spark/blob/e82d95bf/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
----------------------------------------------------------------------
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index f26c57b..5abd62c 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -454,6 +454,16 @@ public class JavaDatasetSuite implements Serializable {
     Assert.assertEquals(data, ds.collectAsList());
   }
 
+  @Test
+  public void testRandomSplit() {
+    List<String> data = Arrays.asList("hello", "world", "from", "spark");
+    Dataset<String> ds = context.createDataset(data, Encoders.STRING());
+    double[] arraySplit = {1, 2, 3};
+
+    List<Dataset<String>> randomSplit =  ds.randomSplitAsList(arraySplit, 1);
+    Assert.assertEquals("wrong number of splits", randomSplit.size(), 3);
+  }
+
   /**
    * For testing error messages when creating an encoder on a private class. This is done
    * here since we cannot create truly private classes in Scala.


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org