lindong28 commented on code in PR #196:
URL: https://github.com/apache/flink-ml/pull/196#discussion_r1070769249


##########
flink-ml-core/src/main/java/org/apache/flink/ml/common/metrics/MLMetrics.java:
##########
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.metrics;
+
+import org.apache.flink.annotation.PublicEvolving;
+
+/**
+ * A collection class for handling metrics in Flink ML.
+ *
+ * <p>All metrics of Flink ML are registered under group "ml", which is a 
child group of {@link
+ * org.apache.flink.metrics.groups.OperatorMetricGroup}. Metrics related to 
model data will be
+ * registered in the group "ml.model".
+ *
+ * <p>For example, the timestamp of the current model data will be reported in 
metric:
+ * "{some_parent_groups}.operator.ml.model.timestamp". And the version of the 
current model data
+ * will be reported in metric: 
"{some_parent_groups}.operator.ml.model.version".
+ */
+@PublicEvolving

Review Comment:
   Since we have not opened FLIP for these metrics, should we mark it 
`@Experimental` for now?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasModelVersionCol.java:
##########
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.param;
+
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.StringParam;
+import org.apache.flink.ml.param.WithParams;
+
+/** Interface for the shared model version column param. */
+public interface HasModelVersionCol<T> extends WithParams<T> {
+    Param<String> MODEL_VERSION_COL =
+            new StringParam(
+                    "modelVersionCol",
+                    "The version of the model data that the input data is 
predicted with.",

Review Comment:
   Would it be useful to also specify the value type (e.g. long) of this column?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasMaxAllowedModelDelayMs.java:
##########
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.param;
+
+import org.apache.flink.ml.param.LongParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.WithParams;
+
+/** Interface for the shared max allowed model delay in milliseconds param. */
+public interface HasMaxAllowedModelDelayMs<T> extends WithParams<T> {
+    Param<Long> MAX_ALLOWED_MODEL_DELAY_MS =
+            new LongParam(
+                    "maxAllowedModelDelayMs",
+                    "The maximum difference between the timestamps of the 
input record and model data when "

Review Comment:
   It seems more readable to make the following change:
   
   `model data when using the model data to predict that input record`
   
   -> 
   
   `the model data that is used to make predictions on that input record`



##########
flink-ml-lib/src/test/java/org/apache/flink/ml/feature/StandardScalerTest.java:
##########
@@ -206,7 +206,7 @@ public void testSaveLoadAndPredict() throws Exception {
         model = TestUtils.saveAndReload(tEnv, model, 
tempFolder.newFolder().getAbsolutePath());
 
         assertEquals(
-                Arrays.asList("mean", "std"),
+                Arrays.asList("mean", "std", "version", "timestamp"),

Review Comment:
   Even though we updated the model data emitted by StandardScalerModel that 
have these columns, it is an unintended implementation detail and is avoided, 
since we didn't expect `StandardScalerModel` to have concepts of version.
   
   It might be more intuitive to verify `model data contains those columns`. 
What do you think?
   
   Same for other changes in this test class that can not be explained by the 
semantics of the StandardScaler algorithm.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/feature/standardscaler/OnlineStandardScalerModel.java:
##########
@@ -0,0 +1,298 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.standardscaler;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.metrics.MetricGroup;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.common.metrics.MLMetrics;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamElementSerializer;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+
+/** A Model which transforms data using the model data computed by {@link 
OnlineStandardScaler}. */
+public class OnlineStandardScalerModel
+        implements Model<OnlineStandardScalerModel>,
+                OnlineStandardScalerModelParams<OnlineStandardScalerModel> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table modelDataTable;
+
+    public OnlineStandardScalerModel() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+
+        RowTypeInfo inputTypeInfo = 
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        String modelVersionCol = getModelVersionCol();
+
+        TypeInformation<?>[] outputTypes;
+        String[] outputNames;
+        if (modelVersionCol == null) {
+            outputTypes = ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), 
VectorTypeInfo.INSTANCE);
+            outputNames = ArrayUtils.addAll(inputTypeInfo.getFieldNames(), 
getOutputCol());
+        } else {
+            outputTypes =
+                    ArrayUtils.addAll(
+                            inputTypeInfo.getFieldTypes(), 
VectorTypeInfo.INSTANCE, Types.LONG);
+            outputNames =
+                    ArrayUtils.addAll(
+                            inputTypeInfo.getFieldNames(), getOutputCol(), 
modelVersionCol);
+        }
+        RowTypeInfo outputTypeInfo = new RowTypeInfo(outputTypes, outputNames);
+
+        DataStream<Row> predictionResult =
+                tEnv.toDataStream(inputs[0])
+                        .connect(
+                                
StandardScalerModelData.getModelDataStream(modelDataTable)
+                                        .broadcast())
+                        .transform(
+                                "PredictionOperator",
+                                outputTypeInfo,
+                                new PredictionOperator(
+                                        inputTypeInfo,
+                                        getInputCol(),
+                                        getWithMean(),
+                                        getWithStd(),
+                                        getMaxAllowedModelDelayMs(),
+                                        getModelVersionCol()));
+
+        return new Table[] {tEnv.fromDataStream(predictionResult)};
+    }
+
+    /** A utility operator used for prediction. */
+    @SuppressWarnings({"unchecked", "rawtypes"})
+    private static class PredictionOperator extends AbstractStreamOperator<Row>
+            implements TwoInputStreamOperator<Row, StandardScalerModelData, 
Row> {
+        private final RowTypeInfo inputTypeInfo;
+
+        private final String inputCol;
+
+        private final boolean withMean;
+
+        private final boolean withStd;
+
+        private final long maxAllowedModelDelayMs;
+
+        private final String modelVersionCol;
+
+        private ListState<StreamRecord> bufferedPointsState;
+
+        private ListState<StandardScalerModelData> modelDataState;
+
+        /** Model data for inference. */
+        private StandardScalerModelData modelData;
+
+        private DenseVector mean;
+
+        /** Inverse of standard deviation. */
+        private DenseVector scale;
+
+        private long modelVersion;
+
+        private long modelTimeStamp;
+
+        public PredictionOperator(
+                RowTypeInfo inputTypeInfo,
+                String inputCol,
+                boolean withMean,
+                boolean withStd,
+                long maxAllowedModelDelayMs,
+                String modelVersionCol) {
+            this.inputTypeInfo = inputTypeInfo;
+            this.inputCol = inputCol;
+            this.withMean = withMean;
+            this.withStd = withStd;
+            this.maxAllowedModelDelayMs = maxAllowedModelDelayMs;
+            this.modelVersionCol = modelVersionCol;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws 
Exception {
+            super.initializeState(context);
+            bufferedPointsState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<StreamRecord>(
+                                            "bufferedPoints",
+                                            new StreamElementSerializer(
+                                                    
inputTypeInfo.createSerializer(
+                                                            
getExecutionConfig()))));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "modelData",
+                                            
TypeInformation.of(StandardScalerModelData.class)));
+        }
+
+        @Override
+        public void snapshotState(StateSnapshotContext context) throws 
Exception {
+            super.snapshotState(context);
+            if (modelData != null) {
+                modelDataState.clear();
+                modelDataState.add(modelData);
+            }
+        }
+
+        @Override
+        public void open() throws Exception {
+            super.open();
+            MetricGroup mlModelMetricGroup =
+                    getRuntimeContext()
+                            .getMetricGroup()
+                            .addGroup(MLMetrics.ML_GROUP)
+                            .addGroup(MLMetrics.ML_MODEL_GROUP);
+            mlModelMetricGroup.gauge(MLMetrics.TIMESTAMP, (Gauge<Long>) () -> 
modelTimeStamp);
+            mlModelMetricGroup.gauge(MLMetrics.VERSION, (Gauge<Long>) () -> 
modelVersion);
+        }
+
+        @Override
+        public void processElement1(StreamRecord<Row> dataPoint) throws 
Exception {
+            if (dataPoint.getTimestamp() <= modelTimeStamp + 
maxAllowedModelDelayMs
+                    && mean != null) {
+                doPrediction(dataPoint);
+            } else {
+                bufferedPointsState.add(dataPoint);
+            }
+        }
+
+        @Override
+        public void processElement2(StreamRecord<StandardScalerModelData> 
streamRecord)
+                throws Exception {
+            modelData = streamRecord.getValue();
+
+            modelTimeStamp = modelData.timestamp;
+            modelVersion = modelData.version;
+            mean = modelData.mean;
+            DenseVector std = modelData.std;
+
+            if (withStd) {
+                scale = std;
+                double[] scaleValues = scale.values;
+                for (int i = 0; i < scaleValues.length; i++) {
+                    scaleValues[i] = scaleValues[i] == 0 ? 0 : 1 / 
scaleValues[i];
+                }
+            }
+
+            // Does prediction on the cached data.
+            List<StreamRecord> unprocessedElements = new ArrayList<>();
+            for (StreamRecord dataPoint : bufferedPointsState.get()) {
+                if (dataPoint.getTimestamp() <= modelTimeStamp + 
maxAllowedModelDelayMs) {
+                    doPrediction(dataPoint);
+                } else {
+                    unprocessedElements.add(dataPoint);
+                }
+            }
+            bufferedPointsState.clear();
+            if (unprocessedElements.size() > 0) {
+                bufferedPointsState.update(unprocessedElements);

Review Comment:
   As a minor optimization, can we only clear/update the state if at least one 
element is predicted in the above for loop?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasMaxAllowedModelDelayMs.java:
##########
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.param;
+
+import org.apache.flink.ml.param.LongParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.WithParams;
+
+/** Interface for the shared max allowed model delay in milliseconds param. */
+public interface HasMaxAllowedModelDelayMs<T> extends WithParams<T> {
+    Param<Long> MAX_ALLOWED_MODEL_DELAY_MS =
+            new LongParam(
+                    "maxAllowedModelDelayMs",
+                    "The maximum difference between the timestamps of the 
input record and model data when "

Review Comment:
   `maximum difference` -> `maximum difference allowed`



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/feature/standardscaler/OnlineStandardScalerModel.java:
##########
@@ -0,0 +1,298 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.standardscaler;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.metrics.MetricGroup;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.common.metrics.MLMetrics;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamElementSerializer;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+
+/** A Model which transforms data using the model data computed by {@link 
OnlineStandardScaler}. */
+public class OnlineStandardScalerModel
+        implements Model<OnlineStandardScalerModel>,
+                OnlineStandardScalerModelParams<OnlineStandardScalerModel> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table modelDataTable;
+
+    public OnlineStandardScalerModel() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+
+        RowTypeInfo inputTypeInfo = 
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        String modelVersionCol = getModelVersionCol();
+
+        TypeInformation<?>[] outputTypes;
+        String[] outputNames;
+        if (modelVersionCol == null) {
+            outputTypes = ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), 
VectorTypeInfo.INSTANCE);
+            outputNames = ArrayUtils.addAll(inputTypeInfo.getFieldNames(), 
getOutputCol());
+        } else {
+            outputTypes =
+                    ArrayUtils.addAll(
+                            inputTypeInfo.getFieldTypes(), 
VectorTypeInfo.INSTANCE, Types.LONG);
+            outputNames =
+                    ArrayUtils.addAll(
+                            inputTypeInfo.getFieldNames(), getOutputCol(), 
modelVersionCol);
+        }
+        RowTypeInfo outputTypeInfo = new RowTypeInfo(outputTypes, outputNames);
+
+        DataStream<Row> predictionResult =
+                tEnv.toDataStream(inputs[0])
+                        .connect(
+                                
StandardScalerModelData.getModelDataStream(modelDataTable)
+                                        .broadcast())
+                        .transform(
+                                "PredictionOperator",
+                                outputTypeInfo,
+                                new PredictionOperator(
+                                        inputTypeInfo,
+                                        getInputCol(),
+                                        getWithMean(),
+                                        getWithStd(),
+                                        getMaxAllowedModelDelayMs(),
+                                        getModelVersionCol()));
+
+        return new Table[] {tEnv.fromDataStream(predictionResult)};
+    }
+
+    /** A utility operator used for prediction. */
+    @SuppressWarnings({"unchecked", "rawtypes"})
+    private static class PredictionOperator extends AbstractStreamOperator<Row>
+            implements TwoInputStreamOperator<Row, StandardScalerModelData, 
Row> {
+        private final RowTypeInfo inputTypeInfo;
+
+        private final String inputCol;
+
+        private final boolean withMean;
+
+        private final boolean withStd;
+
+        private final long maxAllowedModelDelayMs;
+
+        private final String modelVersionCol;
+
+        private ListState<StreamRecord> bufferedPointsState;
+
+        private ListState<StandardScalerModelData> modelDataState;
+
+        /** Model data for inference. */
+        private StandardScalerModelData modelData;
+
+        private DenseVector mean;
+
+        /** Inverse of standard deviation. */
+        private DenseVector scale;
+
+        private long modelVersion;
+
+        private long modelTimeStamp;
+
+        public PredictionOperator(
+                RowTypeInfo inputTypeInfo,
+                String inputCol,
+                boolean withMean,
+                boolean withStd,
+                long maxAllowedModelDelayMs,
+                String modelVersionCol) {
+            this.inputTypeInfo = inputTypeInfo;
+            this.inputCol = inputCol;
+            this.withMean = withMean;
+            this.withStd = withStd;
+            this.maxAllowedModelDelayMs = maxAllowedModelDelayMs;
+            this.modelVersionCol = modelVersionCol;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws 
Exception {
+            super.initializeState(context);
+            bufferedPointsState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<StreamRecord>(
+                                            "bufferedPoints",
+                                            new StreamElementSerializer(
+                                                    
inputTypeInfo.createSerializer(
+                                                            
getExecutionConfig()))));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "modelData",
+                                            
TypeInformation.of(StandardScalerModelData.class)));
+        }
+
+        @Override
+        public void snapshotState(StateSnapshotContext context) throws 
Exception {
+            super.snapshotState(context);
+            if (modelData != null) {
+                modelDataState.clear();
+                modelDataState.add(modelData);
+            }
+        }
+
+        @Override
+        public void open() throws Exception {
+            super.open();
+            MetricGroup mlModelMetricGroup =
+                    getRuntimeContext()
+                            .getMetricGroup()
+                            .addGroup(MLMetrics.ML_GROUP)
+                            .addGroup(MLMetrics.ML_MODEL_GROUP);
+            mlModelMetricGroup.gauge(MLMetrics.TIMESTAMP, (Gauge<Long>) () -> 
modelTimeStamp);
+            mlModelMetricGroup.gauge(MLMetrics.VERSION, (Gauge<Long>) () -> 
modelVersion);
+        }
+
+        @Override
+        public void processElement1(StreamRecord<Row> dataPoint) throws 
Exception {
+            if (dataPoint.getTimestamp() <= modelTimeStamp + 
maxAllowedModelDelayMs
+                    && mean != null) {
+                doPrediction(dataPoint);
+            } else {
+                bufferedPointsState.add(dataPoint);

Review Comment:
   Given that we only update `modelDataState` in `snapshotState()`, I suppose 
it is because updating `modelDataState()` every time we receive a model data is 
more expensive than updating the in-memory `modelData`.
   
   Should we follow the same approach for `dataPoint` consistency and 
performance?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasMaxAllowedModelDelayMs.java:
##########
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.param;
+
+import org.apache.flink.ml.param.LongParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.WithParams;
+
+/** Interface for the shared max allowed model delay in milliseconds param. */
+public interface HasMaxAllowedModelDelayMs<T> extends WithParams<T> {
+    Param<Long> MAX_ALLOWED_MODEL_DELAY_MS =
+            new LongParam(
+                    "maxAllowedModelDelayMs",
+                    "The maximum difference between the timestamps of the 
input record and model data when "

Review Comment:
   `maximum difference` -> `maximum difference allowed`



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/feature/standardscaler/OnlineStandardScalerModel.java:
##########
@@ -0,0 +1,298 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.standardscaler;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.metrics.MetricGroup;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.common.metrics.MLMetrics;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamElementSerializer;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+
+/** A Model which transforms data using the model data computed by {@link 
OnlineStandardScaler}. */
+public class OnlineStandardScalerModel
+        implements Model<OnlineStandardScalerModel>,
+                OnlineStandardScalerModelParams<OnlineStandardScalerModel> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table modelDataTable;
+
+    public OnlineStandardScalerModel() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+
+        RowTypeInfo inputTypeInfo = 
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        String modelVersionCol = getModelVersionCol();
+
+        TypeInformation<?>[] outputTypes;
+        String[] outputNames;
+        if (modelVersionCol == null) {
+            outputTypes = ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), 
VectorTypeInfo.INSTANCE);
+            outputNames = ArrayUtils.addAll(inputTypeInfo.getFieldNames(), 
getOutputCol());
+        } else {
+            outputTypes =
+                    ArrayUtils.addAll(
+                            inputTypeInfo.getFieldTypes(), 
VectorTypeInfo.INSTANCE, Types.LONG);
+            outputNames =
+                    ArrayUtils.addAll(
+                            inputTypeInfo.getFieldNames(), getOutputCol(), 
modelVersionCol);
+        }
+        RowTypeInfo outputTypeInfo = new RowTypeInfo(outputTypes, outputNames);
+
+        DataStream<Row> predictionResult =
+                tEnv.toDataStream(inputs[0])
+                        .connect(
+                                
StandardScalerModelData.getModelDataStream(modelDataTable)
+                                        .broadcast())
+                        .transform(
+                                "PredictionOperator",
+                                outputTypeInfo,
+                                new PredictionOperator(
+                                        inputTypeInfo,
+                                        getInputCol(),
+                                        getWithMean(),
+                                        getWithStd(),
+                                        getMaxAllowedModelDelayMs(),
+                                        getModelVersionCol()));
+
+        return new Table[] {tEnv.fromDataStream(predictionResult)};
+    }
+
+    /** A utility operator used for prediction. */
+    @SuppressWarnings({"unchecked", "rawtypes"})
+    private static class PredictionOperator extends AbstractStreamOperator<Row>
+            implements TwoInputStreamOperator<Row, StandardScalerModelData, 
Row> {
+        private final RowTypeInfo inputTypeInfo;
+
+        private final String inputCol;
+
+        private final boolean withMean;
+
+        private final boolean withStd;
+
+        private final long maxAllowedModelDelayMs;
+
+        private final String modelVersionCol;
+
+        private ListState<StreamRecord> bufferedPointsState;
+
+        private ListState<StandardScalerModelData> modelDataState;
+
+        /** Model data for inference. */
+        private StandardScalerModelData modelData;
+
+        private DenseVector mean;
+
+        /** Inverse of standard deviation. */
+        private DenseVector scale;
+
+        private long modelVersion;
+
+        private long modelTimeStamp;
+
+        public PredictionOperator(
+                RowTypeInfo inputTypeInfo,
+                String inputCol,
+                boolean withMean,
+                boolean withStd,
+                long maxAllowedModelDelayMs,
+                String modelVersionCol) {
+            this.inputTypeInfo = inputTypeInfo;
+            this.inputCol = inputCol;
+            this.withMean = withMean;
+            this.withStd = withStd;
+            this.maxAllowedModelDelayMs = maxAllowedModelDelayMs;
+            this.modelVersionCol = modelVersionCol;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws 
Exception {
+            super.initializeState(context);
+            bufferedPointsState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<StreamRecord>(
+                                            "bufferedPoints",
+                                            new StreamElementSerializer(
+                                                    
inputTypeInfo.createSerializer(
+                                                            
getExecutionConfig()))));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "modelData",
+                                            
TypeInformation.of(StandardScalerModelData.class)));
+        }
+
+        @Override
+        public void snapshotState(StateSnapshotContext context) throws 
Exception {
+            super.snapshotState(context);
+            if (modelData != null) {
+                modelDataState.clear();
+                modelDataState.add(modelData);
+            }
+        }
+
+        @Override
+        public void open() throws Exception {
+            super.open();
+            MetricGroup mlModelMetricGroup =
+                    getRuntimeContext()
+                            .getMetricGroup()
+                            .addGroup(MLMetrics.ML_GROUP)
+                            .addGroup(MLMetrics.ML_MODEL_GROUP);

Review Comment:
   Since we might have multiple models running in the same Flink Job, it might 
be better to specify the model name to reduce chance of conflict.
   
   How about using `addGroup(MLMetrics.ML_MODEL_GROUP, 
"OnlineStandardScalerModel")` here?



##########
docs/content/docs/operators/feature/onlinestandardscaler.md:
##########
@@ -0,0 +1,260 @@
+---
+title: "OnlineStandardScaler"
+weight: 1
+type: docs
+aliases:
+- /operators/feature/onlinestandardscaler.html
+---
+
+<!--
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements.  See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership.  The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License.  You may obtain a copy of the License at
+
+  http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing,
+software distributed under the License is distributed on an
+"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, either express or implied.  See the License for the
+specific language governing permissions and limitations
+under the License.
+-->
+
+## OnlineStandardScaler
+
+An Estimator which implements the online standard scaling algorithm, which 
+is the online version of StandardScaler.
+
+OnlineStandardScaler splits the input data by the user-specified window 
strategy.
+For each window, it computes the mean and standard deviation using the data 
seen
+so far (i.e., not only the data in the current window, but also the history 
data).
+The model data generated by OnlineStandardScaler is a model stream. 
+There is one model data for each window.
+
+During the inference phase (i.e., using OnlineStandardScalerModel for 
prediction),
+users could output the model version that is used for predicting each data 
point. 
+Moreover,
+- When the train data and test data both contain event time, users could 
+specify the maximum difference between the timestamps of the input and model 
data,
+which enforces to use a relatively fresh model for prediction.
+- Otherwise, the prediction process always uses the current model data for 
prediction.
+
+
+### Input Columns
+
+| Param name | Type   | Default   | Description            |
+|:-----------|:-------|:----------|:-----------------------|
+| inputCol   | Vector | `"input"` | Features to be scaled. |
+
+### Output Columns
+
+| Param name | Type   | Default    | Description      |
+|:-----------|:-------|:-----------|:-----------------|
+| outputCol  | Vector | `"output"` | Scaled features. |

Review Comment:
   Should we add `modelVersionCol` here?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/feature/standardscaler/OnlineStandardScalerModel.java:
##########
@@ -0,0 +1,298 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.standardscaler;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.metrics.MetricGroup;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.common.metrics.MLMetrics;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamElementSerializer;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+
+/** A Model which transforms data using the model data computed by {@link 
OnlineStandardScaler}. */
+public class OnlineStandardScalerModel
+        implements Model<OnlineStandardScalerModel>,
+                OnlineStandardScalerModelParams<OnlineStandardScalerModel> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table modelDataTable;
+
+    public OnlineStandardScalerModel() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+
+        RowTypeInfo inputTypeInfo = 
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        String modelVersionCol = getModelVersionCol();
+
+        TypeInformation<?>[] outputTypes;
+        String[] outputNames;
+        if (modelVersionCol == null) {
+            outputTypes = ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), 
VectorTypeInfo.INSTANCE);
+            outputNames = ArrayUtils.addAll(inputTypeInfo.getFieldNames(), 
getOutputCol());
+        } else {
+            outputTypes =
+                    ArrayUtils.addAll(
+                            inputTypeInfo.getFieldTypes(), 
VectorTypeInfo.INSTANCE, Types.LONG);
+            outputNames =
+                    ArrayUtils.addAll(
+                            inputTypeInfo.getFieldNames(), getOutputCol(), 
modelVersionCol);
+        }
+        RowTypeInfo outputTypeInfo = new RowTypeInfo(outputTypes, outputNames);
+
+        DataStream<Row> predictionResult =
+                tEnv.toDataStream(inputs[0])
+                        .connect(
+                                
StandardScalerModelData.getModelDataStream(modelDataTable)
+                                        .broadcast())
+                        .transform(
+                                "PredictionOperator",
+                                outputTypeInfo,
+                                new PredictionOperator(
+                                        inputTypeInfo,
+                                        getInputCol(),
+                                        getWithMean(),
+                                        getWithStd(),
+                                        getMaxAllowedModelDelayMs(),
+                                        getModelVersionCol()));
+
+        return new Table[] {tEnv.fromDataStream(predictionResult)};
+    }
+
+    /** A utility operator used for prediction. */
+    @SuppressWarnings({"unchecked", "rawtypes"})
+    private static class PredictionOperator extends AbstractStreamOperator<Row>
+            implements TwoInputStreamOperator<Row, StandardScalerModelData, 
Row> {
+        private final RowTypeInfo inputTypeInfo;
+
+        private final String inputCol;
+
+        private final boolean withMean;
+
+        private final boolean withStd;
+
+        private final long maxAllowedModelDelayMs;
+
+        private final String modelVersionCol;
+
+        private ListState<StreamRecord> bufferedPointsState;
+
+        private ListState<StandardScalerModelData> modelDataState;
+
+        /** Model data for inference. */
+        private StandardScalerModelData modelData;
+
+        private DenseVector mean;
+
+        /** Inverse of standard deviation. */
+        private DenseVector scale;
+
+        private long modelVersion;
+
+        private long modelTimeStamp;
+
+        public PredictionOperator(
+                RowTypeInfo inputTypeInfo,
+                String inputCol,
+                boolean withMean,
+                boolean withStd,
+                long maxAllowedModelDelayMs,
+                String modelVersionCol) {
+            this.inputTypeInfo = inputTypeInfo;
+            this.inputCol = inputCol;
+            this.withMean = withMean;
+            this.withStd = withStd;
+            this.maxAllowedModelDelayMs = maxAllowedModelDelayMs;
+            this.modelVersionCol = modelVersionCol;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws 
Exception {
+            super.initializeState(context);
+            bufferedPointsState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<StreamRecord>(
+                                            "bufferedPoints",
+                                            new StreamElementSerializer(
+                                                    
inputTypeInfo.createSerializer(
+                                                            
getExecutionConfig()))));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "modelData",
+                                            
TypeInformation.of(StandardScalerModelData.class)));
+        }
+
+        @Override
+        public void snapshotState(StateSnapshotContext context) throws 
Exception {
+            super.snapshotState(context);
+            if (modelData != null) {
+                modelDataState.clear();
+                modelDataState.add(modelData);
+            }
+        }
+
+        @Override
+        public void open() throws Exception {
+            super.open();
+            MetricGroup mlModelMetricGroup =
+                    getRuntimeContext()
+                            .getMetricGroup()
+                            .addGroup(MLMetrics.ML_GROUP)
+                            .addGroup(MLMetrics.ML_MODEL_GROUP);
+            mlModelMetricGroup.gauge(MLMetrics.TIMESTAMP, (Gauge<Long>) () -> 
modelTimeStamp);
+            mlModelMetricGroup.gauge(MLMetrics.VERSION, (Gauge<Long>) () -> 
modelVersion);
+        }
+
+        @Override
+        public void processElement1(StreamRecord<Row> dataPoint) throws 
Exception {
+            if (dataPoint.getTimestamp() <= modelTimeStamp + 
maxAllowedModelDelayMs

Review Comment:
   According to the maxAllowedModelDelayMs doc, this parameter should only be 
used if the input contains event time. I suppose there should be logic that 
differentiates between event time and system time. Can you explain where is 
this logic?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasModelVersionCol.java:
##########
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.param;
+
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.StringParam;
+import org.apache.flink.ml.param.WithParams;
+
+/** Interface for the shared model version column param. */
+public interface HasModelVersionCol<T> extends WithParams<T> {
+    Param<String> MODEL_VERSION_COL =
+            new StringParam(
+                    "modelVersionCol",
+                    "The version of the model data that the input data is 
predicted with.",
+                    null);

Review Comment:
   Would it be more usable to using "version" as the default value and always 
output this information? The reason is that I suppose adding an extra long 
column should not have an observable performance impact for nearline (not 
batch) prediction.
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to