You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by sr...@apache.org on 2016/11/17 13:37:50 UTC

spark git commit: [SPARK-17462][MLLIB]use VersionUtils to parse Spark version strings

Repository: spark
Updated Branches:
  refs/heads/master 49b6f456a -> de77c6775


[SPARK-17462][MLLIB]use VersionUtils to parse Spark version strings

## What changes were proposed in this pull request?

Several places in MLlib use custom regexes or other approaches to parse Spark versions.
Those should be fixed to use the VersionUtils. This PR replaces custom regexes with
VersionUtils to get Spark version numbers.
## How was this patch tested?

Existing tests.

Signed-off-by: VinceShieh vincent.xieintel.com

Author: VinceShieh <vi...@intel.com>

Closes #15055 from VinceShieh/SPARK-17462.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/de77c677
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/de77c677
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/de77c677

Branch: refs/heads/master
Commit: de77c67750dc868d75d6af173c3820b75a9fe4b7
Parents: 49b6f45
Author: VinceShieh <vi...@intel.com>
Authored: Thu Nov 17 13:37:42 2016 +0000
Committer: Sean Owen <so...@cloudera.com>
Committed: Thu Nov 17 13:37:42 2016 +0000

----------------------------------------------------------------------
 .../src/main/scala/org/apache/spark/ml/clustering/KMeans.scala | 6 ++----
 mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala     | 6 ++----
 2 files changed, 4 insertions(+), 8 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/de77c677/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
index a0d481b..26505b4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
@@ -33,6 +33,7 @@ import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.{DataFrame, Dataset, Row}
 import org.apache.spark.sql.functions.{col, udf}
 import org.apache.spark.sql.types.{IntegerType, StructType}
+import org.apache.spark.util.VersionUtils.majorVersion
 
 /**
  * Common params for KMeans and KMeansModel
@@ -232,10 +233,7 @@ object KMeansModel extends MLReadable[KMeansModel] {
       val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
       val dataPath = new Path(path, "data").toString
 
-      val versionRegex = "([0-9]+)\\.(.+)".r
-      val versionRegex(major, _) = metadata.sparkVersion
-
-      val clusterCenters = if (major.toInt >= 2) {
+      val clusterCenters = if (majorVersion(metadata.sparkVersion) >= 2) {
         val data: Dataset[Data] = sparkSession.read.parquet(dataPath).as[Data]
         data.collect().sortBy(_.clusterIdx).map(_.clusterCenter).map(OldVectors.fromML)
       } else {

http://git-wip-us.apache.org/repos/asf/spark/blob/de77c677/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
index 444006f..1e49352 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
@@ -34,6 +34,7 @@ import org.apache.spark.rdd.RDD
 import org.apache.spark.sql._
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types.{StructField, StructType}
+import org.apache.spark.util.VersionUtils.majorVersion
 
 /**
  * Params for [[PCA]] and [[PCAModel]].
@@ -204,11 +205,8 @@ object PCAModel extends MLReadable[PCAModel] {
     override def load(path: String): PCAModel = {
       val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
 
-      val versionRegex = "([0-9]+)\\.(.+)".r
-      val versionRegex(major, _) = metadata.sparkVersion
-
       val dataPath = new Path(path, "data").toString
-      val model = if (major.toInt >= 2) {
+      val model = if (majorVersion(metadata.sparkVersion) >= 2) {
         val Row(pc: DenseMatrix, explainedVariance: DenseVector) =
           sparkSession.read.parquet(dataPath)
             .select("pc", "explainedVariance")


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org