Repository: spark
Updated Branches:
  refs/heads/branch-2.1 4fcecb4cf -> 42777b1b3


[SPARK-17462][MLLIB]use VersionUtils to parse Spark version strings

## What changes were proposed in this pull request?

Several places in MLlib use custom regexes or other approaches to parse Spark 
versions.
Those should be fixed to use the VersionUtils. This PR replaces custom regexes 
with
VersionUtils to get Spark version numbers.
## How was this patch tested?

Existing tests.

Signed-off-by: VinceShieh vincent.xieintel.com

Author: VinceShieh <[email protected]>

Closes #15055 from VinceShieh/SPARK-17462.

(cherry picked from commit de77c67750dc868d75d6af173c3820b75a9fe4b7)
Signed-off-by: Sean Owen <[email protected]>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/42777b1b
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/42777b1b
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/42777b1b

Branch: refs/heads/branch-2.1
Commit: 42777b1b3c10d3945494e27f1dedd43f2f836361
Parents: 4fcecb4
Author: VinceShieh <[email protected]>
Authored: Thu Nov 17 13:37:42 2016 +0000
Committer: Sean Owen <[email protected]>
Committed: Thu Nov 17 13:37:53 2016 +0000

----------------------------------------------------------------------
 .../src/main/scala/org/apache/spark/ml/clustering/KMeans.scala | 6 ++----
 mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala     | 6 ++----
 2 files changed, 4 insertions(+), 8 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/42777b1b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala 
b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
index a0d481b..26505b4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
@@ -33,6 +33,7 @@ import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.{DataFrame, Dataset, Row}
 import org.apache.spark.sql.functions.{col, udf}
 import org.apache.spark.sql.types.{IntegerType, StructType}
+import org.apache.spark.util.VersionUtils.majorVersion
 
 /**
  * Common params for KMeans and KMeansModel
@@ -232,10 +233,7 @@ object KMeansModel extends MLReadable[KMeansModel] {
       val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
       val dataPath = new Path(path, "data").toString
 
-      val versionRegex = "([0-9]+)\\.(.+)".r
-      val versionRegex(major, _) = metadata.sparkVersion
-
-      val clusterCenters = if (major.toInt >= 2) {
+      val clusterCenters = if (majorVersion(metadata.sparkVersion) >= 2) {
         val data: Dataset[Data] = sparkSession.read.parquet(dataPath).as[Data]
         
data.collect().sortBy(_.clusterIdx).map(_.clusterCenter).map(OldVectors.fromML)
       } else {

http://git-wip-us.apache.org/repos/asf/spark/blob/42777b1b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
index 444006f..1e49352 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
@@ -34,6 +34,7 @@ import org.apache.spark.rdd.RDD
 import org.apache.spark.sql._
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types.{StructField, StructType}
+import org.apache.spark.util.VersionUtils.majorVersion
 
 /**
  * Params for [[PCA]] and [[PCAModel]].
@@ -204,11 +205,8 @@ object PCAModel extends MLReadable[PCAModel] {
     override def load(path: String): PCAModel = {
       val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
 
-      val versionRegex = "([0-9]+)\\.(.+)".r
-      val versionRegex(major, _) = metadata.sparkVersion
-
       val dataPath = new Path(path, "data").toString
-      val model = if (major.toInt >= 2) {
+      val model = if (majorVersion(metadata.sparkVersion) >= 2) {
         val Row(pc: DenseMatrix, explainedVariance: DenseVector) =
           sparkSession.read.parquet(dataPath)
             .select("pc", "explainedVariance")


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to