yunfengzhou-hub commented on code in PR #156:
URL: https://github.com/apache/flink-ml/pull/156#discussion_r1007512230
##########
flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssembler.java:
##########
@@ -47,10 +47,16 @@
/**
* A Transformer which combines a given list of input columns into a vector
column. Types of input
- * columns must be either vector or numerical value.
+ * columns must be either vector or numerical types. The operator deals with
null values or records
+ * with wrong sizes according to the strategy specified by the {@link
HasHandleInvalid} parameter as
+ * follows:
*
- * <p>The `keep` option of {@link HasHandleInvalid} means that we output bad
rows with output column
- * set to null.
+ * <p>The `keep` option means that if the input column data is NaN, then it
keeps this value and if
+ * data is null vector, then uses a NaN vector to replace it.
+ *
+ * <p>The `skip` option means that it filters out rows with invalid elements.
+ *
+ * <p>The `error` option means that it throws an error exception when meeting
some invalid data.
Review Comment:
The following tags could better format the JavaDoc.
```html
<ul>
<li> AAA
<li> BBB
<li> CCC
</ul>
```
##########
flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VectorAssemblerTest.java:
##########
@@ -45,13 +45,16 @@
Review Comment:
Let's further add test cases to verify different handleInvalid strategies
when the input vectors are valid (no null or NaN values), but the size of the
vectors does not match with that set in `inputSizes`.
##########
flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssembler.java:
##########
@@ -74,64 +81,96 @@ public Table[] transform(Table... inputs) {
DataStream<Row> output =
tEnv.toDataStream(inputs[0])
.flatMap(
- new AssemblerFunc(getInputCols(),
getHandleInvalid()),
+ new AssemblerFunction(
+ getInputCols(), getHandleInvalid(),
getInputSizes()),
outputTypeInfo);
Table outputTable = tEnv.fromDataStream(output);
return new Table[] {outputTable};
}
- private static class AssemblerFunc implements FlatMapFunction<Row, Row> {
+ private static class AssemblerFunction implements FlatMapFunction<Row,
Row> {
private final String[] inputCols;
private final String handleInvalid;
+ private final Integer[] inputSizes;
+ private int vectorSize = 0;
+ private final boolean keepInvalid;
- public AssemblerFunc(String[] inputCols, String handleInvalid) {
+ public AssemblerFunction(String[] inputCols, String handleInvalid,
Integer[] sizeArray) {
this.inputCols = inputCols;
this.handleInvalid = handleInvalid;
+ this.inputSizes = sizeArray;
+ for (Integer inputSize : inputSizes) {
+ vectorSize += inputSize;
+ }
+ keepInvalid = handleInvalid.equals(HasHandleInvalid.KEEP_INVALID);
}
@Override
public void flatMap(Row value, Collector<Row> out) {
int nnz = 0;
- int vectorSize = 0;
try {
- for (String inputCol : inputCols) {
- Object object = value.getField(inputCol);
- Preconditions.checkNotNull(object, "Input column value
should not be null.");
+ for (int i = 0; i < inputCols.length; ++i) {
+ Object object = value.getField(inputCols[i]);
+ if (object == null) {
+ if (keepInvalid) {
+ if (inputSizes[i] > 1) {
+ DenseVector tmpVec = new
DenseVector(inputSizes[i]);
+ for (int j = 0; j < inputSizes[i]; ++j) {
+ tmpVec.values[j] = Double.NaN;
+ }
+ object = tmpVec;
+ } else {
+ object = Double.NaN;
+ }
+ value.setField(inputCols[i], object);
+ } else {
+ throw new RuntimeException(
+ "Input column value is null. Please check
the input data or using handleInvalid = 'keep'.");
+ }
+ }
if (object instanceof Number) {
+ if (Double.isNaN(((Number) object).doubleValue()) &&
!keepInvalid) {
+ throw new RuntimeException(
+ "Encountered NaN while assembling a row
with handleInvalid = 'error'. Consider "
+ + "removing NaNs from dataset or
using handleInvalid = 'keep' or 'skip'.");
+ }
+ checkVectorAndNumberSize(inputSizes[i], 1);
nnz += 1;
- vectorSize += 1;
} else if (object instanceof SparseVector) {
+ int localSize = ((SparseVector) object).size();
+ checkVectorAndNumberSize(inputSizes[i], localSize);
nnz += ((SparseVector) object).indices.length;
- vectorSize += ((SparseVector) object).size();
} else if (object instanceof DenseVector) {
+ int localSize = ((DenseVector) object).size();
+ checkVectorAndNumberSize(inputSizes[i], localSize);
nnz += ((DenseVector) object).size();
- vectorSize += ((DenseVector) object).size();
} else {
throw new IllegalArgumentException(
"Input type has not been supported yet.");
Review Comment:
Let's add descriptions about what is the type of input value that has not
been supported yet.
##########
flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssembler.java:
##########
@@ -74,64 +80,78 @@ public Table[] transform(Table... inputs) {
DataStream<Row> output =
tEnv.toDataStream(inputs[0])
.flatMap(
- new AssemblerFunc(getInputCols(),
getHandleInvalid()),
+ new AssemblerFunction(
+ getInputCols(), getHandleInvalid(),
getInputSizes()),
outputTypeInfo);
Table outputTable = tEnv.fromDataStream(output);
return new Table[] {outputTable};
}
- private static class AssemblerFunc implements FlatMapFunction<Row, Row> {
+ private static class AssemblerFunction implements FlatMapFunction<Row,
Row> {
private final String[] inputCols;
private final String handleInvalid;
+ private final Integer[] inputSizes;
+ private int vectorSize = 0;
+ private final boolean keepInvalid;
- public AssemblerFunc(String[] inputCols, String handleInvalid) {
+ public AssemblerFunction(String[] inputCols, String handleInvalid,
Integer[] sizeArray) {
Review Comment:
Let's rename the parameter `sizeArray` to `inputSizes` to keep the naming
consistent across the file.
##########
docs/content/docs/operators/feature/vectorassembler.md:
##########
@@ -27,8 +27,13 @@ under the License.
## Vector Assembler
-Vector Assembler combines a given list of input columns into a vector column.
-Types of input columns must be either vector or numerical value.
+A Transformer which combines a given list of input columns into a vector
column. Types of input
+columns must be either vector or numerical types. If the element is null or
has the wrong size,
+we will process this case with {@link HasHandleInvalid} parameter as follows:
+
+ * The `keep` option means that if the input column data is NaN, then it keeps
this value and if data is null vector, then uses a NaN vector to replace it.
Review Comment:
Let's wrap the lines so that each line has no more than 80 characters.
##########
flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssemblerParams.java:
##########
@@ -21,11 +21,44 @@
import org.apache.flink.ml.common.param.HasHandleInvalid;
import org.apache.flink.ml.common.param.HasInputCols;
import org.apache.flink.ml.common.param.HasOutputCol;
+import org.apache.flink.ml.param.IntArrayParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidator;
/**
* Params of {@link VectorAssembler}.
*
* @param <T> The class type of this instance.
*/
public interface VectorAssemblerParams<T>
- extends HasInputCols<T>, HasOutputCol<T>, HasHandleInvalid<T> {}
+ extends HasInputCols<T>, HasOutputCol<T>, HasHandleInvalid<T> {
+ Param<Integer[]> INPUT_SIZES =
+ new IntArrayParam(
+ "inputSizes",
+ "Sizes of the input elements to be assembled.",
+ null,
+ sizesValidator());
+
+ default Integer[] getInputSizes() {
+ return get(INPUT_SIZES);
+ }
+
+ default T setInputSizes(Integer... value) {
+ return set(INPUT_SIZES, value);
+ }
+
+ // Checks the inputSizes parameter.
+ static ParamValidator<Integer[]> sizesValidator() {
+ return inputSizes -> {
+ if (inputSizes == null) {
+ return false;
Review Comment:
In Spark, VectorAssembler can infer the vector sizes in some cases, which
means VectorSizeHint is not a compulsory prerequisite. Let's also support those
situations in Flink ML.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]