This is an automated email from the ASF dual-hosted git repository.
gaoyunhaii pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git
The following commit(s) were added to refs/heads/master by this push:
new 79d0c24 [FLINK-24955] Add Estimator and Transformer for One Hot
Encoder
79d0c24 is described below
commit 79d0c247e0cc900c9a16e045a99000423003780b
Author: Yunfeng Zhou <[email protected]>
AuthorDate: Thu Dec 2 15:32:17 2021 +0800
[FLINK-24955] Add Estimator and Transformer for One Hot Encoder
This closes #37.
---
.../org/apache/flink/ml/linalg/SparseVector.java | 168 ++++++++++++
.../java/org/apache/flink/ml/linalg/Vectors.java | 5 +
.../ml/linalg/typeinfo/DenseVectorSerializer.java | 8 +-
.../ml/linalg/typeinfo/DenseVectorTypeInfo.java | 2 +-
.../ml/linalg/typeinfo/SparseVectorSerializer.java | 151 +++++++++++
...ctorTypeInfo.java => SparseVectorTypeInfo.java} | 66 +++--
.../typeinfo/SparseVectorTypeInfoFactory.java | 40 +++
.../org/apache/flink/ml/param/ParamValidators.java | 5 +
.../java/org/apache/flink/ml/api/StageTest.java | 5 +
.../apache/flink/ml/linalg/SparseVectorTest.java | 132 +++++++++
.../flink/ml/common/param/HasHandleInvalid.java | 56 ++++
.../apache/flink/ml/common/param/HasInputCols.java | 23 +-
.../flink/ml/common/param/HasOutputCols.java | 23 +-
.../ml/feature/onehotencoder/OneHotEncoder.java | 148 +++++++++++
.../feature/onehotencoder/OneHotEncoderModel.java | 190 +++++++++++++
.../onehotencoder/OneHotEncoderModelData.java | 109 ++++++++
.../feature/onehotencoder/OneHotEncoderParams.java | 28 +-
.../apache/flink/ml/feature/OneHotEncoderTest.java | 294 +++++++++++++++++++++
18 files changed, 1393 insertions(+), 60 deletions(-)
diff --git
a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/SparseVector.java
b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/SparseVector.java
new file mode 100644
index 0000000..4e683a4
--- /dev/null
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/SparseVector.java
@@ -0,0 +1,168 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.linalg;
+
+import org.apache.flink.api.common.typeinfo.TypeInfo;
+import org.apache.flink.ml.linalg.typeinfo.SparseVectorTypeInfoFactory;
+import org.apache.flink.util.Preconditions;
+
+import java.util.Arrays;
+import java.util.Objects;
+
+/** A sparse vector of double values. */
+@TypeInfo(SparseVectorTypeInfoFactory.class)
+public class SparseVector implements Vector {
+ public final int n;
+ public final int[] indices;
+ public final double[] values;
+
+ public SparseVector(int n, int[] indices, double[] values) {
+ this.n = n;
+ this.indices = indices;
+ this.values = values;
+ if (!isIndicesSorted()) {
+ sortIndices();
+ }
+ validateSortedData();
+ }
+
+ @Override
+ public int size() {
+ return n;
+ }
+
+ @Override
+ public double get(int i) {
+ int pos = Arrays.binarySearch(indices, i);
+ if (pos >= 0) {
+ return values[pos];
+ }
+ return 0.;
+ }
+
+ @Override
+ public double[] toArray() {
+ double[] result = new double[n];
+ for (int i = 0; i < indices.length; i++) {
+ result[indices[i]] = values[i];
+ }
+ return result;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ SparseVector that = (SparseVector) o;
+ return n == that.n
+ && Arrays.equals(indices, that.indices)
+ && Arrays.equals(values, that.values);
+ }
+
+ @Override
+ public int hashCode() {
+ int result = Objects.hash(n);
+ result = 31 * result + Arrays.hashCode(indices);
+ result = 31 * result + Arrays.hashCode(values);
+ return result;
+ }
+
+ /**
+ * Checks whether input data is validate.
+ *
+ * <p>This function does the following checks:
+ *
+ * <ul>
+ * <li>The indices array and values array are of the same size.
+ * <li>vector indices are in valid range.
+ * <li>vector indices are unique.
+ * </ul>
+ *
+ * <p>This function works as expected only when indices are sorted.
+ */
+ private void validateSortedData() {
+ Preconditions.checkArgument(
+ indices.length == values.length,
+ "Indices size and values size should be the same.");
+ if (this.indices.length > 0) {
+ Preconditions.checkArgument(
+ this.indices[0] >= 0 && this.indices[this.indices.length -
1] < this.n,
+ "Index out of bound.");
+ }
+ for (int i = 1; i < this.indices.length; i++) {
+ Preconditions.checkArgument(
+ this.indices[i] > this.indices[i - 1], "Indices
duplicated.");
+ }
+ }
+
+ private boolean isIndicesSorted() {
+ for (int i = 1; i < this.indices.length; i++) {
+ if (this.indices[i] < this.indices[i - 1]) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ /** Sorts the indices and values. */
+ private void sortIndices() {
+ sortImpl(this.indices, this.values, 0, this.indices.length - 1);
+ }
+
+ /** Sorts the indices and values using quick sort. */
+ private static void sortImpl(int[] indices, double[] values, int low, int
high) {
+ int pivotPos = (low + high) / 2;
+ int pivot = indices[pivotPos];
+ swapIndexAndValue(indices, values, pivotPos, high);
+
+ int pos = low - 1;
+ for (int i = low; i <= high; i++) {
+ if (indices[i] <= pivot) {
+ pos++;
+ swapIndexAndValue(indices, values, pos, i);
+ }
+ }
+ if (high > pos + 1) {
+ sortImpl(indices, values, pos + 1, high);
+ }
+ if (pos - 1 > low) {
+ sortImpl(indices, values, low, pos - 1);
+ }
+ }
+
+ private static void swapIndexAndValue(int[] indices, double[] values, int
index1, int index2) {
+ int tempIndex = indices[index1];
+ indices[index1] = indices[index2];
+ indices[index2] = tempIndex;
+ double tempValue = values[index1];
+ values[index1] = values[index2];
+ values[index2] = tempValue;
+ }
+
+ @Override
+ public String toString() {
+ String sbr =
+ "(" + n + ", " + Arrays.toString(indices) + ", " +
Arrays.toString(values) + ")";
+ return sbr;
+ }
+}
diff --git
a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/Vectors.java
b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/Vectors.java
index a058755..424b27f 100644
--- a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/Vectors.java
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/Vectors.java
@@ -25,4 +25,9 @@ public class Vectors {
public static DenseVector dense(double... values) {
return new DenseVector(values);
}
+
+ /** Creates a sparse vector from its values. */
+ public static SparseVector sparse(int size, int[] indices, double[]
values) {
+ return new SparseVector(size, indices, values);
+ }
}
diff --git
a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorSerializer.java
b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorSerializer.java
index 3cbde53..153a20f 100644
---
a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorSerializer.java
+++
b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorSerializer.java
@@ -29,14 +29,14 @@ import org.apache.flink.ml.linalg.DenseVector;
import java.io.IOException;
import java.util.Arrays;
-/** Specialized serializer for {@code DenseVector}. */
+/** Specialized serializer for {@link DenseVector}. */
public final class DenseVectorSerializer extends
TypeSerializerSingleton<DenseVector> {
private static final long serialVersionUID = 1L;
private static final double[] EMPTY = new double[0];
- private static final DenseVectorSerializer INSTANCE = new
DenseVectorSerializer();
+ public static final DenseVectorSerializer INSTANCE = new
DenseVectorSerializer();
@Override
public boolean isImmutableType() {
@@ -84,9 +84,7 @@ public final class DenseVectorSerializer extends
TypeSerializerSingleton<DenseVe
public DenseVector deserialize(DataInputView source) throws IOException {
int len = source.readInt();
double[] values = new double[len];
- for (int i = 0; i < len; i++) {
- values[i] = source.readDouble();
- }
+ readDoubleArray(values, source, len);
return new DenseVector(values);
}
diff --git
a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorTypeInfo.java
b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorTypeInfo.java
index 0239e17..765cacb 100644
---
a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorTypeInfo.java
+++
b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorTypeInfo.java
@@ -64,7 +64,7 @@ public class DenseVectorTypeInfo extends
TypeInformation<DenseVector> {
@Override
@SuppressWarnings("unchecked")
public TypeSerializer<DenseVector> createSerializer(ExecutionConfig
executionConfig) {
- return new DenseVectorSerializer();
+ return DenseVectorSerializer.INSTANCE;
}
//
--------------------------------------------------------------------------------------------
diff --git
a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseVectorSerializer.java
b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseVectorSerializer.java
new file mode 100644
index 0000000..2c922a9
--- /dev/null
+++
b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseVectorSerializer.java
@@ -0,0 +1,151 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.flink.ml.linalg.typeinfo;
+
+import org.apache.flink.api.common.typeutils.SimpleTypeSerializerSnapshot;
+import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot;
+import org.apache.flink.api.common.typeutils.base.TypeSerializerSingleton;
+import org.apache.flink.core.memory.DataInputView;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.ml.linalg.SparseVector;
+
+import java.io.IOException;
+import java.util.Arrays;
+
+/** Specialized serializer for {@link SparseVector}. */
+public final class SparseVectorSerializer extends
TypeSerializerSingleton<SparseVector> {
+
+ private static final long serialVersionUID = 1L;
+
+ private static final double[] EMPTY_DOUBLE_ARRAY = new double[0];
+
+ private static final int[] EMPTY_INT_ARRAY = new int[0];
+
+ public static final SparseVectorSerializer INSTANCE = new
SparseVectorSerializer();
+
+ @Override
+ public boolean isImmutableType() {
+ return false;
+ }
+
+ @Override
+ public SparseVector createInstance() {
+ return new SparseVector(0, EMPTY_INT_ARRAY, EMPTY_DOUBLE_ARRAY);
+ }
+
+ @Override
+ public SparseVector copy(SparseVector from) {
+ return new SparseVector(
+ from.n,
+ Arrays.copyOf(from.indices, from.indices.length),
+ Arrays.copyOf(from.values, from.values.length));
+ }
+
+ @Override
+ public SparseVector copy(SparseVector from, SparseVector reuse) {
+ if (from.values.length == reuse.values.length && from.n == reuse.n) {
+ System.arraycopy(from.values, 0, reuse.values, 0,
from.values.length);
+ System.arraycopy(from.indices, 0, reuse.indices, 0,
from.indices.length);
+ return reuse;
+ }
+ return copy(from);
+ }
+
+ @Override
+ public int getLength() {
+ return -1;
+ }
+
+ @Override
+ public void serialize(SparseVector vector, DataOutputView target) throws
IOException {
+ if (vector == null) {
+ throw new IllegalArgumentException("The vector must not be null.");
+ }
+
+ target.writeInt(vector.n);
+ final int len = vector.values.length;
+ target.writeInt(len);
+ for (int i = 0; i < len; i++) {
+ target.writeInt(vector.indices[i]);
+ target.writeDouble(vector.values[i]);
+ }
+ }
+
+ // Reads `len` int values from `source` into `indices` and `len` double
values from `source`
+ // into `values`.
+ private void readSparseVectorArrays(
+ int[] indices, double[] values, DataInputView source, int len)
throws IOException {
+ for (int i = 0; i < len; i++) {
+ indices[i] = source.readInt();
+ values[i] = source.readDouble();
+ }
+ }
+
+ @Override
+ public SparseVector deserialize(DataInputView source) throws IOException {
+ int n = source.readInt();
+ int len = source.readInt();
+ int[] indices = new int[len];
+ double[] values = new double[len];
+ readSparseVectorArrays(indices, values, source, len);
+ return new SparseVector(n, indices, values);
+ }
+
+ @Override
+ public SparseVector deserialize(SparseVector reuse, DataInputView source)
throws IOException {
+ int n = source.readInt();
+ int len = source.readInt();
+ if (reuse.n == n && reuse.values.length == len) {
+ readSparseVectorArrays(reuse.indices, reuse.values, source, len);
+ return reuse;
+ }
+
+ int[] indices = new int[len];
+ double[] values = new double[len];
+ readSparseVectorArrays(indices, values, source, len);
+ return new SparseVector(n, indices, values);
+ }
+
+ @Override
+ public void copy(DataInputView source, DataOutputView target) throws
IOException {
+ int n = source.readInt();
+ int len = source.readInt();
+
+ target.writeInt(n);
+ target.writeInt(len);
+
+ target.write(source, len * 12);
+ }
+
+ @Override
+ public TypeSerializerSnapshot<SparseVector> snapshotConfiguration() {
+ return new SparseVectorSerializerSnapshot();
+ }
+
+ /** Serializer configuration snapshot for compatibility and format
evolution. */
+ @SuppressWarnings("WeakerAccess")
+ public static final class SparseVectorSerializerSnapshot
+ extends SimpleTypeSerializerSnapshot<SparseVector> {
+
+ public SparseVectorSerializerSnapshot() {
+ super(() -> INSTANCE);
+ }
+ }
+}
diff --git
a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorTypeInfo.java
b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseVectorTypeInfo.java
similarity index 50%
copy from
flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorTypeInfo.java
copy to
flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseVectorTypeInfo.java
index 0239e17..06686f0 100644
---
a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorTypeInfo.java
+++
b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseVectorTypeInfo.java
@@ -7,13 +7,14 @@
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
- * http://www.apache.org/licenses/LICENSE-2.0
+ * http://www.apache.org/licenses/LICENSE-2.0
*
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
*/
package org.apache.flink.ml.linalg.typeinfo;
@@ -21,39 +22,35 @@ package org.apache.flink.ml.linalg.typeinfo;
import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeutils.TypeSerializer;
-import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
-/** A {@link TypeInformation} for the {@link DenseVector} type. */
-public class DenseVectorTypeInfo extends TypeInformation<DenseVector> {
- private static final long serialVersionUID = 1L;
-
- public static final DenseVectorTypeInfo INSTANCE = new
DenseVectorTypeInfo();
-
- public DenseVectorTypeInfo() {}
+/** A {@link TypeInformation} for the {@link SparseVector} type. */
+public class SparseVectorTypeInfo extends TypeInformation<SparseVector> {
+ public static final SparseVectorTypeInfo INSTANCE = new
SparseVectorTypeInfo();
@Override
- public int getArity() {
- return 1;
+ public boolean isBasicType() {
+ return false;
}
@Override
- public int getTotalFields() {
- return 1;
+ public boolean isTupleType() {
+ return false;
}
@Override
- public Class<DenseVector> getTypeClass() {
- return DenseVector.class;
+ public int getArity() {
+ return 3;
}
@Override
- public boolean isBasicType() {
- return false;
+ public int getTotalFields() {
+ return 3;
}
@Override
- public boolean isTupleType() {
- return false;
+ public Class<SparseVector> getTypeClass() {
+ return SparseVector.class;
}
@Override
@@ -62,30 +59,27 @@ public class DenseVectorTypeInfo extends
TypeInformation<DenseVector> {
}
@Override
- @SuppressWarnings("unchecked")
- public TypeSerializer<DenseVector> createSerializer(ExecutionConfig
executionConfig) {
- return new DenseVectorSerializer();
+ public TypeSerializer<SparseVector> createSerializer(ExecutionConfig
executionConfig) {
+ return SparseVectorSerializer.INSTANCE;
}
- //
--------------------------------------------------------------------------------------------
-
@Override
- public int hashCode() {
- return getClass().hashCode();
+ public String toString() {
+ return "SparseVectorType";
}
@Override
public boolean equals(Object obj) {
- return obj instanceof DenseVectorTypeInfo;
+ return obj instanceof SparseVectorTypeInfo;
}
@Override
- public boolean canEqual(Object obj) {
- return obj instanceof DenseVectorTypeInfo;
+ public int hashCode() {
+ return getClass().hashCode();
}
@Override
- public String toString() {
- return "DenseVectorType";
+ public boolean canEqual(Object obj) {
+ return obj instanceof SparseVectorTypeInfo;
}
}
diff --git
a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseVectorTypeInfoFactory.java
b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseVectorTypeInfoFactory.java
new file mode 100644
index 0000000..01c1036
--- /dev/null
+++
b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseVectorTypeInfoFactory.java
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.flink.ml.linalg.typeinfo;
+
+import org.apache.flink.api.common.typeinfo.TypeInfoFactory;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.TypeExtractor;
+import org.apache.flink.ml.linalg.SparseVector;
+
+import java.lang.reflect.Type;
+import java.util.Map;
+
+/**
+ * Used by {@link TypeExtractor} to create a {@link TypeInformation} for
implementations of {@link
+ * SparseVector}.
+ */
+public class SparseVectorTypeInfoFactory extends TypeInfoFactory<SparseVector>
{
+ @Override
+ public TypeInformation<SparseVector> createTypeInfo(
+ Type type, Map<String, TypeInformation<?>> map) {
+ return new SparseVectorTypeInfo();
+ }
+}
diff --git
a/flink-ml-core/src/main/java/org/apache/flink/ml/param/ParamValidators.java
b/flink-ml-core/src/main/java/org/apache/flink/ml/param/ParamValidators.java
index 925ccb2..e7d1436 100644
--- a/flink-ml-core/src/main/java/org/apache/flink/ml/param/ParamValidators.java
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/param/ParamValidators.java
@@ -95,4 +95,9 @@ public class ParamValidators {
}
};
}
+
+ // Check if the parameter value array is not empty array.
+ public static <T> ParamValidator<T[]> nonEmptyArray() {
+ return value -> value != null && value.length > 0;
+ }
}
diff --git a/flink-ml-core/src/test/java/org/apache/flink/ml/api/StageTest.java
b/flink-ml-core/src/test/java/org/apache/flink/ml/api/StageTest.java
index 6ac630d..df0db64 100644
--- a/flink-ml-core/src/test/java/org/apache/flink/ml/api/StageTest.java
+++ b/flink-ml-core/src/test/java/org/apache/flink/ml/api/StageTest.java
@@ -388,5 +388,10 @@ public class StageTest {
ParamValidator<Integer> notNull = ParamValidators.notNull();
Assert.assertTrue(notNull.validate(5));
Assert.assertFalse(notNull.validate(null));
+
+ ParamValidator<Object[]> nonEmptyArray =
ParamValidators.nonEmptyArray();
+ Assert.assertTrue(nonEmptyArray.validate(new String[] {"1"}));
+ Assert.assertFalse(nonEmptyArray.validate(null));
+ Assert.assertFalse(nonEmptyArray.validate(new String[0]));
}
}
diff --git
a/flink-ml-core/src/test/java/org/apache/flink/ml/linalg/SparseVectorTest.java
b/flink-ml-core/src/test/java/org/apache/flink/ml/linalg/SparseVectorTest.java
new file mode 100644
index 0000000..0e7c349
--- /dev/null
+++
b/flink-ml-core/src/test/java/org/apache/flink/ml/linalg/SparseVectorTest.java
@@ -0,0 +1,132 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.linalg;
+
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.ml.linalg.typeinfo.SparseVectorSerializer;
+
+import org.apache.commons.io.output.ByteArrayOutputStream;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.io.ByteArrayInputStream;
+import java.io.IOException;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+
+/** Tests the behavior of Vectors. */
+public class SparseVectorTest {
+ @Test
+ public void testConstructor() {
+ int n = 4;
+ int[] indices = new int[] {0, 2, 3};
+ double[] values = new double[] {0.1, 0.3, 0.4};
+
+ SparseVector vector = Vectors.sparse(n, indices, values);
+ assertEquals(n, vector.n);
+ assertArrayEquals(indices, vector.indices);
+ assertArrayEquals(values, vector.values, 1e-5);
+ assertEquals("(4, [0, 2, 3], [0.1, 0.3, 0.4])", vector.toString());
+ }
+
+ @Test
+ public void testDuplicateIndex() {
+ int n = 4;
+ int[] indices = new int[] {0, 2, 2};
+ double[] values = new double[] {0.1, 0.3, 0.4};
+
+ try {
+ Vectors.sparse(n, indices, values);
+ Assert.fail("Expected IllegalArgumentException.");
+ } catch (Exception e) {
+ assertEquals(IllegalArgumentException.class, e.getClass());
+ assertEquals("Indices duplicated.", e.getMessage());
+ }
+ }
+
+ @Test
+ public void testAllZeroVector() {
+ int n = 4;
+ SparseVector vector = Vectors.sparse(n, new int[0], new double[0]);
+ assertArrayEquals(vector.toArray(), new double[n], 1e-5);
+ }
+
+ @Test
+ public void testUnsortedIndex() {
+ SparseVector vector;
+
+ vector = Vectors.sparse(4, new int[] {2}, new double[] {0.3});
+ assertEquals(4, vector.n);
+ assertArrayEquals(new int[] {2}, vector.indices);
+ assertArrayEquals(new double[] {0.3}, vector.values, 1e-5);
+
+ vector = Vectors.sparse(4, new int[] {1, 2}, new double[] {0.2, 0.3});
+ assertEquals(4, vector.n);
+ assertArrayEquals(new int[] {1, 2}, vector.indices);
+ assertArrayEquals(new double[] {0.2, 0.3}, vector.values, 1e-5);
+
+ vector = Vectors.sparse(4, new int[] {2, 1}, new double[] {0.3, 0.2});
+ assertEquals(4, vector.n);
+ assertArrayEquals(new int[] {1, 2}, vector.indices);
+ assertArrayEquals(new double[] {0.2, 0.3}, vector.values, 1e-5);
+
+ vector = Vectors.sparse(4, new int[] {3, 2, 0}, new double[] {0.4,
0.3, 0.1});
+ assertEquals(4, vector.n);
+ assertArrayEquals(new int[] {0, 2, 3}, vector.indices);
+ assertArrayEquals(new double[] {0.1, 0.3, 0.4}, vector.values, 1e-5);
+
+ vector = Vectors.sparse(4, new int[] {2, 0, 3}, new double[] {0.3,
0.1, 0.4});
+ assertEquals(4, vector.n);
+ assertArrayEquals(new int[] {0, 2, 3}, vector.indices);
+ assertArrayEquals(new double[] {0.1, 0.3, 0.4}, vector.values, 1e-5);
+
+ vector =
+ Vectors.sparse(
+ 7,
+ new int[] {6, 5, 4, 3, 2, 1, 0},
+ new double[] {0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1});
+ assertEquals(7, vector.n);
+ assertArrayEquals(new int[] {0, 1, 2, 3, 4, 5, 6}, vector.indices);
+ assertArrayEquals(new double[] {0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7},
vector.values, 1e-5);
+ }
+
+ @Test
+ public void testSerializer() throws IOException {
+ int n = 4;
+ int[] indices = new int[] {0, 2, 3};
+ double[] values = new double[] {0.1, 0.3, 0.4};
+ SparseVector vector = Vectors.sparse(n, indices, values);
+ SparseVectorSerializer serializer = SparseVectorSerializer.INSTANCE;
+
+ ByteArrayOutputStream bOutput = new ByteArrayOutputStream(1024);
+ DataOutputViewStreamWrapper output = new
DataOutputViewStreamWrapper(bOutput);
+ serializer.serialize(vector, output);
+
+ byte[] b = bOutput.toByteArray();
+ ByteArrayInputStream bInput = new ByteArrayInputStream(b);
+ DataInputViewStreamWrapper input = new
DataInputViewStreamWrapper(bInput);
+ SparseVector vector2 = serializer.deserialize(input);
+
+ assertEquals(vector.n, vector2.n);
+ assertArrayEquals(vector.indices, vector2.indices);
+ assertArrayEquals(vector.values, vector2.values, 1e-5);
+ }
+}
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasHandleInvalid.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasHandleInvalid.java
new file mode 100644
index 0000000..a7ea41a
--- /dev/null
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasHandleInvalid.java
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.param;
+
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.StringParam;
+import org.apache.flink.ml.param.WithParams;
+
+/**
+ * Interface for the shared handleInvalid param.
+ *
+ * <p>Supported options and the corresponding behavior to handle invalid
entries is listed as
+ * follows.
+ *
+ * <ul>
+ * <li>error: raise an exception.
+ * <li>skip: filter out rows with bad values
+ * </ul>
+ */
+public interface HasHandleInvalid<T> extends WithParams<T> {
+ String ERROR_INVALID = "error";
+ String SKIP_INVALID = "skip";
+
+ Param<String> HANDLE_INVALID =
+ new StringParam(
+ "handleInvalid",
+ "Strategy to handle invalid entries.",
+ ERROR_INVALID,
+ ParamValidators.inArray(ERROR_INVALID, SKIP_INVALID));
+
+ default String getHandleInvalid() {
+ return get(HANDLE_INVALID);
+ }
+
+ default T setHandleInvalid(String value) {
+ set(HANDLE_INVALID, value);
+ return (T) this;
+ }
+}
diff --git
a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/Vectors.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasInputCols.java
similarity index 55%
copy from flink-ml-core/src/main/java/org/apache/flink/ml/linalg/Vectors.java
copy to
flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasInputCols.java
index a058755..c567de7 100644
--- a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/Vectors.java
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasInputCols.java
@@ -16,13 +16,24 @@
* limitations under the License.
*/
-package org.apache.flink.ml.linalg;
+package org.apache.flink.ml.common.param;
-/** Utility methods for instantiating Vector. */
-public class Vectors {
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.StringArrayParam;
+import org.apache.flink.ml.param.WithParams;
- /** Creates a dense vector from its values. */
- public static DenseVector dense(double... values) {
- return new DenseVector(values);
+/** Interface for the shared inputCols param. */
+public interface HasInputCols<T> extends WithParams<T> {
+ Param<String[]> INPUT_COLS =
+ new StringArrayParam(
+ "inputCols", "Input column names.", null,
ParamValidators.nonEmptyArray());
+
+ default String[] getInputCols() {
+ return get(INPUT_COLS);
+ }
+
+ default T setInputCols(String... value) {
+ return set(INPUT_COLS, value);
}
}
diff --git
a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/Vectors.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasOutputCols.java
similarity index 54%
copy from flink-ml-core/src/main/java/org/apache/flink/ml/linalg/Vectors.java
copy to
flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasOutputCols.java
index a058755..947501f 100644
--- a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/Vectors.java
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasOutputCols.java
@@ -16,13 +16,24 @@
* limitations under the License.
*/
-package org.apache.flink.ml.linalg;
+package org.apache.flink.ml.common.param;
-/** Utility methods for instantiating Vector. */
-public class Vectors {
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.StringArrayParam;
+import org.apache.flink.ml.param.WithParams;
- /** Creates a dense vector from its values. */
- public static DenseVector dense(double... values) {
- return new DenseVector(values);
+/** Interface for the shared outputCols param. */
+public interface HasOutputCols<T> extends WithParams<T> {
+ Param<String[]> OUTPUT_COLS =
+ new StringArrayParam(
+ "outputCols", "Output column names.", null,
ParamValidators.nonEmptyArray());
+
+ default String[] getOutputCols() {
+ return get(OUTPUT_COLS);
+ }
+
+ default T setOutputCols(String... value) {
+ return set(OUTPUT_COLS, value);
}
}
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/onehotencoder/OneHotEncoder.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/onehotencoder/OneHotEncoder.java
new file mode 100644
index 0000000..374d457
--- /dev/null
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/onehotencoder/OneHotEncoder.java
@@ -0,0 +1,148 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.onehotencoder;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.MapPartitionFunctionWrapper;
+import org.apache.flink.ml.common.param.HasHandleInvalid;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the one-hot encoding algorithm.
+ *
+ * <p>Data of selected input columns should be indexed numbers in order for
OneHotEncoder to
+ * function correctly.
+ *
+ * <p>See https://en.wikipedia.org/wiki/One-hot.
+ */
+public class OneHotEncoder
+ implements Estimator<OneHotEncoder, OneHotEncoderModel>,
+ OneHotEncoderParams<OneHotEncoder> {
+ private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+ public OneHotEncoder() {
+ ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+ }
+
+ @Override
+ public OneHotEncoderModel fit(Table... inputs) {
+ Preconditions.checkArgument(inputs.length == 1);
+
Preconditions.checkArgument(getHandleInvalid().equals(HasHandleInvalid.ERROR_INVALID));
+
+ final String[] inputCols = getInputCols();
+
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl)
inputs[0]).getTableEnvironment();
+ DataStream<Tuple2<Integer, Integer>> modelData =
+ tEnv.toDataStream(inputs[0])
+ .flatMap(new ExtractInputColsValueFunction(inputCols))
+ .keyBy(columnIdAndValue -> columnIdAndValue.f0)
+ .transform(
+ "findMaxIndex",
+ Types.TUPLE(Types.INT, Types.INT),
+ new MapPartitionFunctionWrapper<>(new
FindMaxIndexFunction()));
+
+ OneHotEncoderModel model =
+ new
OneHotEncoderModel().setModelData(tEnv.fromDataStream(modelData));
+ ReadWriteUtils.updateExistingParams(model, paramMap);
+ return model;
+ }
+
+ @Override
+ public void save(String path) throws IOException {
+ ReadWriteUtils.saveMetadata(this, path);
+ }
+
+ public static OneHotEncoder load(StreamExecutionEnvironment env, String
path)
+ throws IOException {
+ return ReadWriteUtils.loadStageParam(path);
+ }
+
+ @Override
+ public Map<Param<?>, Object> getParamMap() {
+ return paramMap;
+ }
+
+ /**
+ * Extract values of input columns of input data.
+ *
+ * <p>Input: rows of input data containing designated input columns
+ *
+ * <p>Output: Pairs of column index and value stored in those columns
+ */
+ private static class ExtractInputColsValueFunction
+ implements FlatMapFunction<Row, Tuple2<Integer, Integer>> {
+ private final String[] inputCols;
+
+ private ExtractInputColsValueFunction(String[] inputCols) {
+ this.inputCols = inputCols;
+ }
+
+ @Override
+ public void flatMap(Row row, Collector<Tuple2<Integer, Integer>>
collector) {
+ for (int i = 0; i < inputCols.length; i++) {
+ Number number = (Number) row.getField(inputCols[i]);
+ Preconditions.checkArgument(
+ number.intValue() == number.doubleValue(),
+ String.format("Value %s cannot be parsed as indexed
integer.", number));
+ Preconditions.checkArgument(
+ number.intValue() >= 0, "Negative value not
supported.");
+ collector.collect(new Tuple2<>(i, number.intValue()));
+ }
+ }
+ }
+
+ /** Function to find the max index value for each column. */
+ private static class FindMaxIndexFunction
+ implements MapPartitionFunction<Tuple2<Integer, Integer>,
Tuple2<Integer, Integer>> {
+
+ @Override
+ public void mapPartition(
+ Iterable<Tuple2<Integer, Integer>> iterable,
+ Collector<Tuple2<Integer, Integer>> collector) {
+ Map<Integer, Integer> map = new HashMap<>();
+ for (Tuple2<Integer, Integer> value : iterable) {
+ map.put(
+ value.f0,
+ Math.max(map.getOrDefault(value.f0,
Integer.MIN_VALUE), value.f1));
+ }
+ for (Map.Entry<Integer, Integer> entry : map.entrySet()) {
+ collector.collect(new Tuple2<>(entry.getKey(),
entry.getValue()));
+ }
+ }
+ }
+}
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/onehotencoder/OneHotEncoderModel.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/onehotencoder/OneHotEncoderModel.java
new file mode 100644
index 0000000..447fe77
--- /dev/null
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/onehotencoder/OneHotEncoderModel.java
@@ -0,0 +1,190 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.onehotencoder;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.common.param.HasHandleInvalid;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.table.runtime.typeutils.ExternalTypeInfo;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Vector;
+import java.util.function.Function;
+
+/**
+ * A Model which encodes data into one-hot format using the model data
computed by {@link
+ * OneHotEncoder}.
+ */
+public class OneHotEncoderModel
+ implements Model<OneHotEncoderModel>,
OneHotEncoderParams<OneHotEncoderModel> {
+ private final Map<Param<?>, Object> paramMap = new HashMap<>();
+ private Table modelDataTable;
+
+ public OneHotEncoderModel() {
+ ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+ }
+
+ @Override
+ public Table[] transform(Table... inputs) {
+ final String[] inputCols = getInputCols();
+ final String[] outputCols = getOutputCols();
+ final boolean dropLast = getDropLast();
+ final String broadcastModelKey = "OneHotModelStream";
+
+
Preconditions.checkArgument(getHandleInvalid().equals(HasHandleInvalid.ERROR_INVALID));
+ Preconditions.checkArgument(inputs.length == 1);
+ Preconditions.checkArgument(inputCols.length == outputCols.length);
+
+ RowTypeInfo inputTypeInfo =
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+ RowTypeInfo outputTypeInfo =
+ new RowTypeInfo(
+ ArrayUtils.addAll(
+ inputTypeInfo.getFieldTypes(),
+ Collections.nCopies(
+ outputCols.length,
+
ExternalTypeInfo.of(Vector.class))
+ .toArray(new TypeInformation[0])),
+ ArrayUtils.addAll(inputTypeInfo.getFieldNames(),
outputCols));
+
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl)
modelDataTable).getTableEnvironment();
+ DataStream<Row> input = tEnv.toDataStream(inputs[0]);
+ DataStream<Tuple2<Integer, Integer>> modelStream =
+ OneHotEncoderModelData.getModelDataStream(modelDataTable);
+
+ Function<List<DataStream<?>>, DataStream<Row>> function =
+ dataStreams -> {
+ DataStream stream = dataStreams.get(0);
+ return stream.map(
+ new GenerateOutputsFunction(inputCols, dropLast,
broadcastModelKey),
+ outputTypeInfo);
+ };
+
+ DataStream<Row> output =
+ BroadcastUtils.withBroadcastStream(
+ Collections.singletonList(input),
+ Collections.singletonMap(broadcastModelKey,
modelStream),
+ function);
+
+ Table outputTable = tEnv.fromDataStream(output);
+
+ return new Table[] {outputTable};
+ }
+
+ @Override
+ public void save(String path) throws IOException {
+ ReadWriteUtils.saveMetadata(this, path);
+ ReadWriteUtils.saveModelData(
+ OneHotEncoderModelData.getModelDataStream(modelDataTable),
+ path,
+ new OneHotEncoderModelData.ModelDataEncoder());
+ }
+
+ public static OneHotEncoderModel load(StreamExecutionEnvironment env,
String path)
+ throws IOException {
+ StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+ OneHotEncoderModel model = ReadWriteUtils.loadStageParam(path);
+ DataStream<Tuple2<Integer, Integer>> modelData =
+ ReadWriteUtils.loadModelData(
+ env, path, new
OneHotEncoderModelData.ModelDataStreamFormat());
+ return model.setModelData(tEnv.fromDataStream(modelData));
+ }
+
+ @Override
+ public Map<Param<?>, Object> getParamMap() {
+ return paramMap;
+ }
+
+ @Override
+ public OneHotEncoderModel setModelData(Table... inputs) {
+ Preconditions.checkArgument(inputs.length == 1);
+ modelDataTable = inputs[0];
+ return this;
+ }
+
+ @Override
+ public Table[] getModelData() {
+ return new Table[] {modelDataTable};
+ }
+
+ private static class GenerateOutputsFunction extends RichMapFunction<Row,
Row> {
+ private final String[] inputCols;
+ private final boolean dropLast;
+ private final String broadcastModelKey;
+ private List<Tuple2<Integer, Integer>> model = null;
+
+ public GenerateOutputsFunction(
+ String[] inputCols, boolean dropLast, String
broadcastModelKey) {
+ this.inputCols = inputCols;
+ this.dropLast = dropLast;
+ this.broadcastModelKey = broadcastModelKey;
+ }
+
+ @Override
+ public Row map(Row row) {
+ if (model == null) {
+ model =
getRuntimeContext().getBroadcastVariable(broadcastModelKey);
+ }
+ int[] categorySizes = new int[model.size()];
+ int offset = dropLast ? 0 : 1;
+ for (Tuple2<Integer, Integer> tup : model) {
+ categorySizes[tup.f0] = tup.f1 + offset;
+ }
+ Row result = new Row(categorySizes.length);
+ for (int i = 0; i < categorySizes.length; i++) {
+ Number number = (Number) row.getField(inputCols[i]);
+ Preconditions.checkArgument(
+ number.intValue() == number.doubleValue(),
+ String.format("Value %s cannot be parsed as indexed
integer.", number));
+ int idx = number.intValue();
+ if (idx == categorySizes[i]) {
+ result.setField(i, Vectors.sparse(categorySizes[i], new
int[0], new double[0]));
+ } else {
+ result.setField(
+ i,
+ Vectors.sparse(categorySizes[i], new int[] {idx},
new double[] {1.0}));
+ }
+ }
+
+ return Row.join(row, result);
+ }
+ }
+}
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/onehotencoder/OneHotEncoderModelData.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/onehotencoder/OneHotEncoderModelData.java
new file mode 100644
index 0000000..f267784
--- /dev/null
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/onehotencoder/OneHotEncoderModelData.java
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.onehotencoder;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.serialization.Encoder;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+
+import com.esotericsoftware.kryo.io.Input;
+import com.esotericsoftware.kryo.io.Output;
+
+import java.io.IOException;
+import java.io.OutputStream;
+
+/**
+ * Model data of {@link OneHotEncoderModel}.
+ *
+ * <p>This class also provides methods to convert model data from Table to
Datastream, and classes
+ * to save/load model data.
+ */
+public class OneHotEncoderModelData {
+ /**
+ * Converts the table model to a data stream.
+ *
+ * @param modelData The table model data.
+ * @return The data stream model data.
+ */
+ public static DataStream<Tuple2<Integer, Integer>>
getModelDataStream(Table modelData) {
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl)
modelData).getTableEnvironment();
+ return tEnv.toDataStream(modelData)
+ .map(
+ new MapFunction<Row, Tuple2<Integer, Integer>>() {
+ @Override
+ public Tuple2<Integer, Integer> map(Row row) {
+ return new Tuple2<>(
+ (int) row.getField("f0"), (int)
row.getField("f1"));
+ }
+ });
+ }
+
+ /** Data encoder for the OneHotEncoder model data. */
+ public static class ModelDataEncoder implements Encoder<Tuple2<Integer,
Integer>> {
+ @Override
+ public void encode(Tuple2<Integer, Integer> modelData, OutputStream
outputStream) {
+ Output output = new Output(outputStream);
+ output.writeInt(modelData.f0);
+ output.writeInt(modelData.f1);
+ output.flush();
+ }
+ }
+
+ /** Data decoder for the OneHotEncoder model data. */
+ public static class ModelDataStreamFormat extends
SimpleStreamFormat<Tuple2<Integer, Integer>> {
+ @Override
+ public Reader<Tuple2<Integer, Integer>> createReader(
+ Configuration config, FSDataInputStream stream) {
+ return new Reader<Tuple2<Integer, Integer>>() {
+ private final Input input = new Input(stream);
+
+ @Override
+ public Tuple2<Integer, Integer> read() {
+ if (input.eof()) {
+ return null;
+ }
+ int f0 = input.readInt();
+ int f1 = input.readInt();
+ return new Tuple2<>(f0, f1);
+ }
+
+ @Override
+ public void close() throws IOException {
+ stream.close();
+ }
+ };
+ }
+
+ @Override
+ public TypeInformation<Tuple2<Integer, Integer>> getProducedType() {
+ return Types.TUPLE(Types.INT, Types.INT);
+ }
+ }
+}
diff --git
a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/Vectors.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/onehotencoder/OneHotEncoderParams.java
similarity index 51%
copy from flink-ml-core/src/main/java/org/apache/flink/ml/linalg/Vectors.java
copy to
flink-ml-lib/src/main/java/org/apache/flink/ml/feature/onehotencoder/OneHotEncoderParams.java
index a058755..9b57159 100644
--- a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/Vectors.java
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/onehotencoder/OneHotEncoderParams.java
@@ -16,13 +16,29 @@
* limitations under the License.
*/
-package org.apache.flink.ml.linalg;
+package org.apache.flink.ml.feature.onehotencoder;
-/** Utility methods for instantiating Vector. */
-public class Vectors {
+import org.apache.flink.ml.common.param.HasHandleInvalid;
+import org.apache.flink.ml.common.param.HasInputCols;
+import org.apache.flink.ml.common.param.HasOutputCols;
+import org.apache.flink.ml.param.BooleanParam;
+import org.apache.flink.ml.param.Param;
- /** Creates a dense vector from its values. */
- public static DenseVector dense(double... values) {
- return new DenseVector(values);
+/**
+ * Params of OneHotEncoderModel.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface OneHotEncoderParams<T>
+ extends HasInputCols<T>, HasOutputCols<T>, HasHandleInvalid<T> {
+ Param<Boolean> DROP_LAST =
+ new BooleanParam("dropLast", "Whether to drop the last category.",
true);
+
+ default boolean getDropLast() {
+ return get(DROP_LAST);
+ }
+
+ default T setDropLast(boolean value) {
+ return set(DROP_LAST, value);
}
}
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/OneHotEncoderTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/OneHotEncoderTest.java
new file mode 100644
index 0000000..51f9735
--- /dev/null
+++
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/OneHotEncoderTest.java
@@ -0,0 +1,294 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.common.param.HasHandleInvalid;
+import org.apache.flink.ml.feature.onehotencoder.OneHotEncoder;
+import org.apache.flink.ml.feature.onehotencoder.OneHotEncoderModel;
+import org.apache.flink.ml.feature.onehotencoder.OneHotEncoderModelData;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.ml.util.StageTestUtils;
+import
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.CloseableIterator;
+
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+
+/** Tests OneHotEncoder and OneHotEncoderModel. */
+public class OneHotEncoderTest {
+ @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+
+ private StreamExecutionEnvironment env;
+ private StreamTableEnvironment tEnv;
+ private Table trainTable;
+ private Table predictTable;
+ private Map<Double, Vector>[] expectedOutput;
+ private OneHotEncoder estimator;
+
+ @Before
+ public void before() {
+ Configuration config = new Configuration();
+
config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH,
true);
+ env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+ env.setParallelism(4);
+ env.enableCheckpointing(100);
+ env.setRestartStrategy(RestartStrategies.noRestart());
+ tEnv = StreamTableEnvironment.create(env);
+
+ List<Row> trainData = Arrays.asList(Row.of(0.0), Row.of(1.0),
Row.of(2.0), Row.of(0.0));
+
+ trainTable =
tEnv.fromDataStream(env.fromCollection(trainData)).as("input");
+
+ List<Row> predictData = Arrays.asList(Row.of(0.0), Row.of(1.0),
Row.of(2.0));
+
+ predictTable =
tEnv.fromDataStream(env.fromCollection(predictData)).as("input");
+
+ expectedOutput =
+ new HashMap[] {
+ new HashMap<Double, Vector>() {
+ {
+ put(0.0, Vectors.sparse(2, new int[] {0}, new
double[] {1.0}));
+ put(1.0, Vectors.sparse(2, new int[] {1}, new
double[] {1.0}));
+ put(2.0, Vectors.sparse(2, new int[0], new
double[0]));
+ }
+ }
+ };
+
+ estimator = new
OneHotEncoder().setInputCols("input").setOutputCols("output");
+ }
+
+ /**
+ * Executes a given table and collect its results. Results are returned as
a map array. Each
+ * element in the array is a map corresponding to a input column whose key
is the original value
+ * in the input column, value is the one-hot encoding result of that value.
+ *
+ * @param table A table to be executed and to have its result collected
+ * @param inputCols Name of the input columns
+ * @param outputCols Name of the output columns containing one-hot
encoding result
+ * @return An array of map containing the collected results for each input
column
+ */
+ private static Map<Double, Vector>[] executeAndCollect(
+ Table table, String[] inputCols, String[] outputCols) {
+ Map<Double, Vector>[] maps = new HashMap[inputCols.length];
+ for (int i = 0; i < inputCols.length; i++) {
+ maps[i] = new HashMap<>();
+ }
+ for (CloseableIterator<Row> it = table.execute().collect();
it.hasNext(); ) {
+ Row row = it.next();
+ for (int i = 0; i < inputCols.length; i++) {
+ maps[i].put(
+ ((Number) row.getField(inputCols[i])).doubleValue(),
+ (Vector) row.getField(outputCols[i]));
+ }
+ }
+ return maps;
+ }
+
+ @Test
+ public void testParam() {
+ OneHotEncoder estimator = new OneHotEncoder();
+
+ assertTrue(estimator.getDropLast());
+
+
estimator.setInputCols("test_input").setOutputCols("test_output").setDropLast(false);
+
+ assertArrayEquals(new String[] {"test_input"},
estimator.getInputCols());
+ assertArrayEquals(new String[] {"test_output"},
estimator.getOutputCols());
+ assertFalse(estimator.getDropLast());
+
+ OneHotEncoderModel model = new OneHotEncoderModel();
+
+ assertTrue(model.getDropLast());
+
+
model.setInputCols("test_input").setOutputCols("test_output").setDropLast(false);
+
+ assertArrayEquals(new String[] {"test_input"}, model.getInputCols());
+ assertArrayEquals(new String[] {"test_output"}, model.getOutputCols());
+ assertFalse(model.getDropLast());
+ }
+
+ @Test
+ public void testFitAndPredict() {
+ OneHotEncoderModel model = estimator.fit(trainTable);
+ Table outputTable = model.transform(predictTable)[0];
+ Map<Double, Vector>[] actualOutput =
+ executeAndCollect(outputTable, model.getInputCols(),
model.getOutputCols());
+ assertArrayEquals(expectedOutput, actualOutput);
+ }
+
+ @Test
+ public void testDropLast() {
+ estimator.setDropLast(false);
+
+ expectedOutput =
+ new HashMap[] {
+ new HashMap<Double, Vector>() {
+ {
+ put(0.0, Vectors.sparse(3, new int[] {0}, new
double[] {1.0}));
+ put(1.0, Vectors.sparse(3, new int[] {1}, new
double[] {1.0}));
+ put(2.0, Vectors.sparse(3, new int[] {2}, new
double[] {1.0}));
+ }
+ }
+ };
+
+ OneHotEncoderModel model = estimator.fit(trainTable);
+ Table outputTable = model.transform(predictTable)[0];
+ Map<Double, Vector>[] actualOutput =
+ executeAndCollect(outputTable, model.getInputCols(),
model.getOutputCols());
+ assertArrayEquals(expectedOutput, actualOutput);
+ }
+
+ @Test
+ public void testInputDataType() {
+ List<Row> trainData = Arrays.asList(Row.of(0), Row.of(1), Row.of(2),
Row.of(0));
+
+ trainTable =
tEnv.fromDataStream(env.fromCollection(trainData)).as("input");
+
+ List<Row> predictData = Arrays.asList(Row.of(0), Row.of(1), Row.of(2));
+ predictTable =
tEnv.fromDataStream(env.fromCollection(predictData)).as("input");
+
+ expectedOutput =
+ new HashMap[] {
+ new HashMap<Double, Vector>() {
+ {
+ put(0.0, Vectors.sparse(2, new int[] {0}, new
double[] {1.0}));
+ put(1.0, Vectors.sparse(2, new int[] {1}, new
double[] {1.0}));
+ put(2.0, Vectors.sparse(2, new int[0], new
double[0]));
+ }
+ }
+ };
+
+ OneHotEncoderModel model = estimator.fit(trainTable);
+ Table outputTable = model.transform(predictTable)[0];
+ Map<Double, Vector>[] actualOutput =
+ executeAndCollect(outputTable, model.getInputCols(),
model.getOutputCols());
+ assertArrayEquals(expectedOutput, actualOutput);
+ }
+
+ @Test
+ public void testNotSupportedHandleInvalidOptions() {
+ estimator.setHandleInvalid(HasHandleInvalid.SKIP_INVALID);
+ try {
+ estimator.fit(trainTable);
+ Assert.fail("Expected IllegalArgumentException");
+ } catch (Exception e) {
+ assertEquals(IllegalArgumentException.class, ((Throwable)
e).getClass());
+ }
+ }
+
+ @Test
+ public void testNonIndexedTrainData() {
+ List<Row> trainData = Arrays.asList(Row.of(0.5), Row.of(1.0),
Row.of(2.0), Row.of(0.0));
+
+ trainTable =
tEnv.fromDataStream(env.fromCollection(trainData)).as("input");
+ OneHotEncoderModel model = estimator.fit(trainTable);
+ Table outputTable = model.transform(predictTable)[0];
+ try {
+ outputTable.execute().collect().next();
+ Assert.fail("Expected IllegalArgumentException");
+ } catch (Exception e) {
+ Throwable exception = e;
+ while (exception.getCause() != null) {
+ exception = exception.getCause();
+ }
+ assertEquals(IllegalArgumentException.class, exception.getClass());
+ assertEquals("Value 0.5 cannot be parsed as indexed integer.",
exception.getMessage());
+ }
+ }
+
+ @Test
+ public void testNonIndexedPredictData() {
+ List<Row> predictData = Arrays.asList(Row.of(0.5), Row.of(1.0),
Row.of(2.0), Row.of(0.0));
+
+ predictTable =
tEnv.fromDataStream(env.fromCollection(predictData)).as("input");
+ OneHotEncoderModel model = estimator.fit(trainTable);
+ Table outputTable = model.transform(predictTable)[0];
+ try {
+ outputTable.execute().collect().next();
+ Assert.fail("Expected IllegalArgumentException");
+ } catch (Exception e) {
+ Throwable exception = e;
+ while (exception.getCause() != null) {
+ exception = exception.getCause();
+ }
+ assertEquals(IllegalArgumentException.class, exception.getClass());
+ assertEquals("Value 0.5 cannot be parsed as indexed integer.",
exception.getMessage());
+ }
+ }
+
+ @Test
+ public void testSaveLoad() throws Exception {
+ estimator =
+ StageTestUtils.saveAndReload(
+ env, estimator,
tempFolder.newFolder().getAbsolutePath());
+ OneHotEncoderModel model = estimator.fit(trainTable);
+ model = StageTestUtils.saveAndReload(env, model,
tempFolder.newFolder().getAbsolutePath());
+ Table outputTable = model.transform(predictTable)[0];
+ Map<Double, Vector>[] actualOutput =
+ executeAndCollect(outputTable, model.getInputCols(),
model.getOutputCols());
+ assertArrayEquals(expectedOutput, actualOutput);
+ }
+
+ @Test
+ public void testGetModelData() throws Exception {
+ OneHotEncoderModel model = estimator.fit(trainTable);
+ Tuple2<Integer, Integer> expected = new Tuple2<>(0, 2);
+ Tuple2<Integer, Integer> actual =
+
OneHotEncoderModelData.getModelDataStream(model.getModelData()[0])
+ .executeAndCollect()
+ .next();
+ assertEquals(expected, actual);
+ }
+
+ @Test
+ public void testSetModelData() {
+ OneHotEncoderModel modelA = estimator.fit(trainTable);
+
+ Table modelData = modelA.getModelData()[0];
+ OneHotEncoderModel modelB = new
OneHotEncoderModel().setModelData(modelData);
+ ReadWriteUtils.updateExistingParams(modelB, modelA.getParamMap());
+
+ Table outputTable = modelB.transform(predictTable)[0];
+ Map<Double, Vector>[] actualOutput =
+ executeAndCollect(outputTable, modelB.getInputCols(),
modelB.getOutputCols());
+ assertArrayEquals(expectedOutput, actualOutput);
+ }
+}