lindong28 commented on a change in pull request #73:
URL: https://github.com/apache/flink-ml/pull/73#discussion_r837111842
##########
File path: flink-ml-core/src/main/java/org/apache/flink/ml/linalg/BLAS.java
##########
@@ -32,9 +32,31 @@ public static double asum(DenseVector x) {
}
/** y += a * x . */
- public static void axpy(double a, DenseVector x, DenseVector y) {
+ public static void axpy(double a, Vector x, DenseVector y) {
Preconditions.checkArgument(x.size() == y.size(), "Vector size
mismatched.");
- JAVA_BLAS.daxpy(x.size(), a, x.values, 1, y.values, 1);
+ if (x instanceof SparseVector) {
+ axpy(a, (SparseVector) x, y);
+ } else {
+ axpy(a, (DenseVector) x, y);
+ }
+ }
+
+ /** Computes the hadamard product of the two vectors (y = y \hdot x). */
+ public static void hDot(Vector x, Vector y) {
+ Preconditions.checkArgument(x.size() == y.size(), "Vector size
mismatched.");
+ if (y instanceof DenseVector) {
+ if (x instanceof SparseVector) {
Review comment:
nits: it seems that we use both `instanceof DenseVector` and `instanceof
SparseVector` in this class. Would it be slightly better to consistently use
one of the two in this class?
##########
File path: flink-ml-core/src/main/java/org/apache/flink/ml/linalg/BLAS.java
##########
@@ -89,4 +111,56 @@ public static void gemv(
y.values,
1);
}
+
+ private static void axpy(double a, DenseVector x, DenseVector y) {
+ JAVA_BLAS.daxpy(x.size(), a, x.values, 1, y.values, 1);
+ }
+
+ private static void axpy(double a, SparseVector x, DenseVector y) {
+ for (int i = 0; i < x.indices.length; i++) {
+ int index = x.indices[i];
+ y.values[index] += a * x.values[i];
+ }
+ }
+
+ private static void hDot(SparseVector x, SparseVector y) {
+ int idx = 0;
+ int idy = 0;
+ while (idx < x.indices.length && idy < y.indices.length) {
+ int indexX = x.indices[idx];
+ while (idy < y.indices.length && y.indices[idy] < indexX) {
+ y.values[idy] = 0;
+ idy++;
+ }
+ if (idy < y.indices.length && y.indices[idy] == indexX) {
+ y.values[idy] *= x.values[idx];
+ idy++;
+ }
+ idx++;
+ }
+ }
+
+ private static void hDot(SparseVector x, DenseVector y) {
+ int idx = 0;
+ for (int i = 0; i < y.size(); i++) {
+ if (x.indices[idx] == i) {
Review comment:
Should we use the following check:
`if (idx < x.indices.length && x.indices[idx] == i)`
##########
File path:
flink-ml-lib/src/main/java/org/apache/flink/ml/feature/standardscaler/StandardScaler.java
##########
@@ -0,0 +1,288 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.standardscaler;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the standard scaling algorithm.
+ *
+ * <p>Standardization is a common requirement for machine learning training
because they may behave
+ * badly if the individual features of a input do not look like standard
normally distributed data
+ * (e.g. Gaussian with 0 mean and unit variance).
+ *
+ * <p>This estimator standardizes the input features by removing the mean and
scaling each dimension
+ * to unit variance.
+ */
+public class StandardScaler
+ implements Estimator<StandardScaler, StandardScalerModel>,
+ StandardScalerParams<StandardScaler> {
+ private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+ public StandardScaler() {
+ ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+ }
+
+ @Override
+ public StandardScalerModel fit(Table... inputs) {
+ Preconditions.checkArgument(inputs.length == 1);
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl)
inputs[0]).getTableEnvironment();
+ DataStream<Tuple3<DenseVector, DenseVector, Long>>
sumAndSquaredSumAndWeight =
+ tEnv.toDataStream(inputs[0])
+ .transform(
+ "computeMeta",
+ new TupleTypeInfo<>(
+ TypeInformation.of(DenseVector.class),
+ TypeInformation.of(DenseVector.class),
+ BasicTypeInfo.LONG_TYPE_INFO),
+ new ComputeMetaOperator(getFeaturesCol()));
+
+ DataStream<StandardScalerModelData> modelData =
+ sumAndSquaredSumAndWeight
+ .transform(
+ "buildModel",
+
TypeInformation.of(StandardScalerModelData.class),
+ new BuildModelOperator())
+ .setParallelism(1);
+
+ StandardScalerModel model =
+ new
StandardScalerModel().setModelData(tEnv.fromDataStream(modelData));
+ ReadWriteUtils.updateExistingParams(model, paramMap);
+ return model;
+ }
+
+ /**
+ * Builds the {@link StandardScalerModelData} using the meta data computed
on each partition.
+ */
+ private static class BuildModelOperator extends
AbstractStreamOperator<StandardScalerModelData>
+ implements OneInputStreamOperator<
+ Tuple3<DenseVector, DenseVector, Long>,
StandardScalerModelData>,
+ BoundedOneInput {
+ private ListState<DenseVector> sumState;
+ private ListState<DenseVector> squaredSumState;
+ private ListState<Long> numElementsState;
+ private DenseVector sum;
+ private DenseVector squaredSum;
+ private long numElements;
+
+ @Override
+ public void endInput() {
+ if (numElements > 0) {
+ BLAS.scal(1.0 / numElements, sum);
+ double[] mean = sum.values;
+ double[] std = squaredSum.values;
+ if (numElements > 1) {
+ for (int i = 0; i < mean.length; i++) {
+ std[i] =
+ Math.sqrt(
+ (squaredSum.values[i] - numElements *
mean[i] * mean[i])
+ / (numElements - 1));
+ }
+ } else {
+ Arrays.fill(std, 0.0);
+ }
+
+ output.collect(
+ new StreamRecord<>(
+ new StandardScalerModelData(
+ Vectors.dense(mean),
Vectors.dense(std))));
+ }
+ }
+
+ @Override
+ public void processElement(StreamRecord<Tuple3<DenseVector,
DenseVector, Long>> element) {
+ Tuple3<DenseVector, DenseVector, Long> value = element.getValue();
+ if (sum == null) {
+ sum = value.f0;
+ squaredSum = value.f1;
+ numElements = value.f2;
+ } else {
+ BLAS.axpy(1, value.f0, sum);
+ BLAS.axpy(1, value.f1, squaredSum);
+ numElements += value.f2;
+ }
+ }
+
+ @Override
+ public void initializeState(StateInitializationContext context) throws
Exception {
+ super.initializeState(context);
+ sumState =
+ context.getOperatorStateStore()
+ .getListState(
+ new ListStateDescriptor<>(
+ "sumState",
TypeInformation.of(DenseVector.class)));
+ squaredSumState =
+ context.getOperatorStateStore()
+ .getListState(
+ new ListStateDescriptor<>(
+ "squaredSumState",
+
TypeInformation.of(DenseVector.class)));
+ numElementsState =
+ context.getOperatorStateStore()
+ .getListState(
+ new ListStateDescriptor<>(
+ "numElementsState",
BasicTypeInfo.LONG_TYPE_INFO));
+
+ sum = OperatorStateUtils.getUniqueElement(sumState,
"sumState").orElse(null);
+ squaredSum =
+ OperatorStateUtils.getUniqueElement(squaredSumState,
"squaredSumState")
+ .orElse(null);
+ numElements =
+ OperatorStateUtils.getUniqueElement(numElementsState,
"numElementsState")
+ .orElse(0L);
+ }
+
+ @Override
+ public void snapshotState(StateSnapshotContext context) throws
Exception {
+ super.snapshotState(context);
+ if (numElements > 0) {
+ sumState.update(Collections.singletonList(sum));
+ squaredSumState.update(Collections.singletonList(squaredSum));
+
numElementsState.update(Collections.singletonList(numElements));
+ }
+ }
+ }
+
+ /** Computes sum, squared sum and number of elements in each partition. */
+ private static class ComputeMetaOperator
+ extends AbstractStreamOperator<Tuple3<DenseVector, DenseVector,
Long>>
+ implements OneInputStreamOperator<Row, Tuple3<DenseVector,
DenseVector, Long>>,
+ BoundedOneInput {
+ private ListState<DenseVector> sumState;
+ private ListState<DenseVector> squaredSumState;
+ private ListState<Long> numElementsState;
+ private DenseVector sum;
+ private DenseVector squaredSum;
+ private long numElements;
+
+ private final String featuresCol;
+
+ public ComputeMetaOperator(String featuresCol) {
+ this.featuresCol = featuresCol;
+ }
+
+ @Override
+ public void endInput() {
+ if (numElements > 0) {
+ output.collect(new StreamRecord<>(Tuple3.of(sum, squaredSum,
numElements)));
+ }
+ }
+
+ @Override
+ public void processElement(StreamRecord<Row> element) throws Exception
{
+ Vector feature = (Vector) element.getValue().getField(featuresCol);
+ if (sum == null) {
+ sum = new DenseVector(feature.size());
+ squaredSum = new DenseVector(feature.size());
+ }
+ BLAS.axpy(1, feature, sum);
+ BLAS.hDot(feature, feature);
+ BLAS.axpy(1, feature, squaredSum);
+ numElements++;
+ }
+
+ @Override
+ public void initializeState(StateInitializationContext context) throws
Exception {
+ super.initializeState(context);
+ sumState =
+ context.getOperatorStateStore()
+ .getListState(
+ new ListStateDescriptor<>(
+ "sumState",
TypeInformation.of(DenseVector.class)));
+ squaredSumState =
+ context.getOperatorStateStore()
+ .getListState(
+ new ListStateDescriptor<>(
+ "squaredSumState",
+
TypeInformation.of(DenseVector.class)));
+ numElementsState =
+ context.getOperatorStateStore()
+ .getListState(
+ new ListStateDescriptor<>(
+ "numElementsState",
BasicTypeInfo.LONG_TYPE_INFO));
+
+ sum = OperatorStateUtils.getUniqueElement(sumState,
"sumState").orElse(null);
+ squaredSum =
+ OperatorStateUtils.getUniqueElement(squaredSumState,
"squaredSumState")
+ .orElse(null);
+ numElements =
+ OperatorStateUtils.getUniqueElement(numElementsState,
"numElementsState")
+ .orElse(0L);
+ }
+
+ @Override
+ public void snapshotState(StateSnapshotContext context) throws
Exception {
+ super.snapshotState(context);
+ if (numElements > 0) {
Review comment:
This code aims to handle the scenario where `fit(...)` is called with an
empty table. Under this scenario, the current implementation would generate an
empty modelDataTable.
Should we instead generate a modelTable with one element, whose value
indicates that the model data can not be used for inference?
By emitting a model data table with this value, model.transform(...) could
throw proper exception to indicate that it can not do inference. In comparison,
if we generate a model data table with no value, model.transform(...) would
have to block forever, since it will not be able to differentiate this scenario
from the scenario where source is taking a long time to read the model data.
Same for other usages of `if (numElements > 0)` in this PR.
##########
File path:
flink-ml-lib/src/main/java/org/apache/flink/ml/feature/standardscaler/StandardScalerModel.java
##########
@@ -0,0 +1,188 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.standardscaler;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+/** A Model which transforms data using the model data computed by {@link
StandardScaler}. */
+public class StandardScalerModel
+ implements Model<StandardScalerModel>,
StandardScalerParams<StandardScalerModel> {
+ private final Map<Param<?>, Object> paramMap = new HashMap<>();
+ private Table modelDataTable;
+
+ public StandardScalerModel() {
+ ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+ }
+
+ @Override
+ @SuppressWarnings("unchecked, rawtypes")
+ public Table[] transform(Table... inputs) {
+ Preconditions.checkArgument(inputs.length == 1);
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl)
inputs[0]).getTableEnvironment();
+ DataStream<Row> inputStream = tEnv.toDataStream(inputs[0]);
+
+ RowTypeInfo inputTypeInfo =
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+ RowTypeInfo outputTypeInfo =
+ new RowTypeInfo(
+ ArrayUtils.addAll(
+ inputTypeInfo.getFieldTypes(),
TypeInformation.of(Vector.class)),
+ ArrayUtils.addAll(inputTypeInfo.getFieldNames(),
getPredictionCol()));
+
+ final String broadcastModelKey = "broadcastModelKey";
+ DataStream<StandardScalerModelData> modelDataStream =
+ StandardScalerModelData.getModelDataStream(modelDataTable);
+
+ DataStream<Row> predictionResult =
+ BroadcastUtils.withBroadcastStream(
+ Collections.singletonList(inputStream),
+ Collections.singletonMap(broadcastModelKey,
modelDataStream),
+ inputList -> {
+ DataStream inputData = inputList.get(0);
+ return inputData.map(
+ new PredictOutputFunction(
+ broadcastModelKey,
+ getFeaturesCol(),
+ getWithMean(),
+ getWithStd()),
+ outputTypeInfo);
+ });
+
+ return new Table[] {tEnv.fromDataStream(predictionResult)};
+ }
+
+ /** A utility function used for prediction. */
+ private static class PredictOutputFunction extends RichMapFunction<Row,
Row> {
+ private final String broadcastModelKey;
+ private final String featuresCol;
+ private final boolean withMean;
+ private final boolean withStd;
+ private DenseVector mean;
+ private DenseVector scale;
+
+ public PredictOutputFunction(
+ String broadcastModelKey, String featuresCol, boolean
withMean, boolean withStd) {
+ this.broadcastModelKey = broadcastModelKey;
+ this.featuresCol = featuresCol;
+ this.withMean = withMean;
+ this.withStd = withStd;
+ }
+
+ @Override
+ public Row map(Row dataPoint) {
+ if (mean == null) {
+ StandardScalerModelData modelData =
+ (StandardScalerModelData)
+
getRuntimeContext().getBroadcastVariable(broadcastModelKey).get(0);
+ mean = modelData.mean;
+ DenseVector std = modelData.std;
+
+ if (withStd) {
+ scale = std;
+ double[] scaleValues = scale.values;
+ for (int i = 0; i < scaleValues.length; i++) {
+ scaleValues[i] = scaleValues[i] == 0 ? 0 : 1 /
scaleValues[i];
+ }
+ }
+ }
+
+ Vector feature = (Vector) (dataPoint.getField(featuresCol));
+ Vector output;
+ if (feature instanceof DenseVector) {
Review comment:
Would it be simpler to add the method `Vector::clone()` similar to
Spark's `Vector::copy()`?
##########
File path:
flink-ml-lib/src/main/java/org/apache/flink/ml/feature/standardscaler/StandardScalerParams.java
##########
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.standardscaler;
+
+import org.apache.flink.ml.common.param.HasFeaturesCol;
+import org.apache.flink.ml.common.param.HasPredictionCol;
+import org.apache.flink.ml.param.BooleanParam;
+import org.apache.flink.ml.param.Param;
+
+/**
+ * Params for {@link StandardScaler}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface StandardScalerParams<T> extends HasFeaturesCol<T>,
HasPredictionCol<T> {
+ Param<Boolean> WITH_MEAN =
+ new BooleanParam(
+ "withMean", "Whether centers the data with mean before
scaling.", false);
+
+ default Boolean getWithMean() {
+ return get(WITH_MEAN);
+ }
+
+ default T setWithMean(boolean withMean) {
+ return set(WITH_MEAN, withMean);
+ }
+
+ Param<Boolean> WITH_STD =
Review comment:
Could we move the variable declaration to be before the method
declarations, similar to other Java classes in Flink ML?
##########
File path:
flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScalerModel.java
##########
@@ -172,11 +173,9 @@ public Row map(Row row) {
}
}
DenseVector feature = (DenseVector) row.getField(featureCol);
- DenseVector outputVector = new DenseVector(scaleVector.size());
- for (int i = 0; i < scaleVector.size(); ++i) {
- outputVector.values[i] =
- feature.values[i] * scaleVector.values[i] +
offsetVector.values[i];
- }
+ DenseVector outputVector = feature.clone();
+ BLAS.hDot(scaleVector, outputVector);
+ BLAS.axpy(1, offsetVector, outputVector);
Review comment:
The new implementation seems to be strictly slower than the previous
implementation. Should we keep the previous implementation?
The previous implementation just needs one for loop over the the
`outputVector`.
The new implementation needs two for loops (i.e. `clone()` and `hDot()`)
over the `outputVector`, plus one `JAVA_BLAS.daxpy()` call.
##########
File path:
flink-ml-lib/src/main/java/org/apache/flink/ml/feature/standardscaler/StandardScaler.java
##########
@@ -0,0 +1,288 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.standardscaler;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the standard scaling algorithm.
+ *
+ * <p>Standardization is a common requirement for machine learning training
because they may behave
+ * badly if the individual features of a input do not look like standard
normally distributed data
+ * (e.g. Gaussian with 0 mean and unit variance).
+ *
+ * <p>This estimator standardizes the input features by removing the mean and
scaling each dimension
+ * to unit variance.
+ */
+public class StandardScaler
+ implements Estimator<StandardScaler, StandardScalerModel>,
+ StandardScalerParams<StandardScaler> {
+ private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+ public StandardScaler() {
+ ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+ }
+
+ @Override
+ public StandardScalerModel fit(Table... inputs) {
+ Preconditions.checkArgument(inputs.length == 1);
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl)
inputs[0]).getTableEnvironment();
+ DataStream<Tuple3<DenseVector, DenseVector, Long>>
sumAndSquaredSumAndWeight =
+ tEnv.toDataStream(inputs[0])
+ .transform(
+ "computeMeta",
+ new TupleTypeInfo<>(
+ TypeInformation.of(DenseVector.class),
+ TypeInformation.of(DenseVector.class),
+ BasicTypeInfo.LONG_TYPE_INFO),
+ new ComputeMetaOperator(getFeaturesCol()));
+
+ DataStream<StandardScalerModelData> modelData =
+ sumAndSquaredSumAndWeight
+ .transform(
+ "buildModel",
+
TypeInformation.of(StandardScalerModelData.class),
+ new BuildModelOperator())
+ .setParallelism(1);
+
+ StandardScalerModel model =
+ new
StandardScalerModel().setModelData(tEnv.fromDataStream(modelData));
+ ReadWriteUtils.updateExistingParams(model, paramMap);
+ return model;
+ }
+
+ /**
+ * Builds the {@link StandardScalerModelData} using the meta data computed
on each partition.
+ */
+ private static class BuildModelOperator extends
AbstractStreamOperator<StandardScalerModelData>
+ implements OneInputStreamOperator<
+ Tuple3<DenseVector, DenseVector, Long>,
StandardScalerModelData>,
+ BoundedOneInput {
+ private ListState<DenseVector> sumState;
+ private ListState<DenseVector> squaredSumState;
+ private ListState<Long> numElementsState;
+ private DenseVector sum;
+ private DenseVector squaredSum;
+ private long numElements;
+
+ @Override
+ public void endInput() {
+ if (numElements > 0) {
+ BLAS.scal(1.0 / numElements, sum);
+ double[] mean = sum.values;
+ double[] std = squaredSum.values;
+ if (numElements > 1) {
+ for (int i = 0; i < mean.length; i++) {
+ std[i] =
+ Math.sqrt(
+ (squaredSum.values[i] - numElements *
mean[i] * mean[i])
+ / (numElements - 1));
+ }
+ } else {
+ Arrays.fill(std, 0.0);
+ }
+
+ output.collect(
+ new StreamRecord<>(
+ new StandardScalerModelData(
+ Vectors.dense(mean),
Vectors.dense(std))));
+ }
+ }
+
+ @Override
+ public void processElement(StreamRecord<Tuple3<DenseVector,
DenseVector, Long>> element) {
+ Tuple3<DenseVector, DenseVector, Long> value = element.getValue();
+ if (sum == null) {
+ sum = value.f0;
+ squaredSum = value.f1;
+ numElements = value.f2;
+ } else {
+ BLAS.axpy(1, value.f0, sum);
+ BLAS.axpy(1, value.f1, squaredSum);
+ numElements += value.f2;
+ }
+ }
+
+ @Override
+ public void initializeState(StateInitializationContext context) throws
Exception {
+ super.initializeState(context);
+ sumState =
+ context.getOperatorStateStore()
+ .getListState(
+ new ListStateDescriptor<>(
+ "sumState",
TypeInformation.of(DenseVector.class)));
+ squaredSumState =
+ context.getOperatorStateStore()
+ .getListState(
+ new ListStateDescriptor<>(
+ "squaredSumState",
+
TypeInformation.of(DenseVector.class)));
+ numElementsState =
+ context.getOperatorStateStore()
+ .getListState(
+ new ListStateDescriptor<>(
+ "numElementsState",
BasicTypeInfo.LONG_TYPE_INFO));
+
+ sum = OperatorStateUtils.getUniqueElement(sumState,
"sumState").orElse(null);
+ squaredSum =
+ OperatorStateUtils.getUniqueElement(squaredSumState,
"squaredSumState")
+ .orElse(null);
+ numElements =
+ OperatorStateUtils.getUniqueElement(numElementsState,
"numElementsState")
+ .orElse(0L);
+ }
+
+ @Override
+ public void snapshotState(StateSnapshotContext context) throws
Exception {
+ super.snapshotState(context);
+ if (numElements > 0) {
+ sumState.update(Collections.singletonList(sum));
+ squaredSumState.update(Collections.singletonList(squaredSum));
+
numElementsState.update(Collections.singletonList(numElements));
+ }
+ }
+ }
+
+ /** Computes sum, squared sum and number of elements in each partition. */
+ private static class ComputeMetaOperator
+ extends AbstractStreamOperator<Tuple3<DenseVector, DenseVector,
Long>>
+ implements OneInputStreamOperator<Row, Tuple3<DenseVector,
DenseVector, Long>>,
+ BoundedOneInput {
+ private ListState<DenseVector> sumState;
+ private ListState<DenseVector> squaredSumState;
+ private ListState<Long> numElementsState;
+ private DenseVector sum;
+ private DenseVector squaredSum;
+ private long numElements;
+
+ private final String featuresCol;
+
+ public ComputeMetaOperator(String featuresCol) {
+ this.featuresCol = featuresCol;
+ }
+
+ @Override
+ public void endInput() {
+ if (numElements > 0) {
Review comment:
Would it be simpler to remove this check?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]