This is an automated email from the ASF dual-hosted git repository.

zhangzp pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git


The following commit(s) were added to refs/heads/master by this push:
     new 2015daf  [FLINK-26263] Check data size in LogisticRegression
2015daf is described below

commit 2015dafc6bad65d3ce9b5e8ae6d1ae9b567e910c
Author: yunfengzhou-hub <[email protected]>
AuthorDate: Mon Feb 21 09:44:42 2022 +0800

    [FLINK-26263] Check data size in LogisticRegression
    
    This closes #63.
---
 .../logisticregression/LogisticRegression.java               |  3 +++
 .../flink/ml/classification/LogisticRegressionTest.java      | 12 ++++++++++++
 2 files changed, 15 insertions(+)

diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java
 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java
index 602b9c6..587f225 100644
--- 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java
@@ -363,6 +363,9 @@ public class LogisticRegression
         public void onEpochWatermarkIncremented(
                 int epochWatermark, Context context, Collector<double[]> 
collector)
                 throws Exception {
+            if (!trainDataState.get().iterator().hasNext()) {
+                return;
+            }
             if (epochWatermark == 0) {
                 coefficient = new DenseVector(feedbackBuffer);
                 coefficientDim = coefficient.size();
diff --git 
a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java
 
b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java
index 2c73564..ea6d80a 100644
--- 
a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java
+++ 
b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java
@@ -281,4 +281,16 @@ public class LogisticRegressionTest {
                     e.getCause().getCause().getMessage());
         }
     }
+
+    @Test
+    public void testMoreSubtaskThanData() throws Exception {
+        env.setParallelism(12);
+        LogisticRegression logisticRegression = new 
LogisticRegression().setWeightCol("weight");
+        Table output = 
logisticRegression.fit(binomialDataTable).transform(binomialDataTable)[0];
+        verifyPredictionResult(
+                output,
+                logisticRegression.getFeaturesCol(),
+                logisticRegression.getPredictionCol(),
+                logisticRegression.getRawPredictionCol());
+    }
 }

Reply via email to