This is an automated email from the ASF dual-hosted git repository.

lzljs3620320 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git


The following commit(s) were added to refs/heads/master by this push:
     new 13de145fc [spark][core] Support input_file_name UDF (#3094)
13de145fc is described below

commit 13de145fc92bb63f8055ef190d3044d26b144591
Author: Yann Byron <[email protected]>
AuthorDate: Tue Mar 26 17:32:39 2024 +0800

    [spark][core] Support input_file_name UDF (#3094)
---
 .../paimon/data/columnar/ColumnarRowIterator.java  |  17 ++-
 ...sitionIterator.java => FileRecordIterator.java} |  28 +++-
 .../org/apache/paimon/reader/RecordReader.java     |   4 +-
 .../deletionvectors/ApplyDeletionVectorReader.java |   6 +-
 .../org/apache/paimon/mergetree/LookupLevels.java  |   7 +-
 .../apache/paimon/format/orc/OrcReaderFactory.java |  15 +-
 .../format/parquet/ParquetReaderFactory.java       |  23 ++-
 .../paimon/spark/PaimonPartitionReader.scala       |   2 +-
 .../paimon/spark/PaimonRecordReaderIterator.scala  | 105 ++++++++++++++
 .../main/scala/org/apache/spark/sql/Utils.scala    |   9 ++
 .../apache/paimon/spark/sql/PaimonQueryTest.scala  | 161 +++++++++++++++++++++
 .../apache/paimon/spark/sql/WithTableOptions.scala |   2 +
 12 files changed, 345 insertions(+), 34 deletions(-)

diff --git 
a/paimon-common/src/main/java/org/apache/paimon/data/columnar/ColumnarRowIterator.java
 
b/paimon-common/src/main/java/org/apache/paimon/data/columnar/ColumnarRowIterator.java
index 6de861af0..13d706cf6 100644
--- 
a/paimon-common/src/main/java/org/apache/paimon/data/columnar/ColumnarRowIterator.java
+++ 
b/paimon-common/src/main/java/org/apache/paimon/data/columnar/ColumnarRowIterator.java
@@ -20,8 +20,9 @@ package org.apache.paimon.data.columnar;
 
 import org.apache.paimon.data.InternalRow;
 import org.apache.paimon.data.PartitionInfo;
+import org.apache.paimon.fs.Path;
+import org.apache.paimon.reader.FileRecordIterator;
 import org.apache.paimon.reader.RecordReader;
-import org.apache.paimon.reader.RecordWithPositionIterator;
 import org.apache.paimon.utils.RecyclableIterator;
 import org.apache.paimon.utils.VectorMappingUtils;
 
@@ -32,8 +33,9 @@ import javax.annotation.Nullable;
  * {@link ColumnarRow#setRowId}.
  */
 public class ColumnarRowIterator extends RecyclableIterator<InternalRow>
-        implements RecordWithPositionIterator<InternalRow> {
+        implements FileRecordIterator<InternalRow> {
 
+    private final Path filePath;
     private final ColumnarRow rowData;
     private final Runnable recycler;
 
@@ -41,8 +43,9 @@ public class ColumnarRowIterator extends 
RecyclableIterator<InternalRow>
     private int nextPos;
     private long nextGlobalPos;
 
-    public ColumnarRowIterator(ColumnarRow rowData, @Nullable Runnable 
recycler) {
+    public ColumnarRowIterator(Path filePath, ColumnarRow rowData, @Nullable 
Runnable recycler) {
         super(recycler);
+        this.filePath = filePath;
         this.rowData = rowData;
         this.recycler = recycler;
     }
@@ -74,8 +77,14 @@ public class ColumnarRowIterator extends 
RecyclableIterator<InternalRow>
         return nextGlobalPos - 1;
     }
 
+    @Override
+    public Path filePath() {
+        return this.filePath;
+    }
+
     public ColumnarRowIterator copy(ColumnVector[] vectors) {
-        ColumnarRowIterator newIterator = new 
ColumnarRowIterator(rowData.copy(vectors), recycler);
+        ColumnarRowIterator newIterator =
+                new ColumnarRowIterator(filePath, rowData.copy(vectors), 
recycler);
         newIterator.reset(num, nextGlobalPos);
         return newIterator;
     }
diff --git 
a/paimon-common/src/main/java/org/apache/paimon/reader/RecordWithPositionIterator.java
 b/paimon-common/src/main/java/org/apache/paimon/reader/FileRecordIterator.java
similarity index 77%
rename from 
paimon-common/src/main/java/org/apache/paimon/reader/RecordWithPositionIterator.java
rename to 
paimon-common/src/main/java/org/apache/paimon/reader/FileRecordIterator.java
index e4778413a..0cef8cc00 100644
--- 
a/paimon-common/src/main/java/org/apache/paimon/reader/RecordWithPositionIterator.java
+++ 
b/paimon-common/src/main/java/org/apache/paimon/reader/FileRecordIterator.java
@@ -18,6 +18,7 @@
 
 package org.apache.paimon.reader;
 
+import org.apache.paimon.fs.Path;
 import org.apache.paimon.utils.Filter;
 
 import javax.annotation.Nullable;
@@ -30,7 +31,7 @@ import java.util.function.Function;
  *
  * @param <T> The type of the record.
  */
-public interface RecordWithPositionIterator<T> extends 
RecordReader.RecordIterator<T> {
+public interface FileRecordIterator<T> extends RecordReader.RecordIterator<T> {
 
     /**
      * Get the row position of the row returned by {@link 
RecordReader.RecordIterator#next}.
@@ -39,15 +40,23 @@ public interface RecordWithPositionIterator<T> extends 
RecordReader.RecordIterat
      */
     long returnedPosition();
 
+    /** @return the file path */
+    Path filePath();
+
     @Override
-    default <R> RecordWithPositionIterator<R> transform(Function<T, R> 
function) {
-        RecordWithPositionIterator<T> thisIterator = this;
-        return new RecordWithPositionIterator<R>() {
+    default <R> FileRecordIterator<R> transform(Function<T, R> function) {
+        FileRecordIterator<T> thisIterator = this;
+        return new FileRecordIterator<R>() {
             @Override
             public long returnedPosition() {
                 return thisIterator.returnedPosition();
             }
 
+            @Override
+            public Path filePath() {
+                return thisIterator.filePath();
+            }
+
             @Nullable
             @Override
             public R next() throws IOException {
@@ -66,14 +75,19 @@ public interface RecordWithPositionIterator<T> extends 
RecordReader.RecordIterat
     }
 
     @Override
-    default RecordWithPositionIterator<T> filter(Filter<T> filter) {
-        RecordWithPositionIterator<T> thisIterator = this;
-        return new RecordWithPositionIterator<T>() {
+    default FileRecordIterator<T> filter(Filter<T> filter) {
+        FileRecordIterator<T> thisIterator = this;
+        return new FileRecordIterator<T>() {
             @Override
             public long returnedPosition() {
                 return thisIterator.returnedPosition();
             }
 
+            @Override
+            public Path filePath() {
+                return thisIterator.filePath();
+            }
+
             @Nullable
             @Override
             public T next() throws IOException {
diff --git 
a/paimon-common/src/main/java/org/apache/paimon/reader/RecordReader.java 
b/paimon-common/src/main/java/org/apache/paimon/reader/RecordReader.java
index 276a85571..5c7482d9d 100644
--- a/paimon-common/src/main/java/org/apache/paimon/reader/RecordReader.java
+++ b/paimon-common/src/main/java/org/apache/paimon/reader/RecordReader.java
@@ -149,11 +149,11 @@ public interface RecordReader<T> extends Closeable {
      */
     default void forEachRemainingWithPosition(BiConsumer<Long, ? super T> 
action)
             throws IOException {
-        RecordWithPositionIterator<T> batch;
+        FileRecordIterator<T> batch;
         T record;
 
         try {
-            while ((batch = (RecordWithPositionIterator<T>) readBatch()) != 
null) {
+            while ((batch = (FileRecordIterator<T>) readBatch()) != null) {
                 while ((record = batch.next()) != null) {
                     action.accept(batch.returnedPosition(), record);
                 }
diff --git 
a/paimon-core/src/main/java/org/apache/paimon/deletionvectors/ApplyDeletionVectorReader.java
 
b/paimon-core/src/main/java/org/apache/paimon/deletionvectors/ApplyDeletionVectorReader.java
index 3bba07506..dadde99ea 100644
--- 
a/paimon-core/src/main/java/org/apache/paimon/deletionvectors/ApplyDeletionVectorReader.java
+++ 
b/paimon-core/src/main/java/org/apache/paimon/deletionvectors/ApplyDeletionVectorReader.java
@@ -18,8 +18,8 @@
 
 package org.apache.paimon.deletionvectors;
 
+import org.apache.paimon.reader.FileRecordIterator;
 import org.apache.paimon.reader.RecordReader;
-import org.apache.paimon.reader.RecordWithPositionIterator;
 
 import javax.annotation.Nullable;
 
@@ -62,10 +62,10 @@ public class ApplyDeletionVectorReader<T> implements 
RecordReader<T> {
         }
 
         checkArgument(
-                batch instanceof RecordWithPositionIterator,
+                batch instanceof FileRecordIterator,
                 "There is a bug, RecordIterator in ApplyDeletionVectorReader 
must be RecordWithPositionIterator");
 
-        RecordWithPositionIterator<T> batchWithPosition = 
(RecordWithPositionIterator<T>) batch;
+        FileRecordIterator<T> batchWithPosition = (FileRecordIterator<T>) 
batch;
 
         return batchWithPosition.filter(
                 a -> 
!deletionVector.isDeleted(batchWithPosition.returnedPosition()));
diff --git 
a/paimon-core/src/main/java/org/apache/paimon/mergetree/LookupLevels.java 
b/paimon-core/src/main/java/org/apache/paimon/mergetree/LookupLevels.java
index dd45e7fc1..d6024ebb8 100644
--- a/paimon-core/src/main/java/org/apache/paimon/mergetree/LookupLevels.java
+++ b/paimon-core/src/main/java/org/apache/paimon/mergetree/LookupLevels.java
@@ -28,8 +28,8 @@ import org.apache.paimon.lookup.LookupStoreReader;
 import org.apache.paimon.lookup.LookupStoreWriter;
 import org.apache.paimon.memory.MemorySegment;
 import org.apache.paimon.options.MemorySize;
+import org.apache.paimon.reader.FileRecordIterator;
 import org.apache.paimon.reader.RecordReader;
-import org.apache.paimon.reader.RecordWithPositionIterator;
 import org.apache.paimon.types.RowKind;
 import org.apache.paimon.types.RowType;
 import org.apache.paimon.utils.BloomFilter;
@@ -176,9 +176,8 @@ public class LookupLevels<T> implements 
Levels.DropFileCallback, Closeable {
         try (RecordReader<KeyValue> reader = fileReaderFactory.apply(file)) {
             KeyValue kv;
             if (valueProcessor.withPosition()) {
-                RecordWithPositionIterator<KeyValue> batch;
-                while ((batch = (RecordWithPositionIterator<KeyValue>) 
reader.readBatch())
-                        != null) {
+                FileRecordIterator<KeyValue> batch;
+                while ((batch = (FileRecordIterator<KeyValue>) 
reader.readBatch()) != null) {
                     while ((kv = batch.next()) != null) {
                         byte[] keyBytes = 
keySerializer.serializeToBytes(kv.key());
                         byte[] valueBytes =
diff --git 
a/paimon-format/src/main/java/org/apache/paimon/format/orc/OrcReaderFactory.java
 
b/paimon-format/src/main/java/org/apache/paimon/format/orc/OrcReaderFactory.java
index 55cff9298..8cf95fad3 100644
--- 
a/paimon-format/src/main/java/org/apache/paimon/format/orc/OrcReaderFactory.java
+++ 
b/paimon-format/src/main/java/org/apache/paimon/format/orc/OrcReaderFactory.java
@@ -28,6 +28,7 @@ import org.apache.paimon.format.OrcFormatReaderContext;
 import org.apache.paimon.format.fs.HadoopReadOnlyFileSystem;
 import org.apache.paimon.format.orc.filter.OrcFilters;
 import org.apache.paimon.fs.FileIO;
+import org.apache.paimon.fs.Path;
 import org.apache.paimon.reader.RecordReader.RecordIterator;
 import org.apache.paimon.types.DataType;
 import org.apache.paimon.types.RowType;
@@ -94,7 +95,7 @@ public class OrcReaderFactory implements FormatReaderFactory {
                 context instanceof OrcFormatReaderContext
                         ? ((OrcFormatReaderContext) context).poolSize()
                         : 1;
-        Pool<OrcReaderBatch> poolOfBatches = createPoolOfBatches(poolSize);
+        Pool<OrcReaderBatch> poolOfBatches = 
createPoolOfBatches(context.filePath(), poolSize);
 
         RecordReader orcReader =
                 createRecordReader(
@@ -114,7 +115,7 @@ public class OrcReaderFactory implements 
FormatReaderFactory {
      * conversion from the ORC representation to the result format.
      */
     public OrcReaderBatch createReaderBatch(
-            VectorizedRowBatch orcBatch, Pool.Recycler<OrcReaderBatch> 
recycler) {
+            Path filePath, VectorizedRowBatch orcBatch, 
Pool.Recycler<OrcReaderBatch> recycler) {
         List<String> tableFieldNames = tableType.getFieldNames();
         List<DataType> tableFieldTypes = tableType.getFieldTypes();
 
@@ -125,17 +126,17 @@ public class OrcReaderFactory implements 
FormatReaderFactory {
             DataType type = tableFieldTypes.get(i);
             vectors[i] = 
createPaimonVector(orcBatch.cols[tableFieldNames.indexOf(name)], type);
         }
-        return new OrcReaderBatch(orcBatch, new 
VectorizedColumnBatch(vectors), recycler);
+        return new OrcReaderBatch(filePath, orcBatch, new 
VectorizedColumnBatch(vectors), recycler);
     }
 
     // ------------------------------------------------------------------------
 
-    private Pool<OrcReaderBatch> createPoolOfBatches(int numBatches) {
+    private Pool<OrcReaderBatch> createPoolOfBatches(Path filePath, int 
numBatches) {
         final Pool<OrcReaderBatch> pool = new Pool<>(numBatches);
 
         for (int i = 0; i < numBatches; i++) {
             final VectorizedRowBatch orcBatch = createBatchWrapper(schema, 
batchSize / numBatches);
-            final OrcReaderBatch batch = createReaderBatch(orcBatch, 
pool.recycler());
+            final OrcReaderBatch batch = createReaderBatch(filePath, orcBatch, 
pool.recycler());
             pool.add(batch);
         }
 
@@ -153,6 +154,7 @@ public class OrcReaderFactory implements 
FormatReaderFactory {
         private final ColumnarRowIterator result;
 
         protected OrcReaderBatch(
+                final Path filePath,
                 final VectorizedRowBatch orcVectorizedRowBatch,
                 final VectorizedColumnBatch paimonColumnBatch,
                 final Pool.Recycler<OrcReaderBatch> recycler) {
@@ -160,7 +162,8 @@ public class OrcReaderFactory implements 
FormatReaderFactory {
             this.recycler = checkNotNull(recycler);
             this.paimonColumnBatch = paimonColumnBatch;
             this.result =
-                    new ColumnarRowIterator(new 
ColumnarRow(paimonColumnBatch), this::recycle);
+                    new ColumnarRowIterator(
+                            filePath, new ColumnarRow(paimonColumnBatch), 
this::recycle);
         }
 
         /**
diff --git 
a/paimon-format/src/main/java/org/apache/paimon/format/parquet/ParquetReaderFactory.java
 
b/paimon-format/src/main/java/org/apache/paimon/format/parquet/ParquetReaderFactory.java
index ed778c0bf..004a0d655 100644
--- 
a/paimon-format/src/main/java/org/apache/paimon/format/parquet/ParquetReaderFactory.java
+++ 
b/paimon-format/src/main/java/org/apache/paimon/format/parquet/ParquetReaderFactory.java
@@ -28,6 +28,7 @@ import org.apache.paimon.format.FormatReaderFactory;
 import org.apache.paimon.format.parquet.reader.ColumnReader;
 import org.apache.paimon.format.parquet.reader.ParquetDecimalVector;
 import org.apache.paimon.format.parquet.reader.ParquetTimestampVector;
+import org.apache.paimon.fs.Path;
 import org.apache.paimon.options.Options;
 import org.apache.paimon.reader.RecordReader;
 import org.apache.paimon.reader.RecordReader.RecordIterator;
@@ -100,7 +101,8 @@ public class ParquetReaderFactory implements 
FormatReaderFactory {
 
         checkSchema(fileSchema, requestedSchema);
 
-        Pool<ParquetReaderBatch> poolOfBatches = 
createPoolOfBatches(requestedSchema);
+        Pool<ParquetReaderBatch> poolOfBatches =
+                createPoolOfBatches(context.filePath(), requestedSchema);
 
         return new ParquetReader(reader, requestedSchema, 
reader.getRecordCount(), poolOfBatches);
     }
@@ -174,21 +176,24 @@ public class ParquetReaderFactory implements 
FormatReaderFactory {
         }
     }
 
-    private Pool<ParquetReaderBatch> createPoolOfBatches(MessageType 
requestedSchema) {
+    private Pool<ParquetReaderBatch> createPoolOfBatches(
+            Path filePath, MessageType requestedSchema) {
         // In a VectorizedColumnBatch, the dictionary will be lazied 
deserialized.
         // If there are multiple batches at the same time, there may be thread 
safety problems,
         // because the deserialization of the dictionary depends on some 
internal structures.
         // We need set poolCapacity to 1.
         Pool<ParquetReaderBatch> pool = new Pool<>(1);
-        pool.add(createReaderBatch(requestedSchema, pool.recycler()));
+        pool.add(createReaderBatch(filePath, requestedSchema, 
pool.recycler()));
         return pool;
     }
 
     private ParquetReaderBatch createReaderBatch(
-            MessageType requestedSchema, Pool.Recycler<ParquetReaderBatch> 
recycler) {
+            Path filePath,
+            MessageType requestedSchema,
+            Pool.Recycler<ParquetReaderBatch> recycler) {
         WritableColumnVector[] writableVectors = 
createWritableVectors(requestedSchema);
         VectorizedColumnBatch columnarBatch = 
createVectorizedColumnBatch(writableVectors);
-        return createReaderBatch(writableVectors, columnarBatch, recycler);
+        return createReaderBatch(filePath, writableVectors, columnarBatch, 
recycler);
     }
 
     private WritableColumnVector[] createWritableVectors(MessageType 
requestedSchema) {
@@ -361,10 +366,11 @@ public class ParquetReaderFactory implements 
FormatReaderFactory {
     }
 
     private ParquetReaderBatch createReaderBatch(
+            Path filePath,
             WritableColumnVector[] writableVectors,
             VectorizedColumnBatch columnarBatch,
             Pool.Recycler<ParquetReaderBatch> recycler) {
-        return new ParquetReaderBatch(writableVectors, columnarBatch, 
recycler);
+        return new ParquetReaderBatch(filePath, writableVectors, 
columnarBatch, recycler);
     }
 
     private static class ParquetReaderBatch {
@@ -376,13 +382,16 @@ public class ParquetReaderFactory implements 
FormatReaderFactory {
         private final ColumnarRowIterator result;
 
         protected ParquetReaderBatch(
+                Path filePath,
                 WritableColumnVector[] writableVectors,
                 VectorizedColumnBatch columnarBatch,
                 Pool.Recycler<ParquetReaderBatch> recycler) {
             this.writableVectors = writableVectors;
             this.columnarBatch = columnarBatch;
             this.recycler = recycler;
-            this.result = new ColumnarRowIterator(new 
ColumnarRow(columnarBatch), this::recycle);
+            this.result =
+                    new ColumnarRowIterator(
+                            filePath, new ColumnarRow(columnarBatch), 
this::recycle);
         }
 
         public void recycle() {
diff --git 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonPartitionReader.scala
 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonPartitionReader.scala
index cfb8803b4..c4e694814 100644
--- 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonPartitionReader.scala
+++ 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonPartitionReader.scala
@@ -40,7 +40,7 @@ case class PaimonPartitionReader(
 
   private lazy val iterator = {
     val reader = readFunc(split)
-    new RecordReaderIterator[PaimonInternalRow](reader)
+    PaimonRecordReaderIterator(reader)
   }
 
   override def next(): Boolean = {
diff --git 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonRecordReaderIterator.scala
 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonRecordReaderIterator.scala
new file mode 100644
index 000000000..3debb5e18
--- /dev/null
+++ 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonRecordReaderIterator.scala
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.spark
+
+import org.apache.paimon.data.{InternalRow => PaimonInternalRow}
+import org.apache.paimon.fs.Path
+import org.apache.paimon.reader.{FileRecordIterator, RecordReader}
+import org.apache.paimon.utils.CloseableIterator
+
+import org.apache.spark.sql.Utils
+
+import java.io.IOException
+
+case class PaimonRecordReaderIterator(reader: RecordReader[PaimonInternalRow])
+  extends CloseableIterator[PaimonInternalRow] {
+
+  private var lastFilePath: Path = _
+  private var currentIterator: RecordReader.RecordIterator[PaimonInternalRow] 
= readBatch()
+  private var advanced = false
+  private var currentResult: PaimonInternalRow = _
+
+  override def hasNext: Boolean = {
+    if (currentIterator == null) {
+      false
+    } else {
+      advanceIfNeeded()
+      currentResult != null
+    }
+  }
+
+  override def next(): PaimonInternalRow = {
+    if (!hasNext) {
+      null
+    } else {
+      advanced = false
+      currentResult
+    }
+  }
+
+  override def close(): Unit = {
+    try {
+      if (currentIterator != null) {
+        currentIterator.releaseBatch()
+        currentResult == null
+      }
+    } finally {
+      reader.close()
+      Utils.unsetInputFileName()
+    }
+  }
+
+  private def readBatch(): RecordReader.RecordIterator[PaimonInternalRow] = {
+    val iter = reader.readBatch()
+    iter match {
+      case fileRecordIterator: FileRecordIterator[_] =>
+        if (lastFilePath != fileRecordIterator.filePath()) {
+          Utils.setInputFileName(fileRecordIterator.filePath().toUri.toString)
+          lastFilePath = fileRecordIterator.filePath()
+        }
+      case _ =>
+    }
+    iter
+  }
+
+  private def advanceIfNeeded(): Unit = {
+    if (!advanced) {
+      advanced = true
+      try {
+        var stop = false
+        while (!stop) {
+          currentResult = currentIterator.next
+          if (currentResult != null) {
+            stop = true
+          } else {
+            currentIterator.releaseBatch()
+            currentIterator = null
+            currentIterator = readBatch()
+            if (currentIterator == null) {
+              stop = true
+            }
+          }
+        }
+      } catch {
+        case e: IOException =>
+          throw new RuntimeException(e)
+      }
+    }
+  }
+}
diff --git 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/Utils.scala
 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/Utils.scala
index 8f7e5aaf7..4767dab39 100644
--- 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/Utils.scala
+++ 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/Utils.scala
@@ -18,6 +18,7 @@
 
 package org.apache.spark.sql
 
+import org.apache.spark.rdd.InputFileBlockHolder
 import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.connector.expressions.{FieldReference, 
NamedReference}
@@ -69,4 +70,12 @@ object Utils {
   def bytesToString(size: Long): String = {
     SparkUtils.bytesToString(size)
   }
+
+  def setInputFileName(inputFileName: String): Unit = {
+    InputFileBlockHolder.set(inputFileName, 0, -1)
+  }
+
+  def unsetInputFileName(): Unit = {
+    InputFileBlockHolder.unset()
+  }
 }
diff --git 
a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/PaimonQueryTest.scala
 
b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/PaimonQueryTest.scala
new file mode 100644
index 000000000..ef683366d
--- /dev/null
+++ 
b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/PaimonQueryTest.scala
@@ -0,0 +1,161 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.spark.sql
+
+import org.apache.paimon.spark.PaimonSparkTestBase
+import org.apache.paimon.table.source.DataSplit
+
+import org.apache.spark.sql.{Row, SparkSession}
+import org.junit.jupiter.api.Assertions
+
+import java.util
+
+import scala.collection.JavaConverters._
+
+class PaimonQueryTest extends PaimonSparkTestBase {
+
+  fileFormats.foreach {
+    fileFormat =>
+      bucketModes.foreach {
+        bucketMode =>
+          test(s"Query input_file_name(): file.format=$fileFormat, 
bucket=$bucketMode") {
+            val _spark: SparkSession = spark
+            import _spark.implicits._
+
+            withTable("T") {
+              spark.sql(s"""
+                           |CREATE TABLE T (id INT, name STRING)
+                           |TBLPROPERTIES ('file.format'='$fileFormat', 
'bucket'='$bucketMode')
+                           |""".stripMargin)
+
+              val location = loadTable("T").location().toUri.toString
+
+              spark.sql("INSERT INTO T VALUES (1, 'x1'), (3, 'x3')")
+
+              val res1 = spark.sql(s"""
+                                      |SELECT *,
+                                      |startswith(input_file_name(), 
'$location') AS start,
+                                      |endswith(input_file_name(), 
'.$fileFormat') AS end
+                                      |FROM T
+                                      |ORdER BY id
+                                      |""".stripMargin)
+              checkAnswer(res1, Row(1, "x1", true, true) :: Row(3, "x3", true, 
true) :: Nil)
+
+              spark.sql("INSERT INTO T VALUES (2, 'x2'), (4, 'x4'), (6, 'x6')")
+
+              val res2 =
+                spark.sql("SELECT input_file_name() FROM 
T").distinct().as[String].collect().sorted
+              val allDataFiles = getAllFiles("T", Seq.empty, null)
+              Assertions.assertTrue(res2.sameElements(allDataFiles))
+            }
+          }
+
+      }
+  }
+
+  fileFormats.foreach {
+    fileFormat =>
+      bucketModes.foreach {
+        bucketMode =>
+          test(
+            s"Query input_file_name() for partitioned table: 
file.format=$fileFormat, bucket=$bucketMode") {
+            val _spark: SparkSession = spark
+            import _spark.implicits._
+
+            withTable("T") {
+              spark.sql(s"""
+                           |CREATE TABLE T (id INT, name STRING, pt STRING)
+                           |PARTITIONED BY (pt)
+                           |TBLPROPERTIES ('file.format'='$fileFormat', 
'bucket'='$bucketMode')
+                           |""".stripMargin)
+
+              val location = loadTable("T").location().toUri.toString
+
+              spark.sql("INSERT INTO T VALUES (1, 'x1', '2024'), (3, 'x3', 
'2024')")
+
+              val res1 = spark.sql(s"""
+                                      |SELECT id, name, pt,
+                                      |startswith(input_file_name(), 
'$location') AS start,
+                                      |endswith(input_file_name(), 
'.$fileFormat') AS end
+                                      |FROM T
+                                      |ORdER BY id
+                                      |""".stripMargin)
+              checkAnswer(
+                res1,
+                Row(1, "x1", "2024", true, true) :: Row(3, "x3", "2024", true, 
true) :: Nil)
+
+              spark.sql("""
+                          |INSERT INTO T
+                          |VALUES (2, 'x2', '2025'), (4, 'x4', '2025'), (6, 
'x6', '2026')
+                          |""".stripMargin)
+
+              val res2 =
+                spark
+                  .sql("SELECT input_file_name() FROM T WHERE pt='2026'")
+                  .distinct()
+                  .as[String]
+                  .collect()
+                  .sorted
+              val partitionFilter = new util.HashMap[String, String]()
+              partitionFilter.put("pt", "2026")
+              val partialDataFiles = getAllFiles("T", Seq("pt"), 
partitionFilter)
+              Assertions.assertTrue(res2.sameElements(partialDataFiles))
+
+              val res3 =
+                spark.sql("SELECT input_file_name() FROM 
T").distinct().as[String].collect().sorted
+              val allDataFiles = getAllFiles("T", Seq("pt"), null)
+              Assertions.assertTrue(res3.sameElements(allDataFiles))
+            }
+          }
+
+      }
+  }
+
+  private def getAllFiles(
+      tableName: String,
+      partitions: Seq[String],
+      partitionFilter: java.util.Map[String, String]): Array[String] = {
+    val paimonTable = loadTable(tableName)
+    val location = paimonTable.location()
+
+    val files = paimonTable
+      .newSnapshotReader()
+      .withPartitionFilter(partitionFilter)
+      .read()
+      .splits()
+      .asScala
+      .collect { case ds: DataSplit => ds }
+      .flatMap {
+        ds =>
+          val prefix = if (partitions.isEmpty) {
+            s"$location/bucket-${ds.bucket}"
+          } else {
+            val partitionPath = partitions.zipWithIndex
+              .map {
+                case (pt, index) =>
+                  s"$pt=" + ds.partition().getString(index)
+              }
+              .mkString("/")
+            s"$location/$partitionPath/bucket-${ds.bucket}"
+          }
+          ds.dataFiles().asScala.map(f => prefix + "/" + f.fileName)
+      }
+    files.sorted.toArray
+  }
+}
diff --git 
a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/WithTableOptions.scala
 
b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/WithTableOptions.scala
index 5b1c65525..e390058ba 100644
--- 
a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/WithTableOptions.scala
+++ 
b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/WithTableOptions.scala
@@ -23,6 +23,8 @@ trait WithTableOptions {
   // 3: fixed bucket, -1: dynamic bucket
   protected val bucketModes: Seq[Int] = Seq(3, -1)
 
+  protected val fileFormats: Seq[String] = Seq("orc", "parquet")
+
   protected val withPk: Seq[Boolean] = Seq(true, false)
 
 }

Reply via email to