This is an automated email from the ASF dual-hosted git repository.

rymurr pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iceberg.git


The following commit(s) were added to refs/heads/master by this push:
     new 11562e3  [Python] add to_byte_buffer to literal classes (#2655)
11562e3 is described below

commit 11562e38255e968079a6f5c8f6a0ce14b06825af
Author: jun-he <[email protected]>
AuthorDate: Thu Jun 3 00:29:05 2021 -0700

    [Python] add to_byte_buffer to literal classes (#2655)
---
 python/iceberg/api/expressions/literals.py | 50 ++++++++++++++++++------------
 python/iceberg/api/types/conversions.py    | 29 ++++++++---------
 python/iceberg/core/partition_summary.py   | 10 +++---
 python/tests/api/expressions/conftest.py   | 24 +++++++-------
 python/tests/api/test_conversions.py       | 15 +++++++--
 5 files changed, 76 insertions(+), 52 deletions(-)

diff --git a/python/iceberg/api/expressions/literals.py 
b/python/iceberg/api/expressions/literals.py
index 5090123..41387ce 100644
--- a/python/iceberg/api/expressions/literals.py
+++ b/python/iceberg/api/expressions/literals.py
@@ -27,6 +27,7 @@ from .expression import (FALSE,
                          TRUE)
 from .java_variables import (JAVA_MAX_FLOAT,
                              JAVA_MIN_FLOAT)
+from ..types.conversions import Conversions
 from ..types.type import TypeID
 
 
@@ -60,7 +61,7 @@ class Literals(object):
         elif isinstance(value, Decimal):
             return DecimalLiteral(value)
         else:
-            raise RuntimeError("Unimplemented Type Literal")
+            raise NotImplementedError("Unimplemented Type Literal for value: 
%s" % value)
 
     @staticmethod
     def above_max():
@@ -101,15 +102,20 @@ class Literal(object):
         elif isinstance(value, Decimal):
             return DecimalLiteral(value)
 
-    def to(self, type):
+    def to(self, type_var):
+        raise NotImplementedError()
+
+    def to_byte_buffer(self):
         raise NotImplementedError()
 
 
 class BaseLiteral(Literal):
-    def __init__(self, value):
+    def __init__(self, value, type_id):
         self.value = value
+        self.byte_buffer = None
+        self.type_id = type_id
 
-    def to(self, type):
+    def to(self, type_var):
         raise NotImplementedError()
 
     def __eq__(self, other):
@@ -129,11 +135,17 @@ class BaseLiteral(Literal):
     def __str__(self):
         return str(self.value)
 
+    def to_byte_buffer(self):
+        if self.byte_buffer is None:
+            self.byte_buffer = Conversions.to_byte_buffer(self.type_id, 
self.value)
+
+        return self.byte_buffer
+
 
 class ComparableLiteral(BaseLiteral):
 
-    def __init__(self, value):
-        super(ComparableLiteral, self).__init__(value)
+    def __init__(self, value, type_id):
+        super(ComparableLiteral, self).__init__(value, type_id)
 
     def to(self, type):
         raise NotImplementedError()
@@ -212,7 +224,7 @@ class BelowMin(Literal):
 class BooleanLiteral(ComparableLiteral):
 
     def __init__(self, value):
-        super(BooleanLiteral, self).__init__(value)
+        super(BooleanLiteral, self).__init__(value, TypeID.BOOLEAN)
 
     def to(self, type_var):
         if type_var.type_id == TypeID.BOOLEAN:
@@ -222,7 +234,7 @@ class BooleanLiteral(ComparableLiteral):
 class IntegerLiteral(ComparableLiteral):
 
     def __init__(self, value):
-        super(IntegerLiteral, self).__init__(value)
+        super(IntegerLiteral, self).__init__(value, TypeID.INTEGER)
 
     def to(self, type_var):
         if type_var.type_id == TypeID.INTEGER:
@@ -247,7 +259,7 @@ class IntegerLiteral(ComparableLiteral):
 class LongLiteral(ComparableLiteral):
 
     def __init__(self, value):
-        super(LongLiteral, self).__init__(value)
+        super(LongLiteral, self).__init__(value, TypeID.LONG)
 
     def to(self, type_var):  # noqa: C901
         if type_var.type_id == TypeID.INTEGER:
@@ -279,7 +291,7 @@ class LongLiteral(ComparableLiteral):
 class FloatLiteral(ComparableLiteral):
 
     def __init__(self, value):
-        super(FloatLiteral, self).__init__(value)
+        super(FloatLiteral, self).__init__(value, TypeID.FLOAT)
 
     def to(self, type_var):
         if type_var.type_id == TypeID.FLOAT:
@@ -300,7 +312,7 @@ class FloatLiteral(ComparableLiteral):
 class DoubleLiteral(ComparableLiteral):
 
     def __init__(self, value):
-        super(DoubleLiteral, self).__init__(value)
+        super(DoubleLiteral, self).__init__(value, TypeID.DOUBLE)
 
     def to(self, type_var):
         if type_var.type_id == TypeID.FLOAT:
@@ -326,7 +338,7 @@ class DoubleLiteral(ComparableLiteral):
 class DateLiteral(ComparableLiteral):
 
     def __init__(self, value):
-        super(DateLiteral, self).__init__(value)
+        super(DateLiteral, self).__init__(value, TypeID.DATE)
 
     def to(self, type_var):
         if type_var.type_id == TypeID.DATE:
@@ -336,7 +348,7 @@ class DateLiteral(ComparableLiteral):
 class TimeLiteral(ComparableLiteral):
 
     def __init__(self, value):
-        super(TimeLiteral, self).__init__(value)
+        super(TimeLiteral, self).__init__(value, TypeID.TIME)
 
     def to(self, type_var):
         if type_var.type_id == TypeID.TIME:
@@ -346,7 +358,7 @@ class TimeLiteral(ComparableLiteral):
 class TimestampLiteral(ComparableLiteral):
 
     def __init__(self, value):
-        super(TimestampLiteral, self).__init__(value)
+        super(TimestampLiteral, self).__init__(value, TypeID.TIMESTAMP)
 
     def to(self, type_var):
         if type_var.type_id == TypeID.TIMESTAMP:
@@ -358,7 +370,7 @@ class TimestampLiteral(ComparableLiteral):
 class DecimalLiteral(ComparableLiteral):
 
     def __init__(self, value):
-        super(DecimalLiteral, self).__init__(value)
+        super(DecimalLiteral, self).__init__(value, TypeID.DECIMAL)
 
     def to(self, type_var):
         if type_var.type_id == TypeID.DECIMAL and type_var.scale == 
abs(self.value.as_tuple().exponent):
@@ -367,7 +379,7 @@ class DecimalLiteral(ComparableLiteral):
 
 class StringLiteral(BaseLiteral):
     def __init__(self, value):
-        super(StringLiteral, self).__init__(value)
+        super(StringLiteral, self).__init__(value, TypeID.STRING)
 
     def to(self, type_var):  # noqa: C901
         import dateutil.parser
@@ -445,7 +457,7 @@ class StringLiteral(BaseLiteral):
 
 class UUIDLiteral(ComparableLiteral):
     def __init__(self, value):
-        super(UUIDLiteral, self).__init__(value)
+        super(UUIDLiteral, self).__init__(value, TypeID.UUID)
 
     def to(self, type_var):
         if type_var.type_id == TypeID.UUID:
@@ -454,7 +466,7 @@ class UUIDLiteral(ComparableLiteral):
 
 class FixedLiteral(BaseLiteral):
     def __init__(self, value):
-        super(FixedLiteral, self).__init__(value)
+        super(FixedLiteral, self).__init__(value, TypeID.FIXED)
 
     def to(self, type_var):
         if type_var.type_id == TypeID.FIXED:
@@ -499,7 +511,7 @@ class FixedLiteral(BaseLiteral):
 
 class BinaryLiteral(BaseLiteral):
     def __init__(self, value):
-        super(BinaryLiteral, self).__init__(value)
+        super(BinaryLiteral, self).__init__(value, TypeID.BINARY)
 
     def to(self, type_var):
         if type_var.type_id == TypeID.FIXED:
diff --git a/python/iceberg/api/types/conversions.py 
b/python/iceberg/api/types/conversions.py
index 6768b61..ae92a63 100644
--- a/python/iceberg/api/types/conversions.py
+++ b/python/iceberg/api/types/conversions.py
@@ -39,17 +39,18 @@ class Conversions(object):
                      TypeID.DECIMAL: lambda as_str: Decimal(as_str),
                      }
 
-    to_byte_buff_mapping = {TypeID.BOOLEAN: lambda type_var, value: 
struct.pack("<h", 1 if value else 0),
-                            TypeID.INTEGER: lambda type_var, value: 
struct.pack("<i", value),
-                            TypeID.DATE: lambda type_var, value: 
struct.pack("<i", value),
-                            TypeID.LONG: lambda type_var, value: 
struct.pack("<l", value),
-                            TypeID.TIME: lambda type_var, value: 
struct.pack("<l", value),
-                            TypeID.TIMESTAMP: lambda type_var, value: 
struct.pack("<l", value),
-                            TypeID.FLOAT: lambda type_var, value: 
struct.pack("<f", value),
-                            TypeID.DOUBLE: lambda type_var, value: 
struct.pack("<d", value),
-                            TypeID.STRING: lambda type_var, value: 
value.encode('UTF-8'),
-                            TypeID.UUID: lambda type_var, value: 
struct.pack('>QQ', (value.int >> 64) & 0xFFFFFFFFFFFFFFFF,
-                                                                             
value.int & 0xFFFFFFFFFFFFFFFF),
+    to_byte_buff_mapping = {TypeID.BOOLEAN: lambda type_id, value: 
struct.pack("<h", 1 if value else 0),
+                            TypeID.INTEGER: lambda type_id, value: 
struct.pack("<i", value),
+                            TypeID.DATE: lambda type_id, value: 
struct.pack("<i", value),
+                            TypeID.LONG: lambda type_id, value: 
struct.pack("<l", value),
+                            TypeID.TIME: lambda type_id, value: 
struct.pack("<l", value),
+                            TypeID.TIMESTAMP: lambda type_id, value: 
struct.pack("<l", value),
+                            TypeID.FLOAT: lambda type_id, value: 
struct.pack("<f", value),
+                            TypeID.DOUBLE: lambda type_id, value: 
struct.pack("<d", value),
+                            TypeID.STRING: lambda type_id, value: 
value.encode('UTF-8'),
+                            TypeID.UUID: lambda type_id, value: 
struct.pack('>QQ', (value.int >> 64)
+                                                                            & 
0xFFFFFFFFFFFFFFFF, value.int
+                                                                            & 
0xFFFFFFFFFFFFFFFF),
                             # TypeId.FIXED: lambda as_str: None,
                             # TypeId.BINARY: lambda as_str: None,
                             # TypeId.DECIMAL: lambda type_var, value: 
struct.pack(value.quantize(
@@ -81,11 +82,11 @@ class Conversions(object):
         return part_func(as_string)
 
     @staticmethod
-    def to_byte_buffer(type_var, value):
+    def to_byte_buffer(type_id, value):
         try:
-            return 
Conversions.to_byte_buff_mapping.get(type_var.type_id)(type_var, value)
+            return Conversions.to_byte_buff_mapping.get(type_id)(type_id, 
value)
         except KeyError:
-            raise RuntimeError("Cannot Serialize Type: %s" % type_var)
+            raise TypeError("Cannot Serialize Type: %s" % type_id)
 
     @staticmethod
     def from_byte_buffer(type_var, buffer_var):
diff --git a/python/iceberg/core/partition_summary.py 
b/python/iceberg/core/partition_summary.py
index 8b7ec3b..45239ae 100644
--- a/python/iceberg/core/partition_summary.py
+++ b/python/iceberg/core/partition_summary.py
@@ -42,18 +42,18 @@ class PartitionSummary(object):
 
 class PartitionFieldStats(object):
 
-    def __init__(self, type):
+    def __init__(self, type_var):
         self.contains_null = False
-        self.type = type
+        self.type = type_var
         self.min = None
         self.max = None
 
     def to_summary(self):
-        lower_bound = None if self.min is None else 
Conversions.to_byte_buffer(self.type, self.min)
-        upper_bound = None if self.max is None else 
Conversions.to_byte_buffer(self.type, self.max)
+        lower_bound = None if self.min is None else 
Conversions.to_byte_buffer(self.type.type_id, self.min)
+        upper_bound = None if self.max is None else 
Conversions.to_byte_buffer(self.type.type_id, self.max)
         return GenericPartitionFieldSummary(contains_null=self.contains_null,
                                             lower_bound=lower_bound,
-                                            uppwer_bound=upper_bound)
+                                            upper_bound=upper_bound)
 
     def update(self, value):
         if value is None:
diff --git a/python/tests/api/expressions/conftest.py 
b/python/tests/api/expressions/conftest.py
index 28cc43d..a0e748f 100644
--- a/python/tests/api/expressions/conftest.py
+++ b/python/tests/api/expressions/conftest.py
@@ -208,9 +208,9 @@ def file():
                         # null value counts
                         {4: 50, 5: 10, 6: 0},
                         # lower bounds
-                        {1: Conversions.to_byte_buffer(IntegerType.get(), 30)},
+                        {1: 
Conversions.to_byte_buffer(IntegerType.get().type_id, 30)},
                         # upper bounds
-                        {1: Conversions.to_byte_buffer(IntegerType.get(), 79)})
+                        {1: 
Conversions.to_byte_buffer(IntegerType.get().type_id, 79)})
 
 
 @pytest.fixture(scope="session")
@@ -220,10 +220,10 @@ def strict_file():
                         50,
                         {4: 50, 5: 50, 6: 50},
                         {4: 50, 5: 10, 6: 0},
-                        {1: Conversions.to_byte_buffer(IntegerType.get(), 30),
-                         7: Conversions.to_byte_buffer(IntegerType.get(), 5)},
-                        {1: Conversions.to_byte_buffer(IntegerType.get(), 79),
-                         7: Conversions.to_byte_buffer(IntegerType.get(), 5)}
+                        {1: 
Conversions.to_byte_buffer(IntegerType.get().type_id, 30),
+                         7: 
Conversions.to_byte_buffer(IntegerType.get().type_id, 5)},
+                        {1: 
Conversions.to_byte_buffer(IntegerType.get().type_id, 79),
+                         7: 
Conversions.to_byte_buffer(IntegerType.get().type_id, 5)}
                         )
 
 
@@ -406,17 +406,17 @@ def inc_man_spec():
 def inc_man_file():
     return TestManifestFile("manifest-list.avro", 1024, 0, int(time.time() * 
1000), 5, 10, 0,
                             (TestFieldSummary(False,
-                                              
Conversions.to_byte_buffer(IntegerType.get(), 30),
-                                              
Conversions.to_byte_buffer(IntegerType.get(), 79)),
+                                              
Conversions.to_byte_buffer(IntegerType.get().type_id, 30),
+                                              
Conversions.to_byte_buffer(IntegerType.get().type_id, 79)),
                              TestFieldSummary(True,
                                               None,
                                               None),
                              TestFieldSummary(True,
-                                              
Conversions.to_byte_buffer(StringType.get(), 'a'),
-                                              
Conversions.to_byte_buffer(StringType.get(), 'z')),
+                                              
Conversions.to_byte_buffer(StringType.get().type_id, 'a'),
+                                              
Conversions.to_byte_buffer(StringType.get().type_id, 'z')),
                              TestFieldSummary(False,
-                                              
Conversions.to_byte_buffer(StringType.get(), 'a'),
-                                              
Conversions.to_byte_buffer(StringType.get(), 'z'))
+                                              
Conversions.to_byte_buffer(StringType.get().type_id, 'a'),
+                                              
Conversions.to_byte_buffer(StringType.get().type_id, 'z'))
                              ))
 
 
diff --git a/python/tests/api/test_conversions.py 
b/python/tests/api/test_conversions.py
index be8587a..7f6778c 100644
--- a/python/tests/api/test_conversions.py
+++ b/python/tests/api/test_conversions.py
@@ -17,9 +17,12 @@
 
 import unittest
 
-from iceberg.api.types import (DoubleType,
+from iceberg.api.expressions import Literal
+from iceberg.api.types import (DateType,
+                               DoubleType,
                                IntegerType,
-                               LongType)
+                               LongType,
+                               StringType)
 from iceberg.api.types.conversions import Conversions
 
 
@@ -32,3 +35,11 @@ class TestConversions(unittest.TestCase):
                                                             
b'\xd2\x04\x00\x00\x00\x00\x00\x00'))
         self.assertAlmostEqual(1.2345, 
Conversions.from_byte_buffer(DoubleType.get(),
                                                                     
b'\x8d\x97\x6e\x12\x83\xc0\xf3\x3f'))
+        self.assertEqual("foo", Conversions.from_byte_buffer(StringType.get(), 
b'foo'))
+
+    def test_to_bytes(self):
+        self.assertEqual(b'\x00\x00', Literal.of(False).to_byte_buffer())
+        self.assertEqual(b'\x01\x00', Literal.of(True).to_byte_buffer())
+        self.assertEqual(b'foo', Literal.of("foo").to_byte_buffer())
+        self.assertEqual(b'\xd2\x04\x00\x00', 
Literal.of(1234).to_byte_buffer())
+        self.assertEqual(b'\xe8\x03\x00\x00', 
Literal.of(1000).to(DateType.get()).to_byte_buffer())

Reply via email to