Github user HyukjinKwon commented on a diff in the pull request:

    https://github.com/apache/spark/pull/22295#discussion_r226166020
  
    --- Diff: python/pyspark/sql/tests.py ---
    @@ -3863,6 +3863,145 @@ def test_jvm_default_session_already_set(self):
                 spark.stop()
     
     
    +class SparkSessionTests2(unittest.TestCase):
    +
    +    def test_active_session(self):
    +        spark = SparkSession.builder \
    +            .master("local") \
    +            .getOrCreate()
    +        try:
    +            activeSession = SparkSession.getActiveSession()
    +            df = activeSession.createDataFrame([(1, 'Alice')], ['age', 
'name'])
    +            self.assertEqual(df.collect(), [Row(age=1, name=u'Alice')])
    +        finally:
    +            spark.stop()
    +
    +    def test_get_active_session_when_no_active_session(self):
    +        active = SparkSession.getActiveSession()
    +        self.assertEqual(active, None)
    +        spark = SparkSession.builder \
    +            .master("local") \
    +            .getOrCreate()
    +        active = SparkSession.getActiveSession()
    +        self.assertEqual(active, spark)
    +        spark.stop()
    +        active = SparkSession.getActiveSession()
    +        self.assertEqual(active, None)
    +
    +    def test_SparkSession(self):
    +        spark = SparkSession.builder \
    +            .master("local") \
    +            .config("some-config", "v2") \
    +            .getOrCreate()
    +        try:
    +            self.assertEqual(spark.conf.get("some-config"), "v2")
    +            self.assertEqual(spark.sparkContext._conf.get("some-config"), 
"v2")
    +            self.assertEqual(spark.version, spark.sparkContext.version)
    +            spark.sql("CREATE DATABASE test_db")
    +            spark.catalog.setCurrentDatabase("test_db")
    +            self.assertEqual(spark.catalog.currentDatabase(), "test_db")
    +            spark.sql("CREATE TABLE table1 (name STRING, age INT) USING 
parquet")
    +            self.assertEqual(spark.table("table1").columns, ['name', 
'age'])
    +            self.assertEqual(spark.range(3).count(), 3)
    +        finally:
    +            spark.stop()
    +
    +    def test_global_default_session(self):
    +        spark = SparkSession.builder \
    +            .master("local") \
    +            .getOrCreate()
    +        try:
    +            self.assertEqual(SparkSession.builder.getOrCreate(), spark)
    +        finally:
    +            spark.stop()
    +
    +    def test_default_and_active_session(self):
    +        spark = SparkSession.builder \
    +            .master("local") \
    +            .getOrCreate()
    +        activeSession = spark._jvm.SparkSession.getActiveSession()
    +        defaultSession = spark._jvm.SparkSession.getDefaultSession()
    +        try:
    +            self.assertEqual(activeSession, defaultSession)
    +        finally:
    +            spark.stop()
    +
    +    def test_config_option_propagated_to_existing_SparkSession(self):
    +        session1 = SparkSession.builder \
    +            .master("local") \
    +            .config("spark-config1", "a") \
    +            .getOrCreate()
    +        self.assertEqual(session1.conf.get("spark-config1"), "a")
    +        session2 = SparkSession.builder \
    +            .config("spark-config1", "b") \
    +            .getOrCreate()
    +        try:
    +            self.assertEqual(session1, session2)
    +            self.assertEqual(session1.conf.get("spark-config1"), "b")
    +        finally:
    +            session1.stop()
    +
    +    def test_new_session(self):
    +        session = SparkSession.builder \
    +            .master("local") \
    +            .getOrCreate()
    +        newSession = session.newSession()
    +        try:
    +            self.assertNotEqual(session, newSession)
    +        finally:
    +            session.stop()
    +            newSession.stop()
    +
    +    def test_create_new_session_if_old_session_stopped(self):
    +        session = SparkSession.builder \
    +            .master("local") \
    +            .getOrCreate()
    +        session.stop()
    +        newSession = SparkSession.builder \
    +            .master("local") \
    +            .getOrCreate()
    +        try:
    +            self.assertNotEqual(session, newSession)
    +        finally:
    +            newSession.stop()
    +
    +    def test_active_session_with_None_and_not_None_context(self):
    +        from pyspark.context import SparkContext
    +        from pyspark.conf import SparkConf
    +        sc = SparkContext._active_spark_context
    +        self.assertEqual(sc, None)
    +        activeSession = SparkSession.getActiveSession()
    +        self.assertEqual(activeSession, None)
    +        sparkConf = SparkConf()
    +        sc = SparkContext.getOrCreate(sparkConf)
    +        activeSession = sc._jvm.SparkSession.getActiveSession()
    +        self.assertFalse(activeSession.isDefined())
    +        session = SparkSession(sc)
    +        activeSession = sc._jvm.SparkSession.getActiveSession()
    +        self.assertTrue(activeSession.isDefined())
    +        activeSession2 = SparkSession.getActiveSession()
    +        self.assertNotEqual(activeSession2, None)
    +
    +
    +class SparkSessionTests3(ReusedSQLTestCase):
    +
    +    def test_get_active_session_after_create_dataframe(self):
    +        activeSession1 = SparkSession.getActiveSession()
    +        session1 = self.spark
    +        self.assertEqual(session1, activeSession1)
    +        session2 = self.spark.newSession()
    +        activeSession2 = SparkSession.getActiveSession()
    +        self.assertEqual(session1, activeSession2)
    +        self.assertNotEqual(session2, activeSession2)
    +        session2.createDataFrame([(1, 'Alice')], ['age', 'name'])
    +        activeSession3 = SparkSession.getActiveSession()
    +        self.assertEqual(session2, activeSession3)
    +        session1.createDataFrame([(1, 'Alice')], ['age', 'name'])
    +        activeSession4 = SparkSession.getActiveSession()
    +        self.assertEqual(session1, activeSession4)
    +        session2.stop()
    --- End diff --
    
    I think you can put this in try-finally


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to