This is an automated email from the ASF dual-hosted git repository.
andygrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-java.git
The following commit(s) were added to refs/heads/main by this push:
new 757a078 feat(proto): execute Java-built protobuf plans via
SessionContext (#13)
757a078 is described below
commit 757a078c3ed70c0627b29d6ccf081eaf281761bc
Author: Andy Grove <[email protected]>
AuthorDate: Wed May 13 00:33:54 2026 -0600
feat(proto): execute Java-built protobuf plans via SessionContext (#13)
---
native/Cargo.lock | 64 +++++++
native/Cargo.toml | 2 +
native/src/lib.rs | 3 +-
native/src/proto.rs | 81 ++++++++
.../java/org/apache/datafusion/SessionContext.java | 44 +++++
.../apache/datafusion/proto/SchemaConverter.java | 212 +++++++++++++++++++++
.../datafusion/proto/SchemaConverterTest.java | 107 +++++++++++
.../datafusion/proto/SessionContextProtoTest.java | 188 ++++++++++++++++++
8 files changed, 700 insertions(+), 1 deletion(-)
diff --git a/native/Cargo.lock b/native/Cargo.lock
index 5ebef8e..004fa98 100644
--- a/native/Cargo.lock
+++ b/native/Cargo.lock
@@ -1140,7 +1140,9 @@ version = "0.1.0"
dependencies = [
"arrow",
"datafusion",
+ "datafusion-proto",
"jni",
+ "prost",
"tokio",
]
@@ -1282,6 +1284,45 @@ dependencies = [
"tokio",
]
+[[package]]
+name = "datafusion-proto"
+version = "53.1.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "6a387aaef949dc16bb6abc81bd1af850ec7449183aef011214f9724957495738"
+dependencies = [
+ "arrow",
+ "chrono",
+ "datafusion-catalog",
+ "datafusion-catalog-listing",
+ "datafusion-common",
+ "datafusion-datasource",
+ "datafusion-datasource-arrow",
+ "datafusion-datasource-csv",
+ "datafusion-datasource-json",
+ "datafusion-datasource-parquet",
+ "datafusion-execution",
+ "datafusion-expr",
+ "datafusion-functions-table",
+ "datafusion-physical-expr",
+ "datafusion-physical-expr-common",
+ "datafusion-physical-plan",
+ "datafusion-proto-common",
+ "object_store",
+ "prost",
+ "rand",
+]
+
+[[package]]
+name = "datafusion-proto-common"
+version = "53.1.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "16e614c7c53a9c304c6a850b821010bb492e57300311835f1180613f9d2c63d9"
+dependencies = [
+ "arrow",
+ "datafusion-common",
+ "prost",
+]
+
[[package]]
name = "datafusion-pruning"
version = "53.1.0"
@@ -2281,6 +2322,29 @@ dependencies = [
"unicode-ident",
]
+[[package]]
+name = "prost"
+version = "0.14.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "d2ea70524a2f82d518bce41317d0fae74151505651af45faf1ffbd6fd33f0568"
+dependencies = [
+ "bytes",
+ "prost-derive",
+]
+
+[[package]]
+name = "prost-derive"
+version = "0.14.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "27c6023962132f4b30eb4c172c91ce92d933da334c59c23cddee82358ddafb0b"
+dependencies = [
+ "anyhow",
+ "itertools",
+ "proc-macro2",
+ "quote",
+ "syn",
+]
+
[[package]]
name = "psm"
version = "0.1.31"
diff --git a/native/Cargo.toml b/native/Cargo.toml
index 033fb4c..f2a6818 100644
--- a/native/Cargo.toml
+++ b/native/Cargo.toml
@@ -27,5 +27,7 @@ crate-type = ["cdylib"]
[dependencies]
arrow = { version = "58", features = ["ffi"] }
datafusion = "53.1.0"
+datafusion-proto = "53.1.0"
jni = "0.21"
+prost = "0.14"
tokio = { version = "1", features = ["rt-multi-thread"] }
diff --git a/native/src/lib.rs b/native/src/lib.rs
index 42ca5f5..463d075 100644
--- a/native/src/lib.rs
+++ b/native/src/lib.rs
@@ -16,6 +16,7 @@
// under the License.
mod errors;
+mod proto;
use std::sync::{Arc, OnceLock};
@@ -33,7 +34,7 @@ use tokio::runtime::Runtime;
use crate::errors::{try_unwrap_or_throw, JniResult};
-fn runtime() -> &'static Runtime {
+pub(crate) fn runtime() -> &'static Runtime {
static RT: OnceLock<Runtime> = OnceLock::new();
RT.get_or_init(|| Runtime::new().expect("failed to create Tokio runtime"))
}
diff --git a/native/src/proto.rs b/native/src/proto.rs
new file mode 100644
index 0000000..429a91e
--- /dev/null
+++ b/native/src/proto.rs
@@ -0,0 +1,81 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::sync::Arc;
+
+use datafusion::arrow::datatypes::SchemaRef;
+use datafusion::arrow::ipc::writer::StreamWriter;
+use datafusion::dataframe::DataFrame;
+use datafusion::prelude::SessionContext;
+use datafusion_proto::logical_plan::{AsLogicalPlan,
DefaultLogicalExtensionCodec};
+use datafusion_proto::protobuf::LogicalPlanNode;
+use jni::objects::{JByteArray, JClass, JString};
+use jni::sys::{jbyteArray, jlong};
+use jni::JNIEnv;
+use prost::Message;
+
+use crate::errors::{try_unwrap_or_throw, JniResult};
+use crate::runtime;
+
+#[no_mangle]
+pub extern "system" fn
Java_org_apache_datafusion_SessionContext_createDataFrameFromProto<'local>(
+ mut env: JNIEnv<'local>,
+ _class: JClass<'local>,
+ handle: jlong,
+ plan_bytes: JByteArray<'local>,
+) -> jlong {
+ try_unwrap_or_throw(&mut env, 0, |env| -> JniResult<jlong> {
+ if handle == 0 {
+ return Err("SessionContext handle is null".into());
+ }
+ let ctx = unsafe { &*(handle as *const SessionContext) };
+ let bytes: Vec<u8> = env.convert_byte_array(&plan_bytes)?;
+ let node = LogicalPlanNode::decode(bytes.as_slice())?;
+ let codec = DefaultLogicalExtensionCodec {};
+ let task_ctx = ctx.task_ctx();
+ let plan = node.try_into_logical_plan(task_ctx.as_ref(), &codec)?;
+ let df: DataFrame =
runtime().block_on(ctx.execute_logical_plan(plan))?;
+ Ok(Box::into_raw(Box::new(df)) as jlong)
+ })
+}
+
+#[no_mangle]
+pub extern "system" fn
Java_org_apache_datafusion_SessionContext_tableSchemaIpc<'local>(
+ mut env: JNIEnv<'local>,
+ _class: JClass<'local>,
+ handle: jlong,
+ name: JString<'local>,
+) -> jbyteArray {
+ try_unwrap_or_throw(&mut env, std::ptr::null_mut(), |env| ->
JniResult<jbyteArray> {
+ if handle == 0 {
+ return Err("SessionContext handle is null".into());
+ }
+ let ctx = unsafe { &*(handle as *const SessionContext) };
+ let name: String = env.get_string(&name)?.into();
+
+ let df = runtime().block_on(ctx.table(name.as_str()))?;
+ let schema: SchemaRef = Arc::new(df.schema().as_arrow().clone());
+
+ let mut buf: Vec<u8> = Vec::new();
+ {
+ let mut writer = StreamWriter::try_new(&mut buf, schema.as_ref())?;
+ writer.finish()?;
+ }
+ let arr = env.byte_array_from_slice(&buf)?;
+ Ok(arr.into_raw())
+ })
+}
diff --git a/src/main/java/org/apache/datafusion/SessionContext.java
b/src/main/java/org/apache/datafusion/SessionContext.java
index fb79d43..823ee13 100644
--- a/src/main/java/org/apache/datafusion/SessionContext.java
+++ b/src/main/java/org/apache/datafusion/SessionContext.java
@@ -19,6 +19,7 @@
package org.apache.datafusion;
+import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.channels.Channels;
@@ -27,6 +28,8 @@ import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.ipc.ArrowStreamWriter;
+import org.apache.arrow.vector.ipc.ReadChannel;
+import org.apache.arrow.vector.ipc.message.MessageSerializer;
import org.apache.arrow.vector.types.pojo.Schema;
/**
@@ -63,6 +66,43 @@ public final class SessionContext implements AutoCloseable {
return new DataFrame(dfHandle);
}
+ /**
+ * Decode a DataFusion-Proto {@code LogicalPlanNode} and return a lazy
{@link DataFrame}. The plan
+ * is not executed until {@link DataFrame#collect} is called.
+ *
+ * <p>The bytes must be a serialized {@code datafusion.LogicalPlanNode} (see
{@code
+ * org.apache.datafusion.protobuf.LogicalPlanNode}).
+ *
+ * @throws RuntimeException if the bytes are not a valid {@code
LogicalPlanNode} or if logical
+ * planning fails.
+ */
+ public DataFrame fromProto(byte[] planBytes) {
+ if (nativeHandle == 0) {
+ throw new IllegalStateException("SessionContext is closed");
+ }
+ long dfHandle = createDataFrameFromProto(nativeHandle, planBytes);
+ return new DataFrame(dfHandle);
+ }
+
+ /**
+ * Return the Arrow {@link Schema} of a registered table. Transferred via
Arrow IPC; no {@link
+ * org.apache.arrow.memory.BufferAllocator} is required because a schema
carries no buffer data.
+ *
+ * @throws RuntimeException if {@code tableName} is not registered in this
context.
+ */
+ public Schema tableSchema(String tableName) {
+ if (nativeHandle == 0) {
+ throw new IllegalStateException("SessionContext is closed");
+ }
+ byte[] ipcBytes = tableSchemaIpc(nativeHandle, tableName);
+ try {
+ return MessageSerializer.deserializeSchema(
+ new ReadChannel(Channels.newChannel(new
ByteArrayInputStream(ipcBytes))));
+ } catch (IOException e) {
+ throw new RuntimeException("Failed to deserialize IPC schema", e);
+ }
+ }
+
public void registerParquet(String name, String path) {
registerParquet(name, path, new ParquetReadOptions());
}
@@ -142,6 +182,10 @@ public final class SessionContext implements AutoCloseable
{
private static native long createDataFrame(long handle, String sql);
+ private static native long createDataFrameFromProto(long handle, byte[]
planBytes);
+
+ private static native byte[] tableSchemaIpc(long handle, String tableName);
+
private static native void registerParquetWithOptions(
long handle,
String name,
diff --git a/src/main/java/org/apache/datafusion/proto/SchemaConverter.java
b/src/main/java/org/apache/datafusion/proto/SchemaConverter.java
new file mode 100644
index 0000000..96d0d76
--- /dev/null
+++ b/src/main/java/org/apache/datafusion/proto/SchemaConverter.java
@@ -0,0 +1,212 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.datafusion.proto;
+
+import java.util.ArrayList;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+
+import org.apache.arrow.vector.types.DateUnit;
+import org.apache.arrow.vector.types.FloatingPointPrecision;
+import org.apache.arrow.vector.types.pojo.ArrowType;
+import org.apache.arrow.vector.types.pojo.Field;
+import org.apache.arrow.vector.types.pojo.FieldType;
+import org.apache.arrow.vector.types.pojo.Schema;
+
+import datafusion_common.DatafusionCommon;
+
+/**
+ * Convert between Arrow Java {@link Schema} and the {@code
datafusion_common.Schema} protobuf shape
+ * used by DataFusion plan messages such as {@code
ListingTableScanNode.schema}.
+ *
+ * <p>Supports the primitive Arrow types this project's tests exercise (Bool,
signed/unsigned Int
+ * 8..64, Float32/64, Utf8, Utf8View, LargeUtf8, Date32, Decimal128). Anything
else raises {@link
+ * UnsupportedOperationException} with a message naming the offending type.
+ */
+public final class SchemaConverter {
+
+ private SchemaConverter() {}
+
+ public static DatafusionCommon.Schema toProto(Schema arrow) {
+ DatafusionCommon.Schema.Builder builder =
DatafusionCommon.Schema.newBuilder();
+ for (Field f : arrow.getFields()) {
+ builder.addColumns(fieldToProto(f));
+ }
+ if (arrow.getCustomMetadata() != null) {
+ builder.putAllMetadata(arrow.getCustomMetadata());
+ }
+ return builder.build();
+ }
+
+ public static Schema fromProto(DatafusionCommon.Schema proto) {
+ List<Field> fields = new ArrayList<>(proto.getColumnsCount());
+ for (DatafusionCommon.Field f : proto.getColumnsList()) {
+ fields.add(fieldFromProto(f));
+ }
+ Map<String, String> metadata = new LinkedHashMap<>(proto.getMetadataMap());
+ return new Schema(fields, metadata);
+ }
+
+ static DatafusionCommon.Field fieldToProto(Field f) {
+ DatafusionCommon.Field.Builder b =
+ DatafusionCommon.Field.newBuilder()
+ .setName(f.getName())
+ .setArrowType(arrowTypeToProto(f.getType()))
+ .setNullable(f.isNullable());
+ if (f.getMetadata() != null) {
+ b.putAllMetadata(f.getMetadata());
+ }
+ return b.build();
+ }
+
+ static Field fieldFromProto(DatafusionCommon.Field f) {
+ ArrowType type = arrowTypeFromProto(f.getArrowType());
+ FieldType ft = new FieldType(f.getNullable(), type, null,
f.getMetadataMap());
+ return new Field(f.getName(), ft, null);
+ }
+
+ static DatafusionCommon.ArrowType arrowTypeToProto(ArrowType t) {
+ DatafusionCommon.EmptyMessage empty =
DatafusionCommon.EmptyMessage.getDefaultInstance();
+ DatafusionCommon.ArrowType.Builder b =
DatafusionCommon.ArrowType.newBuilder();
+
+ if (t instanceof ArrowType.Bool) {
+ return b.setBOOL(empty).build();
+ }
+ if (t instanceof ArrowType.Int) {
+ ArrowType.Int i = (ArrowType.Int) t;
+ if (i.getIsSigned()) {
+ switch (i.getBitWidth()) {
+ case 8:
+ return b.setINT8(empty).build();
+ case 16:
+ return b.setINT16(empty).build();
+ case 32:
+ return b.setINT32(empty).build();
+ case 64:
+ return b.setINT64(empty).build();
+ default:
+ throw new UnsupportedOperationException(
+ "Arrow type Int signed width "
+ + i.getBitWidth()
+ + " not yet supported by SchemaConverter");
+ }
+ } else {
+ switch (i.getBitWidth()) {
+ case 8:
+ return b.setUINT8(empty).build();
+ case 16:
+ return b.setUINT16(empty).build();
+ case 32:
+ return b.setUINT32(empty).build();
+ case 64:
+ return b.setUINT64(empty).build();
+ default:
+ throw new UnsupportedOperationException(
+ "Arrow type Int unsigned width "
+ + i.getBitWidth()
+ + " not yet supported by SchemaConverter");
+ }
+ }
+ }
+ if (t instanceof ArrowType.FloatingPoint) {
+ FloatingPointPrecision p = ((ArrowType.FloatingPoint) t).getPrecision();
+ if (p == FloatingPointPrecision.SINGLE) return
b.setFLOAT32(empty).build();
+ if (p == FloatingPointPrecision.DOUBLE) return
b.setFLOAT64(empty).build();
+ throw new UnsupportedOperationException(
+ "Arrow type FloatingPoint " + p + " not yet supported by
SchemaConverter");
+ }
+ if (t instanceof ArrowType.Utf8) {
+ return b.setUTF8(empty).build();
+ }
+ if (t instanceof ArrowType.Utf8View) {
+ return b.setUTF8VIEW(empty).build();
+ }
+ if (t instanceof ArrowType.LargeUtf8) {
+ return b.setLARGEUTF8(empty).build();
+ }
+ if (t instanceof ArrowType.Date) {
+ DateUnit u = ((ArrowType.Date) t).getUnit();
+ if (u == DateUnit.DAY) return b.setDATE32(empty).build();
+ throw new UnsupportedOperationException(
+ "Arrow type Date " + u + " not yet supported by SchemaConverter");
+ }
+ if (t instanceof ArrowType.Decimal) {
+ ArrowType.Decimal d = (ArrowType.Decimal) t;
+ if (d.getBitWidth() != 128) {
+ throw new UnsupportedOperationException(
+ "Arrow type Decimal bit width "
+ + d.getBitWidth()
+ + " not yet supported by SchemaConverter");
+ }
+ return b.setDECIMAL128(
+ DatafusionCommon.Decimal128Type.newBuilder()
+ .setPrecision(d.getPrecision())
+ .setScale(d.getScale())
+ .build())
+ .build();
+ }
+ throw new UnsupportedOperationException(
+ "Arrow type " + t.getClass().getSimpleName() + " not yet supported by
SchemaConverter");
+ }
+
+ static ArrowType arrowTypeFromProto(DatafusionCommon.ArrowType p) {
+ switch (p.getArrowTypeEnumCase()) {
+ case BOOL:
+ return ArrowType.Bool.INSTANCE;
+ case INT8:
+ return new ArrowType.Int(8, true);
+ case INT16:
+ return new ArrowType.Int(16, true);
+ case INT32:
+ return new ArrowType.Int(32, true);
+ case INT64:
+ return new ArrowType.Int(64, true);
+ case UINT8:
+ return new ArrowType.Int(8, false);
+ case UINT16:
+ return new ArrowType.Int(16, false);
+ case UINT32:
+ return new ArrowType.Int(32, false);
+ case UINT64:
+ return new ArrowType.Int(64, false);
+ case FLOAT32:
+ return new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE);
+ case FLOAT64:
+ return new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE);
+ case UTF8:
+ return ArrowType.Utf8.INSTANCE;
+ case UTF8_VIEW:
+ return ArrowType.Utf8View.INSTANCE;
+ case LARGE_UTF8:
+ return ArrowType.LargeUtf8.INSTANCE;
+ case DATE32:
+ return new ArrowType.Date(DateUnit.DAY);
+ case DECIMAL128:
+ DatafusionCommon.Decimal128Type d = p.getDECIMAL128();
+ return new ArrowType.Decimal(d.getPrecision(), d.getScale(), 128);
+ default:
+ throw new UnsupportedOperationException(
+ "datafusion_common.ArrowType "
+ + p.getArrowTypeEnumCase()
+ + " not yet supported by SchemaConverter");
+ }
+ }
+}
diff --git a/src/test/java/org/apache/datafusion/proto/SchemaConverterTest.java
b/src/test/java/org/apache/datafusion/proto/SchemaConverterTest.java
new file mode 100644
index 0000000..dd4cd23
--- /dev/null
+++ b/src/test/java/org/apache/datafusion/proto/SchemaConverterTest.java
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.datafusion.proto;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import org.apache.arrow.vector.types.DateUnit;
+import org.apache.arrow.vector.types.FloatingPointPrecision;
+import org.apache.arrow.vector.types.pojo.ArrowType;
+import org.apache.arrow.vector.types.pojo.Field;
+import org.apache.arrow.vector.types.pojo.FieldType;
+import org.apache.arrow.vector.types.pojo.Schema;
+import org.junit.jupiter.api.Test;
+
+import datafusion_common.DatafusionCommon;
+
+class SchemaConverterTest {
+
+ @Test
+ void roundTripsPrimitivesAndMetadata() {
+ Map<String, String> schemaMeta = new HashMap<>();
+ schemaMeta.put("origin", "test");
+ Map<String, String> fieldMeta = new HashMap<>();
+ fieldMeta.put("ns", "demo");
+
+ List<Field> fields =
+ Arrays.asList(
+ new Field("a_bool", FieldType.nullable(ArrowType.Bool.INSTANCE),
null),
+ new Field("a_i32", FieldType.notNullable(new ArrowType.Int(32,
true)), null),
+ new Field("a_i64", FieldType.nullable(new ArrowType.Int(64,
true)), null),
+ new Field("a_u32", FieldType.nullable(new ArrowType.Int(32,
false)), null),
+ new Field(
+ "a_f64",
+ FieldType.nullable(new
ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)),
+ null),
+ new Field("a_str", FieldType.nullable(ArrowType.Utf8.INSTANCE),
null),
+ new Field("a_date", FieldType.nullable(new
ArrowType.Date(DateUnit.DAY)), null),
+ new Field(
+ "a_with_meta",
+ new FieldType(true, new ArrowType.Int(8, true), null,
fieldMeta),
+ null));
+
+ Schema original = new Schema(fields, schemaMeta);
+
+ DatafusionCommon.Schema proto = SchemaConverter.toProto(original);
+ Schema roundTripped = SchemaConverter.fromProto(proto);
+
+ assertEquals(original, roundTripped);
+ }
+
+ @Test
+ void decimalPreservesPrecisionAndScale() {
+ Schema original =
+ new Schema(
+ List.of(
+ new Field("amount", FieldType.nullable(new
ArrowType.Decimal(18, 5, 128)), null)));
+
+ DatafusionCommon.Schema proto = SchemaConverter.toProto(original);
+ Schema roundTripped = SchemaConverter.fromProto(proto);
+
+ assertEquals(original, roundTripped);
+ ArrowType.Decimal d = (ArrowType.Decimal)
roundTripped.getFields().get(0).getType();
+ assertEquals(18, d.getPrecision());
+ assertEquals(5, d.getScale());
+ assertEquals(128, d.getBitWidth());
+ }
+
+ @Test
+ void unsupportedTypeRaisesUnsupportedOperationException() {
+ Field listField =
+ new Field(
+ "nested",
+ FieldType.nullable(new ArrowType.List()),
+ List.of(new Field("item", FieldType.nullable(new ArrowType.Int(32,
true)), null)));
+ Schema original = new Schema(List.of(listField));
+
+ UnsupportedOperationException ex =
+ assertThrows(UnsupportedOperationException.class, () ->
SchemaConverter.toProto(original));
+ assertTrue(
+ ex.getMessage().contains("List"),
+ "exception message should name the unsupported type, was: " +
ex.getMessage());
+ }
+}
diff --git
a/src/test/java/org/apache/datafusion/proto/SessionContextProtoTest.java
b/src/test/java/org/apache/datafusion/proto/SessionContextProtoTest.java
new file mode 100644
index 0000000..7bb65e6
--- /dev/null
+++ b/src/test/java/org/apache/datafusion/proto/SessionContextProtoTest.java
@@ -0,0 +1,188 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.datafusion.proto;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+import java.nio.file.Files;
+import java.nio.file.Path;
+
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.memory.RootAllocator;
+import org.apache.arrow.vector.BigIntVector;
+import org.apache.arrow.vector.IntVector;
+import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.arrow.vector.ipc.ArrowReader;
+import org.apache.arrow.vector.types.DateUnit;
+import org.apache.arrow.vector.types.pojo.ArrowType;
+import org.apache.arrow.vector.types.pojo.Schema;
+import org.apache.datafusion.DataFrame;
+import org.apache.datafusion.SessionContext;
+import org.apache.datafusion.protobuf.BareTableReference;
+import org.apache.datafusion.protobuf.EmptyRelationNode;
+import org.apache.datafusion.protobuf.ListingTableScanNode;
+import org.apache.datafusion.protobuf.LogicalExprNode;
+import org.apache.datafusion.protobuf.LogicalPlanNode;
+import org.apache.datafusion.protobuf.ProjectionColumns;
+import org.apache.datafusion.protobuf.ProjectionNode;
+import org.apache.datafusion.protobuf.SortExprNode;
+import org.apache.datafusion.protobuf.SortNode;
+import org.apache.datafusion.protobuf.TableReference;
+import org.junit.jupiter.api.Assumptions;
+import org.junit.jupiter.api.Test;
+
+import datafusion_common.DatafusionCommon;
+
+class SessionContextProtoTest {
+
+ @Test
+ void fromProtoExecutesProjectionOverEmptyRelation() throws Exception {
+ LogicalPlanNode plan =
+ LogicalPlanNode.newBuilder()
+ .setProjection(
+ ProjectionNode.newBuilder()
+ .setInput(
+ LogicalPlanNode.newBuilder()
+ .setEmptyRelation(
+
EmptyRelationNode.newBuilder().setProduceOneRow(true).build())
+ .build())
+ .addExpr(
+ LogicalExprNode.newBuilder()
+ .setLiteral(
+
DatafusionCommon.ScalarValue.newBuilder().setInt32Value(1).build())
+ .build())
+ .build())
+ .build();
+
+ try (BufferAllocator allocator = new RootAllocator();
+ SessionContext ctx = new SessionContext();
+ DataFrame df = ctx.fromProto(plan.toByteArray());
+ ArrowReader reader = df.collect(allocator)) {
+ assertTrue(reader.loadNextBatch());
+ VectorSchemaRoot root = reader.getVectorSchemaRoot();
+ assertEquals(1, root.getRowCount());
+ IntVector col = (IntVector) root.getVector(0);
+ assertEquals(1, col.get(0));
+ }
+ }
+
+ @Test
+ void tableSchemaReturnsExpectedLineitemFields() {
+ Path lineitem = Path.of("tpch-data/sf1/lineitem.parquet");
+ Assumptions.assumeTrue(
+ Files.exists(lineitem), "TPC-H SF1 data not found; run `make
tpch-data` first");
+
+ try (SessionContext ctx = new SessionContext()) {
+ ctx.registerParquet("lineitem", lineitem.toAbsolutePath().toString());
+ Schema schema = ctx.tableSchema("lineitem");
+
+ assertEquals(16, schema.getFields().size());
+ assertEquals("l_orderkey", schema.getFields().get(0).getName());
+ assertTrue(schema.getFields().get(0).getType() instanceof ArrowType.Int);
+
+ // l_extendedprice = Decimal128(15, 2)
+ ArrowType.Decimal price = (ArrowType.Decimal)
schema.findField("l_extendedprice").getType();
+ assertEquals(15, price.getPrecision());
+ assertEquals(2, price.getScale());
+
+ // l_shipdate = Date(DAY)
+ ArrowType.Date ship = (ArrowType.Date)
schema.findField("l_shipdate").getType();
+ assertEquals(DateUnit.DAY, ship.getUnit());
+ }
+ }
+
+ @Test
+ void fromProtoListingScanMatchesSql() throws Exception {
+ Path lineitem = Path.of("tpch-data/sf1/lineitem.parquet");
+ Assumptions.assumeTrue(
+ Files.exists(lineitem), "TPC-H SF1 data not found; run `make
tpch-data` first");
+ String absPath = lineitem.toAbsolutePath().toString();
+
+ try (BufferAllocator allocator = new RootAllocator();
+ SessionContext ctx = new SessionContext()) {
+ ctx.registerParquet("lineitem", absPath);
+ Schema arrow = ctx.tableSchema("lineitem");
+ DatafusionCommon.Schema schemaProto = SchemaConverter.toProto(arrow);
+
+ LogicalExprNode orderKeyExpr =
+ LogicalExprNode.newBuilder()
+
.setColumn(DatafusionCommon.Column.newBuilder().setName("l_orderkey").build())
+ .build();
+
+ LogicalPlanNode plan =
+ LogicalPlanNode.newBuilder()
+ .setSort(
+ SortNode.newBuilder()
+ .setInput(
+ LogicalPlanNode.newBuilder()
+ .setListingScan(
+ ListingTableScanNode.newBuilder()
+ .setTableName(
+ TableReference.newBuilder()
+ .setBare(
+
BareTableReference.newBuilder()
+ .setTable("lineitem")
+ .build())
+ .build())
+ .addPaths(absPath)
+ .setFileExtension(".parquet")
+ .setSchema(schemaProto)
+ .setProjection(
+ ProjectionColumns.newBuilder()
+ .addColumns("l_orderkey")
+ .build())
+ .setParquet(
+
DatafusionCommon.ParquetFormat.getDefaultInstance())
+ .setTargetPartitions(1)
+ .build())
+ .build())
+ .addExpr(
+ SortExprNode.newBuilder()
+ .setExpr(orderKeyExpr)
+ .setAsc(true)
+ .setNullsFirst(false)
+ .build())
+ .setFetch(1)
+ .build())
+ .build();
+
+ long protoValue;
+ try (DataFrame df = ctx.fromProto(plan.toByteArray());
+ ArrowReader reader = df.collect(allocator)) {
+ assertTrue(reader.loadNextBatch());
+ VectorSchemaRoot root = reader.getVectorSchemaRoot();
+ assertEquals(1, root.getRowCount());
+ protoValue = ((BigIntVector) root.getVector(0)).get(0);
+ }
+
+ long sqlValue;
+ try (DataFrame df = ctx.sql("SELECT l_orderkey FROM lineitem ORDER BY
l_orderkey LIMIT 1");
+ ArrowReader reader = df.collect(allocator)) {
+ assertTrue(reader.loadNextBatch());
+ VectorSchemaRoot root = reader.getVectorSchemaRoot();
+ assertEquals(1, root.getRowCount());
+ sqlValue = ((BigIntVector) root.getVector(0)).get(0);
+ }
+
+ assertEquals(sqlValue, protoValue);
+ }
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]