This is an automated email from the ASF dual-hosted git repository.
srowen pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.0 by this push:
new 9d770d3 [SPARK-30929][ML] ML, GraphX 3.0 QA: API: New Scala APIs, docs
9d770d3 is described below
commit 9d770d38fdf4a733e7bd12baf04530adafa2be4a
Author: Huaxin Gao <[email protected]>
AuthorDate: Mon Mar 9 09:11:21 2020 -0500
[SPARK-30929][ML] ML, GraphX 3.0 QA: API: New Scala APIs, docs
### What changes were proposed in this pull request?
Auditing new ML Scala APIs introduced in 3.0. Fix found issues.
### Why are the changes needed?
### Does this PR introduce any user-facing change?
Yes. Some doc changes
### How was this patch tested?
Existing tests
Closes #27818 from huaxingao/spark-30929.
Authored-by: Huaxin Gao <[email protected]>
Signed-off-by: Sean Owen <[email protected]>
(cherry picked from commit b6b0343e3e90f0421cac277ed5ae8fb15b278d4e)
Signed-off-by: Sean Owen <[email protected]>
---
.../spark/ml/classification/FMClassifier.scala | 2 +-
.../MultilabelClassificationEvaluator.scala | 20 ++++++++++++++++++--
.../spark/ml/evaluation/RankingEvaluator.scala | 16 ++++++++++++++--
.../org/apache/spark/ml/feature/RobustScaler.scala | 10 +++++-----
.../spark/ml/regression/AFTSurvivalRegression.scala | 4 ++--
.../org/apache/spark/ml/regression/FMRegressor.scala | 2 +-
.../scala/org/apache/spark/ml/tree/treeParams.scala | 3 +++
7 files changed, 44 insertions(+), 13 deletions(-)
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala
b/mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala
index 3057d51..a4d2427 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala
@@ -187,7 +187,7 @@ class FMClassifier @Since("3.0.0") (
@Since("3.0.0")
def setSeed(value: Long): this.type = set(seed, value)
- override protected[spark] def train(
+ override protected def train(
dataset: Dataset[_]
): FMClassificationModel = instrumented { instr =>
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MultilabelClassificationEvaluator.scala
b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MultilabelClassificationEvaluator.scala
index 5216c40..a8db545 100644
---
a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MultilabelClassificationEvaluator.scala
+++
b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MultilabelClassificationEvaluator.scala
@@ -34,12 +34,13 @@ import org.apache.spark.sql.types._
*/
@Since("3.0.0")
@Experimental
-class MultilabelClassificationEvaluator (override val uid: String)
+class MultilabelClassificationEvaluator @Since("3.0.0") (@Since("3.0.0")
override val uid: String)
extends Evaluator with HasPredictionCol with HasLabelCol
with DefaultParamsWritable {
import MultilabelClassificationEvaluator.supportedMetricNames
+ @Since("3.0.0")
def this() = this(Identifiable.randomUID("mlcEval"))
/**
@@ -49,6 +50,7 @@ class MultilabelClassificationEvaluator (override val uid:
String)
* `"microF1Measure"`)
* @group param
*/
+ @Since("3.0.0")
final val metricName: Param[String] = {
val allowedParams = ParamValidators.inArray(supportedMetricNames)
new Param(this, "metricName", "metric name in evaluation " +
@@ -56,13 +58,21 @@ class MultilabelClassificationEvaluator (override val uid:
String)
}
/** @group getParam */
+ @Since("3.0.0")
def getMetricName: String = $(metricName)
/** @group setParam */
+ @Since("3.0.0")
def setMetricName(value: String): this.type = set(metricName, value)
setDefault(metricName -> "f1Measure")
+ /**
+ * param for the class whose metric will be computed in
`"precisionByLabel"`, `"recallByLabel"`,
+ * `"f1MeasureByLabel"`.
+ * @group param
+ */
+ @Since("3.0.0")
final val metricLabel: DoubleParam = new DoubleParam(this, "metricLabel",
"The class whose metric will be computed in " +
s"${supportedMetricNames.filter(_.endsWith("ByLabel")).mkString("(",
"|", ")")}. " +
@@ -70,6 +80,7 @@ class MultilabelClassificationEvaluator (override val uid:
String)
ParamValidators.gtEq(0.0))
/** @group getParam */
+ @Since("3.0.0")
def getMetricLabel: Double = $(metricLabel)
/** @group setParam */
@@ -78,12 +89,14 @@ class MultilabelClassificationEvaluator (override val uid:
String)
setDefault(metricLabel -> 0.0)
/** @group setParam */
+ @Since("3.0.0")
def setPredictionCol(value: String): this.type = set(predictionCol, value)
/** @group setParam */
+ @Since("3.0.0")
def setLabelCol(value: String): this.type = set(labelCol, value)
-
+ @Since("3.0.0")
override def evaluate(dataset: Dataset[_]): Double = {
val schema = dataset.schema
SchemaUtils.checkColumnTypes(schema, $(predictionCol),
@@ -113,6 +126,7 @@ class MultilabelClassificationEvaluator (override val uid:
String)
}
}
+ @Since("3.0.0")
override def isLargerBetter: Boolean = {
$(metricName) match {
case "hammingLoss" => false
@@ -120,6 +134,7 @@ class MultilabelClassificationEvaluator (override val uid:
String)
}
}
+ @Since("3.0.0")
override def copy(extra: ParamMap): MultilabelClassificationEvaluator =
defaultCopy(extra)
@Since("3.0.0")
@@ -139,5 +154,6 @@ object MultilabelClassificationEvaluator
"precisionByLabel", "recallByLabel", "f1MeasureByLabel",
"microPrecision", "microRecall", "microF1Measure")
+ @Since("3.0.0")
override def load(path: String): MultilabelClassificationEvaluator =
super.load(path)
}
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RankingEvaluator.scala
b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RankingEvaluator.scala
index 8d017eb..c5dea6c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RankingEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RankingEvaluator.scala
@@ -33,11 +33,12 @@ import org.apache.spark.sql.types._
*/
@Experimental
@Since("3.0.0")
-class RankingEvaluator (override val uid: String)
+class RankingEvaluator @Since("3.0.0") (@Since("3.0.0") override val uid:
String)
extends Evaluator with HasPredictionCol with HasLabelCol with
DefaultParamsWritable {
import RankingEvaluator.supportedMetricNames
+ @Since("3.0.0")
def this() = this(Identifiable.randomUID("rankEval"))
/**
@@ -45,6 +46,7 @@ class RankingEvaluator (override val uid: String)
* `"meanAveragePrecisionAtK"`, `"precisionAtK"`, `"ndcgAtK"`, `"recallAtK"`)
* @group param
*/
+ @Since("3.0.0")
final val metricName: Param[String] = {
val allowedParams = ParamValidators.inArray(supportedMetricNames)
new Param(this, "metricName", "metric name in evaluation " +
@@ -52,9 +54,11 @@ class RankingEvaluator (override val uid: String)
}
/** @group getParam */
+ @Since("3.0.0")
def getMetricName: String = $(metricName)
/** @group setParam */
+ @Since("3.0.0")
def setMetricName(value: String): this.type = set(metricName, value)
setDefault(metricName -> "meanAveragePrecision")
@@ -64,6 +68,7 @@ class RankingEvaluator (override val uid: String)
* `"ndcgAtK"`, `"recallAtK"`. Must be > 0. The default value is 10.
* @group param
*/
+ @Since("3.0.0")
final val k = new IntParam(this, "k",
"The ranking position value used in " +
s"${supportedMetricNames.filter(_.endsWith("AtK")).mkString("(", "|",
")")} " +
@@ -71,20 +76,24 @@ class RankingEvaluator (override val uid: String)
ParamValidators.gt(0))
/** @group getParam */
+ @Since("3.0.0")
def getK: Int = $(k)
/** @group setParam */
+ @Since("3.0.0")
def setK(value: Int): this.type = set(k, value)
setDefault(k -> 10)
/** @group setParam */
+ @Since("3.0.0")
def setPredictionCol(value: String): this.type = set(predictionCol, value)
/** @group setParam */
+ @Since("3.0.0")
def setLabelCol(value: String): this.type = set(labelCol, value)
-
+ @Since("3.0.0")
override def evaluate(dataset: Dataset[_]): Double = {
val schema = dataset.schema
SchemaUtils.checkColumnTypes(schema, $(predictionCol),
@@ -107,8 +116,10 @@ class RankingEvaluator (override val uid: String)
}
}
+ @Since("3.0.0")
override def isLargerBetter: Boolean = true
+ @Since("3.0.0")
override def copy(extra: ParamMap): RankingEvaluator = defaultCopy(extra)
@Since("3.0.0")
@@ -124,5 +135,6 @@ object RankingEvaluator extends
DefaultParamsReadable[RankingEvaluator] {
private val supportedMetricNames = Array("meanAveragePrecision",
"meanAveragePrecisionAtK", "precisionAtK", "ndcgAtK", "recallAtK")
+ @Since("3.0.0")
override def load(path: String): RankingEvaluator = super.load(path)
}
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/feature/RobustScaler.scala
b/mllib/src/main/scala/org/apache/spark/ml/feature/RobustScaler.scala
index 2a1204a..bd9be77 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/RobustScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RobustScaler.scala
@@ -120,7 +120,7 @@ private[feature] trait RobustScalerParams extends Params
with HasInputCol with H
* Note that NaN values are ignored in the computation of medians and ranges.
*/
@Since("3.0.0")
-class RobustScaler (override val uid: String)
+class RobustScaler @Since("3.0.0") (@Since("3.0.0") override val uid: String)
extends Estimator[RobustScalerModel] with RobustScalerParams with
DefaultParamsWritable {
import RobustScaler._
@@ -186,7 +186,7 @@ class RobustScaler (override val uid: String)
object RobustScaler extends DefaultParamsReadable[RobustScaler] {
// compute QuantileSummaries for each feature
- private[spark] def computeSummaries(
+ private[ml] def computeSummaries(
vectors: RDD[Vector],
numFeatures: Int,
relativeError: Double): RDD[(Int, QuantileSummaries)] = {
@@ -229,9 +229,9 @@ object RobustScaler extends
DefaultParamsReadable[RobustScaler] {
*/
@Since("3.0.0")
class RobustScalerModel private[ml] (
- override val uid: String,
- val range: Vector,
- val median: Vector)
+ @Since("3.0.0") override val uid: String,
+ @Since("3.0.0") val range: Vector,
+ @Since("3.0.0") val median: Vector)
extends Model[RobustScalerModel] with RobustScalerParams with MLWritable {
import RobustScalerModel._
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
index 8c95d25..2da65a0 100644
---
a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
+++
b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
@@ -194,8 +194,8 @@ class AFTSurvivalRegression @Since("1.6.0")
(@Since("1.6.0") override val uid: S
}
}
- @Since("3.0.0")
- override def train(dataset: Dataset[_]): AFTSurvivalRegressionModel =
instrumented { instr =>
+ override protected def train(
+ dataset: Dataset[_]): AFTSurvivalRegressionModel = instrumented { instr
=>
val instances = extractAFTPoints(dataset)
val handlePersistence = dataset.storageLevel == StorageLevel.NONE
if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/regression/FMRegressor.scala
b/mllib/src/main/scala/org/apache/spark/ml/regression/FMRegressor.scala
index a612448..b017a1a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/FMRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/FMRegressor.scala
@@ -410,7 +410,7 @@ class FMRegressor @Since("3.0.0") (
@Since("3.0.0")
def setSeed(value: Long): this.type = set(seed, value)
- override protected[spark] def train(
+ override protected def train(
dataset: Dataset[_]
): FMRegressionModel = instrumented { instr =>
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
index 26a639e..a273cd7 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
@@ -47,6 +47,7 @@ private[ml] trait DecisionTreeParams extends PredictorParams
* (default = "")
* @group param
*/
+ @Since("3.0.0")
final val leafCol: Param[String] =
new Param[String](this, "leafCol", "Leaf indices column name. " +
"Predicted leaf index of each instance in each tree by preorder")
@@ -139,9 +140,11 @@ private[ml] trait DecisionTreeParams extends
PredictorParams
cacheNodeIds -> false, checkpointInterval -> 10)
/** @group setParam */
+ @Since("3.0.0")
final def setLeafCol(value: String): this.type = set(leafCol, value)
/** @group getParam */
+ @Since("3.0.0")
final def getLeafCol: String = $(leafCol)
/** @group getParam */
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]