Github user viirya commented on a diff in the pull request:

    https://github.com/apache/spark/pull/22253#discussion_r213864459
  
    --- Diff: 
sql/core/src/test/java/test/org/apache/spark/sql/JavaColumnExpressionSuite.java 
---
    @@ -0,0 +1,80 @@
    +package test.org.apache.spark.sql;
    +
    +import org.apache.spark.api.java.function.FilterFunction;
    +import org.apache.spark.sql.Column;
    +import org.apache.spark.sql.Dataset;
    +import org.apache.spark.sql.Row;
    +import org.apache.spark.sql.RowFactory;
    +import org.apache.spark.sql.test.TestSparkSession;
    +import org.apache.spark.sql.types.StructType;
    +import org.junit.After;
    +import org.junit.Assert;
    +import org.junit.Before;
    +import org.junit.Test;
    +
    +import java.util.*;
    +
    +import static org.apache.spark.sql.types.DataTypes.*;
    +
    +public class JavaColumnExpressionSuite {
    +
    +    private transient TestSparkSession spark;
    +
    +    @Before
    +    public void setUp() {
    +        spark = new TestSparkSession();
    +    }
    +
    +    @After
    +    public void tearDown() {
    +        spark.stop();
    +        spark = null;
    +    }
    +
    +    @Test
    +    public void isInCollectionWorksCorrectlyOnJava() {
    +        List<Row> rows = Arrays.asList(
    +                RowFactory.create(1, "x"),
    +                RowFactory.create(2, "y"),
    +                RowFactory.create(3, "z")
    +        );
    +        StructType schema = createStructType(Arrays.asList(
    +                createStructField("a", IntegerType, false),
    +                createStructField("b", StringType, false)
    +        ));
    +        Dataset<Row> df = spark.createDataFrame(rows, schema);
    +        // Test with different types of collections
    +        Assert.assertTrue(Arrays.equals(
    +                (Row[]) 
df.filter(df.col("a").isInCollection(Arrays.asList(1, 2))).collect(),
    +                (Row[]) df.filter((FilterFunction<Row>) r -> r.getInt(0) 
== 1 || r.getInt(0) == 2).collect()
    +        ));
    +        Assert.assertTrue(Arrays.equals(
    +                (Row[]) df.filter(df.col("a").isInCollection(new 
HashSet<>(Arrays.asList(1, 2)))).collect(),
    +                (Row[]) df.filter((FilterFunction<Row>) r -> r.getInt(0) 
== 1 || r.getInt(0) == 2).collect()
    +        ));
    +        Assert.assertTrue(Arrays.equals(
    +                (Row[]) df.filter(df.col("a").isInCollection(new 
ArrayList<>(Arrays.asList(3, 1)))).collect(),
    +                (Row[]) df.filter((FilterFunction<Row>) r -> r.getInt(0) 
== 3 || r.getInt(0) == 1).collect()
    +        ));
    +    }
    +
    +    @Test
    +    public void isInCollectionThrowsExceptionWithCorrectMessageOnJava() {
    --- End diff --
    
    Can we shorten this method name? E.g., `checkExceptionMessage`?


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to