Repository: spark Updated Branches: refs/heads/master 4b0c1edfd -> 32218307e
[SPARK-4372][MLLIB] Make LR and SVM's default parameters consistent in Scala and Python The current default regParam is 1.0 and regType is claimed to be none in Python (but actually it is l2), while regParam = 0.0 and regType is L2 in Scala. We should make the default values consistent. This PR sets the default regType to L2 and regParam to 0.01. Note that the default regParam value in LIBLINEAR (and hence scikit-learn) is 1.0. However, we use average loss instead of total loss in our formulation. Hence regParam=1.0 is definitely too heavy. In LinearRegression, we set regParam=0.0 and regType=None, because we have separate classes for Lasso and Ridge, both of which use regParam=0.01 as the default. davies atalwalkar Author: Xiangrui Meng <[email protected]> Closes #3232 from mengxr/SPARK-4372 and squashes the following commits: 9979837 [Xiangrui Meng] update Ridge/Lasso to use default regParam 0.01 cast input arguments d3ba096 [Xiangrui Meng] change 'none' back to None 1909a6e [Xiangrui Meng] change default regParam to 0.01 and regType to L2 in LR and SVM Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/32218307 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/32218307 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/32218307 Branch: refs/heads/master Commit: 32218307edc6de2b08d5f7a0db6d566081d27197 Parents: 4b0c1ed Author: Xiangrui Meng <[email protected]> Authored: Thu Nov 13 13:54:16 2014 -0800 Committer: Xiangrui Meng <[email protected]> Committed: Thu Nov 13 13:54:16 2014 -0800 ---------------------------------------------------------------------- .../examples/mllib/BinaryClassification.scala | 2 +- .../spark/examples/mllib/LinearRegression.scala | 2 +- .../spark/mllib/api/python/PythonMLLibAPI.scala | 34 ++++++++++-------- .../classification/LogisticRegression.scala | 12 ++++--- .../apache/spark/mllib/classification/SVM.scala | 10 +++--- .../apache/spark/mllib/regression/Lasso.scala | 6 ++-- .../mllib/regression/RidgeRegression.scala | 8 ++--- .../LogisticRegressionSuite.scala | 28 ++++++++------- python/pyspark/mllib/classification.py | 36 +++++++++++--------- python/pyspark/mllib/regression.py | 36 ++++++++++---------- 10 files changed, 95 insertions(+), 79 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/32218307/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala ---------------------------------------------------------------------- diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala index 1edd243..a113653 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala @@ -55,7 +55,7 @@ object BinaryClassification { stepSize: Double = 1.0, algorithm: Algorithm = LR, regType: RegType = L2, - regParam: Double = 0.1) extends AbstractParams[Params] + regParam: Double = 0.01) extends AbstractParams[Params] def main(args: Array[String]) { val defaultParams = Params() http://git-wip-us.apache.org/repos/asf/spark/blob/32218307/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala ---------------------------------------------------------------------- diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala index e1f9622..6815b1c 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala @@ -47,7 +47,7 @@ object LinearRegression extends App { numIterations: Int = 100, stepSize: Double = 1.0, regType: RegType = L2, - regParam: Double = 0.1) extends AbstractParams[Params] + regParam: Double = 0.01) extends AbstractParams[Params] val defaultParams = Params() http://git-wip-us.apache.org/repos/asf/spark/blob/32218307/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 70d7138..c8476a5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -28,22 +28,22 @@ import net.razorvine.pickle._ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} -import org.apache.spark.api.python.{PythonRDD, SerDeUtil} +import org.apache.spark.api.python.SerDeUtil import org.apache.spark.mllib.classification._ import org.apache.spark.mllib.clustering._ import org.apache.spark.mllib.feature._ -import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.linalg._ +import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.random.{RandomRDDs => RG} import org.apache.spark.mllib.recommendation._ import org.apache.spark.mllib.regression._ -import org.apache.spark.mllib.tree.configuration.{Algo, Strategy} -import org.apache.spark.mllib.tree.DecisionTree -import org.apache.spark.mllib.tree.impurity._ -import org.apache.spark.mllib.tree.model.DecisionTreeModel import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Statistics} import org.apache.spark.mllib.stat.correlation.CorrelationNames import org.apache.spark.mllib.stat.test.ChiSqTestResult +import org.apache.spark.mllib.tree.DecisionTree +import org.apache.spark.mllib.tree.configuration.{Algo, Strategy} +import org.apache.spark.mllib.tree.impurity._ +import org.apache.spark.mllib.tree.model.DecisionTreeModel import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel @@ -103,9 +103,11 @@ class PythonMLLibAPI extends Serializable { lrAlg.optimizer.setUpdater(new SquaredL2Updater) } else if (regType == "l1") { lrAlg.optimizer.setUpdater(new L1Updater) - } else if (regType != "none") { - throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter." - + " Can only be initialized using the following string values: [l1, l2, none].") + } else if (regType == null) { + lrAlg.optimizer.setUpdater(new SimpleUpdater) + } else { + throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter." + + " Can only be initialized using the following string values: ['l1', 'l2', None].") } trainRegressionModel( lrAlg, @@ -180,9 +182,11 @@ class PythonMLLibAPI extends Serializable { SVMAlg.optimizer.setUpdater(new SquaredL2Updater) } else if (regType == "l1") { SVMAlg.optimizer.setUpdater(new L1Updater) - } else if (regType != "none") { + } else if (regType == null) { + SVMAlg.optimizer.setUpdater(new SimpleUpdater) + } else { throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter." - + " Can only be initialized using the following string values: [l1, l2, none].") + + " Can only be initialized using the following string values: ['l1', 'l2', None].") } trainRegressionModel( SVMAlg, @@ -213,9 +217,11 @@ class PythonMLLibAPI extends Serializable { LogRegAlg.optimizer.setUpdater(new SquaredL2Updater) } else if (regType == "l1") { LogRegAlg.optimizer.setUpdater(new L1Updater) - } else if (regType != "none") { + } else if (regType == null) { + LogRegAlg.optimizer.setUpdater(new SimpleUpdater) + } else { throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter." - + " Can only be initialized using the following string values: [l1, l2, none].") + + " Can only be initialized using the following string values: ['l1', 'l2', None].") } trainRegressionModel( LogRegAlg, @@ -250,7 +256,7 @@ class PythonMLLibAPI extends Serializable { .setInitializationMode(initializationMode) // Disable the uncached input warning because 'data' is a deliberately uncached MappedRDD. .disableUncachedWarning() - return kMeansAlg.run(data.rdd) + kMeansAlg.run(data.rdd) } /** http://git-wip-us.apache.org/repos/asf/spark/blob/32218307/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index 84d3c7c..18b95f1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -71,9 +71,10 @@ class LogisticRegressionModel ( } /** - * Train a classification model for Logistic Regression using Stochastic Gradient Descent. - * NOTE: Labels used in Logistic Regression should be {0, 1} - * + * Train a classification model for Logistic Regression using Stochastic Gradient Descent. By + * default L2 regularization is used, which can be changed via + * [[LogisticRegressionWithSGD.optimizer]]. + * NOTE: Labels used in Logistic Regression should be {0, 1}. * Using [[LogisticRegressionWithLBFGS]] is recommended over this. */ class LogisticRegressionWithSGD private ( @@ -93,9 +94,10 @@ class LogisticRegressionWithSGD private ( override protected val validators = List(DataValidators.binaryLabelValidator) /** - * Construct a LogisticRegression object with default parameters + * Construct a LogisticRegression object with default parameters: {stepSize: 1.0, + * numIterations: 100, regParm: 0.01, miniBatchFraction: 1.0}. */ - def this() = this(1.0, 100, 0.0, 1.0) + def this() = this(1.0, 100, 0.01, 1.0) override protected def createModel(weights: Vector, intercept: Double) = { new LogisticRegressionModel(weights, intercept) http://git-wip-us.apache.org/repos/asf/spark/blob/32218307/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala index 80f8a1b..ab9515b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala @@ -72,7 +72,8 @@ class SVMModel ( } /** - * Train a Support Vector Machine (SVM) using Stochastic Gradient Descent. + * Train a Support Vector Machine (SVM) using Stochastic Gradient Descent. By default L2 + * regularization is used, which can be changed via [[SVMWithSGD.optimizer]]. * NOTE: Labels used in SVM should be {0, 1}. */ class SVMWithSGD private ( @@ -92,9 +93,10 @@ class SVMWithSGD private ( override protected val validators = List(DataValidators.binaryLabelValidator) /** - * Construct a SVM object with default parameters + * Construct a SVM object with default parameters: {stepSize: 1.0, numIterations: 100, + * regParm: 0.01, miniBatchFraction: 1.0}. */ - def this() = this(1.0, 100, 1.0, 1.0) + def this() = this(1.0, 100, 0.01, 1.0) override protected def createModel(weights: Vector, intercept: Double) = { new SVMModel(weights, intercept) @@ -185,6 +187,6 @@ object SVMWithSGD { * @return a SVMModel which has the weights and offset from training. */ def train(input: RDD[LabeledPoint], numIterations: Int): SVMModel = { - train(input, numIterations, 1.0, 1.0, 1.0) + train(input, numIterations, 1.0, 0.01, 1.0) } } http://git-wip-us.apache.org/repos/asf/spark/blob/32218307/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala index cb0d39e..f9791c6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala @@ -67,9 +67,9 @@ class LassoWithSGD private ( /** * Construct a Lasso object with default parameters: {stepSize: 1.0, numIterations: 100, - * regParam: 1.0, miniBatchFraction: 1.0}. + * regParam: 0.01, miniBatchFraction: 1.0}. */ - def this() = this(1.0, 100, 1.0, 1.0) + def this() = this(1.0, 100, 0.01, 1.0) override protected def createModel(weights: Vector, intercept: Double) = { new LassoModel(weights, intercept) @@ -161,6 +161,6 @@ object LassoWithSGD { def train( input: RDD[LabeledPoint], numIterations: Int): LassoModel = { - train(input, numIterations, 1.0, 1.0, 1.0) + train(input, numIterations, 1.0, 0.01, 1.0) } } http://git-wip-us.apache.org/repos/asf/spark/blob/32218307/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala index a826deb..c8cad77 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala @@ -68,9 +68,9 @@ class RidgeRegressionWithSGD private ( /** * Construct a RidgeRegression object with default parameters: {stepSize: 1.0, numIterations: 100, - * regParam: 1.0, miniBatchFraction: 1.0}. + * regParam: 0.01, miniBatchFraction: 1.0}. */ - def this() = this(1.0, 100, 1.0, 1.0) + def this() = this(1.0, 100, 0.01, 1.0) override protected def createModel(weights: Vector, intercept: Double) = { new RidgeRegressionModel(weights, intercept) @@ -143,7 +143,7 @@ object RidgeRegressionWithSGD { numIterations: Int, stepSize: Double, regParam: Double): RidgeRegressionModel = { - train(input, numIterations, stepSize, regParam, 1.0) + train(input, numIterations, stepSize, regParam, 0.01) } /** @@ -158,6 +158,6 @@ object RidgeRegressionWithSGD { def train( input: RDD[LabeledPoint], numIterations: Int): RidgeRegressionModel = { - train(input, numIterations, 1.0, 1.0, 1.0) + train(input, numIterations, 1.0, 0.01, 1.0) } } http://git-wip-us.apache.org/repos/asf/spark/blob/32218307/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala index 6c1c784..4e81299 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala @@ -80,13 +80,16 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with M val testRDD = sc.parallelize(testData, 2) testRDD.cache() val lr = new LogisticRegressionWithSGD().setIntercept(true) - lr.optimizer.setStepSize(10.0).setNumIterations(20) + lr.optimizer + .setStepSize(10.0) + .setRegParam(0.0) + .setNumIterations(20) val model = lr.run(testRDD) // Test the weights - assert(model.weights(0) ~== -1.52 relTol 0.01) - assert(model.intercept ~== 2.00 relTol 0.01) + assert(model.weights(0) ~== B relTol 0.02) + assert(model.intercept ~== A relTol 0.02) val validationData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 17) val validationRDD = sc.parallelize(validationData, 2) @@ -112,10 +115,8 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with M val model = lr.run(testRDD) // Test the weights - assert(model.weights(0) ~== -1.52 relTol 0.01) - assert(model.intercept ~== 2.00 relTol 0.01) - assert(model.weights(0) ~== model.weights(0) relTol 0.01) - assert(model.intercept ~== model.intercept relTol 0.01) + assert(model.weights(0) ~== B relTol 0.02) + assert(model.intercept ~== A relTol 0.02) val validationData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 17) val validationRDD = sc.parallelize(validationData, 2) @@ -141,13 +142,16 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with M // Use half as many iterations as the previous test. val lr = new LogisticRegressionWithSGD().setIntercept(true) - lr.optimizer.setStepSize(10.0).setNumIterations(10) + lr.optimizer + .setStepSize(10.0) + .setRegParam(0.0) + .setNumIterations(10) val model = lr.run(testRDD, initialWeights) // Test the weights - assert(model.weights(0) ~== -1.50 relTol 0.01) - assert(model.intercept ~== 1.97 relTol 0.01) + assert(model.weights(0) ~== B relTol 0.02) + assert(model.intercept ~== A relTol 0.02) val validationData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 17) val validationRDD = sc.parallelize(validationData, 2) @@ -212,8 +216,8 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with M val model = lr.run(testRDD, initialWeights) // Test the weights - assert(model.weights(0) ~== -1.50 relTol 0.02) - assert(model.intercept ~== 1.97 relTol 0.02) + assert(model.weights(0) ~== B relTol 0.02) + assert(model.intercept ~== A relTol 0.02) val validationData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 17) val validationRDD = sc.parallelize(validationData, 2) http://git-wip-us.apache.org/repos/asf/spark/blob/32218307/python/pyspark/mllib/classification.py ---------------------------------------------------------------------- diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index 5d90ddd..b654813 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -76,7 +76,7 @@ class LogisticRegressionWithSGD(object): @classmethod def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, - initialWeights=None, regParam=1.0, regType="none", intercept=False): + initialWeights=None, regParam=0.01, regType="l2", intercept=False): """ Train a logistic regression model on the given data. @@ -87,16 +87,16 @@ class LogisticRegressionWithSGD(object): :param miniBatchFraction: Fraction of data to be used for each SGD iteration. :param initialWeights: The initial weights (default: None). - :param regParam: The regularizer parameter (default: 1.0). + :param regParam: The regularizer parameter (default: 0.01). :param regType: The type of regularizer used for training our model. :Allowed values: - - "l1" for using L1Updater - - "l2" for using SquaredL2Updater - - "none" for no regularizer + - "l1" for using L1 regularization + - "l2" for using L2 regularization + - None for no regularization - (default: "none") + (default: "l2") @param intercept: Boolean parameter which indicates the use or not of the augmented representation for @@ -104,8 +104,9 @@ class LogisticRegressionWithSGD(object): are activated or not). """ def train(rdd, i): - return callMLlibFunc("trainLogisticRegressionModelWithSGD", rdd, iterations, step, - miniBatchFraction, i, regParam, regType, intercept) + return callMLlibFunc("trainLogisticRegressionModelWithSGD", rdd, int(iterations), + float(step), float(miniBatchFraction), i, float(regParam), regType, + bool(intercept)) return _regression_train_wrapper(train, LogisticRegressionModel, data, initialWeights) @@ -145,8 +146,8 @@ class SVMModel(LinearModel): class SVMWithSGD(object): @classmethod - def train(cls, data, iterations=100, step=1.0, regParam=1.0, - miniBatchFraction=1.0, initialWeights=None, regType="none", intercept=False): + def train(cls, data, iterations=100, step=1.0, regParam=0.01, + miniBatchFraction=1.0, initialWeights=None, regType="l2", intercept=False): """ Train a support vector machine on the given data. @@ -154,7 +155,7 @@ class SVMWithSGD(object): :param iterations: The number of iterations (default: 100). :param step: The step parameter used in SGD (default: 1.0). - :param regParam: The regularizer parameter (default: 1.0). + :param regParam: The regularizer parameter (default: 0.01). :param miniBatchFraction: Fraction of data to be used for each SGD iteration. :param initialWeights: The initial weights (default: None). @@ -162,11 +163,11 @@ class SVMWithSGD(object): our model. :Allowed values: - - "l1" for using L1Updater - - "l2" for using SquaredL2Updater, - - "none" for no regularizer. + - "l1" for using L1 regularization + - "l2" for using L2 regularization + - None for no regularization - (default: "none") + (default: "l2") @param intercept: Boolean parameter which indicates the use or not of the augmented representation for @@ -174,8 +175,9 @@ class SVMWithSGD(object): are activated or not). """ def train(rdd, i): - return callMLlibFunc("trainSVMModelWithSGD", rdd, iterations, step, regParam, - miniBatchFraction, i, regType, intercept) + return callMLlibFunc("trainSVMModelWithSGD", rdd, int(iterations), float(step), + float(regParam), float(miniBatchFraction), i, regType, + bool(intercept)) return _regression_train_wrapper(train, SVMModel, data, initialWeights) http://git-wip-us.apache.org/repos/asf/spark/blob/32218307/python/pyspark/mllib/regression.py ---------------------------------------------------------------------- diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index 66e25a4..f4f5e61 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -138,7 +138,7 @@ class LinearRegressionWithSGD(object): @classmethod def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, - initialWeights=None, regParam=1.0, regType="none", intercept=False): + initialWeights=None, regParam=0.0, regType=None, intercept=False): """ Train a linear regression model on the given data. @@ -149,16 +149,16 @@ class LinearRegressionWithSGD(object): :param miniBatchFraction: Fraction of data to be used for each SGD iteration. :param initialWeights: The initial weights (default: None). - :param regParam: The regularizer parameter (default: 1.0). + :param regParam: The regularizer parameter (default: 0.0). :param regType: The type of regularizer used for training our model. :Allowed values: - - "l1" for using L1Updater, - - "l2" for using SquaredL2Updater, - - "none" for no regularizer. + - "l1" for using L1 regularization (lasso), + - "l2" for using L2 regularization (ridge), + - None for no regularization - (default: "none") + (default: None) @param intercept: Boolean parameter which indicates the use or not of the augmented representation for @@ -166,11 +166,11 @@ class LinearRegressionWithSGD(object): are activated or not). """ def train(rdd, i): - return callMLlibFunc("trainLinearRegressionModelWithSGD", rdd, iterations, step, - miniBatchFraction, i, regParam, regType, intercept) + return callMLlibFunc("trainLinearRegressionModelWithSGD", rdd, int(iterations), + float(step), float(miniBatchFraction), i, float(regParam), + regType, bool(intercept)) - return _regression_train_wrapper(train, LinearRegressionModel, - data, initialWeights) + return _regression_train_wrapper(train, LinearRegressionModel, data, initialWeights) class LassoModel(LinearRegressionModelBase): @@ -209,12 +209,13 @@ class LassoModel(LinearRegressionModelBase): class LassoWithSGD(object): @classmethod - def train(cls, data, iterations=100, step=1.0, regParam=1.0, + def train(cls, data, iterations=100, step=1.0, regParam=0.01, miniBatchFraction=1.0, initialWeights=None): """Train a Lasso regression model on the given data.""" def train(rdd, i): - return callMLlibFunc("trainLassoModelWithSGD", rdd, iterations, step, regParam, - miniBatchFraction, i) + return callMLlibFunc("trainLassoModelWithSGD", rdd, int(iterations), float(step), + float(regParam), float(miniBatchFraction), i) + return _regression_train_wrapper(train, LassoModel, data, initialWeights) @@ -254,15 +255,14 @@ class RidgeRegressionModel(LinearRegressionModelBase): class RidgeRegressionWithSGD(object): @classmethod - def train(cls, data, iterations=100, step=1.0, regParam=1.0, + def train(cls, data, iterations=100, step=1.0, regParam=0.01, miniBatchFraction=1.0, initialWeights=None): """Train a ridge regression model on the given data.""" def train(rdd, i): - return callMLlibFunc("trainRidgeModelWithSGD", rdd, iterations, step, regParam, - miniBatchFraction, i) + return callMLlibFunc("trainRidgeModelWithSGD", rdd, int(iterations), float(step), + float(regParam), float(miniBatchFraction), i) - return _regression_train_wrapper(train, RidgeRegressionModel, - data, initialWeights) + return _regression_train_wrapper(train, RidgeRegressionModel, data, initialWeights) def _test(): --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
