This is an automated email from the ASF dual-hosted git repository.
jincheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push:
new a22443c [FLINK-12408][python] Allow to define the data types in Python
a22443c is described below
commit a22443c9f083ea98f065b9631fea261968871c8d
Author: Dian Fu <[email protected]>
AuthorDate: Sun May 12 22:45:58 2019 +0800
[FLINK-12408][python] Allow to define the data types in Python
This closes #8420
---
flink-python/pyflink/table/__init__.py | 4 +-
flink-python/pyflink/table/table_descriptor.py | 6 +-
flink-python/pyflink/table/table_environment.py | 5 +-
flink-python/pyflink/table/table_source.py | 5 +-
flink-python/pyflink/table/tests/test_aggregate.py | 4 +-
flink-python/pyflink/table/tests/test_calc.py | 11 +-
.../pyflink/table/tests/test_column_operation.py | 14 +-
.../pyflink/table/tests/test_descriptor.py | 112 +-
flink-python/pyflink/table/tests/test_distinct.py | 4 +-
flink-python/pyflink/table/tests/test_join.py | 36 +-
.../pyflink/table/tests/test_print_schema.py | 4 +-
.../pyflink/table/tests/test_set_operation.py | 14 +-
flink-python/pyflink/table/tests/test_sort.py | 2 +-
.../table/tests/test_table_environment_api.py | 34 +-
flink-python/pyflink/table/tests/test_types.py | 737 +++++++++
flink-python/pyflink/table/tests/test_window.py | 8 +-
flink-python/pyflink/table/types.py | 1692 +++++++++++++++++++-
flink-python/pyflink/util/type_utils.py | 55 -
18 files changed, 2506 insertions(+), 241 deletions(-)
diff --git a/flink-python/pyflink/table/__init__.py
b/flink-python/pyflink/table/__init__.py
index 4ea3b3f..281647f 100644
--- a/flink-python/pyflink/table/__init__.py
+++ b/flink-python/pyflink/table/__init__.py
@@ -38,7 +38,7 @@ from pyflink.table.table_environment import
(TableEnvironment, StreamTableEnviro
BatchTableEnvironment)
from pyflink.table.table_sink import TableSink, CsvTableSink
from pyflink.table.table_source import TableSource, CsvTableSource
-from pyflink.table.types import DataTypes
+from pyflink.table.types import DataTypes, UserDefinedType, Row
from pyflink.table.window import Tumble, Session, Slide, Over
from pyflink.table.table_descriptor import Rowtime, Schema, OldCsv, FileSystem
@@ -61,4 +61,6 @@ __all__ = [
'Schema',
'OldCsv',
'FileSystem',
+ 'UserDefinedType',
+ 'Row',
]
diff --git a/flink-python/pyflink/table/table_descriptor.py
b/flink-python/pyflink/table/table_descriptor.py
index ed9a157..1dfbde3 100644
--- a/flink-python/pyflink/table/table_descriptor.py
+++ b/flink-python/pyflink/table/table_descriptor.py
@@ -19,9 +19,9 @@ import sys
from abc import ABCMeta
from py4j.java_gateway import get_method
+from pyflink.table.types import _to_java_type
from pyflink.java_gateway import get_gateway
-from pyflink.util.type_utils import to_java_type
if sys.version >= '3':
unicode = str
@@ -191,7 +191,7 @@ class Schema(Descriptor):
if isinstance(field_type, (str, unicode)):
self._j_schema = self._j_schema.field(field_name, field_type)
else:
- self._j_schema = self._j_schema.field(field_name,
to_java_type(field_type))
+ self._j_schema = self._j_schema.field(field_name,
_to_java_type(field_type))
return self
def from_origin_field(self, origin_field_name):
@@ -298,7 +298,7 @@ class OldCsv(FormatDescriptor):
if isinstance(field_type, (str, unicode)):
self._j_csv = self._j_csv.field(field_name, field_type)
else:
- self._j_csv = self._j_csv.field(field_name,
to_java_type(field_type))
+ self._j_csv = self._j_csv.field(field_name,
_to_java_type(field_type))
return self
def quote_character(self, quote_character):
diff --git a/flink-python/pyflink/table/table_environment.py
b/flink-python/pyflink/table/table_environment.py
index 95e9bed..337a0e4 100644
--- a/flink-python/pyflink/table/table_environment.py
+++ b/flink-python/pyflink/table/table_environment.py
@@ -25,7 +25,8 @@ from pyflink.table.table_descriptor import
(StreamTableDescriptor, ConnectorDesc
from pyflink.java_gateway import get_gateway
from pyflink.table import Table
-from pyflink.util import type_utils, utils
+from pyflink.table.types import _to_java_type
+from pyflink.util import utils
__all__ = [
'BatchTableEnvironment',
@@ -88,7 +89,7 @@ class TableEnvironment(object):
j_field_names = utils.to_jarray(gateway.jvm.String, field_names)
j_field_types = utils.to_jarray(
gateway.jvm.TypeInformation,
- [type_utils.to_java_type(field_type) for field_type in
field_types])
+ [_to_java_type(field_type) for field_type in field_types])
self._j_tenv.registerTableSink(name, j_field_names, j_field_types,
table_sink._j_table_sink)
def scan(self, *table_path):
diff --git a/flink-python/pyflink/table/table_source.py
b/flink-python/pyflink/table/table_source.py
index 225abf2..30de59f 100644
--- a/flink-python/pyflink/table/table_source.py
+++ b/flink-python/pyflink/table/table_source.py
@@ -17,8 +17,7 @@
################################################################################
from pyflink.java_gateway import get_gateway
-from pyflink.table.types import DataType
-from pyflink.util import type_utils
+from pyflink.table.types import DataType, _to_java_type
from pyflink.util import utils
__all__ = ['TableSource', 'CsvTableSource']
@@ -48,7 +47,7 @@ class CsvTableSource(TableSource):
gateway = get_gateway()
j_field_names = utils.to_jarray(gateway.jvm.String, field_names)
j_field_types = utils.to_jarray(gateway.jvm.TypeInformation,
- [type_utils.to_java_type(field_type)
+ [_to_java_type(field_type)
for field_type in field_types])
super(CsvTableSource, self).__init__(
gateway.jvm.CsvTableSource(source_path, j_field_names,
j_field_types))
diff --git a/flink-python/pyflink/table/tests/test_aggregate.py
b/flink-python/pyflink/table/tests/test_aggregate.py
index b37673d..b0bdf82 100644
--- a/flink-python/pyflink/table/tests/test_aggregate.py
+++ b/flink-python/pyflink/table/tests/test_aggregate.py
@@ -28,14 +28,14 @@ class StreamTableAggregateTests(PyFlinkStreamTableTestCase):
def test_group_by(self):
source_path = os.path.join(self.tempdir + '/streaming.csv')
field_names = ["a", "b", "c"]
- field_types = [DataTypes.INT, DataTypes.STRING, DataTypes.STRING]
+ field_types = [DataTypes.INT(), DataTypes.STRING(), DataTypes.STRING()]
data = [(1, "Hi", "Hello"), (2, "Hello", "Hello"), (2, "Hello",
"Hello")]
csv_source = self.prepare_csv_source(source_path, data, field_types,
field_names)
t_env = self.t_env
t_env.register_table_source("Source", csv_source)
source = t_env.scan("Source")
field_names = ["a", "b"]
- field_types = [DataTypes.INT, DataTypes.STRING]
+ field_types = [DataTypes.INT(), DataTypes.STRING()]
t_env.register_table_sink(
"Results",
field_names, field_types, source_sink_utils.TestRetractSink())
diff --git a/flink-python/pyflink/table/tests/test_calc.py
b/flink-python/pyflink/table/tests/test_calc.py
index 137f322..0822c91 100644
--- a/flink-python/pyflink/table/tests/test_calc.py
+++ b/flink-python/pyflink/table/tests/test_calc.py
@@ -18,8 +18,7 @@
import os
-from pyflink.table.table_source import CsvTableSource
-from pyflink.table.types import DataTypes
+from pyflink.table import CsvTableSource, DataTypes
from pyflink.testing import source_sink_utils
from pyflink.testing.test_case_utils import PyFlinkStreamTableTestCase
@@ -33,7 +32,7 @@ class StreamTableCalcTests(PyFlinkStreamTableTestCase):
f.write(lines)
f.close()
field_names = ["a", "b", "c"]
- field_types = [DataTypes.INT, DataTypes.STRING, DataTypes.STRING]
+ field_types = [DataTypes.BIGINT(), DataTypes.STRING(),
DataTypes.STRING()]
t_env = self.t_env
# register Orders table in table environment
t_env.register_table_source(
@@ -55,7 +54,7 @@ class StreamTableCalcTests(PyFlinkStreamTableTestCase):
def test_alias(self):
source_path = os.path.join(self.tempdir + '/streaming.csv')
field_names = ["a", "b", "c"]
- field_types = [DataTypes.INT, DataTypes.STRING, DataTypes.STRING]
+ field_types = [DataTypes.BIGINT(), DataTypes.STRING(),
DataTypes.STRING()]
data = [(1, "Hi", "Hello"), (2, "Hello", "Hello")]
csv_source = self.prepare_csv_source(source_path, data, field_types,
field_names)
t_env = self.t_env
@@ -76,7 +75,7 @@ class StreamTableCalcTests(PyFlinkStreamTableTestCase):
def test_where(self):
source_path = os.path.join(self.tempdir + '/streaming.csv')
field_names = ["a", "b", "c"]
- field_types = [DataTypes.INT, DataTypes.STRING, DataTypes.STRING]
+ field_types = [DataTypes.BIGINT(), DataTypes.STRING(),
DataTypes.STRING()]
data = [(1, "Hi", "Hello"), (2, "Hello", "Hello")]
csv_source = self.prepare_csv_source(source_path, data, field_types,
field_names)
t_env = self.t_env
@@ -97,7 +96,7 @@ class StreamTableCalcTests(PyFlinkStreamTableTestCase):
def test_filter(self):
source_path = os.path.join(self.tempdir + '/streaming.csv')
field_names = ["a", "b", "c"]
- field_types = [DataTypes.INT, DataTypes.STRING, DataTypes.STRING]
+ field_types = [DataTypes.BIGINT(), DataTypes.STRING(),
DataTypes.STRING()]
data = [(1, "Hi", "Hello"), (2, "Hello", "Hello")]
csv_source = self.prepare_csv_source(source_path, data, field_types,
field_names)
t_env = self.t_env
diff --git a/flink-python/pyflink/table/tests/test_column_operation.py
b/flink-python/pyflink/table/tests/test_column_operation.py
index d0fbc46..5faff12 100644
--- a/flink-python/pyflink/table/tests/test_column_operation.py
+++ b/flink-python/pyflink/table/tests/test_column_operation.py
@@ -28,13 +28,13 @@ class
StreamTableColumnsOperationTests(PyFlinkStreamTableTestCase):
def test_add_columns(self):
source_path = os.path.join(self.tempdir + '/streaming.csv')
field_names = ["a", "b", "c"]
- field_types = [DataTypes.INT, DataTypes.STRING, DataTypes.STRING]
+ field_types = [DataTypes.INT(), DataTypes.STRING(), DataTypes.STRING()]
data = [(1, "Hi", "Hello"), (2, "Hello", "Hello")]
csv_source = self.prepare_csv_source(source_path, data, field_types,
field_names)
t_env = self.t_env
t_env.register_table_source("Source", csv_source)
source = t_env.scan("Source")
- field_types = [DataTypes.INT, DataTypes.INT, DataTypes.INT]
+ field_types = [DataTypes.INT(), DataTypes.INT(), DataTypes.INT()]
t_env.register_table_sink(
"Results",
field_names, field_types, source_sink_utils.TestAppendSink())
@@ -50,14 +50,14 @@ class
StreamTableColumnsOperationTests(PyFlinkStreamTableTestCase):
def test_add_or_replace_columns(self):
source_path = os.path.join(self.tempdir + '/streaming.csv')
field_names = ["a", "b", "c"]
- field_types = [DataTypes.INT, DataTypes.STRING, DataTypes.STRING]
+ field_types = [DataTypes.INT(), DataTypes.STRING(), DataTypes.STRING()]
data = [(1, "Hi", "Hello"), (2, "Hello", "Hello")]
csv_source = self.prepare_csv_source(source_path, data, field_types,
field_names)
t_env = self.t_env
t_env.register_table_source("Source", csv_source)
source = t_env.scan("Source")
field_names = ["a", "b"]
- field_types = [DataTypes.INT, DataTypes.INT]
+ field_types = [DataTypes.INT(), DataTypes.INT()]
t_env.register_table_sink(
"Results",
field_names, field_types, source_sink_utils.TestAppendSink())
@@ -73,7 +73,7 @@ class
StreamTableColumnsOperationTests(PyFlinkStreamTableTestCase):
def test_rename_columns(self):
source_path = os.path.join(self.tempdir + '/streaming.csv')
field_names = ["a", "b", "c"]
- field_types = [DataTypes.INT, DataTypes.STRING, DataTypes.STRING]
+ field_types = [DataTypes.INT(), DataTypes.STRING(), DataTypes.STRING()]
data = [(1, "Hi", "Hello"), (2, "Hello", "Hello")]
csv_source = self.prepare_csv_source(source_path, data, field_types,
field_names)
t_env = self.t_env
@@ -95,14 +95,14 @@ class
StreamTableColumnsOperationTests(PyFlinkStreamTableTestCase):
def test_drop_columns(self):
source_path = os.path.join(self.tempdir + '/streaming.csv')
field_names = ["a", "b", "c"]
- field_types = [DataTypes.INT, DataTypes.STRING, DataTypes.STRING]
+ field_types = [DataTypes.INT(), DataTypes.STRING(), DataTypes.STRING()]
data = [(1, "Hi", "Hello"), (2, "Hello", "Hello")]
csv_source = self.prepare_csv_source(source_path, data, field_types,
field_names)
t_env = self.t_env
t_env.register_table_source("Source", csv_source)
source = t_env.scan("Source")
field_names = ["b"]
- field_types = [DataTypes.STRING]
+ field_types = [DataTypes.STRING()]
t_env.register_table_sink(
"Results",
field_names, field_types, source_sink_utils.TestAppendSink())
diff --git a/flink-python/pyflink/table/tests/test_descriptor.py
b/flink-python/pyflink/table/tests/test_descriptor.py
index f2c2976..4fcc355 100644
--- a/flink-python/pyflink/table/tests/test_descriptor.py
+++ b/flink-python/pyflink/table/tests/test_descriptor.py
@@ -110,8 +110,8 @@ class OldCsvDescriptorTests(PyFlinkTestCase):
def test_field(self):
csv = OldCsv()
- csv.field("a", DataTypes.LONG)
- csv.field("b", DataTypes.STRING)
+ csv.field("a", DataTypes.BIGINT())
+ csv.field("b", DataTypes.STRING())
csv.field("c", "SQL_TIMESTAMP")
properties = csv.to_properties()
@@ -216,17 +216,17 @@ class SchemaDescriptorTests(PyFlinkTestCase):
schema = Schema()
schema = schema\
- .field("int_field", DataTypes.INT)\
- .field("long_field", DataTypes.LONG)\
- .field("string_field", DataTypes.STRING)\
- .field("timestamp_field", DataTypes.TIMESTAMP)\
- .field("time_field", DataTypes.TIME)\
- .field("date_field", DataTypes.DATE)\
- .field("double_field", DataTypes.DOUBLE)\
- .field("float_field", DataTypes.FLOAT)\
- .field("byte_field", DataTypes.BYTE)\
- .field("short_field", DataTypes.SHORT)\
- .field("boolean_field", DataTypes.BOOLEAN)
+ .field("int_field", DataTypes.INT())\
+ .field("long_field", DataTypes.BIGINT())\
+ .field("string_field", DataTypes.STRING())\
+ .field("timestamp_field", DataTypes.TIMESTAMP())\
+ .field("time_field", DataTypes.TIME())\
+ .field("date_field", DataTypes.DATE())\
+ .field("double_field", DataTypes.DOUBLE())\
+ .field("float_field", DataTypes.FLOAT())\
+ .field("byte_field", DataTypes.TINYINT())\
+ .field("short_field", DataTypes.SMALLINT())\
+ .field("boolean_field", DataTypes.BOOLEAN())
properties = schema.to_properties()
expected = {'schema.0.name': 'int_field',
@@ -298,9 +298,9 @@ class SchemaDescriptorTests(PyFlinkTestCase):
schema = Schema()
schema = schema\
- .field("int_field", DataTypes.INT)\
- .field("long_field",
DataTypes.LONG).from_origin_field("origin_field_a")\
- .field("string_field", DataTypes.STRING)
+ .field("int_field", DataTypes.INT())\
+ .field("long_field",
DataTypes.BIGINT()).from_origin_field("origin_field_a")\
+ .field("string_field", DataTypes.STRING())
properties = schema.to_properties()
expected = {'schema.0.name': 'int_field',
@@ -316,9 +316,9 @@ class SchemaDescriptorTests(PyFlinkTestCase):
schema = Schema()
schema = schema\
- .field("int_field", DataTypes.INT)\
- .field("ptime", DataTypes.LONG).proctime()\
- .field("string_field", DataTypes.STRING)
+ .field("int_field", DataTypes.INT())\
+ .field("ptime", DataTypes.BIGINT()).proctime()\
+ .field("string_field", DataTypes.STRING())
properties = schema.to_properties()
expected = {'schema.0.name': 'int_field',
@@ -334,12 +334,12 @@ class SchemaDescriptorTests(PyFlinkTestCase):
schema = Schema()
schema = schema\
- .field("int_field", DataTypes.INT)\
- .field("long_field", DataTypes.LONG)\
- .field("rtime", DataTypes.LONG)\
+ .field("int_field", DataTypes.INT())\
+ .field("long_field", DataTypes.BIGINT())\
+ .field("rtime", DataTypes.BIGINT())\
.rowtime(
Rowtime().timestamps_from_field("long_field").watermarks_periodic_bounded(5000))\
- .field("string_field", DataTypes.STRING)
+ .field("string_field", DataTypes.STRING())
properties = schema.to_properties()
print(properties)
@@ -392,7 +392,7 @@ class AbstractTableDescriptorTests(object):
def test_register_table_sink(self):
source_path = os.path.join(self.tempdir + '/streaming.csv')
field_names = ["a", "b", "c"]
- field_types = [DataTypes.INT, DataTypes.STRING, DataTypes.STRING]
+ field_types = [DataTypes.INT(), DataTypes.STRING(), DataTypes.STRING()]
data = [(1, "Hi", "Hello"), (2, "Hello", "Hello")]
csv_source = self.prepare_csv_source(source_path, data, field_types,
field_names)
t_env = self.t_env
@@ -405,13 +405,13 @@ class AbstractTableDescriptorTests(object):
t_env.connect(FileSystem().path(sink_path))\
.with_format(OldCsv()
.field_delimiter(',')
- .field("a", DataTypes.INT)
- .field("b", DataTypes.STRING)
- .field("c", DataTypes.STRING))\
+ .field("a", DataTypes.INT())
+ .field("b", DataTypes.STRING())
+ .field("c", DataTypes.STRING()))\
.with_schema(Schema()
- .field("a", DataTypes.INT)
- .field("b", DataTypes.STRING)
- .field("c", DataTypes.STRING))\
+ .field("a", DataTypes.INT())
+ .field("b", DataTypes.STRING())
+ .field("c", DataTypes.STRING()))\
.register_table_sink("sink")
t_env.scan("source") \
.select("a + 1, b, c") \
@@ -425,7 +425,7 @@ class AbstractTableDescriptorTests(object):
def test_register_table_source(self):
source_path = os.path.join(self.tempdir + '/streaming.csv')
field_names = ["a", "b", "c"]
- field_types = [DataTypes.INT, DataTypes.STRING, DataTypes.STRING]
+ field_types = [DataTypes.INT(), DataTypes.STRING(), DataTypes.STRING()]
data = [(1, "Hi", "Hello"), (2, "Hello", "Hello")]
self.prepare_csv_source(source_path, data, field_types, field_names)
t_env = self.t_env
@@ -440,13 +440,13 @@ class AbstractTableDescriptorTests(object):
t_env.connect(FileSystem().path(source_path))\
.with_format(OldCsv()
.field_delimiter(',')
- .field("a", DataTypes.INT)
- .field("b", DataTypes.STRING)
- .field("c", DataTypes.STRING))\
+ .field("a", DataTypes.INT())
+ .field("b", DataTypes.STRING())
+ .field("c", DataTypes.STRING()))\
.with_schema(Schema()
- .field("a", DataTypes.INT)
- .field("b", DataTypes.STRING)
- .field("c", DataTypes.STRING))\
+ .field("a", DataTypes.INT())
+ .field("b", DataTypes.STRING())
+ .field("c", DataTypes.STRING()))\
.register_table_source("source")
t_env.scan("source") \
.select("a + 1, b, c") \
@@ -460,7 +460,7 @@ class AbstractTableDescriptorTests(object):
def test_register_table_source_and_sink(self):
source_path = os.path.join(self.tempdir + '/streaming.csv')
field_names = ["a", "b", "c"]
- field_types = [DataTypes.INT, DataTypes.STRING, DataTypes.STRING]
+ field_types = [DataTypes.INT(), DataTypes.STRING(), DataTypes.STRING()]
data = [(1, "Hi", "Hello"), (2, "Hello", "Hello")]
self.prepare_csv_source(source_path, data, field_types, field_names)
sink_path = os.path.join(self.tempdir + '/streaming2.csv')
@@ -471,24 +471,24 @@ class AbstractTableDescriptorTests(object):
t_env.connect(FileSystem().path(source_path))\
.with_format(OldCsv()
.field_delimiter(',')
- .field("a", DataTypes.INT)
- .field("b", DataTypes.STRING)
- .field("c", DataTypes.STRING))\
+ .field("a", DataTypes.INT())
+ .field("b", DataTypes.STRING())
+ .field("c", DataTypes.STRING()))\
.with_schema(Schema()
- .field("a", DataTypes.INT)
- .field("b", DataTypes.STRING)
- .field("c", DataTypes.STRING))\
+ .field("a", DataTypes.INT())
+ .field("b", DataTypes.STRING())
+ .field("c", DataTypes.STRING()))\
.register_table_source_and_sink("source")
t_env.connect(FileSystem().path(sink_path))\
.with_format(OldCsv()
.field_delimiter(',')
- .field("a", DataTypes.INT)
- .field("b", DataTypes.STRING)
- .field("c", DataTypes.STRING))\
+ .field("a", DataTypes.INT())
+ .field("b", DataTypes.STRING())
+ .field("c", DataTypes.STRING()))\
.with_schema(Schema()
- .field("a", DataTypes.INT)
- .field("b", DataTypes.STRING)
- .field("c", DataTypes.STRING))\
+ .field("a", DataTypes.INT())
+ .field("b", DataTypes.STRING())
+ .field("c", DataTypes.STRING()))\
.register_table_source_and_sink("sink")
t_env.scan("source") \
.select("a + 1, b, c") \
@@ -588,13 +588,13 @@ class
StreamDescriptorEndToEndTests(PyFlinkStreamTableTestCase):
t_env.connect(FileSystem().path(sink_path))\
.with_format(OldCsv()
.field_delimiter(',')
- .field("a", DataTypes.INT)
- .field("b", DataTypes.STRING)
- .field("c", DataTypes.STRING))\
+ .field("a", DataTypes.INT())
+ .field("b", DataTypes.STRING())
+ .field("c", DataTypes.STRING()))\
.with_schema(Schema()
- .field("a", DataTypes.INT)
- .field("b", DataTypes.STRING)
- .field("c", DataTypes.STRING))\
+ .field("a", DataTypes.INT())
+ .field("b", DataTypes.STRING())
+ .field("c", DataTypes.STRING()))\
.register_table_sink("sink")
t_env.scan("source") \
diff --git a/flink-python/pyflink/table/tests/test_distinct.py
b/flink-python/pyflink/table/tests/test_distinct.py
index 8f5b05a..fd32a28 100644
--- a/flink-python/pyflink/table/tests/test_distinct.py
+++ b/flink-python/pyflink/table/tests/test_distinct.py
@@ -28,14 +28,14 @@ class StreamTableDistinctTests(PyFlinkStreamTableTestCase):
def test_distinct(self):
source_path = os.path.join(self.tempdir + '/streaming.csv')
field_names = ["a", "b", "c"]
- field_types = [DataTypes.INT, DataTypes.STRING, DataTypes.STRING]
+ field_types = [DataTypes.INT(), DataTypes.STRING(), DataTypes.STRING()]
data = [(1, "Hi", "Hello"), (2, "Hello", "Hello"), (2, "Hello",
"Hello")]
csv_source = self.prepare_csv_source(source_path, data, field_types,
field_names)
t_env = self.t_env
t_env.register_table_source("Source", csv_source)
source = t_env.scan("Source")
field_names = ["a", "b"]
- field_types = [DataTypes.INT, DataTypes.STRING]
+ field_types = [DataTypes.INT(), DataTypes.STRING()]
t_env.register_table_sink(
"Results",
field_names, field_types, source_sink_utils.TestRetractSink())
diff --git a/flink-python/pyflink/table/tests/test_join.py
b/flink-python/pyflink/table/tests/test_join.py
index eb524b6..4666b97 100644
--- a/flink-python/pyflink/table/tests/test_join.py
+++ b/flink-python/pyflink/table/tests/test_join.py
@@ -28,12 +28,12 @@ class StreamTableJoinTests(PyFlinkStreamTableTestCase):
def test_join_without_where(self):
source_path = os.path.join(self.tempdir + '/streaming.csv')
field_names = ["a", "b", "c"]
- field_types = [DataTypes.INT, DataTypes.STRING, DataTypes.STRING]
+ field_types = [DataTypes.INT(), DataTypes.STRING(), DataTypes.STRING()]
data = [(1, "Hi", "Hello"), (2, "Hi", "Hello"), (3, "Hello", "Hello")]
csv_source = self.prepare_csv_source(source_path, data, field_types,
field_names)
source_path2 = os.path.join(self.tempdir + '/streaming2.csv')
field_names2 = ["d", "e"]
- field_types2 = [DataTypes.INT, DataTypes.STRING]
+ field_types2 = [DataTypes.INT(), DataTypes.STRING()]
data2 = [(2, "Flink"), (3, "Python"), (3, "Flink")]
csv_source2 = self.prepare_csv_source(source_path2, data2,
field_types2, field_names2)
t_env = self.t_env
@@ -42,7 +42,7 @@ class StreamTableJoinTests(PyFlinkStreamTableTestCase):
source1 = t_env.scan("Source1")
source2 = t_env.scan("Source2")
field_names = ["a", "b"]
- field_types = [DataTypes.INT, DataTypes.STRING]
+ field_types = [DataTypes.INT(), DataTypes.STRING()]
t_env.register_table_sink(
"Results",
field_names, field_types, source_sink_utils.TestRetractSink())
@@ -58,12 +58,12 @@ class StreamTableJoinTests(PyFlinkStreamTableTestCase):
def test_join_with_where(self):
source_path = os.path.join(self.tempdir + '/streaming.csv')
field_names = ["a", "b", "c"]
- field_types = [DataTypes.INT, DataTypes.STRING, DataTypes.STRING]
+ field_types = [DataTypes.INT(), DataTypes.STRING(), DataTypes.STRING()]
data = [(1, "Hi", "Hello"), (2, "Hi", "Hello"), (3, "Hello", "Hello")]
csv_source = self.prepare_csv_source(source_path, data, field_types,
field_names)
source_path2 = os.path.join(self.tempdir + '/streaming2.csv')
field_names2 = ["d", "e"]
- field_types2 = [DataTypes.INT, DataTypes.STRING]
+ field_types2 = [DataTypes.INT(), DataTypes.STRING()]
data2 = [(2, "Flink"), (3, "Python"), (3, "Flink")]
csv_source2 = self.prepare_csv_source(source_path2, data2,
field_types2, field_names2)
t_env = self.t_env
@@ -72,7 +72,7 @@ class StreamTableJoinTests(PyFlinkStreamTableTestCase):
source1 = t_env.scan("Source1")
source2 = t_env.scan("Source2")
field_names = ["a", "b"]
- field_types = [DataTypes.INT, DataTypes.STRING]
+ field_types = [DataTypes.INT(), DataTypes.STRING()]
t_env.register_table_sink(
"Results",
field_names, field_types, source_sink_utils.TestRetractSink())
@@ -88,12 +88,12 @@ class StreamTableJoinTests(PyFlinkStreamTableTestCase):
def test_left_outer_join_without_where(self):
source_path = os.path.join(self.tempdir + '/streaming.csv')
field_names = ["a", "b", "c"]
- field_types = [DataTypes.INT, DataTypes.STRING, DataTypes.STRING]
+ field_types = [DataTypes.INT(), DataTypes.STRING(), DataTypes.STRING()]
data = [(1, "Hi", "Hello"), (2, "Hi", "Hello"), (3, "Hello", "Hello")]
csv_source = self.prepare_csv_source(source_path, data, field_types,
field_names)
source_path2 = os.path.join(self.tempdir + '/streaming2.csv')
field_names2 = ["d", "e"]
- field_types2 = [DataTypes.INT, DataTypes.STRING]
+ field_types2 = [DataTypes.INT(), DataTypes.STRING()]
data2 = [(2, "Flink"), (3, "Python"), (3, "Flink")]
csv_source2 = self.prepare_csv_source(source_path2, data2,
field_types2, field_names2)
t_env = self.t_env
@@ -102,7 +102,7 @@ class StreamTableJoinTests(PyFlinkStreamTableTestCase):
source1 = t_env.scan("Source1")
source2 = t_env.scan("Source2")
field_names = ["a", "b"]
- field_types = [DataTypes.INT, DataTypes.STRING]
+ field_types = [DataTypes.INT(), DataTypes.STRING()]
t_env.register_table_sink(
"Results",
field_names, field_types, source_sink_utils.TestRetractSink())
@@ -118,12 +118,12 @@ class StreamTableJoinTests(PyFlinkStreamTableTestCase):
def test_left_outer_join_with_where(self):
source_path = os.path.join(self.tempdir + '/streaming.csv')
field_names = ["a", "b", "c"]
- field_types = [DataTypes.INT, DataTypes.STRING, DataTypes.STRING]
+ field_types = [DataTypes.INT(), DataTypes.STRING(), DataTypes.STRING()]
data = [(1, "Hi", "Hello"), (2, "Hi", "Hello"), (3, "Hello", "Hello")]
csv_source = self.prepare_csv_source(source_path, data, field_types,
field_names)
source_path2 = os.path.join(self.tempdir + '/streaming2.csv')
field_names2 = ["d", "e"]
- field_types2 = [DataTypes.INT, DataTypes.STRING]
+ field_types2 = [DataTypes.INT(), DataTypes.STRING()]
data2 = [(2, "Flink"), (3, "Python"), (3, "Flink")]
csv_source2 = self.prepare_csv_source(source_path2, data2,
field_types2, field_names2)
t_env = self.t_env
@@ -132,7 +132,7 @@ class StreamTableJoinTests(PyFlinkStreamTableTestCase):
source1 = t_env.scan("Source1")
source2 = t_env.scan("Source2")
field_names = ["a", "b"]
- field_types = [DataTypes.INT, DataTypes.STRING]
+ field_types = [DataTypes.INT(), DataTypes.STRING()]
t_env.register_table_sink(
"Results",
field_names, field_types, source_sink_utils.TestRetractSink())
@@ -148,12 +148,12 @@ class StreamTableJoinTests(PyFlinkStreamTableTestCase):
def test_right_outer_join(self):
source_path = os.path.join(self.tempdir + '/streaming.csv')
field_names = ["a", "b", "c"]
- field_types = [DataTypes.INT, DataTypes.STRING, DataTypes.STRING]
+ field_types = [DataTypes.INT(), DataTypes.STRING(), DataTypes.STRING()]
data = [(1, "Hi", "Hello"), (2, "Hi", "Hello"), (3, "Hello", "Hello")]
csv_source = self.prepare_csv_source(source_path, data, field_types,
field_names)
source_path2 = os.path.join(self.tempdir + '/streaming2.csv')
field_names2 = ["d", "e"]
- field_types2 = [DataTypes.INT, DataTypes.STRING]
+ field_types2 = [DataTypes.INT(), DataTypes.STRING()]
data2 = [(2, "Flink"), (3, "Python"), (4, "Flink")]
csv_source2 = self.prepare_csv_source(source_path2, data2,
field_types2, field_names2)
t_env = self.t_env
@@ -162,7 +162,7 @@ class StreamTableJoinTests(PyFlinkStreamTableTestCase):
source1 = t_env.scan("Source1")
source2 = t_env.scan("Source2")
field_names = ["a", "b"]
- field_types = [DataTypes.INT, DataTypes.STRING]
+ field_types = [DataTypes.INT(), DataTypes.STRING()]
t_env.register_table_sink(
"Results",
field_names, field_types, source_sink_utils.TestRetractSink())
@@ -178,12 +178,12 @@ class StreamTableJoinTests(PyFlinkStreamTableTestCase):
def test_full_outer_join(self):
source_path = os.path.join(self.tempdir + '/streaming.csv')
field_names = ["a", "b", "c"]
- field_types = [DataTypes.INT, DataTypes.STRING, DataTypes.STRING]
+ field_types = [DataTypes.INT(), DataTypes.STRING(), DataTypes.STRING()]
data = [(1, "Hi", "Hello"), (2, "Hi", "Hello"), (3, "Hello", "Hello")]
csv_source = self.prepare_csv_source(source_path, data, field_types,
field_names)
source_path2 = os.path.join(self.tempdir + '/streaming2.csv')
field_names2 = ["d", "e"]
- field_types2 = [DataTypes.INT, DataTypes.STRING]
+ field_types2 = [DataTypes.INT(), DataTypes.STRING()]
data2 = [(2, "Flink"), (3, "Python"), (4, "Flink")]
csv_source2 = self.prepare_csv_source(source_path2, data2,
field_types2, field_names2)
t_env = self.t_env
@@ -192,7 +192,7 @@ class StreamTableJoinTests(PyFlinkStreamTableTestCase):
source1 = t_env.scan("Source1")
source2 = t_env.scan("Source2")
field_names = ["a", "b", "c"]
- field_types = [DataTypes.INT, DataTypes.INT, DataTypes.STRING]
+ field_types = [DataTypes.INT(), DataTypes.INT(), DataTypes.STRING()]
t_env.register_table_sink(
"Results",
field_names, field_types, source_sink_utils.TestRetractSink())
diff --git a/flink-python/pyflink/table/tests/test_print_schema.py
b/flink-python/pyflink/table/tests/test_print_schema.py
index 45fbc11..4a6309f 100644
--- a/flink-python/pyflink/table/tests/test_print_schema.py
+++ b/flink-python/pyflink/table/tests/test_print_schema.py
@@ -28,14 +28,14 @@ class StreamTableSchemaTests(PyFlinkStreamTableTestCase):
def test_print_schema(self):
source_path = os.path.join(self.tempdir + '/streaming.csv')
field_names = ["a", "b", "c"]
- field_types = [DataTypes.INT, DataTypes.STRING, DataTypes.STRING]
+ field_types = [DataTypes.INT(), DataTypes.STRING(), DataTypes.STRING()]
data = [(1, "Hi", "Hello"), (2, "Hello", "Hello"), (2, "Hello",
"Hello")]
csv_source = self.prepare_csv_source(source_path, data, field_types,
field_names)
t_env = self.t_env
t_env.register_table_source("Source", csv_source)
source = t_env.scan("Source")
field_names = ["a", "b"]
- field_types = [DataTypes.INT, DataTypes.STRING]
+ field_types = [DataTypes.INT(), DataTypes.STRING()]
t_env.register_table_sink(
"Results",
field_names, field_types, source_sink_utils.TestRetractSink())
diff --git a/flink-python/pyflink/table/tests/test_set_operation.py
b/flink-python/pyflink/table/tests/test_set_operation.py
index 151e895..62e4dfa 100644
--- a/flink-python/pyflink/table/tests/test_set_operation.py
+++ b/flink-python/pyflink/table/tests/test_set_operation.py
@@ -28,7 +28,7 @@ class
StreamTableSetOperationTests(PyFlinkStreamTableTestCase):
def test_union_all(self):
source_path = os.path.join(self.tempdir + '/streaming.csv')
field_names = ["a", "b", "c"]
- field_types = [DataTypes.INT, DataTypes.STRING, DataTypes.STRING]
+ field_types = [DataTypes.INT(), DataTypes.STRING(), DataTypes.STRING()]
data = [(1, "Hi", "Hello"), (2, "Hi", "Hello"), (3, "Hello", "Hello")]
csv_source = self.prepare_csv_source(source_path, data, field_types,
field_names)
source_path2 = os.path.join(self.tempdir + '/streaming2.csv')
@@ -40,7 +40,7 @@ class
StreamTableSetOperationTests(PyFlinkStreamTableTestCase):
source1 = t_env.scan("Source1")
source2 = t_env.scan("Source2")
field_names = ["a", "b", "c"]
- field_types = [DataTypes.INT, DataTypes.STRING, DataTypes.STRING]
+ field_types = [DataTypes.INT(), DataTypes.STRING(), DataTypes.STRING()]
t_env.register_table_sink(
"Results",
field_names, field_types, source_sink_utils.TestAppendSink())
@@ -64,7 +64,7 @@ class BatchTableSetOperationTests(PyFlinkBatchTableTestCase):
def test_minus(self):
source_path = os.path.join(self.tempdir + '/streaming.csv')
field_names = ["a", "b", "c"]
- field_types = [DataTypes.INT, DataTypes.STRING, DataTypes.STRING]
+ field_types = [DataTypes.INT(), DataTypes.STRING(), DataTypes.STRING()]
data = [(1, "Hi", "Hello"), (1, "Hi", "Hello"), (3, "Hello", "Hello")]
csv_source = self.prepare_csv_source(source_path, data, field_types,
field_names)
source_path2 = os.path.join(self.tempdir + '/streaming2.csv')
@@ -85,7 +85,7 @@ class BatchTableSetOperationTests(PyFlinkBatchTableTestCase):
def test_minus_all(self):
source_path = os.path.join(self.tempdir + '/streaming.csv')
field_names = ["a", "b", "c"]
- field_types = [DataTypes.INT, DataTypes.STRING, DataTypes.STRING]
+ field_types = [DataTypes.INT(), DataTypes.STRING(), DataTypes.STRING()]
data = [(1, "Hi", "Hello"), (1, "Hi", "Hello"), (3, "Hello", "Hello")]
csv_source = self.prepare_csv_source(source_path, data, field_types,
field_names)
source_path2 = os.path.join(self.tempdir + '/streaming2.csv')
@@ -107,7 +107,7 @@ class
BatchTableSetOperationTests(PyFlinkBatchTableTestCase):
def test_union(self):
source_path = os.path.join(self.tempdir + '/streaming.csv')
field_names = ["a", "b", "c"]
- field_types = [DataTypes.INT, DataTypes.STRING, DataTypes.STRING]
+ field_types = [DataTypes.INT(), DataTypes.STRING(), DataTypes.STRING()]
data = [(1, "Hi", "Hello"), (2, "Hi", "Hello"), (3, "Hello", "Hello")]
csv_source = self.prepare_csv_source(source_path, data, field_types,
field_names)
source_path2 = os.path.join(self.tempdir + '/streaming2.csv')
@@ -132,7 +132,7 @@ class
BatchTableSetOperationTests(PyFlinkBatchTableTestCase):
def test_intersect(self):
source_path = os.path.join(self.tempdir + '/streaming.csv')
field_names = ["a", "b", "c"]
- field_types = [DataTypes.INT, DataTypes.STRING, DataTypes.STRING]
+ field_types = [DataTypes.INT(), DataTypes.STRING(), DataTypes.STRING()]
data = [(1, "Hi", "Hello"), (2, "Hi", "Hello"), (2, "Hi", "Hello")]
csv_source = self.prepare_csv_source(source_path, data, field_types,
field_names)
source_path2 = os.path.join(self.tempdir + '/streaming2.csv')
@@ -153,7 +153,7 @@ class
BatchTableSetOperationTests(PyFlinkBatchTableTestCase):
def test_intersect_all(self):
source_path = os.path.join(self.tempdir + '/streaming.csv')
field_names = ["a", "b", "c"]
- field_types = [DataTypes.INT, DataTypes.STRING, DataTypes.STRING]
+ field_types = [DataTypes.INT(), DataTypes.STRING(), DataTypes.STRING()]
data = [(1, "Hi", "Hello"), (2, "Hi", "Hello"), (2, "Hi", "Hello")]
csv_source = self.prepare_csv_source(source_path, data, field_types,
field_names)
source_path2 = os.path.join(self.tempdir + '/streaming2.csv')
diff --git a/flink-python/pyflink/table/tests/test_sort.py
b/flink-python/pyflink/table/tests/test_sort.py
index f916b93..f0a5df5 100644
--- a/flink-python/pyflink/table/tests/test_sort.py
+++ b/flink-python/pyflink/table/tests/test_sort.py
@@ -27,7 +27,7 @@ class BatchTableSortTests(PyFlinkBatchTableTestCase):
def test_order_by_offset_fetch(self):
source_path = os.path.join(self.tempdir + '/streaming.csv')
field_names = ["a", "b"]
- field_types = [DataTypes.INT, DataTypes.STRING]
+ field_types = [DataTypes.INT(), DataTypes.STRING()]
data = [(1, "Hello"), (2, "Hello"), (3, "Flink"), (4, "Python")]
csv_source = self.prepare_csv_source(source_path, data, field_types,
field_names)
t_env = self.t_env
diff --git a/flink-python/pyflink/table/tests/test_table_environment_api.py
b/flink-python/pyflink/table/tests/test_table_environment_api.py
index 0abe244..4b2e7c3 100644
--- a/flink-python/pyflink/table/tests/test_table_environment_api.py
+++ b/flink-python/pyflink/table/tests/test_table_environment_api.py
@@ -33,13 +33,13 @@ class
StreamTableEnvironmentTests(PyFlinkStreamTableTestCase):
def test_register_scan(self):
source_path = os.path.join(self.tempdir + '/streaming.csv')
field_names = ["a", "b", "c"]
- field_types = [DataTypes.INT, DataTypes.STRING, DataTypes.STRING]
+ field_types = [DataTypes.INT(), DataTypes.STRING(), DataTypes.STRING()]
data = [(1, "Hi", "Hello"), (2, "Hello", "Hello")]
csv_source = self.prepare_csv_source(source_path, data, field_types,
field_names)
t_env = self.t_env
t_env.register_table_source("Source", csv_source)
field_names = ["a", "b", "c"]
- field_types = [DataTypes.INT, DataTypes.STRING, DataTypes.STRING]
+ field_types = [DataTypes.INT(), DataTypes.STRING(), DataTypes.STRING()]
t_env.register_table_sink(
"Results",
field_names, field_types, source_sink_utils.TestAppendSink())
@@ -55,7 +55,7 @@ class StreamTableEnvironmentTests(PyFlinkStreamTableTestCase):
def test_register_table_source_sink(self):
source_path = os.path.join(self.tempdir + '/streaming.csv')
field_names = ["a", "b", "c"]
- field_types = [DataTypes.INT, DataTypes.STRING, DataTypes.STRING]
+ field_types = [DataTypes.INT(), DataTypes.STRING(), DataTypes.STRING()]
data = [(1, "Hi", "Hello")]
csv_source = self.prepare_csv_source(source_path, data, field_types,
field_names)
t_env = self.t_env
@@ -74,7 +74,7 @@ class StreamTableEnvironmentTests(PyFlinkStreamTableTestCase):
def test_from_table_source(self):
source_path = os.path.join(self.tempdir + '/streaming.csv')
field_names = ["a", "b", "c"]
- field_types = [DataTypes.INT, DataTypes.STRING, DataTypes.STRING]
+ field_types = [DataTypes.INT(), DataTypes.STRING(), DataTypes.STRING()]
data = [(1, "Hi", "Hello"), (2, "Hi", "Hello")]
csv_source = self.prepare_csv_source(source_path, data, field_types,
field_names)
t_env = self.t_env
@@ -93,7 +93,7 @@ class StreamTableEnvironmentTests(PyFlinkStreamTableTestCase):
def test_list_tables(self):
source_path = os.path.join(self.tempdir + '/streaming.csv')
field_names = ["a", "b", "c"]
- field_types = [DataTypes.INT, DataTypes.STRING, DataTypes.STRING]
+ field_types = [DataTypes.INT(), DataTypes.STRING(), DataTypes.STRING()]
data = []
csv_source = self.prepare_csv_source(source_path, data, field_types,
field_names)
t_env = self.t_env
@@ -113,7 +113,7 @@ class
StreamTableEnvironmentTests(PyFlinkStreamTableTestCase):
def test_explain(self):
source_path = os.path.join(self.tempdir + '/streaming.csv')
field_names = ["a", "b", "c"]
- field_types = [DataTypes.INT, DataTypes.STRING, DataTypes.STRING]
+ field_types = [DataTypes.INT(), DataTypes.STRING(), DataTypes.STRING()]
data = []
csv_source = self.prepare_csv_source(source_path, data, field_types,
field_names)
t_env = self.t_env
@@ -128,7 +128,7 @@ class
StreamTableEnvironmentTests(PyFlinkStreamTableTestCase):
def test_sql_query(self):
source_path = os.path.join(self.tempdir + '/streaming.csv')
field_names = ["a", "b", "c"]
- field_types = [DataTypes.INT, DataTypes.STRING, DataTypes.STRING]
+ field_types = [DataTypes.INT(), DataTypes.STRING(), DataTypes.STRING()]
data = [(1, "Hi", "Hello"), (2, "Hello", "Hello")]
csv_source = self.prepare_csv_source(source_path, data, field_types,
field_names)
t_env = self.t_env
@@ -149,7 +149,7 @@ class
StreamTableEnvironmentTests(PyFlinkStreamTableTestCase):
def test_sql_update(self):
source_path = os.path.join(self.tempdir + '/streaming.csv')
field_names = ["a", "b", "c"]
- field_types = [DataTypes.INT, DataTypes.STRING, DataTypes.STRING]
+ field_types = [DataTypes.INT(), DataTypes.STRING(), DataTypes.STRING()]
data = [(1, "Hi", "Hello"), (2, "Hello", "Hello")]
csv_source = self.prepare_csv_source(source_path, data, field_types,
field_names)
t_env = self.t_env
@@ -168,7 +168,7 @@ class
StreamTableEnvironmentTests(PyFlinkStreamTableTestCase):
def test_sql_update_with_query_config(self):
source_path = os.path.join(self.tempdir + '/streaming.csv')
field_names = ["a", "b", "c"]
- field_types = [DataTypes.INT, DataTypes.STRING, DataTypes.STRING]
+ field_types = [DataTypes.INT(), DataTypes.STRING(), DataTypes.STRING()]
data = [(1, "Hi", "Hello"), (2, "Hello", "Hello")]
csv_source = self.prepare_csv_source(source_path, data, field_types,
field_names)
t_env = self.t_env
@@ -236,7 +236,7 @@ class BatchTableEnvironmentTests(PyFlinkBatchTableTestCase):
def test_register_scan(self):
source_path = os.path.join(self.tempdir + '/streaming.csv')
field_names = ["a", "b", "c"]
- field_types = [DataTypes.INT, DataTypes.STRING, DataTypes.STRING]
+ field_types = [DataTypes.INT(), DataTypes.STRING(), DataTypes.STRING()]
data = [(1, "Hi", "Hello"), (2, "Hello", "Hello")]
csv_source = self.prepare_csv_source(source_path, data, field_types,
field_names)
t_env = self.t_env
@@ -251,7 +251,7 @@ class BatchTableEnvironmentTests(PyFlinkBatchTableTestCase):
def test_register_table_source_sink(self):
source_path = os.path.join(self.tempdir + '/streaming.csv')
field_names = ["a", "b", "c"]
- field_types = [DataTypes.INT, DataTypes.STRING, DataTypes.STRING]
+ field_types = [DataTypes.INT(), DataTypes.STRING(), DataTypes.STRING()]
data = [(1, "Hi", "Hello")]
csv_source = self.prepare_csv_source(source_path, data, field_types,
field_names)
t_env = self.t_env
@@ -274,7 +274,7 @@ class BatchTableEnvironmentTests(PyFlinkBatchTableTestCase):
def test_from_table_source(self):
source_path = os.path.join(self.tempdir + '/streaming.csv')
field_names = ["a", "b", "c"]
- field_types = [DataTypes.INT, DataTypes.STRING, DataTypes.STRING]
+ field_types = [DataTypes.INT(), DataTypes.STRING(), DataTypes.STRING()]
data = [(1, "Hi", "Hello"), (2, "Hi", "Hello")]
csv_source = self.prepare_csv_source(source_path, data, field_types,
field_names)
t_env = self.t_env
@@ -288,7 +288,7 @@ class BatchTableEnvironmentTests(PyFlinkBatchTableTestCase):
def test_list_tables(self):
source_path = os.path.join(self.tempdir + '/streaming.csv')
field_names = ["a", "b", "c"]
- field_types = [DataTypes.INT, DataTypes.STRING, DataTypes.STRING]
+ field_types = [DataTypes.INT(), DataTypes.STRING(), DataTypes.STRING()]
data = []
csv_source = self.prepare_csv_source(source_path, data, field_types,
field_names)
t_env = self.t_env
@@ -310,7 +310,7 @@ class BatchTableEnvironmentTests(PyFlinkBatchTableTestCase):
def test_explain(self):
source_path = os.path.join(self.tempdir + '/streaming.csv')
field_names = ["a", "b", "c"]
- field_types = [DataTypes.INT, DataTypes.STRING, DataTypes.STRING]
+ field_types = [DataTypes.INT(), DataTypes.STRING(), DataTypes.STRING()]
data = []
csv_source = self.prepare_csv_source(source_path, data, field_types,
field_names)
t_env = self.t_env
@@ -325,7 +325,7 @@ class BatchTableEnvironmentTests(PyFlinkBatchTableTestCase):
def test_sql_query(self):
source_path = os.path.join(self.tempdir + '/streaming.csv')
field_names = ["a", "b", "c"]
- field_types = [DataTypes.INT, DataTypes.STRING, DataTypes.STRING]
+ field_types = [DataTypes.INT(), DataTypes.STRING(), DataTypes.STRING()]
data = [(1, "Hi", "Hello"), (2, "Hello", "Hello")]
csv_source = self.prepare_csv_source(source_path, data, field_types,
field_names)
t_env = self.t_env
@@ -350,7 +350,7 @@ class BatchTableEnvironmentTests(PyFlinkBatchTableTestCase):
def test_sql_update(self):
source_path = os.path.join(self.tempdir + '/streaming.csv')
field_names = ["a", "b", "c"]
- field_types = [DataTypes.INT, DataTypes.STRING, DataTypes.STRING]
+ field_types = [DataTypes.INT(), DataTypes.STRING(), DataTypes.STRING()]
data = [(1, "Hi", "Hello"), (2, "Hello", "Hello")]
csv_source = self.prepare_csv_source(source_path, data, field_types,
field_names)
t_env = self.t_env
@@ -373,7 +373,7 @@ class BatchTableEnvironmentTests(PyFlinkBatchTableTestCase):
def test_sql_update_with_query_config(self):
source_path = os.path.join(self.tempdir + '/streaming.csv')
field_names = ["a", "b", "c"]
- field_types = [DataTypes.INT, DataTypes.STRING, DataTypes.STRING]
+ field_types = [DataTypes.INT(), DataTypes.STRING(), DataTypes.STRING()]
data = [(1, "Hi", "Hello"), (2, "Hello", "Hello")]
csv_source = self.prepare_csv_source(source_path, data, field_types,
field_names)
t_env = self.t_env
diff --git a/flink-python/pyflink/table/tests/test_types.py
b/flink-python/pyflink/table/tests/test_types.py
new file mode 100644
index 0000000..bf69906
--- /dev/null
+++ b/flink-python/pyflink/table/tests/test_types.py
@@ -0,0 +1,737 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+import array
+import ctypes
+import datetime
+import pickle
+import sys
+import unittest
+
+from pyflink.table.types import (_infer_schema_from_data, _infer_type,
+ _array_signed_int_typecode_ctype_mappings,
+ _array_unsigned_int_typecode_ctype_mappings,
+ _array_type_mappings, _merge_type,
+ _create_type_verifier, UserDefinedType,
DataTypes, Row, RowField,
+ RowType, ArrayType, BigIntType, VarCharType,
MapType)
+
+
+class ExamplePointUDT(UserDefinedType):
+ """
+ User-defined type (UDT) for ExamplePoint.
+ """
+
+ @classmethod
+ def sql_type(cls):
+ return DataTypes.ARRAY(DataTypes.DOUBLE(False))
+
+ @classmethod
+ def module(cls):
+ return 'pyflink.table.tests.test_types'
+
+ @classmethod
+ def java_udt(cls):
+ return
'org.apache.flink.table.types.python.ExamplePointUserDefinedType'
+
+ def serialize(self, obj):
+ return [obj.x, obj.y]
+
+ def deserialize(self, datum):
+ return ExamplePoint(datum[0], datum[1])
+
+
+class ExamplePoint:
+ """
+ An example class to demonstrate UDT in Java, and Python.
+ """
+
+ __UDT__ = ExamplePointUDT()
+
+ def __init__(self, x, y):
+ self.x = x
+ self.y = y
+
+ def __repr__(self):
+ return "ExamplePoint(%s,%s)" % (self.x, self.y)
+
+ def __str__(self):
+ return "(%s,%s)" % (self.x, self.y)
+
+ def __eq__(self, other):
+ return isinstance(other, self.__class__) and other.x == self.x and
other.y == self.y
+
+
+class PythonOnlyUDT(UserDefinedType):
+ """
+ User-defined type (UDT) for ExamplePoint.
+ """
+
+ @classmethod
+ def sql_type(cls):
+ return DataTypes.ARRAY(DataTypes.DOUBLE(False))
+
+ @classmethod
+ def module(cls):
+ return '__main__'
+
+ def serialize(self, obj):
+ return [obj.x, obj.y]
+
+ def deserialize(self, datum):
+ return PythonOnlyPoint(datum[0], datum[1])
+
+
+class PythonOnlyPoint(ExamplePoint):
+ """
+ An example class to demonstrate UDT in only Python
+ """
+ __UDT__ = PythonOnlyUDT()
+
+
+class TypesTests(unittest.TestCase):
+
+ def test_infer_schema(self):
+ from decimal import Decimal
+
+ class A(object):
+ def __init__(self):
+ self.a = 1
+
+ from collections import namedtuple
+ Point = namedtuple('Point', 'x y')
+
+ data = [
+ True,
+ 1,
+ "a",
+ u"a",
+ datetime.date(1970, 1, 1),
+ datetime.time(0, 0, 0),
+ datetime.datetime(1970, 1, 1, 0, 0),
+ 1.0,
+ array.array("d", [1]),
+ [1],
+ (1,),
+ Point(1.0, 5.0),
+ {"a": 1},
+ bytearray(1),
+ Decimal(1),
+ Row(a=1),
+ Row("a")(1),
+ A(),
+ ]
+
+ expected = [
+ 'BooleanType(true)',
+ 'BigIntType(true)',
+ 'VarCharType(2147483647, true)',
+ 'VarCharType(2147483647, true)',
+ 'DateType(true)',
+ 'TimeType(0, true)',
+ 'TimestampType(0, 6, true)',
+ 'DoubleType(true)',
+ "ArrayType(DoubleType(false), true)",
+ "ArrayType(BigIntType(true), true)",
+ 'RowType(RowField(_1, BigIntType(true), ...))',
+ 'RowType(RowField(x, DoubleType(true), ...),RowField(y,
DoubleType(true), ...))',
+ 'MapType(VarCharType(2147483647, false), BigIntType(true), true)',
+ 'VarBinaryType(2147483647, true)',
+ 'DecimalType(38, 18, true)',
+ 'RowType(RowField(a, BigIntType(true), ...))',
+ 'RowType(RowField(a, BigIntType(true), ...))',
+ 'RowType(RowField(a, BigIntType(true), ...))',
+ ]
+
+ schema = _infer_schema_from_data([data])
+ self.assertEqual(expected, [repr(f.data_type) for f in schema.fields])
+
+ def test_infer_schema_nulltype(self):
+ elements = [Row(c1=[], c2={}, c3=None),
+ Row(c1=[Row(a=1, b='s')], c2={"key": Row(c=1.0, d="2")},
c3="")]
+ schema = _infer_schema_from_data(elements)
+ self.assertTrue(isinstance(schema, RowType))
+ self.assertEqual(3, len(schema.fields))
+
+ # first column is array
+ self.assertTrue(isinstance(schema.fields[0].data_type, ArrayType))
+
+ # element type of first column is struct
+ self.assertTrue(isinstance(schema.fields[0].data_type.element_type,
RowType))
+
+
self.assertTrue(isinstance(schema.fields[0].data_type.element_type.fields[0].data_type,
+ BigIntType))
+
self.assertTrue(isinstance(schema.fields[0].data_type.element_type.fields[1].data_type,
+ VarCharType))
+
+ # second column is map
+ self.assertTrue(isinstance(schema.fields[1].data_type, MapType))
+ self.assertTrue(isinstance(schema.fields[1].data_type.key_type,
VarCharType))
+ self.assertTrue(isinstance(schema.fields[1].data_type.value_type,
RowType))
+
+ # third column is varchar
+ self.assertTrue(isinstance(schema.fields[2].data_type, VarCharType))
+
+ def test_infer_schema_not_enough_names(self):
+ schema = _infer_schema_from_data([["a", "b"]], ["col1"])
+ self.assertTrue(schema.names, ['col1', '_2'])
+
+ def test_infer_schema_fails(self):
+ with self.assertRaises(TypeError):
+ _infer_schema_from_data([[1, 1], ["x", 1]], names=["a", "b"])
+
+ def test_infer_nested_schema(self):
+ NestedRow = Row("f1", "f2")
+ data1 = [NestedRow([1, 2], {"row1": 1.0}), NestedRow([2, 3], {"row2":
2.0})]
+ schema1 = _infer_schema_from_data(data1)
+ expected1 = [
+ 'ArrayType(BigIntType(true), true)',
+ 'MapType(VarCharType(2147483647, false), DoubleType(true), true)'
+ ]
+ self.assertEqual(expected1, [repr(f.data_type) for f in
schema1.fields])
+
+ data2 = [NestedRow([[1, 2], [2, 3]], [1, 2]), NestedRow([[2, 3], [3,
4]], [2, 3])]
+ schema2 = _infer_schema_from_data(data2)
+ expected2 = [
+ 'ArrayType(ArrayType(BigIntType(true), true), true)',
+ 'ArrayType(BigIntType(true), true)'
+ ]
+ self.assertEqual(expected2, [repr(f.data_type) for f in
schema2.fields])
+
+ def test_convert_row_to_dict(self):
+ row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})
+ self.assertEqual(1, row.as_dict()['l'][0].a)
+ self.assertEqual(1.0, row.as_dict()['d']['key'].c)
+
+ def test_udt(self):
+ p = ExamplePoint(1.0, 2.0)
+ self.assertEqual(_infer_type(p), ExamplePointUDT())
+ _create_type_verifier(ExamplePointUDT())(ExamplePoint(1.0, 2.0))
+ self.assertRaises(ValueError, lambda:
_create_type_verifier(ExamplePointUDT())([1.0, 2.0]))
+
+ p = PythonOnlyPoint(1.0, 2.0)
+ self.assertEqual(_infer_type(p), PythonOnlyUDT())
+ _create_type_verifier(PythonOnlyUDT())(PythonOnlyPoint(1.0, 2.0))
+ self.assertRaises(ValueError, lambda:
_create_type_verifier(PythonOnlyUDT())([1.0, 2.0]))
+
+ def test_nested_udt_in_df(self):
+ expected_schema = DataTypes.ROW() \
+ .add("_1", DataTypes.BIGINT()).add("_2",
DataTypes.ARRAY(PythonOnlyUDT()))
+ data = (1, [PythonOnlyPoint(float(1), float(2))])
+ self.assertEqual(expected_schema, _infer_type(data))
+
+ expected_schema = DataTypes.ROW().add("_1", DataTypes.BIGINT()).add(
+ "_2", DataTypes.MAP(DataTypes.BIGINT(False), PythonOnlyUDT()))
+ p = (1, {1: PythonOnlyPoint(1, float(2))})
+ self.assertEqual(expected_schema, _infer_type(p))
+
+ def test_struct_type(self):
+ row1 = DataTypes.ROW().add("f1", DataTypes.VARCHAR(nullable=True)) \
+ .add("f2", DataTypes.VARCHAR(nullable=True))
+ row2 = DataTypes.ROW([DataTypes.FIELD("f1",
DataTypes.VARCHAR(nullable=True)),
+ DataTypes.FIELD("f2",
DataTypes.VARCHAR(nullable=True), None)])
+ self.assertEqual(row1.field_names(), row2.names)
+ self.assertEqual(row1, row2)
+
+ row1 = DataTypes.ROW().add("f1", DataTypes.VARCHAR(nullable=True)) \
+ .add("f2", DataTypes.VARCHAR(nullable=True))
+ row2 = DataTypes.ROW([DataTypes.FIELD("f1",
DataTypes.VARCHAR(nullable=True))])
+ self.assertNotEqual(row1.field_names(), row2.names)
+ self.assertNotEqual(row1, row2)
+
+ row1 = (DataTypes.ROW().add(DataTypes.FIELD("f1",
DataTypes.VARCHAR(nullable=True)))
+ .add("f2", DataTypes.VARCHAR(nullable=True)))
+ row2 = DataTypes.ROW([DataTypes.FIELD("f1",
DataTypes.VARCHAR(nullable=True)),
+ DataTypes.FIELD("f2",
DataTypes.VARCHAR(nullable=True))])
+ self.assertEqual(row1.field_names(), row2.names)
+ self.assertEqual(row1, row2)
+
+ row1 = (DataTypes.ROW().add(DataTypes.FIELD("f1",
DataTypes.VARCHAR(nullable=True)))
+ .add("f2", DataTypes.VARCHAR(nullable=True)))
+ row2 = DataTypes.ROW([DataTypes.FIELD("f1",
DataTypes.VARCHAR(nullable=True))])
+ self.assertNotEqual(row1.field_names(), row2.names)
+ self.assertNotEqual(row1, row2)
+
+ # Catch exception raised during improper construction
+ self.assertRaises(ValueError, lambda: DataTypes.ROW().add("name"))
+
+ row1 = DataTypes.ROW().add("f1", DataTypes.VARCHAR(nullable=True)) \
+ .add("f2", DataTypes.VARCHAR(nullable=True))
+ for field in row1:
+ self.assertIsInstance(field, RowField)
+
+ row1 = DataTypes.ROW().add("f1", DataTypes.VARCHAR(nullable=True)) \
+ .add("f2", DataTypes.VARCHAR(nullable=True))
+ self.assertEqual(len(row1), 2)
+
+ row1 = DataTypes.ROW().add("f1", DataTypes.VARCHAR(nullable=True)) \
+ .add("f2", DataTypes.VARCHAR(nullable=True))
+ self.assertIs(row1["f1"], row1.fields[0])
+ self.assertIs(row1[0], row1.fields[0])
+ self.assertEqual(row1[0:1], DataTypes.ROW(row1.fields[0:1]))
+ self.assertRaises(KeyError, lambda: row1["f9"])
+ self.assertRaises(IndexError, lambda: row1[9])
+ self.assertRaises(TypeError, lambda: row1[9.9])
+
+ def test_infer_bigint_type(self):
+ longrow = [Row(f1='a', f2=100000000000000)]
+ schema = _infer_schema_from_data(longrow)
+ self.assertEqual(DataTypes.BIGINT(), schema.fields[1].data_type)
+ self.assertEqual(DataTypes.BIGINT(), _infer_type(1))
+ self.assertEqual(DataTypes.BIGINT(), _infer_type(2 ** 10))
+ self.assertEqual(DataTypes.BIGINT(), _infer_type(2 ** 20))
+ self.assertEqual(DataTypes.BIGINT(), _infer_type(2 ** 31 - 1))
+ self.assertEqual(DataTypes.BIGINT(), _infer_type(2 ** 31))
+ self.assertEqual(DataTypes.BIGINT(), _infer_type(2 ** 61))
+ self.assertEqual(DataTypes.BIGINT(), _infer_type(2 ** 71))
+
+ def test_merge_type(self):
+ self.assertEqual(_merge_type(DataTypes.BIGINT(), DataTypes.NULL()),
DataTypes.BIGINT())
+ self.assertEqual(_merge_type(DataTypes.NULL(), DataTypes.BIGINT()),
DataTypes.BIGINT())
+
+ self.assertEqual(_merge_type(DataTypes.BIGINT(), DataTypes.BIGINT()),
DataTypes.BIGINT())
+
+ self.assertEqual(_merge_type(
+ DataTypes.ARRAY(DataTypes.BIGINT()),
+ DataTypes.ARRAY(DataTypes.BIGINT())
+ ), DataTypes.ARRAY(DataTypes.BIGINT()))
+ with self.assertRaises(TypeError):
+ _merge_type(DataTypes.ARRAY(DataTypes.BIGINT()),
DataTypes.ARRAY(DataTypes.DOUBLE()))
+
+ self.assertEqual(_merge_type(
+ DataTypes.MAP(DataTypes.VARCHAR(), DataTypes.BIGINT()),
+ DataTypes.MAP(DataTypes.VARCHAR(), DataTypes.BIGINT())
+ ), DataTypes.MAP(DataTypes.VARCHAR(), DataTypes.BIGINT()))
+ with self.assertRaises(TypeError):
+ _merge_type(
+ DataTypes.MAP(DataTypes.VARCHAR(), DataTypes.BIGINT()),
+ DataTypes.MAP(DataTypes.DOUBLE(), DataTypes.BIGINT()))
+ with self.assertRaises(TypeError):
+ _merge_type(
+ DataTypes.MAP(DataTypes.VARCHAR(), DataTypes.BIGINT()),
+ DataTypes.MAP(DataTypes.VARCHAR(), DataTypes.DOUBLE()))
+
+ self.assertEqual(_merge_type(
+ DataTypes.ROW([DataTypes.FIELD('f1', DataTypes.BIGINT()),
+ DataTypes.FIELD('f2', DataTypes.VARCHAR())]),
+ DataTypes.ROW([DataTypes.FIELD('f1', DataTypes.BIGINT()),
+ DataTypes.FIELD('f2', DataTypes.VARCHAR())])
+ ), DataTypes.ROW([DataTypes.FIELD('f1', DataTypes.BIGINT()),
+ DataTypes.FIELD('f2', DataTypes.VARCHAR())]))
+ with self.assertRaises(TypeError):
+ _merge_type(
+ DataTypes.ROW([DataTypes.FIELD('f1', DataTypes.BIGINT()),
+ DataTypes.FIELD('f2', DataTypes.VARCHAR())]),
+ DataTypes.ROW([DataTypes.FIELD('f1', DataTypes.DOUBLE()),
+ DataTypes.FIELD('f2', DataTypes.VARCHAR())]))
+
+ self.assertEqual(_merge_type(
+ DataTypes.ROW([DataTypes.FIELD(
+ 'f1', DataTypes.ROW([DataTypes.FIELD('f2',
DataTypes.BIGINT())]))]),
+ DataTypes.ROW([DataTypes.FIELD(
+ 'f1', DataTypes.ROW([DataTypes.FIELD('f2',
DataTypes.BIGINT())]))])
+ ), DataTypes.ROW([DataTypes.FIELD(
+ 'f1', DataTypes.ROW([DataTypes.FIELD('f2',
DataTypes.BIGINT())]))]))
+ with self.assertRaises(TypeError):
+ _merge_type(
+ DataTypes.ROW([DataTypes.FIELD('f1', DataTypes.ROW(
+ [DataTypes.FIELD('f2', DataTypes.BIGINT())]))]),
+ DataTypes.ROW([DataTypes.FIELD('f1', DataTypes.ROW(
+ [DataTypes.FIELD('f2', DataTypes.VARCHAR())]))]))
+
+ self.assertEqual(_merge_type(
+ DataTypes.ROW([DataTypes.FIELD('f1',
DataTypes.ARRAY(DataTypes.BIGINT())),
+ DataTypes.FIELD('f2', DataTypes.VARCHAR())]),
+ DataTypes.ROW([DataTypes.FIELD('f1',
DataTypes.ARRAY(DataTypes.BIGINT())),
+ DataTypes.FIELD('f2', DataTypes.VARCHAR())])
+ ), DataTypes.ROW([DataTypes.FIELD('f1',
DataTypes.ARRAY(DataTypes.BIGINT())),
+ DataTypes.FIELD('f2', DataTypes.VARCHAR())]))
+ with self.assertRaises(TypeError):
+ _merge_type(
+ DataTypes.ROW([
+ DataTypes.FIELD('f1', DataTypes.ARRAY(DataTypes.BIGINT())),
+ DataTypes.FIELD('f2', DataTypes.VARCHAR())]),
+ DataTypes.ROW([
+ DataTypes.FIELD('f1', DataTypes.ARRAY(DataTypes.DOUBLE())),
+ DataTypes.FIELD('f2', DataTypes.VARCHAR())]))
+
+ self.assertEqual(_merge_type(
+ DataTypes.ROW([
+ DataTypes.FIELD('f1', DataTypes.MAP(DataTypes.VARCHAR(),
DataTypes.BIGINT())),
+ DataTypes.FIELD('f2', DataTypes.VARCHAR())]),
+ DataTypes.ROW([
+ DataTypes.FIELD('f1', DataTypes.MAP(DataTypes.VARCHAR(),
DataTypes.BIGINT())),
+ DataTypes.FIELD('f2', DataTypes.VARCHAR())])
+ ), DataTypes.ROW([
+ DataTypes.FIELD('f1', DataTypes.MAP(DataTypes.VARCHAR(),
DataTypes.BIGINT())),
+ DataTypes.FIELD('f2', DataTypes.VARCHAR())]))
+ with self.assertRaises(TypeError):
+ _merge_type(
+ DataTypes.ROW([
+ DataTypes.FIELD('f1', DataTypes.MAP(DataTypes.VARCHAR(),
DataTypes.BIGINT())),
+ DataTypes.FIELD('f2', DataTypes.VARCHAR())]),
+ DataTypes.ROW([
+ DataTypes.FIELD('f1', DataTypes.MAP(DataTypes.VARCHAR(),
DataTypes.DOUBLE())),
+ DataTypes.FIELD('f2', DataTypes.VARCHAR())]))
+
+ self.assertEqual(_merge_type(
+ DataTypes.ROW([DataTypes.FIELD('f1', DataTypes.ARRAY(
+ DataTypes.MAP(DataTypes.VARCHAR(), DataTypes.BIGINT())))]),
+ DataTypes.ROW([DataTypes.FIELD('f1', DataTypes.ARRAY(
+ DataTypes.MAP(DataTypes.VARCHAR(), DataTypes.BIGINT())))])
+ ), DataTypes.ROW([DataTypes.FIELD('f1', DataTypes.ARRAY(
+ DataTypes.MAP(DataTypes.VARCHAR(), DataTypes.BIGINT())))]))
+ with self.assertRaises(TypeError):
+ _merge_type(
+ DataTypes.ROW([DataTypes.FIELD('f1', DataTypes.ARRAY(
+ DataTypes.MAP(DataTypes.VARCHAR(), DataTypes.BIGINT())))]),
+ DataTypes.ROW([DataTypes.FIELD('f1', DataTypes.ARRAY(
+ DataTypes.MAP(DataTypes.DOUBLE(), DataTypes.BIGINT())))])
+ )
+
+ def test_array_types(self):
+ # This test need to make sure that the Scala type selected is at least
+ # as large as the python's types. This is necessary because python's
+ # array types depend on C implementation on the machine. Therefore
there
+ # is no machine independent correspondence between python's array types
+ # and Scala types.
+ # See: https://docs.python.org/2/library/array.html
+
+ def assert_collect_success(typecode, value, element_type):
+ self.assertEqual(element_type,
+ str(_infer_type(array.array(typecode,
[value])).element_type))
+
+ # supported string types
+ #
+ # String types in python's array are "u" for Py_UNICODE and "c" for
char.
+ # "u" will be removed in python 4, and "c" is not supported in python
3.
+ supported_string_types = []
+ if sys.version_info[0] < 4:
+ supported_string_types += ['u']
+ # test unicode
+ assert_collect_success('u', u'a', 'CHAR')
+ if sys.version_info[0] < 3:
+ supported_string_types += ['c']
+ # test string
+ assert_collect_success('c', 'a', 'CHAR')
+
+ # supported float and double
+ #
+ # Test max, min, and precision for float and double, assuming IEEE 754
+ # floating-point format.
+ supported_fractional_types = ['f', 'd']
+ assert_collect_success('f', ctypes.c_float(1e+38).value, 'FLOAT')
+ assert_collect_success('f', ctypes.c_float(1e-38).value, 'FLOAT')
+ assert_collect_success('f', ctypes.c_float(1.123456).value, 'FLOAT')
+ assert_collect_success('d', sys.float_info.max, 'DOUBLE')
+ assert_collect_success('d', sys.float_info.min, 'DOUBLE')
+ assert_collect_success('d', sys.float_info.epsilon, 'DOUBLE')
+
+ def get_int_data_type(size):
+ if size <= 8:
+ return "TINYINT"
+ if size <= 16:
+ return "SMALLINT"
+ if size <= 32:
+ return "INT"
+ if size <= 64:
+ return "BIGINT"
+
+ # supported signed int types
+ #
+ # The size of C types changes with implementation, we need to make sure
+ # that there is no overflow error on the platform running this test.
+ supported_signed_int_types = list(
+ set(_array_signed_int_typecode_ctype_mappings.keys()).intersection(
+ set(_array_type_mappings.keys())))
+ for t in supported_signed_int_types:
+ ctype = _array_signed_int_typecode_ctype_mappings[t]
+ max_val = 2 ** (ctypes.sizeof(ctype) * 8 - 1)
+ assert_collect_success(t, max_val - 1,
get_int_data_type(ctypes.sizeof(ctype) * 8))
+ assert_collect_success(t, -max_val,
get_int_data_type(ctypes.sizeof(ctype) * 8))
+
+ # supported unsigned int types
+ #
+ # JVM does not have unsigned types. We need to be very careful to make
+ # sure that there is no overflow error.
+ supported_unsigned_int_types = list(
+
set(_array_unsigned_int_typecode_ctype_mappings.keys()).intersection(
+ set(_array_type_mappings.keys())))
+ for t in supported_unsigned_int_types:
+ ctype = _array_unsigned_int_typecode_ctype_mappings[t]
+ max_val = 2 ** (ctypes.sizeof(ctype) * 8 - 1)
+ assert_collect_success(t, max_val,
get_int_data_type(ctypes.sizeof(ctype) * 8 + 1))
+
+ # all supported types
+ #
+ # Make sure the types tested above:
+ # 1. are all supported types
+ # 2. cover all supported types
+ supported_types = (supported_string_types +
+ supported_fractional_types +
+ supported_signed_int_types +
+ supported_unsigned_int_types)
+ self.assertEqual(set(supported_types),
set(_array_type_mappings.keys()))
+
+ # all unsupported types
+ #
+ # Keys in _array_type_mappings is a complete list of all supported
types,
+ # and types not in _array_type_mappings are considered unsupported.
+ # `array.typecodes` are not supported in python 2.
+ if sys.version_info[0] < 3:
+ all_types = {'c', 'b', 'B', 'u', 'h', 'H', 'i', 'I', 'l', 'L',
'f', 'd'}
+ else:
+ all_types = set(array.typecodes)
+ unsupported_types = all_types - set(supported_types)
+ # test unsupported types
+ for t in unsupported_types:
+ with self.assertRaises(TypeError):
+ _infer_schema_from_data([Row(myarray=array.array(t))])
+
+ def test_data_type_eq(self):
+ lt = DataTypes.BIGINT()
+ lt2 = pickle.loads(pickle.dumps(DataTypes.BIGINT()))
+ self.assertEqual(lt, lt2)
+
+ def test_decimal_type(self):
+ t1 = DataTypes.DECIMAL()
+ t2 = DataTypes.DECIMAL(10, 2)
+ self.assertTrue(t2 is not t1)
+ self.assertNotEqual(t1, t2)
+ t3 = DataTypes.DECIMAL(8)
+ self.assertNotEqual(t2, t3)
+
+ def test_datetype_equal_zero(self):
+ dt = DataTypes.DATE()
+ self.assertEqual(dt.from_sql_type(0), datetime.date(1970, 1, 1))
+
+ def test_timestamp_microsecond(self):
+ tst = DataTypes.TIMESTAMP()
+ self.assertEqual(tst.to_sql_type(datetime.datetime.max) % 1000000,
999999)
+
+ def test_empty_row(self):
+ row = Row()
+ self.assertEqual(len(row), 0)
+
+ def test_invalid_create_row(self):
+ row_class = Row("c1", "c2")
+ self.assertRaises(ValueError, lambda: row_class(1, 2, 3))
+
+
+class DataTypeVerificationTests(unittest.TestCase):
+
+ def test_verify_type_exception_msg(self):
+ self.assertRaises(
+ ValueError,
+ lambda: _create_type_verifier(
+ DataTypes.VARCHAR(nullable=False), name="test_name")(None))
+
+ schema = DataTypes.ROW(
+ [DataTypes.FIELD('a', DataTypes.ROW([DataTypes.FIELD('b',
DataTypes.INT())]))])
+ self.assertRaises(
+ TypeError,
+ lambda: _create_type_verifier(schema)([["data"]]))
+
+ def test_verify_type_ok_nullable(self):
+ obj = None
+ types = [DataTypes.INT(), DataTypes.FLOAT(), DataTypes.VARCHAR(),
DataTypes.ROW([])]
+ for data_type in types:
+ try:
+ _create_type_verifier(data_type)(obj)
+ except (TypeError, ValueError):
+ self.fail("verify_type(%s, %s, nullable=True)" % (obj,
data_type))
+
+ def test_verify_type_not_nullable(self):
+ import array
+ import datetime
+ import decimal
+
+ schema = DataTypes.ROW([
+ DataTypes.FIELD('s', DataTypes.VARCHAR(nullable=False)),
+ DataTypes.FIELD('i', DataTypes.INT(True))])
+
+ class MyObj:
+ def __init__(self, **kwargs):
+ for k, v in kwargs.items():
+ setattr(self, k, v)
+
+ # obj, data_type
+ success_spec = [
+ # String
+ ("", DataTypes.VARCHAR()),
+ (u"", DataTypes.VARCHAR()),
+
+ # UDT
+ (ExamplePoint(1.0, 2.0), ExamplePointUDT()),
+
+ # Boolean
+ (True, DataTypes.BOOLEAN()),
+
+ # TinyInt
+ (-(2 ** 7), DataTypes.TINYINT()),
+ (2 ** 7 - 1, DataTypes.TINYINT()),
+
+ # SmallInt
+ (-(2 ** 15), DataTypes.SMALLINT()),
+ (2 ** 15 - 1, DataTypes.SMALLINT()),
+
+ # Int
+ (-(2 ** 31), DataTypes.INT()),
+ (2 ** 31 - 1, DataTypes.INT()),
+
+ # BigInt
+ (2 ** 64, DataTypes.BIGINT()),
+
+ # Float & Double
+ (1.0, DataTypes.FLOAT()),
+ (1.0, DataTypes.DOUBLE()),
+
+ # Decimal
+ (decimal.Decimal("1.0"), DataTypes.DECIMAL()),
+
+ # Binary
+ (bytearray([1]), DataTypes.BINARY()),
+
+ # Date/Time/Timestamp
+ (datetime.date(2000, 1, 2), DataTypes.DATE()),
+ (datetime.datetime(2000, 1, 2, 3, 4), DataTypes.DATE()),
+ (datetime.time(1, 1, 2), DataTypes.TIME()),
+ (datetime.datetime(2000, 1, 2, 3, 4), DataTypes.TIMESTAMP()),
+
+ # Array
+ ([], DataTypes.ARRAY(DataTypes.INT())),
+ (["1", None], DataTypes.ARRAY(DataTypes.VARCHAR(nullable=True))),
+ ([1, 2], DataTypes.ARRAY(DataTypes.INT())),
+ ((1, 2), DataTypes.ARRAY(DataTypes.INT())),
+ (array.array('h', [1, 2]), DataTypes.ARRAY(DataTypes.INT())),
+
+ # Map
+ ({}, DataTypes.MAP(DataTypes.VARCHAR(), DataTypes.INT())),
+ ({"a": 1}, DataTypes.MAP(DataTypes.VARCHAR(), DataTypes.INT())),
+ ({"a": None}, DataTypes.MAP(DataTypes.VARCHAR(nullable=False),
DataTypes.INT(True))),
+
+ # Struct
+ ({"s": "a", "i": 1}, schema),
+ ({"s": "a", "i": None}, schema),
+ ({"s": "a"}, schema),
+ ({"s": "a", "f": 1.0}, schema),
+ (Row(s="a", i=1), schema),
+ (Row(s="a", i=None), schema),
+ (Row(s="a", i=1, f=1.0), schema),
+ (["a", 1], schema),
+ (["a", None], schema),
+ (("a", 1), schema),
+ (MyObj(s="a", i=1), schema),
+ (MyObj(s="a", i=None), schema),
+ (MyObj(s="a"), schema),
+ ]
+
+ # obj, data_type, exception class
+ failure_spec = [
+ # Char/VarChar (match anything but None)
+ (None, DataTypes.VARCHAR(), ValueError),
+ (None, DataTypes.CHAR(), ValueError),
+
+ # VarChar (length exceeds maximum length)
+ ("abc", DataTypes.VARCHAR(), ValueError),
+ # Char (length exceeds length)
+ ("abc", DataTypes.CHAR(), ValueError),
+
+ # UDT
+ (ExamplePoint(1.0, 2.0), PythonOnlyUDT(), ValueError),
+
+ # Boolean
+ (1, DataTypes.BOOLEAN(), TypeError),
+ ("True", DataTypes.BOOLEAN(), TypeError),
+ ([1], DataTypes.BOOLEAN(), TypeError),
+
+ # TinyInt
+ (-(2 ** 7) - 1, DataTypes.TINYINT(), ValueError),
+ (2 ** 7, DataTypes.TINYINT(), ValueError),
+ ("1", DataTypes.TINYINT(), TypeError),
+ (1.0, DataTypes.TINYINT(), TypeError),
+
+ # SmallInt
+ (-(2 ** 15) - 1, DataTypes.SMALLINT(), ValueError),
+ (2 ** 15, DataTypes.SMALLINT(), ValueError),
+
+ # Int
+ (-(2 ** 31) - 1, DataTypes.INT(), ValueError),
+ (2 ** 31, DataTypes.INT(), ValueError),
+
+ # Float & Double
+ (1, DataTypes.FLOAT(), TypeError),
+ (1, DataTypes.DOUBLE(), TypeError),
+
+ # Decimal
+ (1.0, DataTypes.DECIMAL(), TypeError),
+ (1, DataTypes.DECIMAL(), TypeError),
+ ("1.0", DataTypes.DECIMAL(), TypeError),
+
+ # Binary
+ (1, DataTypes.BINARY(), TypeError),
+ # VarBinary (length exceeds maximum length)
+ (bytearray([1, 2]), DataTypes.VARBINARY(), ValueError),
+ # Char (length exceeds length)
+ (bytearray([1, 2]), DataTypes.BINARY(), ValueError),
+
+ # Date/Time/Timestamp
+ ("2000-01-02", DataTypes.DATE(), TypeError),
+ ("10:01:02", DataTypes.TIME(), TypeError),
+ (946811040, DataTypes.TIMESTAMP(), TypeError),
+
+ # Array
+ (["1", None], DataTypes.ARRAY(DataTypes.VARCHAR(nullable=False)),
ValueError),
+ ([1, "2"], DataTypes.ARRAY(DataTypes.INT()), TypeError),
+
+ # Map
+ ({"a": 1}, DataTypes.MAP(DataTypes.INT(), DataTypes.INT()),
TypeError),
+ ({"a": "1"}, DataTypes.MAP(DataTypes.VARCHAR(), DataTypes.INT()),
TypeError),
+ ({"a": None}, DataTypes.MAP(DataTypes.VARCHAR(),
DataTypes.INT(False)), ValueError),
+
+ # Struct
+ ({"s": "a", "i": "1"}, schema, TypeError),
+ (Row(s="a"), schema, ValueError), # Row can't have missing field
+ (Row(s="a", i="1"), schema, TypeError),
+ (["a"], schema, ValueError),
+ (["a", "1"], schema, TypeError),
+ (MyObj(s="a", i="1"), schema, TypeError),
+ (MyObj(s=None, i="1"), schema, ValueError),
+ ]
+
+ # Check success cases
+ for obj, data_type in success_spec:
+ try:
+ _create_type_verifier(data_type.not_null())(obj)
+ except (TypeError, ValueError):
+ self.fail("verify_type(%s, %s, nullable=False)" % (obj,
data_type))
+
+ # Check failure cases
+ for obj, data_type, exp in failure_spec:
+ msg = "verify_type(%s, %s, nullable=False) == %s" % (obj,
data_type, exp)
+ with self.assertRaises(exp, msg=msg):
+ _create_type_verifier(data_type.not_null())(obj)
+
+
+if __name__ == "__main__":
+ try:
+ import xmlrunner
+
+ testRunner = xmlrunner.XMLTestRunner(output='target/test-reports')
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/flink-python/pyflink/table/tests/test_window.py
b/flink-python/pyflink/table/tests/test_window.py
index 32d5305..b37f895 100644
--- a/flink-python/pyflink/table/tests/test_window.py
+++ b/flink-python/pyflink/table/tests/test_window.py
@@ -31,7 +31,7 @@ class StreamTableWindowTests(PyFlinkStreamTableTestCase):
def test_over_window(self):
source_path = os.path.join(self.tempdir + '/streaming.csv')
field_names = ["a", "b", "c"]
- field_types = [DataTypes.LONG, DataTypes.INT, DataTypes.STRING]
+ field_types = [DataTypes.BIGINT(), DataTypes.INT(), DataTypes.STRING()]
data = [(1, 1, "Hello"), (2, 2, "Hello"), (3, 4, "Hello"), (4, 8,
"Hello")]
csv_source = self.prepare_csv_source(source_path, data, field_types,
field_names)
t_env = self.t_env
@@ -51,7 +51,7 @@ class BatchTableWindowTests(PyFlinkBatchTableTestCase):
def test_tumble_window(self):
source_path = os.path.join(self.tempdir + '/streaming.csv')
field_names = ["a", "b", "c"]
- field_types = [DataTypes.LONG, DataTypes.INT, DataTypes.STRING]
+ field_types = [DataTypes.BIGINT(), DataTypes.INT(), DataTypes.STRING()]
data = [(1, 1, "Hello"), (2, 2, "Hello"), (3, 4, "Hello"), (4, 8,
"Hello")]
csv_source = self.prepare_csv_source(source_path, data, field_types,
field_names)
t_env = self.t_env
@@ -68,7 +68,7 @@ class BatchTableWindowTests(PyFlinkBatchTableTestCase):
def test_slide_window(self):
source_path = os.path.join(self.tempdir + '/streaming.csv')
field_names = ["a", "b", "c"]
- field_types = [DataTypes.LONG, DataTypes.INT, DataTypes.STRING]
+ field_types = [DataTypes.BIGINT(), DataTypes.INT(), DataTypes.STRING()]
data = [(1000, 1, "Hello"), (2000, 2, "Hello"), (3000, 4, "Hello"),
(4000, 8, "Hello")]
csv_source = self.prepare_csv_source(source_path, data, field_types,
field_names)
t_env = self.t_env
@@ -85,7 +85,7 @@ class BatchTableWindowTests(PyFlinkBatchTableTestCase):
def test_session_window(self):
source_path = os.path.join(self.tempdir + '/streaming.csv')
field_names = ["a", "b", "c"]
- field_types = [DataTypes.LONG, DataTypes.INT, DataTypes.STRING]
+ field_types = [DataTypes.BIGINT(), DataTypes.INT(), DataTypes.STRING()]
data = [(1000, 1, "Hello"), (2000, 2, "Hello"), (4000, 4, "Hello"),
(5000, 8, "Hello")]
csv_source = self.prepare_csv_source(source_path, data, field_types,
field_names)
t_env = self.t_env
diff --git a/flink-python/pyflink/table/types.py
b/flink-python/pyflink/table/types.py
index f51e3f8..399e759 100644
--- a/flink-python/pyflink/table/types.py
+++ b/flink-python/pyflink/table/types.py
@@ -16,64 +16,145 @@
# limitations under the License.
################################################################################
+import calendar
+import ctypes
+import datetime
+import decimal
import sys
+import time
+from array import array
+from copy import copy
+from functools import reduce
+from threading import RLock
-if sys.version > '3':
- xrange = range
+from pyflink.util.utils import to_jarray
+from pyflink.java_gateway import get_gateway
-__all__ = ['DataTypes']
+if sys.version >= '3':
+ long = int
+ basestring = unicode = str
+
+__all__ = ['DataTypes', 'UserDefinedType', 'Row']
class DataType(object):
"""
- Base class for data types.
+ Describes the data type of a value in the table ecosystem. Instances of
this class can be used
+ to declare input and/or output types of operations.
+
+ :class:`DataType` has two responsibilities: declaring a logical type and
giving hints
+ about the physical representation of data to the optimizer. While the
logical type is mandatory,
+ hints are optional but useful at the edges to other APIs.
+
+ The logical type is independent of any physical representation and is
close to the "data type"
+ terminology of the SQL standard.
+
+ Physical hints are required at the edges of the table ecosystem. Hints
indicate the data format
+ that an implementation expects.
+
+ :param nullable: boolean, whether the type can be null (None) or not.
"""
- @classmethod
- def type_name(cls):
- return cls.__name__[:-4].lower()
+
+ def __init__(self, nullable=True):
+ self.nullable = nullable
+ self.conversion_cls = ''
+
+ def __repr__(self):
+ return '%s(%s)' % (self.__class__.__name__, str(self.nullable).lower())
+
+ def __str__(self, *args, **kwargs):
+ return self.__class__.type_name()
def __hash__(self):
- return hash(self.type_name())
+ return hash(str(self))
def __eq__(self, other):
- return self.type_name() == other.type_name()
+ return isinstance(other, self.__class__) and self.__dict__ ==
other.__dict__
def __ne__(self, other):
- return self.type_name() != other.type_name()
+ return not self.__eq__(other)
+ def not_null(self):
+ cp = copy(self)
+ cp.nullable = False
+ return cp
-class DataTypeSingleton(type):
- """
- Metaclass for DataType
- """
+ def nullable(self):
+ cp = copy(self)
+ cp.nullable = True
+ return cp
- _instances = {}
+ @classmethod
+ def type_name(cls):
+ return cls.__name__[:-4].upper()
+
+ def bridged_to(self, conversion_cls):
+ """
+ Adds a hint that data should be represented using the given class when
entering or leaving
+ the table ecosystem.
+
+ :param conversion_cls: the string representation of the conversion
class
+ """
+ self.conversion_cls = conversion_cls
+
+ def need_conversion(self):
+ """
+ Does this type need to conversion between Python object and internal
SQL object.
+
+ This is used to avoid the unnecessary conversion for
ArrayType/MultisetType/MapType/RowType.
+ """
+ return False
+
+ def to_sql_type(self, obj):
+ """
+ Converts a Python object into an internal SQL object.
+ """
+ return obj
- def __call__(cls):
- if cls not in cls._instances:
- cls._instances[cls] = super(DataTypeSingleton, cls).__call__()
- return cls._instances[cls]
+ def from_sql_type(self, obj):
+ """
+ Converts an internal SQL object into a native Python object.
+ """
+ return obj
class AtomicType(DataType):
"""
An internal type used to represent everything that is not
- null, arrays, structs, and maps.
+ arrays, rows, and maps.
+ """
+
+ def __init__(self, nullable=True):
+ super(AtomicType, self).__init__(nullable)
+
+
+class NullType(AtomicType):
+ """
+ Null type.
+
+ The data type representing None.
"""
+ def __init__(self):
+ super(NullType, self).__init__(True)
+
class NumericType(AtomicType):
"""
Numeric data types.
"""
+ def __init__(self, nullable=True):
+ super(NumericType, self).__init__(nullable)
+
class IntegralType(NumericType):
"""
Integral data types.
"""
- __metaclass__ = DataTypeSingleton
+ def __init__(self, nullable=True):
+ super(IntegralType, self).__init__(nullable)
class FractionalType(NumericType):
@@ -81,106 +162,1607 @@ class FractionalType(NumericType):
Fractional data types.
"""
+ def __init__(self, nullable=True):
+ super(FractionalType, self).__init__(nullable)
+
-class StringType(AtomicType):
+class CharType(AtomicType):
"""
- String data type. SQL VARCHAR
+ Char data type. SQL CHAR(n)
+
+ The serialized string representation is 'char(n)' where 'n' (default: 1)
is the number of
+ bytes. 'n' must have a value between 1 and 255 (both inclusive).
+
+ :param length: int, the string representation length.
+ :param nullable: boolean, whether the type can be null (None) or not.
"""
- __metaclass__ = DataTypeSingleton
+ def __init__(self, length=1, nullable=True):
+ super(CharType, self).__init__(nullable)
+ self.length = length
+ def __repr__(self):
+ return 'CharType(%d, %s)' % (self.length, str(self.nullable).lower())
-class BooleanType(AtomicType):
+
+class VarCharType(AtomicType):
"""
- Boolean data types. SQL BOOLEAN
+ Varchar data type. SQL VARCHAR(n)
+
+ The serialized string representation is 'varchar(n)' where 'n' (default:
1) is the number of
+ characters. 'n' must have a value between 1 and 0x7fffffff (both
inclusive).
+
+ :param length: int, the maximum string representation length.
+ :param nullable: boolean, whether the type can be null (None) or not.
"""
- __metaclass__ = DataTypeSingleton
+ def __init__(self, length=1, nullable=True):
+ super(VarCharType, self).__init__(nullable)
+ self.length = length
+
+ def __repr__(self):
+ return "VarCharType(%d, %s)" % (self.length,
str(self.nullable).lower())
-class ByteType(IntegralType):
+class BinaryType(AtomicType):
"""
- Byte data type. SQL TINYINT
+ Binary (byte array) data type. SQL BINARY(n)
+
+ The serialized string representation is 'binary(n)' where 'n' (default: 1)
is the number of
+ bytes. 'n' must have a value between 1 and 0x7fffffff (both inclusive).
+
+ :param length: int, the number of bytes.
+ :param nullable: boolean, whether the type can be null (None) or not.
"""
+ def __init__(self, length=1, nullable=True):
+ super(BinaryType, self).__init__(nullable)
+ self.length = length
+
+ def __repr__(self):
+ return "BinaryType(%d, %s)" % (self.length, str(self.nullable).lower())
+
-class CharType(IntegralType):
+class VarBinaryType(AtomicType):
"""
- Char data type. SQL CHAR
+ Binary (byte array) data type. SQL VARBINARY(n)
+
+ The serialized string representation is 'varbinary(n)' where 'n' (default:
1) is the
+ maximum number of bytes. 'n' must have a value between 1 and 0x7fffffff
(both inclusive).
+
+ :param length: int, the maximum number of bytes.
+ :param nullable: boolean, whether the type can be null (None) or not.
"""
+ def __init__(self, length=1, nullable=True):
+ super(VarBinaryType, self).__init__(nullable)
+ self.length = length
+
+ def __repr__(self):
+ return "VarBinaryType(%d, %s)" % (self.length,
str(self.nullable).lower())
-class ShortType(IntegralType):
+
+class BooleanType(AtomicType):
"""
- Short data types. SQL SMALLINT (16bits)
+ Boolean data types. SQL BOOLEAN
+
+ :param nullable: boolean, whether the field can be null (None) or not.
"""
+ def __init__(self, nullable=True):
+ super(BooleanType, self).__init__(nullable)
+
-class IntegerType(IntegralType):
+class TinyIntType(IntegralType):
+ """
+ Byte data type. SQL TINYINT (8bits)
+
+ :param nullable: boolean, whether the field can be null (None) or not.
+ """
+
+ def __init__(self, nullable=True):
+ super(TinyIntType, self).__init__(nullable)
+
+
+class SmallIntType(IntegralType):
+ """
+ Short data type. SQL SMALLINT (16bits)
+
+ :param nullable: boolean, whether the field can be null (None) or not.
+ """
+
+ def __init__(self, nullable=True):
+ super(SmallIntType, self).__init__(nullable)
+
+
+class IntType(IntegralType):
"""
Int data types. SQL INT (32bits)
+
+ :param nullable: boolean, whether the field can be null (None) or not.
"""
+ def __init__(self, nullable=True):
+ super(IntType, self).__init__(nullable)
-class LongType(IntegralType):
+
+class BigIntType(IntegralType):
"""
Long data types. SQL BIGINT (64bits)
+
+ :param nullable: boolean, whether the field can be null (None) or not.
"""
+ def __init__(self, nullable=True):
+ super(BigIntType, self).__init__(nullable)
+
class FloatType(FractionalType):
"""
Float data type. SQL FLOAT
+
+ :param nullable: boolean, whether the field can be null (None) or not.
"""
- __metaclass__ = DataTypeSingleton
+ def __init__(self, nullable=True):
+ super(FloatType, self).__init__(nullable)
class DoubleType(FractionalType):
"""
Double data type. SQL DOUBLE
+
+ :param nullable: boolean, whether the field can be null (None) or not.
+ """
+
+ def __init__(self, nullable=True):
+ super(DoubleType, self).__init__(nullable)
+
+
+class DecimalType(FractionalType):
+ """
+ Decimal (decimal.Decimal) data type.
+
+ The DecimalType must have fixed precision (the maximum total number of
digits)
+ and scale (the number of digits on the right of dot). For example, (5, 2)
can
+ support the value from [-999.99 to 999.99].
+
+ The precision can be up to 38, the scale must be less or equal to
precision.
+
+ When create a DecimalType, the default precision and scale is (10, 0).
When infer
+ schema from decimal.Decimal objects, it will be DecimalType(38, 18).
+
+ :param precision: the maximum total number of digits (default: 10)
+ :param scale: the number of digits on right side of dot. (default: 0)
+ :param nullable: boolean, whether the field can be null (None) or not.
"""
- __metaclass__ = DataTypeSingleton
+ def __init__(self, precision=10, scale=0, nullable=True):
+ super(DecimalType, self).__init__(nullable)
+ assert 1 <= precision <= 38
+ assert 0 <= scale <= precision
+ self.precision = precision
+ self.scale = scale
+ self.has_precision_info = True # this is public API
+
+ def __repr__(self):
+ return "DecimalType(%d, %d, %s)" % (self.precision, self.scale,
str(self.nullable).lower())
class DateType(AtomicType):
"""
Date data type. SQL DATE
+
+ :param nullable: boolean, whether the field can be null (None) or not.
"""
- __metaclass__ = DataTypeSingleton
+ def __init__(self, nullable=True):
+ super(DateType, self).__init__(nullable)
+
+ EPOCH_ORDINAL = datetime.datetime(1970, 1, 1).toordinal()
+
+ def need_conversion(self):
+ return True
+
+ def to_sql_type(self, d):
+ if d is not None:
+ return d.toordinal() - self.EPOCH_ORDINAL
+
+ def from_sql_type(self, v):
+ if v is not None:
+ return datetime.date.fromordinal(v + self.EPOCH_ORDINAL)
class TimeType(AtomicType):
"""
Time data type. SQL TIME
+
+ The precision must be greater than or equal to 0 and less than or equal to
9.
+
+ :param precision: int, the number of digits of fractional seconds
(default: 0)
+ :param nullable: boolean, whether the field can be null (None) or not.
"""
- __metaclass__ = DataTypeSingleton
+ EPOCH_ORDINAL = calendar.timegm(time.localtime(0)) * 10**6
+
+ def __init__(self, precision=0, nullable=True):
+ super(TimeType, self).__init__(nullable)
+ assert 0 <= precision <= 9
+ self.precision = precision
+
+ def __repr__(self):
+ return "TimeType(%s, %s)" % (self.precision,
str(self.nullable).lower())
+
+ def need_conversion(self):
+ return True
+
+ def to_sql_type(self, t):
+ if t.tzinfo is not None:
+ offset = t.utcoffset()
+ offset = offset if offset else datetime.timedelta()
+ offset_microseconds =\
+ (offset.days * 86400 + offset.seconds) * 10 ** 6 +
offset.microseconds
+ else:
+ offset_microseconds = self.EPOCH_ORDINAL
+ minutes = t.hour * 60 + t.minute
+ seconds = minutes * 60 + t.second
+ return seconds * 10**6 + t.microsecond - offset_microseconds
+
+ def from_sql_type(self, t):
+ if t is not None:
+ seconds, microseconds = divmod(t, 10**6)
+ minutes, seconds = divmod(seconds, 60)
+ hours, minutes = divmod(minutes, 60)
+ return datetime.time(hours, minutes, seconds, microseconds)
+
+
+class TimestampKind(object):
+ """
+ Timestamp kind for the time attribute metadata to timestamps.
+ """
+ REGULAR = 0
+ ROWTIME = 1
+ PROCTIME = 2
class TimestampType(AtomicType):
"""
Timestamp data type. SQL TIMESTAMP
+
+ The precision must be greater than or equal to 0 and less than or equal to
9.
+
+ :param kind, the time attribute metadata (default: TimestampKind.REGULAR)
+ :param precision: int, the number of digits of fractional seconds
(default: 6)
+ :param nullable: boolean, whether the field can be null (None) or not.
"""
- __metaclass__ = DataTypeSingleton
+ def __init__(self, kind=TimestampKind.REGULAR, precision=6, nullable=True):
+ super(TimestampType, self).__init__(nullable)
+ assert 0 <= kind <= 2
+ assert 0 <= precision <= 9
+ self.kind = kind
+ self.precision = precision
+ def __repr__(self):
+ return "TimestampType(%s, %s, %s)" % (
+ self.kind, self.precision, str(self.nullable).lower())
-class DataTypes(object):
+ def need_conversion(self):
+ return True
+
+ def to_sql_type(self, dt):
+ if dt is not None:
+ seconds = (calendar.timegm(dt.utctimetuple()) if dt.tzinfo
+ else time.mktime(dt.timetuple()))
+ return int(seconds) * 10**6 + dt.microsecond
+
+ def from_sql_type(self, ts):
+ if ts is not None:
+ # using int to avoid precision loss in float
+ return datetime.datetime.fromtimestamp(ts //
10**6).replace(microsecond=ts % 10**6)
+
+
+class ArrayType(DataType):
+ """
+ Array data type.
+
+ :param element_type: :class:`DataType` of each element in the array.
+ :param nullable: boolean, whether the field can be null (None) or not.
+ """
+
+ def __init__(self, element_type, nullable=True):
+ """
+ >>> ArrayType(VarCharType()) == ArrayType(VarCharType())
+ True
+ >>> ArrayType(VarCharType()) == ArrayType(BigIntType())
+ False
+ """
+ assert isinstance(element_type, DataType), \
+ "element_type %s should be an instance of %s" % (element_type,
DataType)
+ super(ArrayType, self).__init__(nullable)
+ self.element_type = element_type
+
+ def __repr__(self):
+ return "ArrayType(%s, %s)" % (repr(self.element_type),
str(self.nullable).lower())
+
+ def need_conversion(self):
+ return self.element_type.need_conversion()
+
+ def to_sql_type(self, obj):
+ if not self.need_conversion():
+ return obj
+ return obj and [self.element_type.to_sql_type(v) for v in obj]
+
+ def from_sql_type(self, obj):
+ if not self.need_conversion():
+ return obj
+ return obj and [self.element_type.to_sql_type(v) for v in obj]
+
+
+class MapType(DataType):
+ """
+ Map data type.
+
+ :param key_type: :class:`DataType` of the keys in the map.
+ :param value_type: :class:`DataType` of the values in the map.
+ :param nullable: boolean, whether the field can be null (None) or not.
+
+ Keys in a map data type are not allowed to be null (None).
+ """
+
+ def __init__(self, key_type, value_type, nullable=True):
+ """
+ >>> (MapType(VarCharType(nullable=False), IntType())
+ ... == MapType(VarCharType(nullable=False), IntType()))
+ True
+ >>> (MapType(VarCharType(nullable=False), IntType())
+ ... == MapType(VarCharType(nullable=False), FloatType()))
+ False
+ """
+ assert isinstance(key_type, DataType), \
+ "key_type %s should be an instance of %s" % (key_type, DataType)
+ assert isinstance(value_type, DataType), \
+ "value_type %s should be an instance of %s" % (value_type,
DataType)
+ super(MapType, self).__init__(nullable)
+ self.key_type = key_type
+ self.value_type = value_type
+
+ def __repr__(self):
+ return "MapType(%s, %s, %s)" % (
+ repr(self.key_type), repr(self.value_type),
str(self.nullable).lower())
+
+ def need_conversion(self):
+ return self.key_type.need_conversion() or
self.value_type.need_conversion()
+
+ def to_sql_type(self, obj):
+ if not self.need_conversion():
+ return obj
+ return obj and dict((self.key_type.to_sql_type(k),
self.value_type.to_sql_type(v))
+ for k, v in obj.items())
+
+ def from_sql_type(self, obj):
+ if not self.need_conversion():
+ return obj
+ return obj and dict((self.key_type.from_sql_type(k),
self.value_type.from_sql_type(v))
+ for k, v in obj.items())
+
+
+class MultisetType(DataType):
+ """
+ MultisetType data type.
+
+ :param element_type: :class:`DataType` of each element in the multiset.
+ :param nullable: boolean, whether the field can be null (None) or not.
"""
- Utils for types
- """
- STRING = StringType()
- BOOLEAN = BooleanType()
- BYTE = ByteType()
- CHAR = CharType()
- SHORT = ShortType()
- INT = IntegerType()
- LONG = LongType()
- FLOAT = FloatType()
- DOUBLE = DoubleType()
- DATE = DateType()
- TIME = TimeType()
- TIMESTAMP = TimestampType()
+
+ def __init__(self, element_type, nullable=True):
+ """
+ >>> MultisetType(VarCharType()) == MultisetType(VarCharType())
+ True
+ >>> MultisetType(VarCharType()) == MultisetType(BigIntType())
+ False
+ """
+ assert isinstance(element_type, DataType), \
+ "element_type %s should be an instance of %s" % (element_type,
DataType)
+ super(MultisetType, self).__init__(nullable)
+ self.element_type = element_type
+
+ def __repr__(self):
+ return "MultisetType(%s, %s)" % (repr(self.element_type),
str(self.nullable).lower())
+
+ def need_conversion(self):
+ return self.element_type.need_conversion()
+
+ def to_sql_type(self, obj):
+ if not self.need_conversion():
+ return obj
+ return obj and [self.element_type.to_sql_type(v) for v in obj]
+
+ def from_sql_type(self, obj):
+ if not self.need_conversion():
+ return obj
+ return obj and [self.element_type.to_sql_type(v) for v in obj]
+
+
+class RowField(object):
+ """
+ A field in :class:`RowType`.
+
+ :param name: string, name of the field.
+ :param data_type: :class:`DataType` of the field.
+ :param description: string, description of the field.
+ """
+
+ def __init__(self, name, data_type, description=None):
+ """
+ >>> (RowField("f1", VarCharType())
+ ... == RowField("f1", VarCharType()))
+ True
+ >>> (RowField("f1", VarCharType())
+ ... == RowField("f2", VarCharType()))
+ False
+ """
+ assert isinstance(data_type, DataType), \
+ "data_type %s should be an instance of %s" % (data_type, DataType)
+ assert isinstance(name, basestring), "field name %s should be string"
% name
+ if not isinstance(name, str):
+ name = name.encode('utf-8')
+ if description is not None:
+ assert isinstance(description, basestring), \
+ "description %s should be string" % description
+ if not isinstance(description, str):
+ description = description.encode('utf-8')
+ self.name = name
+ self.data_type = data_type
+ self.description = '...' if description is None else description
+
+ def __repr__(self):
+ return "RowField(%s, %s, %s)" % (self.name, repr(self.data_type),
self.description)
+
+ def __str__(self, *args, **kwargs):
+ return "RowField(%s, %s)" % (self.name, self.data_type)
+
+ def __eq__(self, other):
+ return isinstance(other, self.__class__) and self.__dict__ ==
other.__dict__
+
+ def need_conversion(self):
+ return self.data_type.need_conversion()
+
+ def to_sql_type(self, obj):
+ return self.data_type.to_sql_type(obj)
+
+ def from_sql_type(self, obj):
+ return self.data_type.from_sql_type(obj)
+
+
+class RowType(DataType):
+ """
+ Row type, consisting of a list of :class:`RowField`.
+
+ This is the data type representing a :class:`Row`.
+
+ Iterating a :class:`RowType` will iterate its :class:`RowField`\\s.
+ A contained :class:`RowField` can be accessed by name or position.
+
+ >>> row1 = RowType([RowField("f1", VarCharType())])
+ >>> row1["f1"]
+ RowField(f1, VarCharType(1))
+ >>> row1[0]
+ RowField(f1, VarCharType(1))
+ """
+
+ def __init__(self, fields=None, nullable=True):
+ """
+ >>> row1 = RowType([RowField("f1", VarCharType())])
+ >>> row2 = RowType([RowField("f1", VarCharType())])
+ >>> row1 == row2
+ True
+ >>> row1 = RowType([RowField("f1", VarCharType())])
+ >>> row2 = RowType([RowField("f1", VarCharType()),
+ ... RowField("f2", IntType())])
+ >>> row1 == row2
+ False
+ """
+ super(RowType, self).__init__(nullable)
+ if not fields:
+ self.fields = []
+ self.names = []
+ else:
+ self.fields = fields
+ self.names = [f.name for f in fields]
+ assert all(isinstance(f, RowField) for f in fields), \
+ "fields should be a list of RowField"
+ # Precalculated list of fields that need conversion with
+ # from_sql_type/to_sql_type functions
+ self._need_conversion = [f.need_conversion() for f in self]
+ self._need_serialize_any_field = any(self._need_conversion)
+
+ def add(self, field, data_type=None):
+ """
+ Constructs a RowType by adding new elements to it to define the
schema. The method accepts
+ either:
+
+ a) A single parameter which is a RowField object.
+ b) 2 parameters as (name, data_type). The data_type parameter may
be either a String
+ or a DataType object.
+
+ >>> row1 = RowType().add("f1", VarCharType()).add("f2", VarCharType())
+ >>> row2 = RowType([RowField("f1", VarCharType()), RowField("f2",
VarCharType())])
+ >>> row1 == row2
+ True
+ >>> row1 = RowType().add(RowField("f1", VarCharType()))
+ >>> row2 = RowType([RowField("f1", VarCharType())])
+ >>> row1 == row2
+ True
+ >>> row2 = RowType([RowField("f1", VarCharType())])
+ >>> row1 == row2
+ True
+
+ :param field: Either the name of the field or a RowField object
+ :param data_type: If present, the DataType of the RowField to create
+ :return: a new updated RowType
+ """
+ if isinstance(field, RowField):
+ self.fields.append(field)
+ self.names.append(field.name)
+ else:
+ if isinstance(field, str) and data_type is None:
+ raise ValueError("Must specify DataType if passing name of
row_field to create.")
+
+ self.fields.append(RowField(field, data_type))
+ self.names.append(field)
+ # Precalculated list of fields that need conversion with
+ # from_sql_type/to_sql_type functions
+ self._need_conversion = [f.need_conversion() for f in self]
+ self._need_serialize_any_field = any(self._need_conversion)
+ return self
+
+ def __iter__(self):
+ """
+ Iterate the fields.
+ """
+ return iter(self.fields)
+
+ def __len__(self):
+ """
+ Returns the number of fields.
+ """
+ return len(self.fields)
+
+ def __getitem__(self, key):
+ """
+ Accesses fields by name or slice.
+ """
+ if isinstance(key, str):
+ for field in self:
+ if field.name == key:
+ return field
+ raise KeyError('No RowField named {0}'.format(key))
+ elif isinstance(key, int):
+ try:
+ return self.fields[key]
+ except IndexError:
+ raise IndexError('RowType index out of range')
+ elif isinstance(key, slice):
+ return RowType(self.fields[key])
+ else:
+ raise TypeError('RowType keys should be strings, integers or
slices')
+
+ def __repr__(self):
+ return "RowType(%s)" % ",".join(repr(field) for field in self)
+
+ def field_names(self):
+ """
+ Returns all field names in a list.
+
+ >>> row = RowType([RowField("f1", VarCharType())])
+ >>> row.field_names()
+ ['f1']
+ """
+ return list(self.names)
+
+ def need_conversion(self):
+ # We need convert Row()/namedtuple into tuple()
+ return True
+
+ def to_sql_type(self, obj):
+ if obj is None:
+ return
+
+ if self._need_serialize_any_field:
+ # Only calling to_sql_type function for fields that need conversion
+ if isinstance(obj, dict):
+ return tuple(f.to_sql_type(obj.get(n)) if c else obj.get(n)
+ for n, f, c in zip(self.names, self.fields,
self._need_conversion))
+ elif isinstance(obj, (tuple, list)):
+ return tuple(f.to_sql_type(v) if c else v
+ for f, v, c in zip(self.fields, obj,
self._need_conversion))
+ elif hasattr(obj, "__dict__"):
+ d = obj.__dict__
+ return tuple(f.to_sql_type(d.get(n)) if c else d.get(n)
+ for n, f, c in zip(self.names, self.fields,
self._need_conversion))
+ else:
+ raise ValueError("Unexpected tuple %r with RowType" % obj)
+ else:
+ if isinstance(obj, dict):
+ return tuple(obj.get(n) for n in self.names)
+ elif isinstance(obj, Row) and getattr(obj, "_from_dict", False):
+ return tuple(obj[n] for n in self.names)
+ elif isinstance(obj, (list, tuple)):
+ return tuple(obj)
+ elif hasattr(obj, "__dict__"):
+ d = obj.__dict__
+ return tuple(d.get(n) for n in self.names)
+ else:
+ raise ValueError("Unexpected tuple %r with RowType" % obj)
+
+ def from_sql_type(self, obj):
+ if obj is None:
+ return
+ if isinstance(obj, Row):
+ # it's already converted by pickler
+ return obj
+ if self._need_serialize_any_field:
+ # Only calling from_sql_type function for fields that need
conversion
+ values = [f.from_sql_type(v) if c else v
+ for f, v, c in zip(self.fields, obj,
self._need_conversion)]
+ else:
+ values = obj
+ return _create_row(self.names, values)
+
+
+class UserDefinedType(DataType):
+ """
+ User-defined type (UDT).
+
+ .. note:: WARN: Flink Internal Use Only
+ """
+
+ def __eq__(self, other):
+ return type(self) == type(other)
+
+ @classmethod
+ def type_name(cls):
+ return cls.__name__.lower()
+
+ @classmethod
+ def sql_type(cls):
+ """
+ Underlying SQL storage type for this UDT.
+ """
+ raise NotImplementedError("UDT must implement sql_type().")
+
+ @classmethod
+ def module(cls):
+ """
+ The Python module of the UDT.
+ """
+ raise NotImplementedError("UDT must implement module().")
+
+ @classmethod
+ def java_udt(cls):
+ """
+ The class name of the paired Java UDT (could be '', if there
+ is no corresponding one).
+ """
+ return ''
+
+ def need_conversion(self):
+ return True
+
+ @classmethod
+ def _cached_sql_type(cls):
+ """
+ Caches the sql_type() into class, because it's heavy used in
`to_sql_type`.
+ """
+ if not hasattr(cls, "__cached_sql_type"):
+ cls.__cached_sql_type = cls.sql_type()
+ return cls.__cached_sql_type
+
+ def to_sql_type(self, obj):
+ if obj is not None:
+ return self._cached_sql_type().to_sql_type(self.serialize(obj))
+
+ def from_sql_type(self, obj):
+ v = self._cached_sql_type().from_sql_type(obj)
+ if v is not None:
+ return self.deserialize(v)
+
+ def serialize(self, obj):
+ """
+ Converts the a user-type object into a SQL datum.
+ """
+ raise NotImplementedError("UDT must implement serialize().")
+
+ def deserialize(self, datum):
+ """
+ Converts a SQL datum into a user-type object.
+ """
+ raise NotImplementedError("UDT must implement deserialize().")
+
+
+# Mapping Python types to Flink SQL types
+_type_mappings = {
+ bool: BooleanType(),
+ int: BigIntType(),
+ float: DoubleType(),
+ str: VarCharType(0x7fffffff),
+ bytearray: VarBinaryType(0x7fffffff),
+ decimal.Decimal: DecimalType(38, 18),
+ datetime.date: DateType(),
+ datetime.datetime: TimestampType(),
+ datetime.time: TimeType(),
+}
+
+if sys.version < "3":
+ _type_mappings.update({
+ unicode: VarCharType(0x7fffffff),
+ long: BigIntType(),
+ })
+
+# Mapping Python array types to Flink SQL types
+# We should be careful here. The size of these types in python depends on C
+# implementation. We need to make sure that this conversion does not lose any
+# precision. Also, JVM only support signed types, when converting unsigned
types,
+# keep in mind that it requires 1 more bit when stored as singed types.
+#
+# Reference for C integer size, see:
+# ISO/IEC 9899:201x specification, chapter 5.2.4.2.1 Sizes of integer types
<limits.h>.
+# Reference for python array typecode, see:
+# https://docs.python.org/2/library/array.html
+# https://docs.python.org/3.6/library/array.html
+# Reference for JVM's supported integral types:
+# http://docs.oracle.com/javase/specs/jvms/se8/html/jvms-2.html#jvms-2.3.1
+
+_array_signed_int_typecode_ctype_mappings = {
+ 'b': ctypes.c_byte,
+ 'h': ctypes.c_short,
+ 'i': ctypes.c_int,
+ 'l': ctypes.c_long,
+}
+
+_array_unsigned_int_typecode_ctype_mappings = {
+ 'B': ctypes.c_ubyte,
+ 'H': ctypes.c_ushort,
+ 'I': ctypes.c_uint,
+ 'L': ctypes.c_ulong
+}
+
+
+def _int_size_to_type(size):
+ """
+ Returns the data type from the size of integers.
+ """
+ if size <= 8:
+ return TinyIntType()
+ if size <= 16:
+ return SmallIntType()
+ if size <= 32:
+ return IntType()
+ if size <= 64:
+ return BigIntType()
+
+
+# The list of all supported array typecodes is stored here
+_array_type_mappings = {
+ # Warning: Actual properties for float and double in C is not specified in
C.
+ # On almost every system supported by both python and JVM, they are IEEE
754
+ # single-precision binary floating-point format and IEEE 754
double-precision
+ # binary floating-point format. And we do assume the same thing here for
now.
+ 'f': FloatType(),
+ 'd': DoubleType()
+}
+
+# compute array typecode mappings for signed integer types
+for _typecode in _array_signed_int_typecode_ctype_mappings.keys():
+ size = ctypes.sizeof(_array_signed_int_typecode_ctype_mappings[_typecode])
* 8
+ dt = _int_size_to_type(size)
+ if dt is not None:
+ _array_type_mappings[_typecode] = dt
+
+# compute array typecode mappings for unsigned integer types
+for _typecode in _array_unsigned_int_typecode_ctype_mappings.keys():
+ # JVM does not have unsigned types, so use signed types that is at least 1
+ # bit larger to store
+ size =
ctypes.sizeof(_array_unsigned_int_typecode_ctype_mappings[_typecode]) * 8 + 1
+ dt = _int_size_to_type(size)
+ if dt is not None:
+ _array_type_mappings[_typecode] = dt
+
+# Type code 'u' in Python's array is deprecated since version 3.3, and will be
+# removed in version 4.0. See: https://docs.python.org/3/library/array.html
+if sys.version_info[0] < 4:
+ # it can be 16 bits or 32 bits depending on the platform
+ _array_type_mappings['u'] = CharType(ctypes.sizeof(ctypes.c_wchar))
+
+# Type code 'c' are only available at python 2
+if sys.version_info[0] < 3:
+ _array_type_mappings['c'] = CharType(ctypes.sizeof(ctypes.c_char))
+
+
+def _infer_type(obj):
+ """
+ Infers the data type from obj.
+ """
+ if obj is None:
+ return NullType()
+
+ if hasattr(obj, '__UDT__'):
+ return obj.__UDT__
+
+ data_type = _type_mappings.get(type(obj))
+ if data_type is not None:
+ return data_type
+
+ if isinstance(obj, dict):
+ for key, value in obj.items():
+ if key is not None and value is not None:
+ return MapType(_infer_type(key).not_null(), _infer_type(value))
+ else:
+ return MapType(NullType(), NullType())
+ elif isinstance(obj, list):
+ for v in obj:
+ if v is not None:
+ return ArrayType(_infer_type(obj[0]))
+ else:
+ return ArrayType(NullType())
+ elif isinstance(obj, array):
+ if obj.typecode in _array_type_mappings:
+ return ArrayType(_array_type_mappings[obj.typecode].not_null())
+ else:
+ raise TypeError("not supported type: array(%s)" % obj.typecode)
+ else:
+ try:
+ return _infer_schema(obj)
+ except TypeError:
+ raise TypeError("not supported type: %s" % type(obj))
+
+
+def _infer_schema(row, names=None):
+ """
+ Infers the schema from dict/row/namedtuple/object.
+ """
+ if isinstance(row, dict): # dict
+ items = sorted(row.items())
+
+ elif isinstance(row, (tuple, list)):
+ if hasattr(row, "_fields"): # namedtuple and Row
+ items = zip(row._fields, tuple(row))
+ else:
+ if names is None:
+ names = ['_%d' % i for i in range(1, len(row) + 1)]
+ elif len(names) < len(row):
+ names.extend('_%d' % i for i in range(len(names) + 1, len(row)
+ 1))
+ items = zip(names, row)
+
+ elif hasattr(row, "__dict__"): # object
+ items = sorted(row.__dict__.items())
+
+ else:
+ raise TypeError("Can not infer schema for type: %s" % type(row))
+
+ fields = [RowField(k, _infer_type(v)) for k, v in items]
+ return RowType(fields)
+
+
+def _has_nulltype(dt):
+ """
+ Returns whether there is NullType in `dt` or not.
+ """
+ if isinstance(dt, RowType):
+ return any(_has_nulltype(f.data_type) for f in dt.fields)
+ elif isinstance(dt, ArrayType) or isinstance(dt, MultisetType):
+ return _has_nulltype(dt.element_type)
+ elif isinstance(dt, MapType):
+ return _has_nulltype(dt.key_type) or _has_nulltype(dt.value_type)
+ else:
+ return isinstance(dt, NullType)
+
+
+def _merge_type(a, b, name=None):
+ if name is None:
+ def new_msg(msg):
+ return msg
+
+ def new_name(n):
+ return "field %s" % n
+ else:
+ def new_msg(msg):
+ return "%s: %s" % (name, msg)
+
+ def new_name(n):
+ return "field %s in %s" % (n, name)
+
+ if isinstance(a, NullType):
+ return b
+ elif isinstance(b, NullType):
+ return a
+ elif type(a) is not type(b):
+ # TODO: type cast (such as int -> long)
+ raise TypeError(new_msg("Can not merge type %s and %s" % (type(a),
type(b))))
+
+ # same type
+ if isinstance(a, RowType):
+ nfs = dict((f.name, f.data_type) for f in b.fields)
+ fields = [RowField(f.name, _merge_type(f.data_type, nfs.get(f.name,
None),
+ name=new_name(f.name)))
+ for f in a.fields]
+ names = set([f.name for f in fields])
+ for n in nfs:
+ if n not in names:
+ fields.append(RowField(n, nfs[n]))
+ return RowType(fields)
+
+ elif isinstance(a, ArrayType):
+ return ArrayType(_merge_type(a.element_type, b.element_type,
+ name='element in array %s' % name))
+
+ elif isinstance(a, MultisetType):
+ return MultisetType(_merge_type(a.element_type, b.element_type,
+ name='element in multiset %s' % name))
+
+ elif isinstance(a, MapType):
+ return MapType(_merge_type(a.key_type, b.key_type, name='key of map
%s' % name),
+ _merge_type(a.value_type, b.value_type, name='value of
map %s' % name))
+ else:
+ return a
+
+
+def _infer_schema_from_data(elements, names=None):
+ """
+ Infers schema from list of Row or tuple.
+
+ :param elements: list of Row or tuple
+ :param names: list of column names
+ :return: :class:`RowType`
+ """
+ if not elements:
+ raise ValueError("can not infer schema from empty data set")
+ schema = reduce(_merge_type, (_infer_schema(row, names) for row in
elements))
+ if _has_nulltype(schema):
+ raise ValueError("Some column types cannot be determined after
inferring")
+ return schema
+
+
+def _need_converter(data_type):
+ if isinstance(data_type, RowType):
+ return True
+ elif isinstance(data_type, ArrayType) or isinstance(data_type,
MultisetType):
+ return _need_converter(data_type.element_type)
+ elif isinstance(data_type, MapType):
+ return _need_converter(data_type.key_type) or
_need_converter(data_type.value_type)
+ elif isinstance(data_type, NullType):
+ return True
+ else:
+ return False
+
+
+def _create_converter(data_type):
+ """
+ Creates a converter to drop the names of fields in obj.
+ """
+ if not _need_converter(data_type):
+ return lambda x: x
+
+ if isinstance(data_type, ArrayType) or isinstance(data_type, MultisetType):
+ conv = _create_converter(data_type.element_type)
+ return lambda row: [conv(v) for v in row]
+
+ elif isinstance(data_type, MapType):
+ kconv = _create_converter(data_type.key_type)
+ vconv = _create_converter(data_type.value_type)
+ return lambda row: dict((kconv(k), vconv(v)) for k, v in row.items())
+
+ elif isinstance(data_type, NullType):
+ return lambda x: None
+
+ elif not isinstance(data_type, RowType):
+ return lambda x: x
+
+ # dataType must be RowType
+ names = [f.name for f in data_type.fields]
+ converters = [_create_converter(f.data_type) for f in data_type.fields]
+ convert_fields = any(_need_converter(f.data_type) for f in
data_type.fields)
+
+ def convert_row(obj):
+ if obj is None:
+ return
+
+ if isinstance(obj, (tuple, list)):
+ if convert_fields:
+ return tuple(conv(v) for v, conv in zip(obj, converters))
+ else:
+ return tuple(obj)
+
+ if isinstance(obj, dict):
+ d = obj
+ elif hasattr(obj, "__dict__"): # object
+ d = obj.__dict__
+ else:
+ raise TypeError("Unexpected obj type: %s" % type(obj))
+
+ if convert_fields:
+ return tuple([conv(d.get(name)) for name, conv in zip(names,
converters)])
+ else:
+ return tuple([d.get(name) for name in names])
+
+ return convert_row
+
+
+_python_java_types_mapping = None
+_python_java_types_mapping_lock = RLock()
+_primitive_array_element_types = {BooleanType, TinyIntType, SmallIntType,
IntType, BigIntType,
+ FloatType, DoubleType}
+
+
+def _to_java_type(data_type):
+ """
+ Converts Python type to Java type.
+ """
+
+ global _python_java_types_mapping
+ global _python_java_types_mapping_lock
+
+ gateway = get_gateway()
+ Types = gateway.jvm.org.apache.flink.table.api.Types
+
+ if _python_java_types_mapping is None:
+ with _python_java_types_mapping_lock:
+ _python_java_types_mapping = {
+ BooleanType: Types.BOOLEAN(),
+ TinyIntType: Types.BYTE(),
+ SmallIntType: Types.SHORT(),
+ IntType: Types.INT(),
+ BigIntType: Types.LONG(),
+ FloatType: Types.FLOAT(),
+ DoubleType: Types.DOUBLE(),
+ DecimalType: Types.DECIMAL(),
+ DateType: Types.SQL_DATE(),
+ TimeType: Types.SQL_TIME(),
+ TimestampType: Types.SQL_TIMESTAMP(),
+ CharType: Types.STRING(),
+ VarCharType: Types.STRING(),
+ BinaryType: Types.PRIMITIVE_ARRAY(Types.BYTE()),
+ VarBinaryType: Types.PRIMITIVE_ARRAY(Types.BYTE())
+ }
+
+ # NullType
+ if isinstance(data_type, NullType):
+ # null type is still not supported in Java
+ raise NotImplementedError
+
+ # basic types
+ elif type(data_type) in _python_java_types_mapping:
+ return _python_java_types_mapping[type(data_type)]
+
+ # ArrayType
+ elif isinstance(data_type, ArrayType):
+ if type(data_type.element_type) in _primitive_array_element_types:
+ return Types.PRIMITIVE_ARRAY(_to_java_type(data_type.element_type))
+ elif isinstance(data_type.element_type, VarCharType) or isinstance(
+ data_type.element_type, CharType):
+ return gateway.jvm.org.apache.flink.api.common.typeinfo.\
+ BasicArrayTypeInfo.STRING_ARRAY_TYPE_INFO
+ else:
+ return Types.OBJECT_ARRAY(_to_java_type(data_type.element_type))
+
+ # MapType
+ elif isinstance(data_type, MapType):
+ return Types.MAP(_to_java_type(data_type.key_type),
_to_java_type(data_type.value_type))
+
+ # MultisetType
+ elif isinstance(data_type, MultisetType):
+ return Types.MULTISET(_to_java_type(data_type.element_type))
+
+ # RowType
+ elif isinstance(data_type, RowType):
+ return Types.ROW(
+ to_jarray(gateway.jvm.String, data_type.field_names()),
+ to_jarray(gateway.jvm.TypeInformation,
+ [_to_java_type(f.data_type) for f in data_type.fields]))
+
+ # UserDefinedType
+ elif isinstance(data_type, UserDefinedType):
+ if data_type.java_udt():
+ return
gateway.jvm.org.apache.flink.util.InstantiationUtil.instantiate(
+ gateway.jvm.Class.forName(data_type.java_udt()))
+ else:
+ return _to_java_type(data_type.sql_type())
+
+ else:
+ raise TypeError("Not supported type: %s" % data_type)
+
+
+def _create_row(fields, values):
+ row = Row(*values)
+ row._fields = fields
+ return row
+
+
+class Row(tuple):
+ """
+ A row in Table.
+ The fields in it can be accessed:
+
+ * like attributes (``row.key``)
+ * like dictionary values (``row[key]``)
+
+ ``key in row`` will search through row keys.
+
+ Row can be used to create a row object by using named arguments,
+ the fields will be sorted by names. It is not allowed to omit
+ a named argument to represent the value is None or missing. This should be
+ explicitly set to None in this case.
+
+ >>> row = Row(name="Alice", age=11)
+ >>> row
+ Row(age=11, name='Alice')
+ >>> row['name'], row['age']
+ ('Alice', 11)
+ >>> row.name, row.age
+ ('Alice', 11)
+ >>> 'name' in row
+ True
+ >>> 'wrong_key' in row
+ False
+
+ Row can also be used to create another Row like class, then it
+ could be used to create Row objects, such as
+
+ >>> Person = Row("name", "age")
+ >>> Person
+ <Row(name, age)>
+ >>> 'name' in Person
+ True
+ >>> 'wrong_key' in Person
+ False
+ >>> Person("Alice", 11)
+ Row(name='Alice', age=11)
+ """
+
+ def __new__(cls, *args, **kwargs):
+ if args and kwargs:
+ raise ValueError("Can not use both args "
+ "and kwargs to create Row")
+ if kwargs:
+ # create row objects
+ names = sorted(kwargs.keys())
+ row = tuple.__new__(cls, [kwargs[n] for n in names])
+ row._fields = names
+ row._from_dict = True
+ return row
+
+ else:
+ # create row class or objects
+ return tuple.__new__(cls, args)
+
+ def as_dict(self, recursive=False):
+ """
+ Returns as a dict.
+
+ :param recursive: turns the nested Row as dict (default: False).
+
+ >>> Row(name="Alice", age=11).as_dict() == {'name': 'Alice', 'age': 11}
+ True
+ >>> row = Row(key=1, value=Row(name='a', age=2))
+ >>> row.as_dict() == {'key': 1, 'value': Row(age=2, name='a')}
+ True
+ >>> row.as_dict(True) == {'key': 1, 'value': {'name': 'a', 'age': 2}}
+ True
+ """
+ if not hasattr(self, "_fields"):
+ raise TypeError("Cannot convert a Row class into dict")
+
+ if recursive:
+ def conv(obj):
+ if isinstance(obj, Row):
+ return obj.as_dict(True)
+ elif isinstance(obj, list):
+ return [conv(o) for o in obj]
+ elif isinstance(obj, dict):
+ return dict((k, conv(v)) for k, v in obj.items())
+ else:
+ return obj
+
+ return dict(zip(self._fields, (conv(o) for o in self)))
+ else:
+ return dict(zip(self._fields, self))
+
+ def __contains__(self, item):
+ if hasattr(self, "_fields"):
+ return item in self._fields
+ else:
+ return super(Row, self).__contains__(item)
+
+ # let object acts like class
+ def __call__(self, *args):
+ """
+ Creates new Row object
+ """
+ if len(args) > len(self):
+ raise ValueError("Can not create Row with fields %s, expected %d
values "
+ "but got %s" % (self, len(self), args))
+ return _create_row(self, args)
+
+ def __getitem__(self, item):
+ if isinstance(item, (int, slice)):
+ return super(Row, self).__getitem__(item)
+ try:
+ # it will be slow when it has many fields,
+ # but this will not be used in normal cases
+ idx = self._fields.index(item)
+ return super(Row, self).__getitem__(idx)
+ except IndexError:
+ raise KeyError(item)
+ except ValueError:
+ raise ValueError(item)
+
+ def __getattr__(self, item):
+ if item.startswith("_"):
+ raise AttributeError(item)
+ try:
+ # it will be slow when it has many fields,
+ # but this will not be used in normal cases
+ idx = self._fields.index(item)
+ return self[idx]
+ except IndexError:
+ raise AttributeError(item)
+ except ValueError:
+ raise AttributeError(item)
+
+ def __setattr__(self, key, value):
+ if key != '_fields' and key != "_from_dict":
+ raise Exception("Row is read-only")
+ self.__dict__[key] = value
+
+ def __reduce__(self):
+ """
+ Returns a tuple so Python knows how to pickle Row.
+ """
+ if hasattr(self, "_fields"):
+ return _create_row, (self._fields, tuple(self))
+ else:
+ return tuple.__reduce__(self)
+
+ def __repr__(self):
+ """
+ Printable representation of Row used in Python REPL.
+ """
+ if hasattr(self, "_fields"):
+ return "Row(%s)" % ", ".join("%s=%r" % (k, v)
+ for k, v in zip(self._fields,
tuple(self)))
+ else:
+ return "<Row(%s)>" % ", ".join(self)
+
+
+_acceptable_types = {
+ BooleanType: (bool,),
+ TinyIntType: (int, long),
+ SmallIntType: (int, long),
+ IntType: (int, long),
+ BigIntType: (int, long),
+ FloatType: (float,),
+ DoubleType: (float,),
+ DecimalType: (decimal.Decimal,),
+ CharType: (str, unicode),
+ VarCharType: (str, unicode),
+ BinaryType: (bytearray,),
+ VarBinaryType: (bytearray,),
+ DateType: (datetime.date, datetime.datetime),
+ TimeType: (datetime.time,),
+ TimestampType: (datetime.datetime,),
+ ArrayType: (list, tuple, array),
+ MapType: (dict,),
+ RowType: (tuple, list, dict),
+}
+
+
+def _create_type_verifier(data_type, name=None):
+ """
+ Creates a verifier that checks the type of obj against data_type and
raises a TypeError if they
+ do not match.
+
+ This verifier also checks the value of obj against data_type and raises a
ValueError if it's
+ not within the allowed range, e.g. using 128 as TinyIntType will overflow.
Note that, Python
+ float is not checked, so it will become infinity when cast to Java float
if it overflows.
+
+ >>> _create_type_verifier(RowType([]))(None)
+ >>> _create_type_verifier(VarCharType())("")
+ >>> _create_type_verifier(BigIntType())(0)
+ >>> _create_type_verifier(ArrayType(SmallIntType()))(list(range(3)))
+ >>> _create_type_verifier(ArrayType(VarCharType()))(set()) # doctest:
+IGNORE_EXCEPTION_DETAIL
+ Traceback (most recent call last):
+ ...
+ TypeError:...
+ >>> _create_type_verifier(MapType(VarCharType(), IntType()))({})
+ >>> _create_type_verifier(RowType([]))(())
+ >>> _create_type_verifier(RowType([]))([])
+ >>> _create_type_verifier(RowType([]))([1]) # doctest:
+IGNORE_EXCEPTION_DETAIL
+ Traceback (most recent call last):
+ ...
+ ValueError:...
+ >>> # Check if numeric values are within the allowed range.
+ >>> _create_type_verifier(TinyIntType())(12)
+ >>> _create_type_verifier(TinyIntType())(1234) # doctest:
+IGNORE_EXCEPTION_DETAIL
+ Traceback (most recent call last):
+ ...
+ ValueError:...
+ >>> _create_type_verifier(TinyIntType(), False)(None) # doctest:
+IGNORE_EXCEPTION_DETAIL
+ Traceback (most recent call last):
+ ...
+ ValueError:...
+ >>> _create_type_verifier(
+ ... ArrayType(SmallIntType(), False))([1, None]) # doctest:
+IGNORE_EXCEPTION_DETAIL
+ Traceback (most recent call last):
+ ...
+ ValueError:...
+ >>> _create_type_verifier(MapType(VarCharType(), IntType()))({None: 1})
+ Traceback (most recent call last):
+ ...
+ ValueError:...
+ >>> schema = RowType().add("a", IntType()).add("b", VarCharType(), False)
+ >>> _create_type_verifier(schema)((1, None)) # doctest:
+IGNORE_EXCEPTION_DETAIL
+ Traceback (most recent call last):
+ ...
+ ValueError:...
+ """
+
+ if name is None:
+ new_msg = lambda msg: msg
+ new_name = lambda n: "field %s" % n
+ else:
+ new_msg = lambda msg: "%s: %s" % (name, msg)
+ new_name = lambda n: "field %s in %s" % (n, name)
+
+ def verify_nullability(obj):
+ if obj is None:
+ if data_type.nullable:
+ return True
+ else:
+ raise ValueError(new_msg("This field is not nullable, but got
None"))
+ else:
+ return False
+
+ _type = type(data_type)
+
+ assert _type in _acceptable_types or isinstance(data_type,
UserDefinedType),\
+ new_msg("unknown datatype: %s" % data_type)
+
+ def verify_acceptable_types(obj):
+ # subclass of them can not be from_sql_type in JVM
+ if type(obj) not in _acceptable_types[_type]:
+ raise TypeError(new_msg("%s can not accept object %r in type %s"
+ % (data_type, obj, type(obj))))
+
+ if isinstance(data_type, CharType):
+ def verify_char(obj):
+ verify_acceptable_types(obj)
+ if len(obj) != data_type.length:
+ raise ValueError(new_msg(
+ "length of object (%s) of CharType is not: %d" % (obj,
data_type.length)))
+
+ verify_value = verify_char
+
+ elif isinstance(data_type, VarCharType):
+ def verify_varchar(obj):
+ verify_acceptable_types(obj)
+ if len(obj) > data_type.length:
+ raise ValueError(new_msg(
+ "length of object (%s) of VarCharType exceeds: %d" % (obj,
data_type.length)))
+
+ verify_value = verify_varchar
+
+ elif isinstance(data_type, BinaryType):
+ def verify_binary(obj):
+ verify_acceptable_types(obj)
+ if len(obj) != data_type.length:
+ raise ValueError(new_msg(
+ "length of object (%s) of BinaryType is not: %d" % (obj,
data_type.length)))
+
+ verify_value = verify_binary
+
+ elif isinstance(data_type, VarBinaryType):
+ def verify_varbinary(obj):
+ verify_acceptable_types(obj)
+ if len(obj) > data_type.length:
+ raise ValueError(new_msg(
+ "length of object (%s) of VarBinaryType exceeds: %d"
+ % (obj, data_type.length)))
+
+ verify_value = verify_varbinary
+
+ elif isinstance(data_type, UserDefinedType):
+ verifier = _create_type_verifier(data_type.sql_type(), name=name)
+
+ def verify_udf(obj):
+ if not (hasattr(obj, '__UDT__') and obj.__UDT__ == data_type):
+ raise ValueError(new_msg("%r is not an instance of type %r" %
(obj, data_type)))
+ verifier(data_type.to_sql_type(obj))
+
+ verify_value = verify_udf
+
+ elif isinstance(data_type, TinyIntType):
+ def verify_tiny_int(obj):
+ verify_acceptable_types(obj)
+ if obj < -128 or obj > 127:
+ raise ValueError(new_msg("object of TinyIntType out of range,
got: %s" % obj))
+
+ verify_value = verify_tiny_int
+
+ elif isinstance(data_type, SmallIntType):
+ def verify_small_int(obj):
+ verify_acceptable_types(obj)
+ if obj < -32768 or obj > 32767:
+ raise ValueError(new_msg("object of SmallIntType out of range,
got: %s" % obj))
+
+ verify_value = verify_small_int
+
+ elif isinstance(data_type, IntType):
+ def verify_integer(obj):
+ verify_acceptable_types(obj)
+ if obj < -2147483648 or obj > 2147483647:
+ raise ValueError(
+ new_msg("object of IntType out of range, got: %s" % obj))
+
+ verify_value = verify_integer
+
+ elif isinstance(data_type, ArrayType):
+ element_verifier = _create_type_verifier(
+ data_type.element_type, name="element in array %s" % name)
+
+ def verify_array(obj):
+ verify_acceptable_types(obj)
+ for i in obj:
+ element_verifier(i)
+
+ verify_value = verify_array
+
+ elif isinstance(data_type, MapType):
+ key_verifier = _create_type_verifier(data_type.key_type, name="key of
map %s" % name)
+ value_verifier = _create_type_verifier(data_type.value_type,
name="value of map %s" % name)
+
+ def verify_map(obj):
+ verify_acceptable_types(obj)
+ for k, v in obj.items():
+ key_verifier(k)
+ value_verifier(v)
+
+ verify_value = verify_map
+
+ elif isinstance(data_type, RowType):
+ verifiers = []
+ for f in data_type.fields:
+ verifier = _create_type_verifier(f.data_type,
name=new_name(f.name))
+ verifiers.append((f.name, verifier))
+
+ def verify_row_field(obj):
+ if isinstance(obj, dict):
+ for f, verifier in verifiers:
+ verifier(obj.get(f))
+ elif isinstance(obj, Row) and getattr(obj, "_from_dict", False):
+ # the order in obj could be different than dataType.fields
+ for f, verifier in verifiers:
+ verifier(obj[f])
+ elif isinstance(obj, (tuple, list)):
+ if len(obj) != len(verifiers):
+ raise ValueError(
+ new_msg("Length of object (%d) does not match with "
+ "length of fields (%d)" % (len(obj),
len(verifiers))))
+ for v, (_, verifier) in zip(obj, verifiers):
+ verifier(v)
+ elif hasattr(obj, "__dict__"):
+ d = obj.__dict__
+ for f, verifier in verifiers:
+ verifier(d.get(f))
+ else:
+ raise TypeError(new_msg("RowType can not accept object %r in
type %s"
+ % (obj, type(obj))))
+
+ verify_value = verify_row_field
+
+ else:
+ def verify_default(obj):
+ verify_acceptable_types(obj)
+
+ verify_value = verify_default
+
+ def verify(obj):
+ if not verify_nullability(obj):
+ verify_value(obj)
+
+ return verify
+
+
+class DataTypes(object):
+
+ @classmethod
+ def NULL(cls):
+ return NullType()
+
+ @classmethod
+ def CHAR(cls, length=1, nullable=True):
+ return CharType(length, nullable)
+
+ @classmethod
+ def VARCHAR(cls, length=1, nullable=True):
+ return VarCharType(length, nullable)
+
+ @classmethod
+ def STRING(cls, nullable=True):
+ return DataTypes.VARCHAR(0x7fffffff, nullable)
+
+ @classmethod
+ def BOOLEAN(cls, nullable=True):
+ return BooleanType(nullable)
+
+ @classmethod
+ def BINARY(cls, length=1, nullable=True):
+ return BinaryType(length, nullable)
+
+ @classmethod
+ def VARBINARY(cls, length=1, nullable=True):
+ return VarBinaryType(length, nullable)
+
+ @classmethod
+ def BYTES(cls, nullable=True):
+ return DataTypes.VARBINARY(0x7fffffff, nullable)
+
+ @classmethod
+ def DECIMAL(cls, precision=10, scale=0, nullable=True):
+ return DecimalType(precision, scale, nullable)
+
+ @classmethod
+ def TINYINT(cls, nullable=True):
+ return TinyIntType(nullable)
+
+ @classmethod
+ def SMALLINT(cls, nullable=True):
+ return SmallIntType(nullable)
+
+ @classmethod
+ def INT(cls, nullable=True):
+ return IntType(nullable)
+
+ @classmethod
+ def BIGINT(cls, nullable=True):
+ return BigIntType(nullable)
+
+ @classmethod
+ def FLOAT(cls, nullable=True):
+ return FloatType(nullable)
+
+ @classmethod
+ def DOUBLE(cls, nullable=True):
+ return DoubleType(nullable)
+
+ @classmethod
+ def DATE(cls, nullable=True):
+ return DateType(nullable)
+
+ @classmethod
+ def TIME(cls, precision=0, nullable=True):
+ return TimeType(precision, nullable)
+
+ @classmethod
+ def TIMESTAMP(cls, kind=TimestampKind.REGULAR, precision=6, nullable=True):
+ return TimestampType(kind, precision, nullable)
+
+ @classmethod
+ def ARRAY(cls, element_type, nullable=True):
+ return ArrayType(element_type, nullable)
+
+ @classmethod
+ def MAP(cls, key_type, value_type, nullable=True):
+ return MapType(key_type, value_type, nullable)
+
+ @classmethod
+ def MULTISET(cls, element_type, nullable=True):
+ return MultisetType(element_type, nullable)
+
+ @classmethod
+ def ROW(cls, row_fields=[], nullable=True):
+ return RowType(row_fields, nullable)
+
+ @classmethod
+ def FIELD(cls, name, data_type, description=None):
+ return RowField(name, data_type, description)
diff --git a/flink-python/pyflink/util/type_utils.py
b/flink-python/pyflink/util/type_utils.py
deleted file mode 100644
index 093ab00..0000000
--- a/flink-python/pyflink/util/type_utils.py
+++ /dev/null
@@ -1,55 +0,0 @@
-################################################################################
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-################################################################################
-
-import sys
-from threading import RLock
-
-from pyflink.java_gateway import get_gateway
-from pyflink.table.types import DataTypes
-
-if sys.version > '3':
- xrange = range
-
-_data_types_mapping = None
-_lock = RLock()
-
-
-def to_java_type(py_type):
- global _data_types_mapping
- global _lock
-
- if _data_types_mapping is None:
- with _lock:
- gateway = get_gateway()
- TYPES = gateway.jvm.org.apache.flink.api.common.typeinfo.Types
- _data_types_mapping = {
- DataTypes.STRING: TYPES.STRING,
- DataTypes.BOOLEAN: TYPES.BOOLEAN,
- DataTypes.BYTE: TYPES.BYTE,
- DataTypes.CHAR: TYPES.CHAR,
- DataTypes.SHORT: TYPES.SHORT,
- DataTypes.INT: TYPES.INT,
- DataTypes.LONG: TYPES.LONG,
- DataTypes.FLOAT: TYPES.FLOAT,
- DataTypes.DOUBLE: TYPES.DOUBLE,
- DataTypes.DATE: TYPES.SQL_DATE,
- DataTypes.TIME: TYPES.SQL_TIME,
- DataTypes.TIMESTAMP: TYPES.SQL_TIMESTAMP
- }
-
- return _data_types_mapping[py_type]