This is an automated email from the ASF dual-hosted git repository.

xuanwo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-rust.git


The following commit(s) were added to refs/heads/main by this push:
     new eae94643 refactor(python): Expose transform as a submodule for 
pyiceberg_core (#628)
eae94643 is described below

commit eae946437e7fa34e95fd604ac3431c8f86b70597
Author: Xuanwo <[email protected]>
AuthorDate: Thu Sep 12 02:06:31 2024 +0800

    refactor(python): Expose transform as a submodule for pyiceberg_core (#628)
---
 bindings/python/Cargo.toml               |   4 +-
 bindings/python/src/{lib.rs => error.rs} |  23 ++-----
 bindings/python/src/lib.rs               |  16 +----
 bindings/python/src/transform.rs         | 104 ++++++++++++++++---------------
 bindings/python/tests/test_basic.py      |  22 -------
 bindings/python/tests/test_transform.py  |  28 ++++++---
 6 files changed, 83 insertions(+), 114 deletions(-)

diff --git a/bindings/python/Cargo.toml b/bindings/python/Cargo.toml
index 0260f788..469fc769 100644
--- a/bindings/python/Cargo.toml
+++ b/bindings/python/Cargo.toml
@@ -32,5 +32,5 @@ crate-type = ["cdylib"]
 
 [dependencies]
 iceberg = { path = "../../crates/iceberg" }
-pyo3 = { version = "0.21.1", features = ["extension-module"] }
-arrow = { version = "52.2.0", features = ["pyarrow"] }
+pyo3 = { version = "0.21", features = ["extension-module"] }
+arrow = { version = "52", features = ["pyarrow"] }
diff --git a/bindings/python/src/lib.rs b/bindings/python/src/error.rs
similarity index 64%
copy from bindings/python/src/lib.rs
copy to bindings/python/src/error.rs
index 5c3f77ff..a2d1424c 100644
--- a/bindings/python/src/lib.rs
+++ b/bindings/python/src/error.rs
@@ -15,23 +15,10 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use iceberg::io::FileIOBuilder;
-use pyo3::prelude::*;
-use pyo3::wrap_pyfunction;
+use pyo3::exceptions::PyValueError;
+use pyo3::PyErr;
 
-mod transform;
-
-#[pyfunction]
-fn hello_world() -> PyResult<String> {
-    let _ = FileIOBuilder::new_fs_io().build().unwrap();
-    Ok("Hello, world!".to_string())
-}
-
-
-#[pymodule]
-fn pyiceberg_core_rust(m: &Bound<'_, PyModule>) -> PyResult<()> {
-    m.add_function(wrap_pyfunction!(hello_world, m)?)?;
-
-    m.add_class::<transform::ArrowArrayTransform>()?;
-    Ok(())
+/// Convert an iceberg error to a python error
+pub fn to_py_err(err: iceberg::Error) -> PyErr {
+    PyValueError::new_err(err.to_string())
 }
diff --git a/bindings/python/src/lib.rs b/bindings/python/src/lib.rs
index 5c3f77ff..a16bdac4 100644
--- a/bindings/python/src/lib.rs
+++ b/bindings/python/src/lib.rs
@@ -15,23 +15,13 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use iceberg::io::FileIOBuilder;
 use pyo3::prelude::*;
-use pyo3::wrap_pyfunction;
 
+mod error;
 mod transform;
 
-#[pyfunction]
-fn hello_world() -> PyResult<String> {
-    let _ = FileIOBuilder::new_fs_io().build().unwrap();
-    Ok("Hello, world!".to_string())
-}
-
-
 #[pymodule]
-fn pyiceberg_core_rust(m: &Bound<'_, PyModule>) -> PyResult<()> {
-    m.add_function(wrap_pyfunction!(hello_world, m)?)?;
-
-    m.add_class::<transform::ArrowArrayTransform>()?;
+fn pyiceberg_core_rust(py: Python<'_>, m: &Bound<'_, PyModule>) -> 
PyResult<()> {
+    transform::register_module(py, m)?;
     Ok(())
 }
diff --git a/bindings/python/src/transform.rs b/bindings/python/src/transform.rs
index 8f4585b2..5b0d82f2 100644
--- a/bindings/python/src/transform.rs
+++ b/bindings/python/src/transform.rs
@@ -15,24 +15,55 @@
 // specific language governing permissions and limitations
 // under the License.
 
+use arrow::array::{make_array, Array, ArrayData};
+use arrow::pyarrow::{FromPyArrow, ToPyArrow};
 use iceberg::spec::Transform;
 use iceberg::transform::create_transform_function;
+use pyo3::prelude::*;
 
-use arrow::{
-    array::{make_array, Array, ArrayData},
-};
-use arrow::pyarrow::{FromPyArrow, ToPyArrow};
-use pyo3::{exceptions::PyValueError, prelude::*};
+use crate::error::to_py_err;
+
+#[pyfunction]
+pub fn identity(py: Python, array: PyObject) -> PyResult<PyObject> {
+    apply(py, array, Transform::Identity)
+}
+
+#[pyfunction]
+pub fn void(py: Python, array: PyObject) -> PyResult<PyObject> {
+    apply(py, array, Transform::Void)
+}
+
+#[pyfunction]
+pub fn year(py: Python, array: PyObject) -> PyResult<PyObject> {
+    apply(py, array, Transform::Year)
+}
+
+#[pyfunction]
+pub fn month(py: Python, array: PyObject) -> PyResult<PyObject> {
+    apply(py, array, Transform::Month)
+}
 
-fn to_py_err(err: iceberg::Error) -> PyErr {
-    PyValueError::new_err(err.to_string())
+#[pyfunction]
+pub fn day(py: Python, array: PyObject) -> PyResult<PyObject> {
+    apply(py, array, Transform::Day)
 }
 
-#[pyclass]
-pub struct ArrowArrayTransform {
+#[pyfunction]
+pub fn hour(py: Python, array: PyObject) -> PyResult<PyObject> {
+    apply(py, array, Transform::Hour)
 }
 
-fn apply(array: PyObject, transform: Transform, py: Python) -> 
PyResult<PyObject> {
+#[pyfunction]
+pub fn bucket(py: Python, array: PyObject, num_buckets: u32) -> 
PyResult<PyObject> {
+    apply(py, array, Transform::Bucket(num_buckets))
+}
+
+#[pyfunction]
+pub fn truncate(py: Python, array: PyObject, width: u32) -> PyResult<PyObject> 
{
+    apply(py, array, Transform::Truncate(width))
+}
+
+fn apply(py: Python, array: PyObject, transform: Transform) -> 
PyResult<PyObject> {
     // import
     let array = ArrayData::from_pyarrow_bound(array.bind(py))?;
     let array = make_array(array);
@@ -43,45 +74,20 @@ fn apply(array: PyObject, transform: Transform, py: Python) 
-> PyResult<PyObject
     array.to_pyarrow(py)
 }
 
-#[pymethods]
-impl ArrowArrayTransform {
-    #[staticmethod]
-    pub fn identity(array: PyObject, py: Python) -> PyResult<PyObject> {
-        apply(array, Transform::Identity, py)
-    }
-
-    #[staticmethod]
-    pub fn void(array: PyObject, py: Python) -> PyResult<PyObject> {
-        apply(array, Transform::Void, py)
-    }
-
-    #[staticmethod]
-    pub fn year(array: PyObject, py: Python) -> PyResult<PyObject> {
-        apply(array, Transform::Year, py)
-    }
-
-    #[staticmethod]
-    pub fn month(array: PyObject, py: Python) -> PyResult<PyObject> {
-        apply(array, Transform::Month, py)
-    }
-
-    #[staticmethod]
-    pub fn day(array: PyObject, py: Python) -> PyResult<PyObject> {
-        apply(array, Transform::Day, py)
-    }
-
-    #[staticmethod]
-    pub fn hour(array: PyObject, py: Python) -> PyResult<PyObject> {
-        apply(array, Transform::Hour, py)
-    }
+pub fn register_module(py: Python<'_>, m: &Bound<'_, PyModule>) -> 
PyResult<()> {
+    let this = PyModule::new_bound(py, "transform")?;
 
-    #[staticmethod]
-    pub fn bucket(array: PyObject, num_buckets: u32, py: Python) -> 
PyResult<PyObject> {
-        apply(array, Transform::Bucket(num_buckets), py)
-    }
+    this.add_function(wrap_pyfunction!(identity, &this)?)?;
+    this.add_function(wrap_pyfunction!(void, &this)?)?;
+    this.add_function(wrap_pyfunction!(year, &this)?)?;
+    this.add_function(wrap_pyfunction!(month, &this)?)?;
+    this.add_function(wrap_pyfunction!(day, &this)?)?;
+    this.add_function(wrap_pyfunction!(hour, &this)?)?;
+    this.add_function(wrap_pyfunction!(bucket, &this)?)?;
+    this.add_function(wrap_pyfunction!(truncate, &this)?)?;
 
-    #[staticmethod]
-    pub fn truncate(array: PyObject, width: u32, py: Python) -> 
PyResult<PyObject> {
-        apply(array, Transform::Truncate(width), py)
-    }
+    m.add_submodule(&this)?;
+    py.import_bound("sys")?
+        .getattr("modules")?
+        .set_item("pyiceberg_core.transform", this)
 }
diff --git a/bindings/python/tests/test_basic.py 
b/bindings/python/tests/test_basic.py
deleted file mode 100644
index 817793ba..00000000
--- a/bindings/python/tests/test_basic.py
+++ /dev/null
@@ -1,22 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-from pyiceberg_core import hello_world
-
-
-def test_hello_world():
-    hello_world()
diff --git a/bindings/python/tests/test_transform.py 
b/bindings/python/tests/test_transform.py
index 1fa2d577..4180b690 100644
--- a/bindings/python/tests/test_transform.py
+++ b/bindings/python/tests/test_transform.py
@@ -19,18 +19,18 @@ from datetime import date, datetime
 
 import pyarrow as pa
 import pytest
-from pyiceberg_core import ArrowArrayTransform
+from pyiceberg_core import transform
 
 
 def test_identity_transform():
     arr = pa.array([1, 2])
-    result = ArrowArrayTransform.identity(arr)
+    result = transform.identity(arr)
     assert result == arr
 
 
 def test_bucket_transform():
     arr = pa.array([1, 2])
-    result = ArrowArrayTransform.bucket(arr, 10)
+    result = transform.bucket(arr, 10)
     expected = pa.array([6, 2], type=pa.int32())
     assert result == expected
 
@@ -41,14 +41,14 @@ def test_bucket_transform_fails_for_list_type_input():
         ValueError,
         match=r"FeatureUnsupported => Unsupported data type for bucket 
transform",
     ):
-        ArrowArrayTransform.bucket(arr, 10)
+        transform.bucket(arr, 10)
 
 
 def test_bucket_chunked_array():
     chunked = pa.chunked_array([pa.array([1, 2]), pa.array([3, 4])])
     result_chunks = []
     for arr in chunked.iterchunks():
-        result_chunks.append(ArrowArrayTransform.bucket(arr, 10))
+        result_chunks.append(transform.bucket(arr, 10))
 
     expected = pa.chunked_array(
         [pa.array([6, 2], type=pa.int32()), pa.array([5, 0], type=pa.int32())]
@@ -58,34 +58,42 @@ def test_bucket_chunked_array():
 
 def test_year_transform():
     arr = pa.array([date(1970, 1, 1), date(2000, 1, 1)])
-    result = ArrowArrayTransform.year(arr)
+    result = transform.year(arr)
     expected = pa.array([0, 30], type=pa.int32())
     assert result == expected
 
 
 def test_month_transform():
     arr = pa.array([date(1970, 1, 1), date(2000, 4, 1)])
-    result = ArrowArrayTransform.month(arr)
+    result = transform.month(arr)
     expected = pa.array([0, 30 * 12 + 3], type=pa.int32())
     assert result == expected
 
 
 def test_day_transform():
     arr = pa.array([date(1970, 1, 1), date(2000, 4, 1)])
-    result = ArrowArrayTransform.day(arr)
+    result = transform.day(arr)
     expected = pa.array([0, 11048], type=pa.int32())
     assert result == expected
 
 
 def test_hour_transform():
     arr = pa.array([datetime(1970, 1, 1, 19, 1, 23), datetime(2000, 3, 1, 12, 
1, 23)])
-    result = ArrowArrayTransform.hour(arr)
+    result = transform.hour(arr)
     expected = pa.array([19, 264420], type=pa.int32())
     assert result == expected
 
 
 def test_truncate_transform():
     arr = pa.array(["this is a long string", "hi my name is sung"])
-    result = ArrowArrayTransform.truncate(arr, 5)
+    result = transform.truncate(arr, 5)
     expected = pa.array(["this ", "hi my"])
     assert result == expected
+
+
+def test_identity_transform_with_direct_import():
+    from pyiceberg_core.transform import identity
+
+    arr = pa.array([1, 2])
+    result = identity(arr)
+    assert result == arr

Reply via email to