Github user mengxr commented on a diff in the pull request:
https://github.com/apache/spark/pull/5967#discussion_r29956414
--- Diff:
mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
---
@@ -44,45 +53,151 @@ private[classification] trait LogisticRegressionParams
extends ProbabilisticClas
@AlphaComponent
class LogisticRegression
extends ProbabilisticClassifier[Vector, LogisticRegression,
LogisticRegressionModel]
- with LogisticRegressionParams {
+ with LogisticRegressionParams with Logging {
- /** @group setParam */
+ /**
+ * Set the regularization parameter.
+ * Default is 0.0.
+ * @group setParam
+ */
def setRegParam(value: Double): this.type = set(regParam, value)
+ setDefault(regParam -> 0.0)
- /** @group setParam */
+ /**
+ * Set the ElasticNet mixing parameter.
+ * For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an
L1 penalty.
+ * For 0 < alpha < 1, the penalty is a combination of L1 and L2.
+ * Default is 0.0 which is an L2 penalty.
+ * @group setParam
+ */
+ def setElasticNetParam(value: Double): this.type = set(elasticNetParam,
value)
+ setDefault(elasticNetParam -> 0.0)
+
+ /**
+ * Set the maximal number of iterations.
+ * Default is 100.
+ * @group setParam
+ */
def setMaxIter(value: Int): this.type = set(maxIter, value)
+ setDefault(maxIter -> 100)
+
+ /**
+ * Set the convergence tolerance of iterations.
+ * Smaller value will lead to higher accuracy with the cost of more
iterations.
+ * Default is 1E-6.
+ * @group setParam
+ */
+ def setTol(value: Double): this.type = set(tol, value)
+ setDefault(tol -> 1E-6)
/** @group setParam */
def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value)
+ setDefault(fitIntercept -> true)
/** @group setParam */
def setThreshold(value: Double): this.type = set(threshold, value)
+ setDefault(threshold -> 0.5)
override protected def train(dataset: DataFrame):
LogisticRegressionModel = {
// Extract columns from data. If dataset is persisted, do not persist
oldDataset.
- val oldDataset = extractLabeledPoints(dataset)
+ val instances = extractLabeledPoints(dataset).map {
+ case LabeledPoint(label: Double, features: Vector) => (label,
features)
+ }
val handlePersistence = dataset.rdd.getStorageLevel ==
StorageLevel.NONE
- if (handlePersistence) {
- oldDataset.persist(StorageLevel.MEMORY_AND_DISK)
+ if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)
+
+ val (summarizer, labelSummarizer) = instances.treeAggregate(
+ (new MultivariateOnlineSummarizer, new MultiClassSummarizer))( {
+ case ((summarizer: MultivariateOnlineSummarizer, labelSummarizer:
MultiClassSummarizer),
+ (label: Double, features: Vector)) =>
+ (summarizer.add(features), labelSummarizer.add(label))
+ }, {
+ case ((summarizer1: MultivariateOnlineSummarizer,
labelSummarizer1: MultiClassSummarizer),
+ (summarizer2: MultivariateOnlineSummarizer, labelSummarizer2:
MultiClassSummarizer)) =>
+ (summarizer1.merge(summarizer2),
labelSummarizer1.merge(labelSummarizer2))
+ })
+
+ val histogram = labelSummarizer.histogram
+ val numInvalid = labelSummarizer.countInvalid
+ val numClasses = histogram.length
+ val numFeatures = summarizer.mean.size
+
+ if (numInvalid != 0) {
+ logError("Classification labels should be in {0 to " + (numClasses -
1) + "}. " +
+ "Found " + numInvalid + " invalid labels.")
+ throw new SparkException("Input validation failed.")
+ }
+
+ if (numClasses > 2) {
+ logError("Currently, LogisticRegression with ElasticNet in ML
package only supports " +
+ "binary classification. Found " + numClasses + " in the input
dataset.")
+ throw new SparkException("Input validation failed.")
}
- // Train model
- val lr = new LogisticRegressionWithLBFGS()
- .setIntercept($(fitIntercept))
- lr.optimizer
- .setRegParam($(regParam))
- .setNumIterations($(maxIter))
- val oldModel = lr.run(oldDataset)
- val lrm = new LogisticRegressionModel(this, oldModel.weights,
oldModel.intercept)
+ val featuresMean = summarizer.mean.toArray
+ val featuresStd = summarizer.variance.toArray.map(math.sqrt)
+
+ val regParamL1 = $(elasticNetParam) * $(regParam)
+ val regParamL2 = (1.0 - $(elasticNetParam)) * $(regParam)
- if (handlePersistence) {
- oldDataset.unpersist()
+ val costFun = new LogisticCostFun(instances, numClasses,
$(fitIntercept),
+ featuresStd, featuresMean, regParamL2)
+
+ val optimizer = if ($(elasticNetParam) == 0.0 || $(regParam) == 0.0) {
+ new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol))
+ } else {
+ // Remove the L1 penalization on the intercept
+ def regParamL1Fun = (index: Int) => {
+ if (index == numFeatures) 0.0 else regParamL1
+ }
+ new BreezeOWLQN[Int, BDV[Double]]($(maxIter), 10, regParamL1Fun,
$(tol))
}
- copyValues(lrm)
+
+ val initialWeightsWithIntercept =
+ Vectors.zeros(if ($(fitIntercept)) numFeatures + 1 else numFeatures)
+
+ // TODO: Compute the initial intercept based on the histogram.
+ if ($(fitIntercept)) initialWeightsWithIntercept.toArray(numFeatures)
= 1.0
+
+ val states = optimizer.iterations(new CachedDiffFunction(costFun),
+ initialWeightsWithIntercept.toBreeze.toDenseVector)
+
+ var state = states.next()
+ val lossHistory = mutable.ArrayBuilder.make[Double]
+
+ while (states.hasNext) {
+ lossHistory += state.value
+ state = states.next()
+ }
+ lossHistory += state.value
+
+ // The weights are trained in the scaled space; we're converting them
back to
+ // the original space.
+ val weightsWithIntercept = {
+ val rawWeights = state.x.toArray.clone()
+ var i = 0
+ // Note that the intercept in scaled space and original space is the
same;
+ // as a result, no scaling is needed.
+ while (i < numFeatures) {
+ rawWeights(i) *= { if (featuresStd(i) != 0.0) 1.0 / featuresStd(i)
else 0.0 }
--- End diff --
`{ .. }` is not necessary
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]