lindong28 commented on code in PR #219:
URL: https://github.com/apache/flink-ml/pull/219#discussion_r1138192339
##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModel.java:
##########
@@ -149,22 +152,14 @@ public Row map(Row dataPoint) {
getRuntimeContext().getBroadcastVariable(broadcastModelKey).get(0);
coefficient = modelData.coefficient;
}
- DenseVector features = ((Vector)
dataPoint.getField(featuresCol)).toDense();
- Row predictionResult = predictOneDataPoint(features, coefficient);
- return Row.join(dataPoint, predictionResult);
- }
- }
+ Vector features = (Vector) dataPoint.getField(featuresCol);
+
+ LogisticRegressionModelServable servable =
Review Comment:
Would it be more performant to initialize the
`LogisticRegressionModelServable` only once and re-use it for predictions?
##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModel.java:
##########
@@ -149,22 +152,14 @@ public Row map(Row dataPoint) {
getRuntimeContext().getBroadcastVariable(broadcastModelKey).get(0);
coefficient = modelData.coefficient;
}
- DenseVector features = ((Vector)
dataPoint.getField(featuresCol)).toDense();
- Row predictionResult = predictOneDataPoint(features, coefficient);
- return Row.join(dataPoint, predictionResult);
- }
- }
+ Vector features = (Vector) dataPoint.getField(featuresCol);
+
+ LogisticRegressionModelServable servable =
+ new LogisticRegressionModelServable(
Review Comment:
Can we set model data using `LogisticRegressionModelServable#setModelData`,
instead of introducing the
`LogisticRegressionModelServable(LogisticRegressionModelData)` constructor?
Also, should we forward all parameters from `LogisticRegressionModelParams`
to this servable instance?
Note that servables such as `MinMaxScalerModelServable` might rely on
parameters (e.g. MinMaxScalerParams#getMin) to do online transformation. It
will be useful to consistently forward parameters from model to servable and
figure out the required infra in this PR.
##########
flink-ml-servable-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelServable.java:
##########
@@ -0,0 +1,132 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.servable.api.DataFrame;
+import org.apache.flink.ml.servable.api.ModelServable;
+import org.apache.flink.ml.servable.api.Row;
+import org.apache.flink.ml.servable.types.BasicType;
+import org.apache.flink.ml.servable.types.DataTypes;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ServableReadWriteUtils;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/** A Servable which can be used to classifies data in online inference. */
+public class LogisticRegressionModelServable
+ implements ModelServable<LogisticRegressionModelServable>,
+ LogisticRegressionModelParams<LogisticRegressionModelServable>
{
+
+ private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+ private LogisticRegressionModelData modelData;
+
+ public LogisticRegressionModelServable() {
+ ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+ }
+
+ public LogisticRegressionModelServable(LogisticRegressionModelData
modelData) {
+ this();
+ this.modelData = modelData;
+ }
+
+ @Override
+ public DataFrame transform(DataFrame input) {
+ List<Double> predictionResults = new ArrayList<>();
+ List<DenseVector> rawPredictionResults = new ArrayList<>();
+
+ int featuresColIndex = input.getIndex(getFeaturesCol());
+ for (Row row : input.collect()) {
+ Vector features = (Vector) row.get(featuresColIndex);
+ Tuple2<Double, DenseVector> dataPoint = transform(features);
+ predictionResults.add(dataPoint.f0);
+ rawPredictionResults.add(dataPoint.f1);
+ }
+
+ input.addColumn(getPredictionCol(), DataTypes.DOUBLE,
predictionResults);
+ input.addColumn(
+ getRawPredictionCol(), DataTypes.VECTOR(BasicType.DOUBLE),
rawPredictionResults);
+
+ return input;
+ }
+
+ public LogisticRegressionModelServable setModelData(InputStream...
modelDataInputs)
+ throws IOException {
+ Preconditions.checkArgument(modelDataInputs.length == 1);
+
+ DataInputViewStreamWrapper inputViewStreamWrapper =
+ new DataInputViewStreamWrapper(modelDataInputs[0]);
+
+ DenseVectorSerializer serializer = new DenseVectorSerializer();
Review Comment:
It seems that the logic of deserialization model data is duplicated between
`LogisticRegressionModelServable#load` and
`LogisticRegressionModelServable#setModelData`.
And it might be more readable to put the method that de-serialize the model
data in the same class as the method that serialize the model data.
##########
flink-ml-servable-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelServable.java:
##########
@@ -0,0 +1,132 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.servable.api.DataFrame;
+import org.apache.flink.ml.servable.api.ModelServable;
+import org.apache.flink.ml.servable.api.Row;
+import org.apache.flink.ml.servable.types.BasicType;
+import org.apache.flink.ml.servable.types.DataTypes;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ServableReadWriteUtils;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/** A Servable which can be used to classifies data in online inference. */
+public class LogisticRegressionModelServable
+ implements ModelServable<LogisticRegressionModelServable>,
+ LogisticRegressionModelParams<LogisticRegressionModelServable>
{
+
+ private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+ private LogisticRegressionModelData modelData;
+
+ public LogisticRegressionModelServable() {
+ ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+ }
+
+ public LogisticRegressionModelServable(LogisticRegressionModelData
modelData) {
+ this();
Review Comment:
It seems unreadable that the instance created using this constructor does
not have `paramMap` initialized.
##########
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java:
##########
@@ -282,6 +287,50 @@ public void testSetModelData() throws Exception {
logisticRegression.getRawPredictionCol());
}
+ @Test
+ public void testSaveLoadServableAndPredict() throws Exception {
+ LogisticRegression logisticRegression = new
LogisticRegression().setWeightCol("weight");
+ LogisticRegressionModel model =
logisticRegression.fit(binomialDataTable);
+
+ LogisticRegressionModelServable servable =
+ saveAndLoadServable(
+ tEnv,
+ model,
+ tempFolder.newFolder().getAbsolutePath(),
+ LogisticRegressionModel::loadServable);
+
+ assertEquals("features", servable.getFeaturesCol());
+ assertEquals("prediction", servable.getPredictionCol());
+ assertEquals("rawPrediction", servable.getRawPredictionCol());
+
+ DataFrame output =
servable.transform(LogisticRegressionModelServableTest.PREDICT_DATA);
+ LogisticRegressionModelServableTest.verifyPredictionResult(
+ output,
+ servable.getFeaturesCol(),
+ servable.getPredictionCol(),
+ servable.getRawPredictionCol());
+ }
+
+ @Test
+ public void testSetModelDataToServable() throws Exception {
+ LogisticRegression logisticRegression = new
LogisticRegression().setWeightCol("weight");
+ LogisticRegressionModel model =
logisticRegression.fit(binomialDataTable);
+ List<LogisticRegressionModelData> modelData =
Review Comment:
The code requires users to explicitly deal with `LogisticRegressionModel`,
which seems different from what we want.
How about we add the method `public static
LogisticRegressionModelDataUtil#DataStream<byte[]> getModelDataByteStream(Table
modelData)` and do the following:
```
byte[] modelData =
(byte[]) IteratorUtils.toList(
LogisticRegressionModelDataUtil.getModelDataByteStream(model.getModelData()[0])
.executeAndCollect()).get(0);
LogisticRegressionModelServable servable = new
LogisticRegressionModelServable();
servable.setModelData(new ByteArrayInputStream(modelData));
```
And this approach would not need `LogisticRegressionModelData#serialize` as
a public method.
##########
flink-ml-servable-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelServable.java:
##########
@@ -0,0 +1,132 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.servable.api.DataFrame;
+import org.apache.flink.ml.servable.api.ModelServable;
+import org.apache.flink.ml.servable.api.Row;
+import org.apache.flink.ml.servable.types.BasicType;
+import org.apache.flink.ml.servable.types.DataTypes;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ServableReadWriteUtils;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/** A Servable which can be used to classifies data in online inference. */
+public class LogisticRegressionModelServable
+ implements ModelServable<LogisticRegressionModelServable>,
+ LogisticRegressionModelParams<LogisticRegressionModelServable>
{
+
+ private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+ private LogisticRegressionModelData modelData;
+
+ public LogisticRegressionModelServable() {
+ ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+ }
+
+ public LogisticRegressionModelServable(LogisticRegressionModelData
modelData) {
+ this();
+ this.modelData = modelData;
+ }
+
+ @Override
+ public DataFrame transform(DataFrame input) {
+ List<Double> predictionResults = new ArrayList<>();
+ List<DenseVector> rawPredictionResults = new ArrayList<>();
+
+ int featuresColIndex = input.getIndex(getFeaturesCol());
+ for (Row row : input.collect()) {
+ Vector features = (Vector) row.get(featuresColIndex);
+ Tuple2<Double, DenseVector> dataPoint = transform(features);
+ predictionResults.add(dataPoint.f0);
+ rawPredictionResults.add(dataPoint.f1);
+ }
+
+ input.addColumn(getPredictionCol(), DataTypes.DOUBLE,
predictionResults);
+ input.addColumn(
+ getRawPredictionCol(), DataTypes.VECTOR(BasicType.DOUBLE),
rawPredictionResults);
+
+ return input;
+ }
+
+ public LogisticRegressionModelServable setModelData(InputStream...
modelDataInputs)
+ throws IOException {
+ Preconditions.checkArgument(modelDataInputs.length == 1);
+
+ DataInputViewStreamWrapper inputViewStreamWrapper =
+ new DataInputViewStreamWrapper(modelDataInputs[0]);
+
+ DenseVectorSerializer serializer = new DenseVectorSerializer();
+
+ DenseVector coefficient =
serializer.deserialize(inputViewStreamWrapper);
+ long modelVersion = inputViewStreamWrapper.readLong();
+ modelData = new LogisticRegressionModelData(coefficient, modelVersion);
+
+ return this;
+ }
+
+ public static LogisticRegressionModelServable load(String path) throws
IOException {
+ LogisticRegressionModelServable servable =
+ ServableReadWriteUtils.loadServableParam(
+ path, LogisticRegressionModelServable.class);
+
+ try (InputStream fsDataInputStream =
ServableReadWriteUtils.loadModelData(path)) {
+ DataInputViewStreamWrapper dataInputViewStreamWrapper =
+ new DataInputViewStreamWrapper(fsDataInputStream);
+ DenseVectorSerializer serializer = new DenseVectorSerializer();
+ DenseVector coefficient =
serializer.deserialize(dataInputViewStreamWrapper);
+ long modelVersion = dataInputViewStreamWrapper.readLong();
+
+ servable.modelData = new LogisticRegressionModelData(coefficient,
modelVersion);
Review Comment:
Would it be simpler to set model data via `servable#setModelData`?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]