This is an automated email from the ASF dual-hosted git repository.
wesm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new eb5dd50 ARROW-840: [Python] Expose extension types
eb5dd50 is described below
commit eb5dd508ee3f592bf1c2a04cce09ee95e137e89b
Author: Antoine Pitrou <[email protected]>
AuthorDate: Fri Jun 14 07:53:40 2019 -0500
ARROW-840: [Python] Expose extension types
Add infrastructure to consume C++ extension types and extension arrays from
Python.
Also allow creating Python-specific extension types by subclassing
`ExtensionType`, and creating extension arrays by passing the type and storage
array to `ExtensionArray.from_storage`.
Author: Antoine Pitrou <[email protected]>
Closes #4532 from pitrou/ARROW-840-py-ext-types and squashes the following
commits:
95ca6148e <Antoine Pitrou> Add IPC tests
44ac0a156 <Antoine Pitrou> ARROW-840: Expose extension types
---
cpp/src/arrow/array.cc | 11 +-
cpp/src/arrow/extension_type.cc | 18 +++
cpp/src/arrow/extension_type.h | 9 +-
cpp/src/arrow/python/CMakeLists.txt | 1 +
cpp/src/arrow/python/extension_type.cc | 196 +++++++++++++++++++++++++
cpp/src/arrow/python/extension_type.h | 77 ++++++++++
cpp/src/arrow/python/pyarrow.h | 1 +
python/pyarrow/__init__.py | 4 +-
python/pyarrow/array.pxi | 42 +++++-
python/pyarrow/includes/libarrow.pxd | 32 ++++
python/pyarrow/lib.pxd | 20 ++-
python/pyarrow/public-api.pxi | 12 +-
python/pyarrow/tests/test_extension_type.py | 219 ++++++++++++++++++++++++++++
python/pyarrow/types.pxi | 150 +++++++++++++++++--
14 files changed, 775 insertions(+), 17 deletions(-)
diff --git a/cpp/src/arrow/array.cc b/cpp/src/arrow/array.cc
index 7a3d36e..9d37b45 100644
--- a/cpp/src/arrow/array.cc
+++ b/cpp/src/arrow/array.cc
@@ -1259,7 +1259,16 @@ struct ValidateVisitor {
return Status::OK();
}
- Status Visit(const ExtensionArray& array) { return
ValidateArray(*array.storage()); }
+ Status Visit(const ExtensionArray& array) {
+ const auto& ext_type = checked_cast<const ExtensionType&>(*array.type());
+
+ if (!array.storage()->type()->Equals(*ext_type.storage_type())) {
+ return Status::Invalid("Extension array of type '",
array.type()->ToString(),
+ "' has storage array of incompatible type '",
+ array.storage()->type()->ToString(), "'");
+ }
+ return ValidateArray(*array.storage());
+ }
protected:
template <typename ArrayType>
diff --git a/cpp/src/arrow/extension_type.cc b/cpp/src/arrow/extension_type.cc
index e104c03..25945f3 100644
--- a/cpp/src/arrow/extension_type.cc
+++ b/cpp/src/arrow/extension_type.cc
@@ -27,10 +27,14 @@
#include "arrow/array.h"
#include "arrow/status.h"
#include "arrow/type.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/logging.h"
#include "arrow/util/visibility.h"
namespace arrow {
+using internal::checked_cast;
+
DataTypeLayout ExtensionType::layout() const { return storage_type_->layout();
}
std::string ExtensionType::ToString() const {
@@ -41,7 +45,21 @@ std::string ExtensionType::ToString() const {
std::string ExtensionType::name() const { return "extension"; }
+ExtensionArray::ExtensionArray(const std::shared_ptr<ArrayData>& data) {
SetData(data); }
+
+ExtensionArray::ExtensionArray(const std::shared_ptr<DataType>& type,
+ const std::shared_ptr<Array>& storage) {
+ DCHECK_EQ(type->id(), Type::EXTENSION);
+ DCHECK(
+ storage->type()->Equals(*checked_cast<const
ExtensionType&>(*type).storage_type()));
+ auto data = storage->data()->Copy();
+ // XXX This pointer is reverted below in SetData()...
+ data->type = type;
+ SetData(data);
+}
+
void ExtensionArray::SetData(const std::shared_ptr<ArrayData>& data) {
+ DCHECK_EQ(data->type->id(), Type::EXTENSION);
this->Array::SetData(data);
auto storage_data = data->Copy();
diff --git a/cpp/src/arrow/extension_type.h b/cpp/src/arrow/extension_type.h
index b3df2b3..6a1ca0b 100644
--- a/cpp/src/arrow/extension_type.h
+++ b/cpp/src/arrow/extension_type.h
@@ -84,7 +84,14 @@ class ARROW_EXPORT ExtensionType : public DataType {
/// \brief Base array class for user-defined extension types
class ARROW_EXPORT ExtensionArray : public Array {
public:
- explicit ExtensionArray(const std::shared_ptr<ArrayData>& data) {
SetData(data); }
+ /// \brief Construct an ExtensionArray from an ArrayData.
+ ///
+ /// The ArrayData must have the right ExtensionType.
+ explicit ExtensionArray(const std::shared_ptr<ArrayData>& data);
+
+ /// \brief Construct an ExtensionArray from a type and the underlying
storage.
+ ExtensionArray(const std::shared_ptr<DataType>& type,
+ const std::shared_ptr<Array>& storage);
/// \brief The physical storage for the extension array
std::shared_ptr<Array> storage() const { return storage_; }
diff --git a/cpp/src/arrow/python/CMakeLists.txt
b/cpp/src/arrow/python/CMakeLists.txt
index d6376f5..0d17a9f 100644
--- a/cpp/src/arrow/python/CMakeLists.txt
+++ b/cpp/src/arrow/python/CMakeLists.txt
@@ -34,6 +34,7 @@ set(ARROW_PYTHON_SRCS
config.cc
decimal.cc
deserialize.cc
+ extension_type.cc
helpers.cc
inference.cc
init.cc
diff --git a/cpp/src/arrow/python/extension_type.cc
b/cpp/src/arrow/python/extension_type.cc
new file mode 100644
index 0000000..b130030
--- /dev/null
+++ b/cpp/src/arrow/python/extension_type.cc
@@ -0,0 +1,196 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <memory>
+
+#include "arrow/python/extension_type.h"
+#include "arrow/python/helpers.h"
+#include "arrow/python/pyarrow.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/logging.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+
+namespace py {
+
+namespace {
+
+// Serialize a Python ExtensionType instance
+Status SerializeExtInstance(PyObject* type_instance, std::string* out) {
+ OwnedRef res(PyObject_CallMethod(type_instance, "__arrow_ext_serialize__",
nullptr));
+ if (!res) {
+ return ConvertPyError();
+ }
+ if (!PyBytes_Check(res.obj())) {
+ return Status::TypeError(
+ "__arrow_ext_serialize__ should return bytes object, "
+ "got ",
+ internal::PyObject_StdStringRepr(res.obj()));
+ }
+ *out = internal::PyBytes_AsStdString(res.obj());
+ return Status::OK();
+}
+
+// Deserialize a Python ExtensionType instance
+PyObject* DeserializeExtInstance(PyObject* type_class,
+ std::shared_ptr<DataType> storage_type,
+ const std::string& serialized_data) {
+ OwnedRef storage_ref(wrap_data_type(storage_type));
+ if (!storage_ref) {
+ return nullptr;
+ }
+ OwnedRef data_ref(PyBytes_FromStringAndSize(
+ serialized_data.data(),
static_cast<Py_ssize_t>(serialized_data.size())));
+ if (!data_ref) {
+ return nullptr;
+ }
+
+ return PyObject_CallMethod(type_class, "__arrow_ext_deserialize__", "OO",
+ storage_ref.obj(), data_ref.obj());
+}
+
+} // namespace
+
+static const char* kExtensionName = "arrow.py_extension_type";
+
+PyExtensionType::PyExtensionType(std::shared_ptr<DataType> storage_type,
PyObject* typ,
+ PyObject* inst)
+ : ExtensionType(storage_type), type_class_(typ), type_instance_(inst) {}
+
+std::string PyExtensionType::extension_name() const { return kExtensionName; }
+
+bool PyExtensionType::ExtensionEquals(const ExtensionType& other) const {
+ PyAcquireGIL lock;
+
+ if (other.extension_name() != extension_name()) {
+ return false;
+ }
+ const auto& other_ext = checked_cast<const PyExtensionType&>(other);
+ int res = -1;
+ if (!type_instance_) {
+ if (other_ext.type_instance_) {
+ return false;
+ }
+ // Compare Python types
+ res = PyObject_RichCompareBool(type_class_.obj(),
other_ext.type_class_.obj(), Py_EQ);
+ } else {
+ if (!other_ext.type_instance_) {
+ return false;
+ }
+ // Compare Python instances
+ OwnedRef left(GetInstance());
+ OwnedRef right(other_ext.GetInstance());
+ if (!left || !right) {
+ goto error;
+ }
+ res = PyObject_RichCompareBool(left.obj(), right.obj(), Py_EQ);
+ }
+ if (res == -1) {
+ goto error;
+ }
+ return res == 1;
+
+error:
+ // Cannot propagate error
+ PyErr_WriteUnraisable(nullptr);
+ return false;
+}
+
+std::shared_ptr<Array> PyExtensionType::MakeArray(std::shared_ptr<ArrayData>
data) const {
+ DCHECK_EQ(data->type->id(), Type::EXTENSION);
+ DCHECK_EQ(kExtensionName,
+ checked_cast<const ExtensionType&>(*data->type).extension_name());
+ return std::make_shared<ExtensionArray>(data);
+}
+
+std::string PyExtensionType::Serialize() const {
+ DCHECK(type_instance_);
+ return serialized_;
+}
+
+Status PyExtensionType::Deserialize(std::shared_ptr<DataType> storage_type,
+ const std::string& serialized_data,
+ std::shared_ptr<DataType>* out) const {
+ PyAcquireGIL lock;
+
+ if (import_pyarrow()) {
+ return ConvertPyError();
+ }
+ OwnedRef res(DeserializeExtInstance(type_class_.obj(), storage_type,
serialized_data));
+ if (!res) {
+ return ConvertPyError();
+ }
+ return unwrap_data_type(res.obj(), out);
+}
+
+PyObject* PyExtensionType::GetInstance() const {
+ if (!type_instance_) {
+ PyErr_SetString(PyExc_TypeError, "Not an instance");
+ return nullptr;
+ }
+ DCHECK(PyWeakref_CheckRef(type_instance_.obj()));
+ PyObject* inst = PyWeakref_GET_OBJECT(type_instance_.obj());
+ if (inst != Py_None) {
+ // Cached instance still alive
+ Py_INCREF(inst);
+ return inst;
+ } else {
+ // Must reconstruct from serialized form
+ // XXX cache again?
+ return DeserializeExtInstance(type_class_.obj(), storage_type_,
serialized_);
+ }
+}
+
+Status PyExtensionType::SetInstance(PyObject* inst) const {
+ // Check we have the right type
+ PyObject* typ = reinterpret_cast<PyObject*>(Py_TYPE(inst));
+ if (typ != type_class_.obj()) {
+ return Status::TypeError("Unexpected Python ExtensionType class ",
+ internal::PyObject_StdStringRepr(typ), " expected
",
+
internal::PyObject_StdStringRepr(type_class_.obj()));
+ }
+
+ PyObject* wr = PyWeakref_NewRef(inst, nullptr);
+ if (wr == NULL) {
+ return ConvertPyError();
+ }
+ type_instance_.reset(wr);
+ return SerializeExtInstance(inst, &serialized_);
+}
+
+Status PyExtensionType::FromClass(std::shared_ptr<DataType> storage_type,
PyObject* typ,
+ std::shared_ptr<ExtensionType>* out) {
+ Py_INCREF(typ);
+ out->reset(new PyExtensionType(storage_type, typ));
+ return Status::OK();
+}
+
+Status RegisterPyExtensionType(const std::shared_ptr<DataType>& type) {
+ DCHECK_EQ(type->id(), Type::EXTENSION);
+ auto ext_type = std::dynamic_pointer_cast<ExtensionType>(type);
+ DCHECK_EQ(ext_type->extension_name(), kExtensionName);
+ return RegisterExtensionType(ext_type);
+}
+
+Status UnregisterPyExtensionType() { return
UnregisterExtensionType(kExtensionName); }
+
+std::string PyExtensionName() { return kExtensionName; }
+
+} // namespace py
+} // namespace arrow
diff --git a/cpp/src/arrow/python/extension_type.h
b/cpp/src/arrow/python/extension_type.h
new file mode 100644
index 0000000..12f9108
--- /dev/null
+++ b/cpp/src/arrow/python/extension_type.h
@@ -0,0 +1,77 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+#include <string>
+
+#include "arrow/extension_type.h"
+#include "arrow/python/common.h"
+#include "arrow/python/visibility.h"
+#include "arrow/util/macros.h"
+
+namespace arrow {
+namespace py {
+
+class ARROW_PYTHON_EXPORT PyExtensionType : public ExtensionType {
+ public:
+ // Implement extensionType API
+ std::string extension_name() const override;
+
+ bool ExtensionEquals(const ExtensionType& other) const override;
+
+ std::shared_ptr<Array> MakeArray(std::shared_ptr<ArrayData> data) const
override;
+
+ Status Deserialize(std::shared_ptr<DataType> storage_type,
+ const std::string& serialized_data,
+ std::shared_ptr<DataType>* out) const override;
+
+ std::string Serialize() const override;
+
+ // For use from Cython
+ static Status FromClass(std::shared_ptr<DataType> storage_type, PyObject*
typ,
+ std::shared_ptr<ExtensionType>* out);
+
+ // Return new ref
+ PyObject* GetInstance() const;
+ Status SetInstance(PyObject*) const;
+
+ protected:
+ PyExtensionType(std::shared_ptr<DataType> storage_type, PyObject* typ,
+ PyObject* inst = NULLPTR);
+
+ // These fields are mutable because of two-step initialization.
+ mutable OwnedRefNoGIL type_class_;
+ // A weakref or null. Storing a strong reference to the Python extension
type
+ // instance would create an unreclaimable reference cycle between Python and
C++
+ // (the Python instance has to keep a strong reference to the C++
ExtensionType
+ // in other direction). Instead, we store a weakref to the instance.
+ // If the weakref is dead, we reconstruct the instance from its serialized
form.
+ mutable OwnedRefNoGIL type_instance_;
+ // Empty if type_instance_ is null
+ mutable std::string serialized_;
+};
+
+ARROW_PYTHON_EXPORT std::string PyExtensionName();
+
+ARROW_PYTHON_EXPORT Status RegisterPyExtensionType(const
std::shared_ptr<DataType>&);
+
+ARROW_PYTHON_EXPORT Status UnregisterPyExtensionType();
+
+} // namespace py
+} // namespace arrow
diff --git a/cpp/src/arrow/python/pyarrow.h b/cpp/src/arrow/python/pyarrow.h
index a5a3910..5e42333 100644
--- a/cpp/src/arrow/python/pyarrow.h
+++ b/cpp/src/arrow/python/pyarrow.h
@@ -39,6 +39,7 @@ class Tensor;
namespace py {
+// Returns 0 on success, -1 on error.
ARROW_PYTHON_EXPORT int import_pyarrow();
ARROW_PYTHON_EXPORT bool is_buffer(PyObject* buffer);
diff --git a/python/pyarrow/__init__.py b/python/pyarrow/__init__.py
index f9ba819..556b87d 100644
--- a/python/pyarrow/__init__.py
+++ b/python/pyarrow/__init__.py
@@ -58,6 +58,8 @@ from pyarrow.lib import (null, bool_,
DataType, DictionaryType, ListType, StructType,
UnionType, TimestampType, Time32Type, Time64Type,
FixedSizeBinaryType, Decimal128Type,
+ BaseExtensionType, ExtensionType,
+ UnknownExtensionType,
DictionaryMemo,
Field,
Schema,
@@ -78,7 +80,7 @@ from pyarrow.lib import (null, bool_,
DictionaryArray,
Date32Array, Date64Array,
TimestampArray, Time32Array, Time64Array,
- Decimal128Array, StructArray,
+ Decimal128Array, StructArray, ExtensionArray,
ArrayValue, Scalar, NA, _NULL as NULL,
BooleanValue,
Int8Value, Int16Value, Int32Value, Int64Value,
diff --git a/python/pyarrow/array.pxi b/python/pyarrow/array.pxi
index cce967e..607d7ae 100644
--- a/python/pyarrow/array.pxi
+++ b/python/pyarrow/array.pxi
@@ -415,7 +415,7 @@ cdef class Array(_PandasConvertible):
"the `pyarrow.Array.from_*` functions instead."
.format(self.__class__.__name__))
- cdef void init(self, const shared_ptr[CArray]& sp_array):
+ cdef void init(self, const shared_ptr[CArray]& sp_array) except *:
self.sp_array = sp_array
self.ap = sp_array.get()
self.type = pyarrow_wrap_data_type(self.sp_array.get().type())
@@ -1458,6 +1458,45 @@ cdef class StructArray(Array):
return pyarrow_wrap_array(c_result)
+cdef class ExtensionArray(Array):
+ """
+ Concrete class for Arrow extension arrays.
+ """
+
+ @property
+ def storage(self):
+ cdef:
+ CExtensionArray* ext_array = <CExtensionArray*>(self.ap)
+
+ return pyarrow_wrap_array(ext_array.storage())
+
+ @staticmethod
+ def from_storage(BaseExtensionType typ, Array storage):
+ """
+ Construct ExtensionArray from type and storage array.
+
+ Parameters
+ ----------
+ typ: DataType
+ The extension type for the result array.
+ storage: Array
+ The underlying storage for the result array.
+
+ Returns
+ -------
+ ext_array : ExtensionArray
+ """
+ cdef:
+ shared_ptr[CExtensionArray] ext_array
+
+ if storage.type != typ.storage_type:
+ raise TypeError("Incompatible storage type {0} "
+ "for extension type {1}".format(storage.type, typ))
+
+ ext_array = make_shared[CExtensionArray](typ.sp_type, storage.sp_array)
+ return pyarrow_wrap_array(<shared_ptr[CArray]> ext_array)
+
+
cdef dict _array_classes = {
_Type_NA: NullArray,
_Type_BOOL: BooleanArray,
@@ -1485,6 +1524,7 @@ cdef dict _array_classes = {
_Type_FIXED_SIZE_BINARY: FixedSizeBinaryArray,
_Type_DECIMAL: Decimal128Array,
_Type_STRUCT: StructArray,
+ _Type_EXTENSION: ExtensionArray,
}
diff --git a/python/pyarrow/includes/libarrow.pxd
b/python/pyarrow/includes/libarrow.pxd
index f979cd6..178a250 100644
--- a/python/pyarrow/includes/libarrow.pxd
+++ b/python/pyarrow/includes/libarrow.pxd
@@ -73,6 +73,8 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil:
_Type_DICTIONARY" arrow::Type::DICTIONARY"
_Type_MAP" arrow::Type::MAP"
+ _Type_EXTENSION" arrow::Type::EXTENSION"
+
enum UnionMode" arrow::UnionMode::type":
_UnionMode_SPARSE" arrow::UnionMode::SPARSE"
_UnionMode_DENSE" arrow::UnionMode::DENSE"
@@ -1272,6 +1274,36 @@ cdef extern from 'arrow/python/inference.h' namespace
'arrow::py':
c_bool IsPyFloat(object o)
+cdef extern from 'arrow/extension_type.h' namespace 'arrow':
+ cdef cppclass CExtensionType" arrow::ExtensionType"(CDataType):
+ c_string extension_name()
+ shared_ptr[CDataType] storage_type()
+
+ cdef cppclass CExtensionArray" arrow::ExtensionArray"(CArray):
+ CExtensionArray(shared_ptr[CDataType], shared_ptr[CArray] storage)
+
+ shared_ptr[CArray] storage()
+
+
+cdef extern from 'arrow/python/extension_type.h' namespace 'arrow::py':
+ cdef cppclass CPyExtensionType \
+ " arrow::py::PyExtensionType"(CExtensionType):
+ @staticmethod
+ CStatus FromClass(shared_ptr[CDataType] storage_type,
+ object typ, shared_ptr[CExtensionType]* out)
+
+ @staticmethod
+ CStatus FromInstance(shared_ptr[CDataType] storage_type,
+ object inst, shared_ptr[CExtensionType]* out)
+
+ object GetInstance()
+ CStatus SetInstance(object)
+
+ c_string PyExtensionName()
+ CStatus RegisterPyExtensionType(shared_ptr[CDataType])
+ CStatus UnregisterPyExtensionType()
+
+
cdef extern from 'arrow/python/benchmark.h' namespace 'arrow::py::benchmark':
void Benchmark_PandasObjectIsNull(object lst) except *
diff --git a/python/pyarrow/lib.pxd b/python/pyarrow/lib.pxd
index 998848d..79ab947 100644
--- a/python/pyarrow/lib.pxd
+++ b/python/pyarrow/lib.pxd
@@ -53,8 +53,9 @@ cdef class DataType:
shared_ptr[CDataType] sp_type
CDataType* type
bytes pep3118_format
+ object __weakref__
- cdef void init(self, const shared_ptr[CDataType]& type)
+ cdef void init(self, const shared_ptr[CDataType]& type) except *
cdef Field child(self, int i)
@@ -106,6 +107,16 @@ cdef class Decimal128Type(FixedSizeBinaryType):
const CDecimal128Type* decimal128_type
+cdef class BaseExtensionType(DataType):
+ cdef:
+ const CExtensionType* ext_type
+
+
+cdef class ExtensionType(BaseExtensionType):
+ cdef:
+ const CPyExtensionType* cpy_ext_type
+
+
cdef class Field:
cdef:
shared_ptr[CField] sp_field
@@ -199,11 +210,12 @@ cdef class Array(_PandasConvertible):
cdef:
shared_ptr[CArray] sp_array
CArray* ap
+ object __weakref__
cdef readonly:
DataType type
- cdef void init(self, const shared_ptr[CArray]& sp_array)
+ cdef void init(self, const shared_ptr[CArray]& sp_array) except *
cdef getitem(self, int64_t i)
cdef int64_t length(self)
@@ -316,6 +328,10 @@ cdef class DictionaryArray(Array):
object _indices, _dictionary
+cdef class ExtensionArray(Array):
+ pass
+
+
cdef wrap_array_output(PyObject* output)
cdef object box_scalar(DataType type,
const shared_ptr[CArray]& sp_array,
diff --git a/python/pyarrow/public-api.pxi b/python/pyarrow/public-api.pxi
index 9392259..33bc803 100644
--- a/python/pyarrow/public-api.pxi
+++ b/python/pyarrow/public-api.pxi
@@ -66,7 +66,10 @@ cdef api shared_ptr[CDataType] pyarrow_unwrap_data_type(
cdef api object pyarrow_wrap_data_type(
const shared_ptr[CDataType]& type):
- cdef DataType out
+ cdef:
+ const CExtensionType* ext_type
+ const CPyExtensionType* cpy_ext_type
+ DataType out
if type.get() == NULL:
return None
@@ -85,6 +88,13 @@ cdef api object pyarrow_wrap_data_type(
out = FixedSizeBinaryType.__new__(FixedSizeBinaryType)
elif type.get().id() == _Type_DECIMAL:
out = Decimal128Type.__new__(Decimal128Type)
+ elif type.get().id() == _Type_EXTENSION:
+ ext_type = <const CExtensionType*> type.get()
+ if ext_type.extension_name() == PyExtensionName():
+ cpy_ext_type = <const CPyExtensionType*> ext_type
+ return cpy_ext_type.GetInstance()
+ else:
+ out = BaseExtensionType.__new__(BaseExtensionType)
else:
out = DataType.__new__(DataType)
diff --git a/python/pyarrow/tests/test_extension_type.py
b/python/pyarrow/tests/test_extension_type.py
new file mode 100644
index 0000000..d688d3c
--- /dev/null
+++ b/python/pyarrow/tests/test_extension_type.py
@@ -0,0 +1,219 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import pickle
+import weakref
+
+import pyarrow as pa
+
+import pytest
+
+
+class UuidType(pa.ExtensionType):
+
+ def __init__(self):
+ pa.ExtensionType.__init__(self, pa.binary(16))
+
+ def __reduce__(self):
+ return UuidType, ()
+
+
+class ParamExtType(pa.ExtensionType):
+
+ def __init__(self, width):
+ self.width = width
+ pa.ExtensionType.__init__(self, pa.binary(width))
+
+ def __reduce__(self):
+ return ParamExtType, (self.width,)
+
+
+def ipc_write_batch(batch):
+ stream = pa.BufferOutputStream()
+ writer = pa.RecordBatchStreamWriter(stream, batch.schema)
+ writer.write_batch(batch)
+ writer.close()
+ return stream.getvalue()
+
+
+def ipc_read_batch(buf):
+ reader = pa.RecordBatchStreamReader(buf)
+ return reader.read_next_batch()
+
+
+def test_ext_type_basics():
+ ty = UuidType()
+ assert ty.extension_name == "arrow.py_extension_type"
+
+
+def test_ext_type__lifetime():
+ ty = UuidType()
+ wr = weakref.ref(ty)
+ del ty
+ assert wr() is None
+
+
+def test_ext_type__storage_type():
+ ty = UuidType()
+ assert ty.storage_type == pa.binary(16)
+ assert ty.__class__ is UuidType
+ ty = ParamExtType(5)
+ assert ty.storage_type == pa.binary(5)
+ assert ty.__class__ is ParamExtType
+
+
+def test_uuid_type_pickle():
+ for proto in range(0, pickle.HIGHEST_PROTOCOL + 1):
+ ty = UuidType()
+ ser = pickle.dumps(ty, protocol=proto)
+ del ty
+ ty = pickle.loads(ser)
+ wr = weakref.ref(ty)
+ assert ty.extension_name == "arrow.py_extension_type"
+ del ty
+ assert wr() is None
+
+
+def test_ext_type_equality():
+ a = ParamExtType(5)
+ b = ParamExtType(6)
+ c = ParamExtType(6)
+ assert a != b
+ assert b == c
+ d = UuidType()
+ e = UuidType()
+ assert a != d
+ assert d == e
+
+
+def test_ext_array_basics():
+ ty = ParamExtType(3)
+ storage = pa.array([b"foo", b"bar"], type=pa.binary(3))
+ arr = pa.ExtensionArray.from_storage(ty, storage)
+ arr.validate()
+ assert arr.type is ty
+ assert arr.storage.equals(storage)
+
+
+def test_ext_array_lifetime():
+ ty = ParamExtType(3)
+ storage = pa.array([b"foo", b"bar"], type=pa.binary(3))
+ arr = pa.ExtensionArray.from_storage(ty, storage)
+
+ refs = [weakref.ref(obj) for obj in (ty, arr, storage)]
+ del ty, storage, arr
+ for ref in refs:
+ assert ref() is None
+
+
+def test_ext_array_errors():
+ ty = ParamExtType(4)
+ storage = pa.array([b"foo", b"bar"], type=pa.binary(3))
+ with pytest.raises(TypeError, match="Incompatible storage type"):
+ pa.ExtensionArray.from_storage(ty, storage)
+
+
+def test_ext_array_equality():
+ storage1 = pa.array([b"0123456789abcdef"], type=pa.binary(16))
+ storage2 = pa.array([b"0123456789abcdef"], type=pa.binary(16))
+ storage3 = pa.array([], type=pa.binary(16))
+ ty1 = UuidType()
+ ty2 = ParamExtType(16)
+
+ a = pa.ExtensionArray.from_storage(ty1, storage1)
+ b = pa.ExtensionArray.from_storage(ty1, storage2)
+ assert a.equals(b)
+ c = pa.ExtensionArray.from_storage(ty1, storage3)
+ assert not a.equals(c)
+ d = pa.ExtensionArray.from_storage(ty2, storage1)
+ assert not a.equals(d)
+ e = pa.ExtensionArray.from_storage(ty2, storage2)
+ assert d.equals(e)
+ f = pa.ExtensionArray.from_storage(ty2, storage3)
+ assert not d.equals(f)
+
+
+def test_ext_array_pickling():
+ for proto in range(0, pickle.HIGHEST_PROTOCOL + 1):
+ ty = ParamExtType(3)
+ storage = pa.array([b"foo", b"bar"], type=pa.binary(3))
+ arr = pa.ExtensionArray.from_storage(ty, storage)
+ ser = pickle.dumps(arr, protocol=proto)
+ del ty, storage, arr
+ arr = pickle.loads(ser)
+ arr.validate()
+ assert isinstance(arr, pa.ExtensionArray)
+ assert arr.type == ParamExtType(3)
+ assert arr.type.storage_type == pa.binary(3)
+ assert arr.storage.type == pa.binary(3)
+ assert arr.storage.to_pylist() == [b"foo", b"bar"]
+
+
+def example_batch():
+ ty = ParamExtType(3)
+ storage = pa.array([b"foo", b"bar"], type=pa.binary(3))
+ arr = pa.ExtensionArray.from_storage(ty, storage)
+ return pa.RecordBatch.from_arrays([arr], ["exts"])
+
+
+def check_example_batch(batch):
+ arr = batch.column(0)
+ assert isinstance(arr, pa.ExtensionArray)
+ assert arr.type.storage_type == pa.binary(3)
+ assert arr.storage.to_pylist() == [b"foo", b"bar"]
+ return arr
+
+
+def test_ipc():
+ batch = example_batch()
+ buf = ipc_write_batch(batch)
+ del batch
+
+ batch = ipc_read_batch(buf)
+ arr = check_example_batch(batch)
+ assert arr.type == ParamExtType(3)
+
+
+def test_ipc_unknown_type():
+ batch = example_batch()
+ buf = ipc_write_batch(batch)
+ del batch
+
+ orig_type = ParamExtType
+ try:
+ # Simulate the original Python type being unavailable.
+ # Deserialization should not fail but return a placeholder type.
+ del globals()['ParamExtType']
+
+ batch = ipc_read_batch(buf)
+ arr = check_example_batch(batch)
+ assert isinstance(arr.type, pa.UnknownExtensionType)
+
+ # Can be serialized again
+ buf2 = ipc_write_batch(batch)
+ del batch, arr
+
+ batch = ipc_read_batch(buf2)
+ arr = check_example_batch(batch)
+ assert isinstance(arr.type, pa.UnknownExtensionType)
+ finally:
+ globals()['ParamExtType'] = orig_type
+
+ # Deserialize again with the type restored
+ batch = ipc_read_batch(buf2)
+ arr = check_example_batch(batch)
+ assert arr.type == ParamExtType(3)
diff --git a/python/pyarrow/types.pxi b/python/pyarrow/types.pxi
index 9a92761..1f0db4c 100644
--- a/python/pyarrow/types.pxi
+++ b/python/pyarrow/types.pxi
@@ -19,6 +19,7 @@ import re
import warnings
from pyarrow import compat
+from pyarrow.compat import builtin_pickle
# These are imprecise because the type (in pandas 0.x) depends on the presence
@@ -103,7 +104,7 @@ cdef class DataType:
"functions like pyarrow.int64, pyarrow.list_, etc. "
"instead.".format(self.__class__.__name__))
- cdef void init(self, const shared_ptr[CDataType]& type):
+ cdef void init(self, const shared_ptr[CDataType]& type) except *:
self.sp_type = type
self.type = type.get()
self.pep3118_format = _datatype_to_pep3118(self.type)
@@ -203,7 +204,7 @@ cdef class DictionaryType(DataType):
Concrete class for dictionary data types.
"""
- cdef void init(self, const shared_ptr[CDataType]& type):
+ cdef void init(self, const shared_ptr[CDataType]& type) except *:
DataType.init(self, type)
self.dict_type = <const CDictionaryType*> type.get()
@@ -239,7 +240,7 @@ cdef class ListType(DataType):
Concrete class for list data types.
"""
- cdef void init(self, const shared_ptr[CDataType]& type):
+ cdef void init(self, const shared_ptr[CDataType]& type) except *:
DataType.init(self, type)
self.list_type = <const CListType*> type.get()
@@ -259,7 +260,7 @@ cdef class StructType(DataType):
Concrete class for struct data types.
"""
- cdef void init(self, const shared_ptr[CDataType]& type):
+ cdef void init(self, const shared_ptr[CDataType]& type) except *:
DataType.init(self, type)
self.struct_type = <const CStructType*> type.get()
@@ -318,7 +319,7 @@ cdef class UnionType(DataType):
Concrete class for struct data types.
"""
- cdef void init(self, const shared_ptr[CDataType]& type):
+ cdef void init(self, const shared_ptr[CDataType]& type) except *:
DataType.init(self, type)
@property
@@ -370,7 +371,7 @@ cdef class TimestampType(DataType):
Concrete class for timestamp data types.
"""
- cdef void init(self, const shared_ptr[CDataType]& type):
+ cdef void init(self, const shared_ptr[CDataType]& type) except *:
DataType.init(self, type)
self.ts_type = <const CTimestampType*> type.get()
@@ -411,7 +412,7 @@ cdef class Time32Type(DataType):
Concrete class for time32 data types.
"""
- cdef void init(self, const shared_ptr[CDataType]& type):
+ cdef void init(self, const shared_ptr[CDataType]& type) except *:
DataType.init(self, type)
self.time_type = <const CTime32Type*> type.get()
@@ -428,7 +429,7 @@ cdef class Time64Type(DataType):
Concrete class for time64 data types.
"""
- cdef void init(self, const shared_ptr[CDataType]& type):
+ cdef void init(self, const shared_ptr[CDataType]& type) except *:
DataType.init(self, type)
self.time_type = <const CTime64Type*> type.get()
@@ -445,7 +446,7 @@ cdef class FixedSizeBinaryType(DataType):
Concrete class for fixed-size binary data types.
"""
- cdef void init(self, const shared_ptr[CDataType]& type):
+ cdef void init(self, const shared_ptr[CDataType]& type) except *:
DataType.init(self, type)
self.fixed_size_binary_type = (
<const CFixedSizeBinaryType*> type.get())
@@ -466,7 +467,7 @@ cdef class Decimal128Type(FixedSizeBinaryType):
Concrete class for decimal128 data types.
"""
- cdef void init(self, const shared_ptr[CDataType]& type):
+ cdef void init(self, const shared_ptr[CDataType]& type) except *:
FixedSizeBinaryType.init(self, type)
self.decimal128_type = <const CDecimal128Type*> type.get()
@@ -488,6 +489,132 @@ cdef class Decimal128Type(FixedSizeBinaryType):
return self.decimal128_type.scale()
+cdef class BaseExtensionType(DataType):
+ """
+ Concrete base class for extension types.
+ """
+
+ cdef void init(self, const shared_ptr[CDataType]& type) except *:
+ DataType.init(self, type)
+ self.ext_type = <const CExtensionType*> type.get()
+
+ @property
+ def extension_name(self):
+ """
+ The extension type name.
+ """
+ return frombytes(self.ext_type.extension_name())
+
+ @property
+ def storage_type(self):
+ """
+ The underlying storage type.
+ """
+ return pyarrow_wrap_data_type(self.ext_type.storage_type())
+
+
+cdef class ExtensionType(BaseExtensionType):
+ """
+ Concrete base class for Python-defined extension types.
+ """
+
+ def __cinit__(self):
+ if type(self) is ExtensionType:
+ raise TypeError("Can only instantiate subclasses of "
+ "ExtensionType")
+
+ def __init__(self, DataType storage_type):
+ cdef:
+ shared_ptr[CExtensionType] cpy_ext_type
+
+ assert storage_type is not None
+ check_status(CPyExtensionType.FromClass(storage_type.sp_type,
+ type(self), &cpy_ext_type))
+ self.init(<shared_ptr[CDataType]> cpy_ext_type)
+
+ cdef void init(self, const shared_ptr[CDataType]& type) except *:
+ BaseExtensionType.init(self, type)
+ self.cpy_ext_type = <const CPyExtensionType*> type.get()
+ # Store weakref and serialized version of self on C++ type instance
+ check_status(self.cpy_ext_type.SetInstance(self))
+
+ def __eq__(self, other):
+ # Default implementation to avoid infinite recursion through
+ # DataType.__eq__ -> ExtensionType::ExtensionEquals -> DataType.__eq__
+ if isinstance(other, ExtensionType):
+ return (type(self) == type(other) and
+ self.storage_type == other.storage_type)
+ else:
+ return NotImplemented
+
+ def __reduce__(self):
+ raise NotImplementedError("Please implement {0}.__reduce__"
+ .format(type(self).__name__))
+
+ def __arrow_ext_serialize__(self):
+ return builtin_pickle.dumps(self)
+
+ @classmethod
+ def __arrow_ext_deserialize__(cls, storage_type, serialized):
+ try:
+ ty = builtin_pickle.loads(serialized)
+ except Exception:
+ # For some reason, it's impossible to deserialize the
+ # ExtensionType instance. Perhaps the serialized data is
+ # corrupt, or more likely the type is being deserialized
+ # in an environment where the original Python class or module
+ # is not available. Fall back on a generic BaseExtensionType.
+ return UnknownExtensionType(storage_type, serialized)
+
+ if ty.storage_type != storage_type:
+ raise TypeError("Expected storage type {0} but got {1}"
+ .format(ty.storage_type, storage_type))
+ return ty
+
+
+cdef class UnknownExtensionType(ExtensionType):
+ """
+ A concrete class for Python-defined extension types that refer to
+ an unknown Python implementation.
+ """
+
+ cdef:
+ bytes serialized
+
+ def __init__(self, DataType storage_type, serialized):
+ self.serialized = serialized
+ ExtensionType.__init__(self, storage_type)
+
+ def __arrow_ext_serialize__(self):
+ return self.serialized
+
+
+cdef class _ExtensionTypesInitializer:
+ #
+ # A private object that handles process-wide registration of the Python
+ # ExtensionType.
+ #
+
+ def __cinit__(self):
+ cdef:
+ DataType storage_type
+ shared_ptr[CExtensionType] cpy_ext_type
+
+ # Make a dummy C++ ExtensionType
+ storage_type = null()
+ check_status(CPyExtensionType.FromClass(storage_type.sp_type,
+ ExtensionType, &cpy_ext_type))
+ check_status(
+ RegisterPyExtensionType(<shared_ptr[CDataType]> cpy_ext_type))
+
+ def __dealloc__(self):
+ # This needs to be done explicitly before the Python interpreter is
+ # finalized. If the C++ type is destroyed later in the process
+ # teardown stage, it will invoke CPython APIs such as Py_DECREF
+ # with a destroyed interpreter.
+ check_status(UnregisterPyExtensionType())
+
+
cdef class Field:
"""
A named field, with a data type, nullability, and optional metadata.
@@ -1726,3 +1853,6 @@ def is_integer_value(object obj):
def is_float_value(object obj):
return IsPyFloat(obj)
+
+
+_extension_types_initializer = _ExtensionTypesInitializer()