Repository: spark Updated Branches: refs/heads/branch-1.5 fe05142f5 -> 49085b56c
[SPARK-8965] [DOCS] Add ml-guide Python Example: Estimator, Transformer, and Param Added ml-guide Python Example: Estimator, Transformer, and Param /docs/_site/ml-guide.html Author: Rosstin <[email protected]> Closes #8081 from Rosstin/SPARK-8965. (cherry picked from commit 7a539ef3b1792764f866fa88c84c78ad59903f21) Signed-off-by: Joseph K. Bradley <[email protected]> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/49085b56 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/49085b56 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/49085b56 Branch: refs/heads/branch-1.5 Commit: 49085b56c10a2d05345b343277ddf19b502aee9c Parents: fe05142 Author: Rosstin <[email protected]> Authored: Thu Aug 13 09:18:39 2015 -0700 Committer: Joseph K. Bradley <[email protected]> Committed: Thu Aug 13 09:18:50 2015 -0700 ---------------------------------------------------------------------- docs/ml-guide.md | 68 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/49085b56/docs/ml-guide.md ---------------------------------------------------------------------- diff --git a/docs/ml-guide.md b/docs/ml-guide.md index b6ca50e..a03ab43 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -355,6 +355,74 @@ jsc.stop(); {% endhighlight %} </div> +<div data-lang="python"> +{% highlight python %} +from pyspark import SparkContext +from pyspark.mllib.regression import LabeledPoint +from pyspark.ml.classification import LogisticRegression +from pyspark.ml.param import Param, Params +from pyspark.sql import Row, SQLContext + +sc = SparkContext(appName="SimpleParamsExample") +sqlContext = SQLContext(sc) + +# Prepare training data. +# We use LabeledPoint. +# Spark SQL can convert RDDs of LabeledPoints into DataFrames. +training = sc.parallelize([LabeledPoint(1.0, [0.0, 1.1, 0.1]), + LabeledPoint(0.0, [2.0, 1.0, -1.0]), + LabeledPoint(0.0, [2.0, 1.3, 1.0]), + LabeledPoint(1.0, [0.0, 1.2, -0.5])]) + +# Create a LogisticRegression instance. This instance is an Estimator. +lr = LogisticRegression(maxIter=10, regParam=0.01) +# Print out the parameters, documentation, and any default values. +print "LogisticRegression parameters:\n" + lr.explainParams() + "\n" + +# Learn a LogisticRegression model. This uses the parameters stored in lr. +model1 = lr.fit(training.toDF()) + +# Since model1 is a Model (i.e., a transformer produced by an Estimator), +# we can view the parameters it used during fit(). +# This prints the parameter (name: value) pairs, where names are unique IDs for this +# LogisticRegression instance. +print "Model 1 was fit using parameters: " +print model1.extractParamMap() + +# We may alternatively specify parameters using a Python dictionary as a paramMap +paramMap = {lr.maxIter: 20} +paramMap[lr.maxIter] = 30 # Specify 1 Param, overwriting the original maxIter. +paramMap.update({lr.regParam: 0.1, lr.threshold: 0.55}) # Specify multiple Params. + +# You can combine paramMaps, which are python dictionaries. +paramMap2 = {lr.probabilityCol: "myProbability"} # Change output column name +paramMapCombined = paramMap.copy() +paramMapCombined.update(paramMap2) + +# Now learn a new model using the paramMapCombined parameters. +# paramMapCombined overrides all parameters set earlier via lr.set* methods. +model2 = lr.fit(training.toDF(), paramMapCombined) +print "Model 2 was fit using parameters: " +print model2.extractParamMap() + +# Prepare test data +test = sc.parallelize([LabeledPoint(1.0, [-1.0, 1.5, 1.3]), + LabeledPoint(0.0, [ 3.0, 2.0, -0.1]), + LabeledPoint(1.0, [ 0.0, 2.2, -1.5])]) + +# Make predictions on test data using the Transformer.transform() method. +# LogisticRegression.transform will only use the 'features' column. +# Note that model2.transform() outputs a "myProbability" column instead of the usual +# 'probability' column since we renamed the lr.probabilityCol parameter previously. +prediction = model2.transform(test.toDF()) +selected = prediction.select("features", "label", "myProbability", "prediction") +for row in selected.collect(): + print row + +sc.stop() +{% endhighlight %} +</div> + </div> ## Example: Pipeline --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
