This is an automated email from the ASF dual-hosted git repository.
zhangzp pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git
The following commit(s) were added to refs/heads/master by this push:
new 2015daf [FLINK-26263] Check data size in LogisticRegression
2015daf is described below
commit 2015dafc6bad65d3ce9b5e8ae6d1ae9b567e910c
Author: yunfengzhou-hub <[email protected]>
AuthorDate: Mon Feb 21 09:44:42 2022 +0800
[FLINK-26263] Check data size in LogisticRegression
This closes #63.
---
.../logisticregression/LogisticRegression.java | 3 +++
.../flink/ml/classification/LogisticRegressionTest.java | 12 ++++++++++++
2 files changed, 15 insertions(+)
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java
index 602b9c6..587f225 100644
---
a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java
@@ -363,6 +363,9 @@ public class LogisticRegression
public void onEpochWatermarkIncremented(
int epochWatermark, Context context, Collector<double[]>
collector)
throws Exception {
+ if (!trainDataState.get().iterator().hasNext()) {
+ return;
+ }
if (epochWatermark == 0) {
coefficient = new DenseVector(feedbackBuffer);
coefficientDim = coefficient.size();
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java
index 2c73564..ea6d80a 100644
---
a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java
+++
b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java
@@ -281,4 +281,16 @@ public class LogisticRegressionTest {
e.getCause().getCause().getMessage());
}
}
+
+ @Test
+ public void testMoreSubtaskThanData() throws Exception {
+ env.setParallelism(12);
+ LogisticRegression logisticRegression = new
LogisticRegression().setWeightCol("weight");
+ Table output =
logisticRegression.fit(binomialDataTable).transform(binomialDataTable)[0];
+ verifyPredictionResult(
+ output,
+ logisticRegression.getFeaturesCol(),
+ logisticRegression.getPredictionCol(),
+ logisticRegression.getRawPredictionCol());
+ }
}