allisonwang-db commented on code in PR #52317: URL: https://github.com/apache/spark/pull/52317#discussion_r2353560389
########## python/pyspark/sql/tests/arrow/test_arrow_udtf.py: ########## @@ -730,6 +730,496 @@ def eval(self, x: "pa.Array", y: "pa.Array") -> Iterator["pa.Table"]: expected_df2 = self.spark.createDataFrame([(7, 3, 10)], "x int, y int, sum int") assertDataFrameEqual(sql_result_df2, expected_df2) + def test_arrow_udtf_with_partition_by(self): + @arrow_udtf(returnType="partition_key int, sum_value int") + class SumUDTF: + def eval(self, table_data: "pa.RecordBatch") -> Iterator["pa.Table"]: + table = pa.table(table_data) + partition_key = pc.unique(table["partition_key"]).to_pylist() + assert ( + len(partition_key) == 1 + ), f"Expected exactly one partition key, got {partition_key}" + sum_value = pc.sum(table["value"]).as_py() + result_table = pa.table( + { + "partition_key": pa.array([partition_key[0]], type=pa.int32()), + "sum_value": pa.array([sum_value], type=pa.int32()), + } + ) + yield result_table + + test_data = [ + (1, 10), + (2, 5), + (1, 20), + (2, 15), + (1, 30), + (3, 100), + ] + input_df = self.spark.createDataFrame(test_data, "partition_key int, value int") + + self.spark.udtf.register("sum_udtf", SumUDTF) + input_df.createOrReplaceTempView("test_data") + + result_df = self.spark.sql( + """ + SELECT * FROM sum_udtf(TABLE(test_data) PARTITION BY partition_key) + """ + ) + + expected_data = [ + (1, 60), + (2, 20), + (3, 100), + ] Review Comment: thanks for catching this! -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org