This is an automated email from the ASF dual-hosted git repository. gaoyunhaii pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink-ml.git
commit d36fe8feb043a010e45f75dff9d7d7f21aa37fc7 Author: zhangzp <[email protected]> AuthorDate: Fri Dec 17 17:36:41 2021 +0800 [FLINK-24556] Make model data pojo for naive bayes, kmeans and logistic regression This closes #28. --- .../ml/common/feature/LabeledPointWithWeight.java | 32 +++++++++++++++-- .../logisticregression/LogisticGradient.java | 17 ++++----- .../logisticregression/LogisticRegression.java | 8 ++--- .../LogisticRegressionModelData.java | 7 ++-- .../ml/classification/naivebayes/NaiveBayes.java | 4 ++- .../naivebayes/NaiveBayesModelData.java | 19 +++++----- .../ml/clustering/kmeans/KMeansModelData.java | 7 ++-- .../logisticregression/LogisticRegressionTest.java | 9 +++-- .../org/apache/flink/ml/clustering/KMeansTest.java | 41 ++++++++++------------ 9 files changed, 87 insertions(+), 57 deletions(-) diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/feature/LabeledPointWithWeight.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/feature/LabeledPointWithWeight.java index e02192f..8440bc9 100644 --- a/flink-ml-core/src/main/java/org/apache/flink/ml/common/feature/LabeledPointWithWeight.java +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/feature/LabeledPointWithWeight.java @@ -23,15 +23,41 @@ import org.apache.flink.ml.linalg.DenseVector; /** Utility class to represent a data point that contains features, label and weight. */ public class LabeledPointWithWeight { - public final DenseVector features; + private DenseVector features; - public final double label; + private double label; - public final double weight; + private double weight; public LabeledPointWithWeight(DenseVector features, double label, double weight) { this.features = features; this.label = label; this.weight = weight; } + + public LabeledPointWithWeight() {} + + public DenseVector getFeatures() { + return features; + } + + public void setFeatures(DenseVector features) { + this.features = features; + } + + public double getLabel() { + return label; + } + + public void setLabel(double label) { + this.label = label; + } + + public double getWeight() { + return weight; + } + + public void setWeight(double weight) { + this.weight = weight; + } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticGradient.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticGradient.java index c63b72e..13f753b 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticGradient.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticGradient.java @@ -52,8 +52,8 @@ public class LogisticGradient implements Serializable { double weightSum = 0.0; double lossSum = 0.0; for (LabeledPointWithWeight dataPoint : dataPoints) { - lossSum += dataPoint.weight * computeLoss(dataPoint, coefficient); - weightSum += dataPoint.weight; + lossSum += dataPoint.getWeight() * computeLoss(dataPoint, coefficient); + weightSum += dataPoint.getWeight(); } if (Double.compare(0, l2) != 0) { lossSum += l2 * Math.pow(BLAS.norm2(coefficient), 2); @@ -81,16 +81,17 @@ public class LogisticGradient implements Serializable { } private double computeLoss(LabeledPointWithWeight dataPoint, DenseVector coefficient) { - double dot = BLAS.dot(dataPoint.features, coefficient); - double labelScaled = 2 * dataPoint.label - 1; + double dot = BLAS.dot(dataPoint.getFeatures(), coefficient); + double labelScaled = 2 * dataPoint.getLabel() - 1; return Math.log(1 + Math.exp(-dot * labelScaled)); } private void computeGradient( LabeledPointWithWeight dataPoint, DenseVector coefficient, DenseVector cumGradient) { - double dot = BLAS.dot(dataPoint.features, coefficient); - double labelScaled = 2 * dataPoint.label - 1; - double multiplier = dataPoint.weight * (-labelScaled / (Math.exp(dot * labelScaled) + 1)); - BLAS.axpy(multiplier, dataPoint.features, cumGradient); + double dot = BLAS.dot(dataPoint.getFeatures(), coefficient); + double labelScaled = 2 * dataPoint.getLabel() - 1; + double multiplier = + dataPoint.getWeight() * (-labelScaled / (Math.exp(dot * labelScaled) + 1)); + BLAS.axpy(multiplier, dataPoint.getFeatures(), cumGradient); } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java index 7266610..a17269b 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java @@ -114,7 +114,7 @@ public class LogisticRegression dataPoint -> { Double weight = getWeightCol() == null - ? new Double(1.0) + ? 1.0 : (Double) dataPoint.getField(getWeightCol()); Double label = (Double) dataPoint.getField(getLabelCol()); boolean isBinomial = @@ -160,9 +160,9 @@ public class LogisticRegression @Override public void processElement(StreamRecord<LabeledPointWithWeight> streamRecord) { if (dim == 0) { - dim = streamRecord.getValue().features.size(); + dim = streamRecord.getValue().getFeatures().size(); } else { - if (dim != streamRecord.getValue().features.size()) { + if (dim != streamRecord.getValue().getFeatures().size()) { throw new RuntimeException( "The training data should all have same dimensions."); } @@ -390,7 +390,7 @@ public class LogisticRegression } @Override - public void onIterationTerminated(Context context, Collector collector) { + public void onIterationTerminated(Context context, Collector<double[]> collector) { trainDataState.clear(); coefficientState.clear(); feedbackBufferState.clear(); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java index aae66fb..774c19e 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java @@ -44,12 +44,14 @@ import java.io.OutputStream; */ public class LogisticRegressionModelData { - public final DenseVector coefficient; + public DenseVector coefficient; public LogisticRegressionModelData(DenseVector coefficient) { this.coefficient = coefficient; } + public LogisticRegressionModelData() {} + /** * Converts the table model to a data stream. * @@ -59,7 +61,8 @@ public class LogisticRegressionModelData { public static DataStream<LogisticRegressionModelData> getModelDataStream(Table modelData) { StreamTableEnvironment tEnv = (StreamTableEnvironment) ((TableImpl) modelData).getTableEnvironment(); - return tEnv.toDataStream(modelData).map(x -> (LogisticRegressionModelData) x.getField(0)); + return tEnv.toDataStream(modelData) + .map(x -> new LogisticRegressionModelData((DenseVector) x.getField(0))); } /** Data encoder for {@link LogisticRegressionModel}. */ diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayes.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayes.java index aefb44d..7a3cc3d 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayes.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayes.java @@ -27,6 +27,7 @@ import org.apache.flink.api.java.tuple.Tuple4; import org.apache.flink.ml.api.Estimator; import org.apache.flink.ml.common.datastream.DataStreamUtils; import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.ReadWriteUtils; @@ -339,7 +340,8 @@ public class NaiveBayes piArray[i] = Math.log(weightSum + smoothing) - piLog; } - NaiveBayesModelData modelData = new NaiveBayesModelData(theta, piArray, labels); + NaiveBayesModelData modelData = + new NaiveBayesModelData(theta, Vectors.dense(piArray), Vectors.dense(labels)); collector.collect(modelData); } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayesModelData.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayesModelData.java index fee3b35..a03141d 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayesModelData.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayesModelData.java @@ -29,7 +29,6 @@ import org.apache.flink.core.fs.FSDataInputStream; import org.apache.flink.core.memory.DataInputViewStreamWrapper; import org.apache.flink.core.memory.DataOutputViewStreamWrapper; import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.table.api.Table; @@ -54,17 +53,13 @@ public class NaiveBayesModelData { * Log of class conditional probabilities, whose dimension is C (number of classes) by D (number * of features). */ - public final Map<Double, Double>[][] theta; + public Map<Double, Double>[][] theta; /** Log of class priors, whose dimension is C (number of classes). */ - public final DenseVector piArray; + public DenseVector piArray; /** Value of labels. */ - public final DenseVector labels; - - public NaiveBayesModelData(Map<Double, Double>[][] theta, double[] piArray, double[] labels) { - this(theta, Vectors.dense(piArray), Vectors.dense(labels)); - } + public DenseVector labels; public NaiveBayesModelData( Map<Double, Double>[][] theta, DenseVector piArray, DenseVector labels) { @@ -73,6 +68,8 @@ public class NaiveBayesModelData { this.labels = labels; } + public NaiveBayesModelData() {} + /** * Converts the table model to a data stream. * @@ -85,7 +82,11 @@ public class NaiveBayesModelData { return tEnv.toDataStream(modelData) .map( (MapFunction<Row, NaiveBayesModelData>) - row -> (NaiveBayesModelData) row.getField("f0")); + row -> + new NaiveBayesModelData( + (Map<Double, Double>[][]) row.getField(0), + (DenseVector) row.getField(1), + (DenseVector) row.getField(2))); } /** Data encoder for the {@link NaiveBayesModelData}. */ diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeansModelData.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeansModelData.java index 4bbf345..af0733d 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeansModelData.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeansModelData.java @@ -45,12 +45,14 @@ import java.io.OutputStream; */ public class KMeansModelData { - public final DenseVector[] centroids; + public DenseVector[] centroids; public KMeansModelData(DenseVector[] centroids) { this.centroids = centroids; } + public KMeansModelData() {} + /** * Converts the table model to a data stream. * @@ -60,7 +62,8 @@ public class KMeansModelData { public static DataStream<KMeansModelData> getModelDataStream(Table modelData) { StreamTableEnvironment tEnv = (StreamTableEnvironment) ((TableImpl) modelData).getTableEnvironment(); - return tEnv.toDataStream(modelData).map(x -> (KMeansModelData) x.getField(0)); + return tEnv.toDataStream(modelData) + .map(x -> new KMeansModelData((DenseVector[]) x.getField(0))); } /** Data encoder for {@link KMeansModelData}. */ diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionTest.java index db57bd9..e7dc036 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionTest.java @@ -228,7 +228,7 @@ public class LogisticRegressionTest { LogisticRegressionModel model = logisticRegression.fit(binomialDataTable); model = StageTestUtils.saveAndReload(env, model, tempFolder.newFolder().getAbsolutePath()); assertEquals( - Collections.singletonList("f0"), + Collections.singletonList("coefficient"), model.getModelData()[0].getResolvedSchema().getColumnNames()); Table output = model.transform(binomialDataTable)[0]; verifyPredictionResult( @@ -242,11 +242,10 @@ public class LogisticRegressionTest { public void testGetModelData() throws Exception { LogisticRegression logisticRegression = new LogisticRegression().setWeightCol("weight"); LogisticRegressionModel model = logisticRegression.fit(binomialDataTable); - List<Row> collectedModelData = - IteratorUtils.toList( - tEnv.toDataStream(model.getModelData()[0]).executeAndCollect()); LogisticRegressionModelData modelData = - (LogisticRegressionModelData) collectedModelData.get(0).getField(0); + LogisticRegressionModelData.getModelDataStream(model.getModelData()[0]) + .executeAndCollect() + .next(); assertNotNull(modelData); assertArrayEquals(expectedCoefficient, modelData.coefficient.values, 0.1); } diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/KMeansTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/KMeansTest.java index 9f613e1..fe42829 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/KMeansTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/KMeansTest.java @@ -60,7 +60,7 @@ import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; -/** Tests KMeans and KMeansModel. */ +/** Tests {@link KMeans} and {@link KMeansModel}. */ public class KMeansTest extends AbstractTestBase { @Rule public final TemporaryFolder tempFolder = new TemporaryFolder(); @@ -150,7 +150,7 @@ public class KMeansTest extends AbstractTestBase { } @Test - public void testFeaturePredictionParam() throws Exception { + public void testFeaturePredictionParam() { Table input = dataTable.as("test_feature"); KMeans kmeans = new KMeans().setFeaturesCol("test_feature").setPredictionCol("test_prediction"); @@ -166,7 +166,7 @@ public class KMeansTest extends AbstractTestBase { } @Test - public void testFewerDistinctPointsThanCluster() throws Exception { + public void testFewerDistinctPointsThanCluster() { List<DenseVector> data = Arrays.asList( Vectors.dense(0.0, 0.1), Vectors.dense(0.0, 0.1), Vectors.dense(0.0, 0.1)); @@ -185,7 +185,7 @@ public class KMeansTest extends AbstractTestBase { } @Test - public void testFitAndPredict() throws Exception { + public void testFitAndPredict() { KMeans kmeans = new KMeans().setMaxIter(2).setK(2); KMeansModel model = kmeans.fit(dataTable); Table output = model.transform(dataTable)[0]; @@ -201,18 +201,14 @@ public class KMeansTest extends AbstractTestBase { @Test public void testSaveLoadAndPredict() throws Exception { KMeans kmeans = new KMeans().setMaxIter(2).setK(2); - KMeans loadedKmeans = StageTestUtils.saveAndReload(env, kmeans, tempFolder.newFolder().getAbsolutePath()); - KMeansModel model = loadedKmeans.fit(dataTable); - KMeansModel loadedModel = StageTestUtils.saveAndReload(env, model, tempFolder.newFolder().getAbsolutePath()); Table output = loadedModel.transform(dataTable)[0]; - assertEquals( - Arrays.asList("f0"), + Collections.singletonList("centroids"), loadedModel.getModelData()[0].getResolvedSchema().getColumnNames()); assertEquals( Arrays.asList("features", "prediction"), @@ -226,16 +222,17 @@ public class KMeansTest extends AbstractTestBase { @Test public void testGetModelData() throws Exception { KMeans kmeans = new KMeans().setMaxIter(2).setK(2); - KMeansModel modelA = kmeans.fit(dataTable); - Table modelData = modelA.getModelData()[0]; - - DataStream<KMeansModelData> output = - tEnv.toDataStream(modelData).map(row -> (KMeansModelData) row.getField("f0")); - - assertEquals(Arrays.asList("f0"), modelData.getResolvedSchema().getColumnNames()); - List<KMeansModelData> kMeansModelData = IteratorUtils.toList(output.executeAndCollect()); - DenseVector[] centroids = kMeansModelData.get(0).centroids; - assertEquals(1, kMeansModelData.size()); + KMeansModel model = kmeans.fit(dataTable); + assertEquals( + Collections.singletonList("centroids"), + model.getModelData()[0].getResolvedSchema().getColumnNames()); + + DataStream<KMeansModelData> modelData = + KMeansModelData.getModelDataStream(model.getModelData()[0]); + List<KMeansModelData> collectedModelData = + IteratorUtils.toList(modelData.executeAndCollect()); + assertEquals(1, collectedModelData.size()); + DenseVector[] centroids = collectedModelData.get(0).centroids; assertEquals(2, centroids.length); Arrays.sort(centroids, Comparator.comparingDouble(vector -> vector.get(0))); assertArrayEquals(centroids[0].values, new double[] {0.1, 0.1}, 1e-5); @@ -243,12 +240,10 @@ public class KMeansTest extends AbstractTestBase { } @Test - public void testSetModelData() throws Exception { + public void testSetModelData() { KMeans kmeans = new KMeans().setMaxIter(2).setK(2); KMeansModel modelA = kmeans.fit(dataTable); - Table modelData = modelA.getModelData()[0]; - - KMeansModel modelB = new KMeansModel().setModelData(modelData); + KMeansModel modelB = new KMeansModel().setModelData(modelA.getModelData()); ReadWriteUtils.updateExistingParams(modelB, modelA.getParamMap()); Table output = modelB.transform(dataTable)[0];
