This is an automated email from the ASF dual-hosted git repository.

zhangzp pushed a commit to branch yuhe_release
in repository https://gitbox.apache.org/repos/asf/flink-ml.git

commit eb8b93b4cca1dd2338347a617d77d88fd9520164
Author: zhangzp <[email protected]>
AuthorDate: Wed Nov 30 17:55:09 2022 +0800

    [FLINK-30249] Fix TableUtils.getRowTypeInfo by using ExternalTypeInfo.of()
---
 .../flink/ml/common/datastream/TableUtils.java     |  3 +-
 .../flink/ml/common/datastream/TableUtilsTest.java | 75 ++++++++++++++++++++++
 .../flink/ml/feature/binarizer/Binarizer.java      |  8 ++-
 .../org/apache/flink/ml/clustering/KMeansTest.java |  5 +-
 .../java/org/apache/flink/ml/util/TestUtils.java   |  6 +-
 5 files changed, 86 insertions(+), 11 deletions(-)

diff --git 
a/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/TableUtils.java
 
b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/TableUtils.java
index 9245c77e..7e69134c 100644
--- 
a/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/TableUtils.java
+++ 
b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/TableUtils.java
@@ -26,6 +26,7 @@ import org.apache.flink.table.api.Table;
 import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
 import org.apache.flink.table.catalog.Column;
 import org.apache.flink.table.catalog.ResolvedSchema;
+import org.apache.flink.table.runtime.typeutils.ExternalTypeInfo;
 import org.apache.flink.types.Row;
 
 /** Utility class for operations related to Table API. */
@@ -37,7 +38,7 @@ public class TableUtils {
 
         for (int i = 0; i < schema.getColumnCount(); i++) {
             Column column = schema.getColumn(i).get();
-            types[i] = 
TypeInformation.of(column.getDataType().getConversionClass());
+            types[i] = ExternalTypeInfo.of(column.getDataType());
             names[i] = column.getName();
         }
         return new RowTypeInfo(types, names);
diff --git 
a/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/TableUtilsTest.java
 
b/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/TableUtilsTest.java
new file mode 100644
index 00000000..099336eb
--- /dev/null
+++ 
b/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/TableUtilsTest.java
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.datastream;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.lang3.ArrayUtils;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.sql.Timestamp;
+
+/** Tests the {@link TableUtils}. */
+public class TableUtilsTest {
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+
+    @Before
+    public void before() {
+        env = StreamExecutionEnvironment.getExecutionEnvironment();
+        tEnv = StreamTableEnvironment.create(env);
+    }
+
+    @Test
+    public void testGetRowTypeInfo() {
+        Table inputTable =
+                tEnv.fromDataStream(
+                        env.fromElements(new Timestamp(0)),
+                        Schema.newBuilder().column("f0", 
DataTypes.TIMESTAMP()).build());
+        DataType inputTimeStampType = 
inputTable.getResolvedSchema().getColumnDataTypes().get(0);
+
+        RowTypeInfo inputTypeInfo = 
TableUtils.getRowTypeInfo(inputTable.getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), 
Types.INT),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), 
"outputCol"));
+
+        DataStream<Row> mappedOutput =
+                tEnv.toDataStream(inputTable)
+                        .map(
+                                (MapFunction<Row, Row>) row -> 
Row.of(row.getField(0), 1),
+                                outputTypeInfo);
+
+        DataType outputTimeStampType =
+                
tEnv.fromDataStream(mappedOutput).getResolvedSchema().getColumnDataTypes().get(0);
+        Assert.assertEquals(inputTimeStampType, outputTimeStampType);
+    }
+}
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/binarizer/Binarizer.java
 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/binarizer/Binarizer.java
index bdf16377..aafbf6e7 100644
--- 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/binarizer/Binarizer.java
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/binarizer/Binarizer.java
@@ -26,6 +26,7 @@ import org.apache.flink.ml.api.Transformer;
 import org.apache.flink.ml.common.datastream.TableUtils;
 import org.apache.flink.ml.linalg.DenseVector;
 import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
 import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
 import org.apache.flink.ml.linalg.typeinfo.SparseVectorTypeInfo;
 import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo;
@@ -70,11 +71,12 @@ public class Binarizer implements Transformer<Binarizer>, 
BinarizerParams<Binari
 
         for (int i = 0; i < inputCols.length; ++i) {
             int idx = inputTypeInfo.getFieldIndex(inputCols[i]);
-            if (inputTypeInfo.getFieldTypes()[idx] instanceof 
SparseVectorTypeInfo) {
+            Class<?> typeClass = inputTypeInfo.getTypeAt(idx).getTypeClass();
+            if (typeClass.equals(SparseVector.class)) {
                 outputTypes[i] = SparseVectorTypeInfo.INSTANCE;
-            } else if (inputTypeInfo.getFieldTypes()[idx] instanceof 
DenseVectorTypeInfo) {
+            } else if (typeClass.equals(DenseVector.class)) {
                 outputTypes[i] = DenseVectorTypeInfo.INSTANCE;
-            } else if (inputTypeInfo.getFieldTypes()[idx] instanceof 
VectorTypeInfo) {
+            } else if (typeClass.equals(Vector.class)) {
                 outputTypes[i] = VectorTypeInfo.INSTANCE;
             } else {
                 outputTypes[i] = Types.DOUBLE;
diff --git 
a/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/KMeansTest.java 
b/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/KMeansTest.java
index 2e73a9d2..292db2fe 100644
--- a/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/KMeansTest.java
+++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/KMeansTest.java
@@ -33,8 +33,6 @@ import org.apache.flink.ml.util.TestUtils;
 import org.apache.flink.streaming.api.datastream.DataStream;
 import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
 import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
-import org.apache.flink.table.api.DataTypes;
-import org.apache.flink.table.api.Schema;
 import org.apache.flink.table.api.Table;
 import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
 import org.apache.flink.test.util.AbstractTestBase;
@@ -168,8 +166,7 @@ public class KMeansTest extends AbstractTestBase {
                 Arrays.asList(
                         Vectors.dense(0.0, 0.1), Vectors.dense(0.0, 0.1), 
Vectors.dense(0.0, 0.1));
 
-        Schema schema = Schema.newBuilder().column("f0", 
DataTypes.of(DenseVector.class)).build();
-        Table input = tEnv.fromDataStream(env.fromCollection(data), 
schema).as("features");
+        Table input = 
tEnv.fromDataStream(env.fromCollection(data)).as("features");
 
         KMeans kmeans = new KMeans().setK(2);
         KMeansModel model = kmeans.fit(input);
diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/util/TestUtils.java 
b/flink-ml-lib/src/test/java/org/apache/flink/ml/util/TestUtils.java
index 94be85f3..09135b36 100644
--- a/flink-ml-lib/src/test/java/org/apache/flink/ml/util/TestUtils.java
+++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/util/TestUtils.java
@@ -24,8 +24,8 @@ import org.apache.flink.api.common.typeinfo.Types;
 import org.apache.flink.api.java.typeutils.RowTypeInfo;
 import org.apache.flink.ml.api.Stage;
 import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.DenseVector;
 import org.apache.flink.ml.linalg.Vector;
-import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
 import org.apache.flink.ml.linalg.typeinfo.SparseVectorTypeInfo;
 import org.apache.flink.streaming.api.datastream.DataStream;
 import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
@@ -75,9 +75,9 @@ public class TestUtils {
         RowTypeInfo inputTypeInfo = 
TableUtils.getRowTypeInfo(table.getResolvedSchema());
         TypeInformation<?>[] fieldTypes = inputTypeInfo.getFieldTypes();
         for (int i = 0; i < fieldTypes.length; i++) {
-            if (fieldTypes[i].equals(DenseVectorTypeInfo.INSTANCE)) {
+            if (fieldTypes[i].getTypeClass().equals(DenseVector.class)) {
                 fieldTypes[i] = SparseVectorTypeInfo.INSTANCE;
-            } else if (fieldTypes[i].equals(Types.DOUBLE)) {
+            } else if (fieldTypes[i].getTypeClass().equals(Double.class)) {
                 fieldTypes[i] = Types.INT;
             }
         }

Reply via email to