zentol closed pull request #6425: [FLINK-9664][Doc] fixing documentation in ML quick start URL: https://github.com/apache/flink/pull/6425
This is a PR merged from a forked repository. As GitHub hides the original diff on merge, it is displayed below for the sake of provenance: As this is a foreign pull request (from a fork), the diff is supplied below (as it won't show otherwise due to GitHub magic): diff --git a/docs/dev/libs/ml/quickstart.md b/docs/dev/libs/ml/quickstart.md index ea6f8049755..e056b28b505 100644 --- a/docs/dev/libs/ml/quickstart.md +++ b/docs/dev/libs/ml/quickstart.md @@ -129,15 +129,14 @@ and the [test set here](http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/b This is an astroparticle binary classification dataset, used by Hsu et al. [[3]](#hsu) in their practical Support Vector Machine (SVM) guide. It contains 4 numerical features, and the class label. -We can simply import the dataset then using: +We can simply import the dataset using: {% highlight scala %} import org.apache.flink.ml.MLUtils -val astroTrain: DataSet[LabeledVector] = MLUtils.readLibSVM(env, "/path/to/svmguide1") -val astroTest: DataSet[(Vector, Double)] = MLUtils.readLibSVM(env, "/path/to/svmguide1.t") - .map(x => (x.vector, x.label)) +val astroTrainLibSVM: DataSet[LabeledVector] = MLUtils.readLibSVM(env, "/path/to/svmguide1") +val astroTestLibSVM: DataSet[LabeledVector] = MLUtils.readLibSVM(env, "/path/to/svmguide1.t") {% endhighlight %} @@ -146,7 +145,23 @@ create a classifier. ## Classification -Once we have imported the dataset we can train a `Predictor` such as a linear SVM classifier. +After importing the training and test dataset, they need to be prepared for the classification. +Since Flink SVM only supports threshold binary values of `+1.0` and `-1.0`, a conversion is +needed after loading the LibSVM dataset because it is labelled using `1`s and `0`s. + +A conversion can be done using a simple normalizer mapping function: + +{% highlight scala %} + +def normalizer : LabeledVector => LabeledVector = { + lv => LabeledVector(if (lv.label > 0.0) 1.0 else -1.0, lv.vector) +} +val astroTrain: DataSet[LabeledVector] = astroTrainLibSVM.map(normalizer) +val astroTest: DataSet[(Vector, Double)] = astroTestLibSVM.map(normalizer).map(x => (x.vector, x.label)) + +{% endhighlight %} + +Once we have converted the dataset we can train a `Predictor` such as a linear SVM classifier. We can set a number of parameters for the classifier. Here we set the `Blocks` parameter, which is used to split the input by the underlying CoCoA algorithm [[2]](#jaggi) uses. The regularization parameter determines the amount of $l_2$ regularization applied, which is used ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected] With regards, Apache Git Services
