This is an automated email from the ASF dual-hosted git repository.

andygrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-java.git


The following commit(s) were added to refs/heads/main by this push:
     new 717b735  feat(dataframe): add select/filter/count/show methods (#19)
717b735 is described below

commit 717b735d3ac2204808c47448d600e487ac968b9c
Author: Andy Grove <[email protected]>
AuthorDate: Wed May 13 00:33:41 2026 -0600

    feat(dataframe): add select/filter/count/show methods (#19)
---
 native/src/lib.rs                                  |  99 ++++++++++-
 src/main/java/org/apache/datafusion/DataFrame.java |  56 ++++++
 .../datafusion/DataFrameTransformationsTest.java   | 196 +++++++++++++++++++++
 3 files changed, 349 insertions(+), 2 deletions(-)

diff --git a/native/src/lib.rs b/native/src/lib.rs
index cc58b7d..42ca5f5 100644
--- a/native/src/lib.rs
+++ b/native/src/lib.rs
@@ -26,8 +26,8 @@ use datafusion::arrow::record_batch::RecordBatchIterator;
 use datafusion::dataframe::DataFrame;
 use datafusion::error::DataFusionError;
 use datafusion::prelude::{ParquetReadOptions, SessionContext};
-use jni::objects::{JByteArray, JClass, JString};
-use jni::sys::{jboolean, jlong};
+use jni::objects::{JByteArray, JClass, JObjectArray, JString};
+use jni::sys::{jboolean, jint, jlong};
 use jni::JNIEnv;
 use tokio::runtime::Runtime;
 
@@ -98,6 +98,101 @@ pub extern "system" fn 
Java_org_apache_datafusion_DataFrame_collectDataFrame<'lo
     })
 }
 
+#[no_mangle]
+pub extern "system" fn Java_org_apache_datafusion_DataFrame_countRows<'local>(
+    mut env: JNIEnv<'local>,
+    _class: JClass<'local>,
+    handle: jlong,
+) -> jlong {
+    try_unwrap_or_throw(&mut env, 0, |_env| -> JniResult<jlong> {
+        if handle == 0 {
+            return Err("DataFrame handle is null".into());
+        }
+        let df = unsafe { &*(handle as *const DataFrame) }.clone();
+        let n = runtime().block_on(async { df.count().await })?;
+        Ok(n as jlong)
+    })
+}
+
+#[no_mangle]
+pub extern "system" fn 
Java_org_apache_datafusion_DataFrame_showDataFrame<'local>(
+    mut env: JNIEnv<'local>,
+    _class: JClass<'local>,
+    handle: jlong,
+) {
+    try_unwrap_or_throw(&mut env, (), |_env| -> JniResult<()> {
+        if handle == 0 {
+            return Err("DataFrame handle is null".into());
+        }
+        let df = unsafe { &*(handle as *const DataFrame) }.clone();
+        runtime().block_on(async { df.show().await })?;
+        Ok(())
+    })
+}
+
+#[no_mangle]
+pub extern "system" fn 
Java_org_apache_datafusion_DataFrame_showDataFrameWithLimit<'local>(
+    mut env: JNIEnv<'local>,
+    _class: JClass<'local>,
+    handle: jlong,
+    limit: jint,
+) {
+    try_unwrap_or_throw(&mut env, (), |_env| -> JniResult<()> {
+        if handle == 0 {
+            return Err("DataFrame handle is null".into());
+        }
+        let df = unsafe { &*(handle as *const DataFrame) }.clone();
+        runtime().block_on(async { df.show_limit(limit as usize).await })?;
+        Ok(())
+    })
+}
+
+#[no_mangle]
+pub extern "system" fn 
Java_org_apache_datafusion_DataFrame_selectColumns<'local>(
+    mut env: JNIEnv<'local>,
+    _class: JClass<'local>,
+    handle: jlong,
+    column_names: JObjectArray<'local>,
+) -> jlong {
+    try_unwrap_or_throw(&mut env, 0, |env| -> JniResult<jlong> {
+        if handle == 0 {
+            return Err("DataFrame handle is null".into());
+        }
+        let df = unsafe { &*(handle as *const DataFrame) }.clone();
+
+        let len = env.get_array_length(&column_names)?;
+        let mut owned: Vec<String> = Vec::with_capacity(len as usize);
+        for i in 0..len {
+            let elem = env.get_object_array_element(&column_names, i)?;
+            let jstr: JString = elem.into();
+            owned.push(env.get_string(&jstr)?.into());
+        }
+        let refs: Vec<&str> = owned.iter().map(String::as_str).collect();
+
+        let new_df = df.select_columns(&refs)?;
+        Ok(Box::into_raw(Box::new(new_df)) as jlong)
+    })
+}
+
+#[no_mangle]
+pub extern "system" fn Java_org_apache_datafusion_DataFrame_filterRows<'local>(
+    mut env: JNIEnv<'local>,
+    _class: JClass<'local>,
+    handle: jlong,
+    predicate: JString<'local>,
+) -> jlong {
+    try_unwrap_or_throw(&mut env, 0, |env| -> JniResult<jlong> {
+        if handle == 0 {
+            return Err("DataFrame handle is null".into());
+        }
+        let df = unsafe { &*(handle as *const DataFrame) }.clone();
+        let predicate: String = env.get_string(&predicate)?.into();
+        let expr = df.parse_sql_expr(&predicate)?;
+        let new_df = df.filter(expr)?;
+        Ok(Box::into_raw(Box::new(new_df)) as jlong)
+    })
+}
+
 #[no_mangle]
 pub extern "system" fn 
Java_org_apache_datafusion_DataFrame_closeDataFrame<'local>(
     mut env: JNIEnv<'local>,
diff --git a/src/main/java/org/apache/datafusion/DataFrame.java 
b/src/main/java/org/apache/datafusion/DataFrame.java
index f285879..0bd77e7 100644
--- a/src/main/java/org/apache/datafusion/DataFrame.java
+++ b/src/main/java/org/apache/datafusion/DataFrame.java
@@ -70,6 +70,52 @@ public final class DataFrame implements AutoCloseable {
     }
   }
 
+  /** Execute the plan and return the number of rows. */
+  public long count() {
+    if (nativeHandle == 0) {
+      throw new IllegalStateException("DataFrame is closed or already 
collected");
+    }
+    return countRows(nativeHandle);
+  }
+
+  /** Execute the plan and print formatted batches to native stdout. */
+  public void show() {
+    if (nativeHandle == 0) {
+      throw new IllegalStateException("DataFrame is closed or already 
collected");
+    }
+    showDataFrame(nativeHandle);
+  }
+
+  /** Execute the plan and print the first {@code limit} rows to native 
stdout. */
+  public void show(int limit) {
+    if (nativeHandle == 0) {
+      throw new IllegalStateException("DataFrame is closed or already 
collected");
+    }
+    showDataFrameWithLimit(nativeHandle, limit);
+  }
+
+  /**
+   * Project the listed columns into a new DataFrame. The receiver remains 
usable and must still be
+   * closed independently.
+   */
+  public DataFrame select(String... columnNames) {
+    if (nativeHandle == 0) {
+      throw new IllegalStateException("DataFrame is closed or already 
collected");
+    }
+    return new DataFrame(selectColumns(nativeHandle, columnNames));
+  }
+
+  /**
+   * Apply a SQL predicate to produce a filtered DataFrame. The predicate is 
parsed against this
+   * DataFrame's own schema. The receiver remains usable and must still be 
closed independently.
+   */
+  public DataFrame filter(String predicate) {
+    if (nativeHandle == 0) {
+      throw new IllegalStateException("DataFrame is closed or already 
collected");
+    }
+    return new DataFrame(filterRows(nativeHandle, predicate));
+  }
+
   @Override
   public void close() {
     if (nativeHandle != 0) {
@@ -81,4 +127,14 @@ public final class DataFrame implements AutoCloseable {
   private static native void collectDataFrame(long handle, long ffiStreamAddr);
 
   private static native void closeDataFrame(long handle);
+
+  private static native long countRows(long handle);
+
+  private static native void showDataFrame(long handle);
+
+  private static native void showDataFrameWithLimit(long handle, int limit);
+
+  private static native long selectColumns(long handle, String[] columnNames);
+
+  private static native long filterRows(long handle, String predicate);
 }
diff --git 
a/src/test/java/org/apache/datafusion/DataFrameTransformationsTest.java 
b/src/test/java/org/apache/datafusion/DataFrameTransformationsTest.java
new file mode 100644
index 0000000..09c7912
--- /dev/null
+++ b/src/test/java/org/apache/datafusion/DataFrameTransformationsTest.java
@@ -0,0 +1,196 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.datafusion;
+
+import static org.junit.jupiter.api.Assertions.assertArrayEquals;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+import java.nio.file.Files;
+import java.nio.file.Path;
+
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.memory.RootAllocator;
+import org.apache.arrow.vector.BigIntVector;
+import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.arrow.vector.ipc.ArrowReader;
+import org.junit.jupiter.api.Assumptions;
+import org.junit.jupiter.api.Test;
+
+class DataFrameTransformationsTest {
+  @Test
+  void countReturnsRowCount() {
+    try (SessionContext ctx = new SessionContext();
+        DataFrame df = ctx.sql("SELECT * FROM (VALUES (1), (2), (3)) AS 
t(x)")) {
+      assertEquals(3L, df.count());
+    }
+  }
+
+  @Test
+  void showDoesNotThrow() {
+    try (SessionContext ctx = new SessionContext();
+        DataFrame df = ctx.sql("SELECT * FROM (VALUES (1), (2)) AS t(x)")) {
+      df.show();
+    }
+  }
+
+  @Test
+  void showWithLimitDoesNotThrow() {
+    try (SessionContext ctx = new SessionContext();
+        DataFrame df = ctx.sql("SELECT * FROM (VALUES (1), (2), (3)) AS 
t(x)")) {
+      df.show(0);
+      df.show(1);
+      df.show(1_000_000);
+    }
+  }
+
+  @Test
+  void selectProjectsAndReordersColumns() throws Exception {
+    try (BufferAllocator allocator = new RootAllocator();
+        SessionContext ctx = new SessionContext();
+        DataFrame source = ctx.sql("SELECT 1 AS a, 2 AS b, 3 AS c");
+        DataFrame projected = source.select("b", "a");
+        ArrowReader reader = projected.collect(allocator)) {
+      assertTrue(reader.loadNextBatch());
+      VectorSchemaRoot root = reader.getVectorSchemaRoot();
+      assertEquals(1, root.getRowCount());
+      assertArrayEquals(
+          new String[] {"b", "a"},
+          root.getSchema().getFields().stream().map(f -> 
f.getName()).toArray(String[]::new));
+    }
+  }
+
+  @Test
+  void selectIsNonDestructive() {
+    try (SessionContext ctx = new SessionContext();
+        DataFrame source = ctx.sql("SELECT 1 AS a, 2 AS b")) {
+      try (DataFrame first = source.select("a")) {
+        assertEquals(1L, first.count());
+      }
+      try (DataFrame second = source.select("b")) {
+        assertEquals(1L, second.count());
+      }
+      assertEquals(1L, source.count());
+    }
+  }
+
+  @Test
+  void filterRemovesRows() {
+    try (SessionContext ctx = new SessionContext();
+        DataFrame source = ctx.sql("SELECT * FROM (VALUES (1), (2), (3), (4)) 
AS t(x)");
+        DataFrame filtered = source.filter("x > 2")) {
+      assertEquals(2L, filtered.count());
+    }
+  }
+
+  @Test
+  void filterIsNonDestructive() {
+    try (SessionContext ctx = new SessionContext();
+        DataFrame source = ctx.sql("SELECT * FROM (VALUES (1), (2), (3), (4)) 
AS t(x)")) {
+      try (DataFrame filtered = source.filter("x > 2")) {
+        assertEquals(2L, filtered.count());
+      }
+      assertEquals(4L, source.count());
+    }
+  }
+
+  @Test
+  void chainFilterSelectCount() {
+    try (SessionContext ctx = new SessionContext();
+        DataFrame source = ctx.sql("SELECT 1 AS a, 2 AS b UNION ALL SELECT 10 
AS a, 20 AS b");
+        DataFrame chained = source.filter("a > 5").select("b")) {
+      assertEquals(1L, chained.count());
+    }
+  }
+
+  @Test
+  void methodsThrowAfterClose() {
+    try (SessionContext ctx = new SessionContext()) {
+      DataFrame df = ctx.sql("SELECT 1 AS x");
+      df.close();
+      assertThrows(IllegalStateException.class, () -> df.select("x"));
+      assertThrows(IllegalStateException.class, () -> df.filter("x > 0"));
+      assertThrows(IllegalStateException.class, df::count);
+      assertThrows(IllegalStateException.class, df::show);
+      assertThrows(IllegalStateException.class, () -> df.show(5));
+    }
+  }
+
+  @Test
+  void methodsThrowAfterCollect() throws Exception {
+    try (BufferAllocator allocator = new RootAllocator();
+        SessionContext ctx = new SessionContext();
+        DataFrame df = ctx.sql("SELECT 1 AS x")) {
+      try (ArrowReader reader = df.collect(allocator)) {
+        assertTrue(reader.loadNextBatch());
+      }
+      assertThrows(IllegalStateException.class, () -> df.select("x"));
+      assertThrows(IllegalStateException.class, () -> df.filter("x > 0"));
+      assertThrows(IllegalStateException.class, df::count);
+      assertThrows(IllegalStateException.class, df::show);
+      assertThrows(IllegalStateException.class, () -> df.show(5));
+    }
+  }
+
+  @Test
+  void selectInvalidColumnThrows() {
+    try (SessionContext ctx = new SessionContext();
+        DataFrame df = ctx.sql("SELECT 1 AS x")) {
+      assertThrows(RuntimeException.class, () -> df.select("not_a_column"));
+    }
+  }
+
+  @Test
+  void filterMalformedPredicateThrows() {
+    try (SessionContext ctx = new SessionContext();
+        DataFrame df = ctx.sql("SELECT 1 AS x")) {
+      assertThrows(RuntimeException.class, () -> df.filter("this is not sql"));
+    }
+  }
+
+  @Test
+  void lineitemFilterCountAgainstSqlBaseline() throws Exception {
+    Path lineitem = Path.of("tpch-data/sf1/lineitem.parquet");
+    Assumptions.assumeTrue(
+        Files.exists(lineitem), "TPC-H SF1 data not found; run `make 
tpch-data` first");
+
+    try (SessionContext ctx = new SessionContext()) {
+      ctx.registerParquet("lineitem", lineitem.toAbsolutePath().toString());
+
+      long viaDataFrame;
+      try (DataFrame df = ctx.sql("SELECT * FROM lineitem");
+          DataFrame filtered = df.filter("l_orderkey < 100")) {
+        viaDataFrame = filtered.count();
+      }
+
+      long viaSql;
+      try (BufferAllocator allocator = new RootAllocator();
+          DataFrame df = ctx.sql("SELECT COUNT(*) FROM lineitem WHERE 
l_orderkey < 100");
+          ArrowReader reader = df.collect(allocator)) {
+        assertTrue(reader.loadNextBatch());
+        VectorSchemaRoot root = reader.getVectorSchemaRoot();
+        viaSql = ((BigIntVector) root.getVector(0)).get(0);
+      }
+
+      assertEquals(viaSql, viaDataFrame);
+    }
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to