Repository: spark
Updated Branches:
  refs/heads/branch-2.0 0a8fd2eb8 -> 8b7e56121


[SPARK-15159][SPARKR] SparkR SparkSession API

## What changes were proposed in this pull request?

This PR introduces the new SparkSession API for SparkR.
`sparkR.session.getOrCreate()` and `sparkR.session.stop()`

"getOrCreate" is a bit unusual in R but it's important to name this clearly.

SparkR implementation should
- SparkSession is the main entrypoint (vs SparkContext; due to limited 
functionality supported with SparkContext in SparkR)
- SparkSession replaces SQLContext and HiveContext (both a wrapper around 
SparkSession, and because of API changes, supporting all 3 would be a lot more 
work)
- Changes to SparkSession is mostly transparent to users due to SPARK-10903
- Full backward compatibility is expected - users should be able to initialize 
everything just in Spark 1.6.1 (`sparkR.init()`), but with deprecation warning
- Mostly cosmetic changes to parameter list - users should be able to move to 
`sparkR.session.getOrCreate()` easily
- An advanced syntax with named parameters (aka varargs aka "...") is 
supported; that should be closer to the Builder syntax that is in Scala/Python 
(which unfortunately does not work in R because it will look like this: 
`enableHiveSupport(config(config(master(appName(builder(), "foo"), "local"), 
"first", "value"), "next, "value"))`
- Updating config on an existing SparkSession is supported, the behavior is the 
same as Python, in which config is applied to both SparkContext and SparkSession
- Some SparkSession changes are not matched in SparkR, mostly because it would 
be breaking API change: `catalog` object, `createOrReplaceTempView`
- Other SQLContext workarounds are replicated in SparkR, eg. `tables`, 
`tableNames`
- `sparkR` shell is updated to use the SparkSession entrypoint (`sqlContext` is 
removed, just like with Scale/Python)
- All tests are updated to use the SparkSession entrypoint
- A bug in `read.jdbc` is fixed

TODO
- [x] Add more tests
- [ ] Separate PR - update all roxygen2 doc coding example
- [ ] Separate PR - update SparkR programming guide

## How was this patch tested?

unit tests, manual tests

shivaram sun-rui rxin

Author: Felix Cheung <felixcheun...@hotmail.com>
Author: felixcheung <felixcheun...@hotmail.com>

Closes #13635 from felixcheung/rsparksession.

(cherry picked from commit 8c198e246d64b5779dc3a2625d06ec958553a20b)
Signed-off-by: Shivaram Venkataraman <shiva...@cs.berkeley.edu>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/8b7e5612
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/8b7e5612
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/8b7e5612

Branch: refs/heads/branch-2.0
Commit: 8b7e561210a29d66317ce66f598d4bd2ad2c8087
Parents: 0a8fd2eb
Author: Felix Cheung <felixcheun...@hotmail.com>
Authored: Fri Jun 17 21:36:01 2016 -0700
Committer: Shivaram Venkataraman <shiva...@cs.berkeley.edu>
Committed: Fri Jun 17 21:36:10 2016 -0700

----------------------------------------------------------------------
 R/pkg/NAMESPACE                                 |   8 +-
 R/pkg/R/DataFrame.R                             |   8 +-
 R/pkg/R/SQLContext.R                            | 109 +++++------
 R/pkg/R/backend.R                               |   2 +-
 R/pkg/R/sparkR.R                                | 183 ++++++++++++++-----
 R/pkg/R/utils.R                                 |   9 +
 R/pkg/inst/profile/shell.R                      |  12 +-
 R/pkg/inst/tests/testthat/jarTest.R             |   4 +-
 R/pkg/inst/tests/testthat/packageInAJarTest.R   |   4 +-
 R/pkg/inst/tests/testthat/test_Serde.R          |   2 +-
 R/pkg/inst/tests/testthat/test_binaryFile.R     |   3 +-
 .../inst/tests/testthat/test_binary_function.R  |   3 +-
 R/pkg/inst/tests/testthat/test_broadcast.R      |   3 +-
 R/pkg/inst/tests/testthat/test_context.R        |  41 +++--
 R/pkg/inst/tests/testthat/test_includePackage.R |   3 +-
 R/pkg/inst/tests/testthat/test_mllib.R          |   5 +-
 .../tests/testthat/test_parallelize_collect.R   |   3 +-
 R/pkg/inst/tests/testthat/test_rdd.R            |   3 +-
 R/pkg/inst/tests/testthat/test_shuffle.R        |   3 +-
 R/pkg/inst/tests/testthat/test_sparkSQL.R       |  86 +++++++--
 R/pkg/inst/tests/testthat/test_take.R           |  17 +-
 R/pkg/inst/tests/testthat/test_textFile.R       |   3 +-
 R/pkg/inst/tests/testthat/test_utils.R          |  16 +-
 .../org/apache/spark/sql/api/r/SQLUtils.scala   |  76 ++++++--
 24 files changed, 420 insertions(+), 186 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/8b7e5612/R/pkg/NAMESPACE
----------------------------------------------------------------------
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index 9412ec3..82e56ca 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -6,10 +6,15 @@ importFrom(methods, setGeneric, setMethod, setOldClass)
 #useDynLib(SparkR, stringHashCode)
 
 # S3 methods exported
+export("sparkR.session")
 export("sparkR.init")
 export("sparkR.stop")
+export("sparkR.session.stop")
 export("print.jobj")
 
+export("sparkRSQL.init",
+       "sparkRHive.init")
+
 # MLlib integration
 exportMethods("glm",
               "spark.glm",
@@ -287,9 +292,6 @@ exportMethods("%in%",
 exportClasses("GroupedData")
 exportMethods("agg")
 
-export("sparkRSQL.init",
-       "sparkRHive.init")
-
 export("as.DataFrame",
        "cacheTable",
        "clearCache",

http://git-wip-us.apache.org/repos/asf/spark/blob/8b7e5612/R/pkg/R/DataFrame.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R
index 4e04456..ea091c8 100644
--- a/R/pkg/R/DataFrame.R
+++ b/R/pkg/R/DataFrame.R
@@ -2333,9 +2333,7 @@ setMethod("write.df",
           signature(df = "SparkDataFrame", path = "character"),
           function(df, path, source = NULL, mode = "error", ...){
             if (is.null(source)) {
-              sqlContext <- getSqlContext()
-              source <- callJMethod(sqlContext, "getConf", 
"spark.sql.sources.default",
-                                    "org.apache.spark.sql.parquet")
+              source <- getDefaultSqlSource()
             }
             jmode <- convertToJSaveMode(mode)
             options <- varargsToEnv(...)
@@ -2393,9 +2391,7 @@ setMethod("saveAsTable",
           signature(df = "SparkDataFrame", tableName = "character"),
           function(df, tableName, source = NULL, mode="error", ...){
             if (is.null(source)) {
-              sqlContext <- getSqlContext()
-              source <- callJMethod(sqlContext, "getConf", 
"spark.sql.sources.default",
-                                    "org.apache.spark.sql.parquet")
+              source <- getDefaultSqlSource()
             }
             jmode <- convertToJSaveMode(mode)
             options <- varargsToEnv(...)

http://git-wip-us.apache.org/repos/asf/spark/blob/8b7e5612/R/pkg/R/SQLContext.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R
index 914b02a..3232241 100644
--- a/R/pkg/R/SQLContext.R
+++ b/R/pkg/R/SQLContext.R
@@ -53,7 +53,8 @@ dispatchFunc <- function(newFuncSig, x, ...) {
   # Strip sqlContext from list of parameters and then pass the rest along.
   contextNames <- c("org.apache.spark.sql.SQLContext",
                     "org.apache.spark.sql.hive.HiveContext",
-                    "org.apache.spark.sql.hive.test.TestHiveContext")
+                    "org.apache.spark.sql.hive.test.TestHiveContext",
+                    "org.apache.spark.sql.SparkSession")
   if (missing(x) && length(list(...)) == 0) {
     f()
   } else if (class(x) == "jobj" &&
@@ -65,14 +66,12 @@ dispatchFunc <- function(newFuncSig, x, ...) {
   }
 }
 
-#' return the SQL Context
-getSqlContext <- function() {
-  if (exists(".sparkRHivesc", envir = .sparkREnv)) {
-    get(".sparkRHivesc", envir = .sparkREnv)
-  } else if (exists(".sparkRSQLsc", envir = .sparkREnv)) {
-    get(".sparkRSQLsc", envir = .sparkREnv)
+#' return the SparkSession
+getSparkSession <- function() {
+  if (exists(".sparkRsession", envir = .sparkREnv)) {
+    get(".sparkRsession", envir = .sparkREnv)
   } else {
-    stop("SQL context not initialized")
+    stop("SparkSession not initialized")
   }
 }
 
@@ -109,6 +108,13 @@ infer_type <- function(x) {
   }
 }
 
+getDefaultSqlSource <- function() {
+  sparkSession <- getSparkSession()
+  conf <- callJMethod(sparkSession, "conf")
+  source <- callJMethod(conf, "get", "spark.sql.sources.default", 
"org.apache.spark.sql.parquet")
+  source
+}
+
 #' Create a SparkDataFrame
 #'
 #' Converts R data.frame or list into SparkDataFrame.
@@ -131,7 +137,7 @@ infer_type <- function(x) {
 
 # TODO(davies): support sampling and infer type from NA
 createDataFrame.default <- function(data, schema = NULL, samplingRatio = 1.0) {
-  sqlContext <- getSqlContext()
+  sparkSession <- getSparkSession()
   if (is.data.frame(data)) {
       # get the names of columns, they will be put into RDD
       if (is.null(schema)) {
@@ -158,7 +164,7 @@ createDataFrame.default <- function(data, schema = NULL, 
samplingRatio = 1.0) {
       data <- do.call(mapply, append(args, data))
   }
   if (is.list(data)) {
-    sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", 
"getJavaSparkContext", sqlContext)
+    sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", 
"getJavaSparkContext", sparkSession)
     rdd <- parallelize(sc, data)
   } else if (inherits(data, "RDD")) {
     rdd <- data
@@ -201,7 +207,7 @@ createDataFrame.default <- function(data, schema = NULL, 
samplingRatio = 1.0) {
   jrdd <- getJRDD(lapply(rdd, function(x) x), "row")
   srdd <- callJMethod(jrdd, "rdd")
   sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "createDF",
-                     srdd, schema$jobj, sqlContext)
+                     srdd, schema$jobj, sparkSession)
   dataFrame(sdf)
 }
 
@@ -265,10 +271,10 @@ setMethod("toDF", signature(x = "RDD"),
 #' @method read.json default
 
 read.json.default <- function(path) {
-  sqlContext <- getSqlContext()
+  sparkSession <- getSparkSession()
   # Allow the user to have a more flexible definiton of the text file path
   paths <- as.list(suppressWarnings(normalizePath(path)))
-  read <- callJMethod(sqlContext, "read")
+  read <- callJMethod(sparkSession, "read")
   sdf <- callJMethod(read, "json", paths)
   dataFrame(sdf)
 }
@@ -336,10 +342,10 @@ jsonRDD <- function(sqlContext, rdd, schema = NULL, 
samplingRatio = 1.0) {
 #' @method read.parquet default
 
 read.parquet.default <- function(path) {
-  sqlContext <- getSqlContext()
+  sparkSession <- getSparkSession()
   # Allow the user to have a more flexible definiton of the text file path
   paths <- as.list(suppressWarnings(normalizePath(path)))
-  read <- callJMethod(sqlContext, "read")
+  read <- callJMethod(sparkSession, "read")
   sdf <- callJMethod(read, "parquet", paths)
   dataFrame(sdf)
 }
@@ -385,10 +391,10 @@ parquetFile <- function(x, ...) {
 #' @method read.text default
 
 read.text.default <- function(path) {
-  sqlContext <- getSqlContext()
+  sparkSession <- getSparkSession()
   # Allow the user to have a more flexible definiton of the text file path
   paths <- as.list(suppressWarnings(normalizePath(path)))
-  read <- callJMethod(sqlContext, "read")
+  read <- callJMethod(sparkSession, "read")
   sdf <- callJMethod(read, "text", paths)
   dataFrame(sdf)
 }
@@ -418,8 +424,8 @@ read.text <- function(x, ...) {
 #' @method sql default
 
 sql.default <- function(sqlQuery) {
-  sqlContext <- getSqlContext()
-  sdf <- callJMethod(sqlContext, "sql", sqlQuery)
+  sparkSession <- getSparkSession()
+  sdf <- callJMethod(sparkSession, "sql", sqlQuery)
   dataFrame(sdf)
 }
 
@@ -449,8 +455,8 @@ sql <- function(x, ...) {
 #' @note since 2.0.0
 
 tableToDF <- function(tableName) {
-  sqlContext <- getSqlContext()
-  sdf <- callJMethod(sqlContext, "table", tableName)
+  sparkSession <- getSparkSession()
+  sdf <- callJMethod(sparkSession, "table", tableName)
   dataFrame(sdf)
 }
 
@@ -472,12 +478,8 @@ tableToDF <- function(tableName) {
 #' @method tables default
 
 tables.default <- function(databaseName = NULL) {
-  sqlContext <- getSqlContext()
-  jdf <- if (is.null(databaseName)) {
-    callJMethod(sqlContext, "tables")
-  } else {
-    callJMethod(sqlContext, "tables", databaseName)
-  }
+  sparkSession <- getSparkSession()
+  jdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getTables", 
sparkSession, databaseName)
   dataFrame(jdf)
 }
 
@@ -503,12 +505,11 @@ tables <- function(x, ...) {
 #' @method tableNames default
 
 tableNames.default <- function(databaseName = NULL) {
-  sqlContext <- getSqlContext()
-  if (is.null(databaseName)) {
-    callJMethod(sqlContext, "tableNames")
-  } else {
-    callJMethod(sqlContext, "tableNames", databaseName)
-  }
+  sparkSession <- getSparkSession()
+  callJStatic("org.apache.spark.sql.api.r.SQLUtils",
+              "getTableNames",
+              sparkSession,
+              databaseName)
 }
 
 tableNames <- function(x, ...) {
@@ -536,8 +537,9 @@ tableNames <- function(x, ...) {
 #' @method cacheTable default
 
 cacheTable.default <- function(tableName) {
-  sqlContext <- getSqlContext()
-  callJMethod(sqlContext, "cacheTable", tableName)
+  sparkSession <- getSparkSession()
+  catalog <- callJMethod(sparkSession, "catalog")
+  callJMethod(catalog, "cacheTable", tableName)
 }
 
 cacheTable <- function(x, ...) {
@@ -565,8 +567,9 @@ cacheTable <- function(x, ...) {
 #' @method uncacheTable default
 
 uncacheTable.default <- function(tableName) {
-  sqlContext <- getSqlContext()
-  callJMethod(sqlContext, "uncacheTable", tableName)
+  sparkSession <- getSparkSession()
+  catalog <- callJMethod(sparkSession, "catalog")
+  callJMethod(catalog, "uncacheTable", tableName)
 }
 
 uncacheTable <- function(x, ...) {
@@ -587,8 +590,9 @@ uncacheTable <- function(x, ...) {
 #' @method clearCache default
 
 clearCache.default <- function() {
-  sqlContext <- getSqlContext()
-  callJMethod(sqlContext, "clearCache")
+  sparkSession <- getSparkSession()
+  catalog <- callJMethod(sparkSession, "catalog")
+  callJMethod(catalog, "clearCache")
 }
 
 clearCache <- function() {
@@ -615,11 +619,12 @@ clearCache <- function() {
 #' @method dropTempTable default
 
 dropTempTable.default <- function(tableName) {
-  sqlContext <- getSqlContext()
+  sparkSession <- getSparkSession()
   if (class(tableName) != "character") {
     stop("tableName must be a string.")
   }
-  callJMethod(sqlContext, "dropTempTable", tableName)
+  catalog <- callJMethod(sparkSession, "catalog")
+  callJMethod(catalog, "dropTempView", tableName)
 }
 
 dropTempTable <- function(x, ...) {
@@ -655,21 +660,21 @@ dropTempTable <- function(x, ...) {
 #' @method read.df default
 
 read.df.default <- function(path = NULL, source = NULL, schema = NULL, ...) {
-  sqlContext <- getSqlContext()
+  sparkSession <- getSparkSession()
   options <- varargsToEnv(...)
   if (!is.null(path)) {
     options[["path"]] <- path
   }
   if (is.null(source)) {
-    source <- callJMethod(sqlContext, "getConf", "spark.sql.sources.default",
-                          "org.apache.spark.sql.parquet")
+    source <- getDefaultSqlSource()
   }
   if (!is.null(schema)) {
     stopifnot(class(schema) == "structType")
-    sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "loadDF", 
sqlContext, source,
+    sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "loadDF", 
sparkSession, source,
                        schema$jobj, options)
   } else {
-    sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "loadDF", 
sqlContext, source, options)
+    sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils",
+                       "loadDF", sparkSession, source, options)
   }
   dataFrame(sdf)
 }
@@ -715,12 +720,13 @@ loadDF <- function(x, ...) {
 #' @method createExternalTable default
 
 createExternalTable.default <- function(tableName, path = NULL, source = NULL, 
...) {
-  sqlContext <- getSqlContext()
+  sparkSession <- getSparkSession()
   options <- varargsToEnv(...)
   if (!is.null(path)) {
     options[["path"]] <- path
   }
-  sdf <- callJMethod(sqlContext, "createExternalTable", tableName, source, 
options)
+  catalog <- callJMethod(sparkSession, "catalog")
+  sdf <- callJMethod(catalog, "createExternalTable", tableName, source, 
options)
   dataFrame(sdf)
 }
 
@@ -767,12 +773,11 @@ read.jdbc <- function(url, tableName,
                       partitionColumn = NULL, lowerBound = NULL, upperBound = 
NULL,
                       numPartitions = 0L, predicates = list(), ...) {
   jprops <- varargsToJProperties(...)
-
-  read <- callJMethod(sqlContext, "read")
+  sparkSession <- getSparkSession()
+  read <- callJMethod(sparkSession, "read")
   if (!is.null(partitionColumn)) {
     if (is.null(numPartitions) || numPartitions == 0) {
-      sqlContext <- getSqlContext()
-      sc <- callJMethod(sqlContext, "sparkContext")
+      sc <- callJMethod(sparkSession, "sparkContext")
       numPartitions <- callJMethod(sc, "defaultParallelism")
     } else {
       numPartitions <- numToInt(numPartitions)

http://git-wip-us.apache.org/repos/asf/spark/blob/8b7e5612/R/pkg/R/backend.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/backend.R b/R/pkg/R/backend.R
index 6c81492..03e70bb 100644
--- a/R/pkg/R/backend.R
+++ b/R/pkg/R/backend.R
@@ -68,7 +68,7 @@ isRemoveMethod <- function(isStatic, objId, methodName) {
 # methodName - name of method to be invoked
 invokeJava <- function(isStatic, objId, methodName, ...) {
   if (!exists(".sparkRCon", .sparkREnv)) {
-    stop("No connection to backend found. Please re-run sparkR.init")
+    stop("No connection to backend found. Please re-run sparkR.session()")
   }
 
   # If this isn't a removeJObject call

http://git-wip-us.apache.org/repos/asf/spark/blob/8b7e5612/R/pkg/R/sparkR.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R
index 04a8b1e..0dfd7b7 100644
--- a/R/pkg/R/sparkR.R
+++ b/R/pkg/R/sparkR.R
@@ -28,10 +28,21 @@ connExists <- function(env) {
   })
 }
 
-#' Stop the Spark context.
-#'
-#' Also terminates the backend this R session is connected to
+#' @rdname sparkR.session.stop
+#' @name sparkR.stop
+#' @export
 sparkR.stop <- function() {
+  sparkR.session.stop()
+}
+
+#' Stop the Spark Session and Spark Context.
+#'
+#' Also terminates the backend this R session is connected to.
+#' @rdname sparkR.session.stop
+#' @name sparkR.session.stop
+#' @export
+#' @note since 2.0.0
+sparkR.session.stop <- function() {
   env <- .sparkREnv
   if (exists(".sparkRCon", envir = env)) {
     if (exists(".sparkRjsc", envir = env)) {
@@ -39,12 +50,8 @@ sparkR.stop <- function() {
       callJMethod(sc, "stop")
       rm(".sparkRjsc", envir = env)
 
-      if (exists(".sparkRSQLsc", envir = env)) {
-        rm(".sparkRSQLsc", envir = env)
-      }
-
-      if (exists(".sparkRHivesc", envir = env)) {
-        rm(".sparkRHivesc", envir = env)
+      if (exists(".sparkRsession", envir = env)) {
+        rm(".sparkRsession", envir = env)
       }
     }
 
@@ -80,7 +87,7 @@ sparkR.stop <- function() {
   clearJobjs()
 }
 
-#' Initialize a new Spark Context.
+#' (Deprecated) Initialize a new Spark Context.
 #'
 #' This function initializes a new SparkContext. For details on how to 
initialize
 #' and use SparkR, refer to SparkR programming guide at
@@ -93,6 +100,8 @@ sparkR.stop <- function() {
 #' @param sparkExecutorEnv Named list of environment variables to be used when 
launching executors
 #' @param sparkJars Character vector of jar files to pass to the worker nodes
 #' @param sparkPackages Character vector of packages from spark-packages.org
+#' @seealso \link{sparkR.session}
+#' @rdname sparkR.init-deprecated
 #' @export
 #' @examples
 #'\dontrun{
@@ -114,18 +123,35 @@ sparkR.init <- function(
   sparkExecutorEnv = list(),
   sparkJars = "",
   sparkPackages = "") {
+  .Deprecated("sparkR.session")
+  sparkR.sparkContext(master,
+     appName,
+     sparkHome,
+     convertNamedListToEnv(sparkEnvir),
+     convertNamedListToEnv(sparkExecutorEnv),
+     sparkJars,
+     sparkPackages)
+}
+
+# Internal function to handle creating the SparkContext.
+sparkR.sparkContext <- function(
+  master = "",
+  appName = "SparkR",
+  sparkHome = Sys.getenv("SPARK_HOME"),
+  sparkEnvirMap = new.env(),
+  sparkExecutorEnvMap = new.env(),
+  sparkJars = "",
+  sparkPackages = "") {
 
   if (exists(".sparkRjsc", envir = .sparkREnv)) {
     cat(paste("Re-using existing Spark Context.",
-              "Please stop SparkR with sparkR.stop() or restart R to create a 
new Spark Context\n"))
+              "Call sparkR.session.stop() or restart R to create a new Spark 
Context\n"))
     return(get(".sparkRjsc", envir = .sparkREnv))
   }
 
   jars <- processSparkJars(sparkJars)
   packages <- processSparkPackages(sparkPackages)
 
-  sparkEnvirMap <- convertNamedListToEnv(sparkEnvir)
-
   existingPort <- Sys.getenv("EXISTING_SPARKR_BACKEND_PORT", "")
   if (existingPort != "") {
     backendPort <- existingPort
@@ -183,7 +209,6 @@ sparkR.init <- function(
     sparkHome <- suppressWarnings(normalizePath(sparkHome))
   }
 
-  sparkExecutorEnvMap <- convertNamedListToEnv(sparkExecutorEnv)
   if (is.null(sparkExecutorEnvMap$LD_LIBRARY_PATH)) {
     sparkExecutorEnvMap[["LD_LIBRARY_PATH"]] <-
       paste0("$LD_LIBRARY_PATH:", Sys.getenv("LD_LIBRARY_PATH"))
@@ -225,12 +250,17 @@ sparkR.init <- function(
   sc
 }
 
-#' Initialize a new SQLContext.
+#' (Deprecated) Initialize a new SQLContext.
 #'
 #' This function creates a SparkContext from an existing JavaSparkContext and
 #' then uses it to initialize a new SQLContext
 #'
+#' Starting SparkR 2.0, a SparkSession is initialized and returned instead.
+#' This API is deprecated and kept for backward compatibility only.
+#'
 #' @param jsc The existing JavaSparkContext created with SparkR.init()
+#' @seealso \link{sparkR.session}
+#' @rdname sparkRSQL.init-deprecated
 #' @export
 #' @examples
 #'\dontrun{
@@ -239,29 +269,26 @@ sparkR.init <- function(
 #'}
 
 sparkRSQL.init <- function(jsc = NULL) {
-  if (exists(".sparkRSQLsc", envir = .sparkREnv)) {
-    return(get(".sparkRSQLsc", envir = .sparkREnv))
-  }
+  .Deprecated("sparkR.session")
 
-  # If jsc is NULL, create a Spark Context
-  sc <- if (is.null(jsc)) {
-    sparkR.init()
-  } else {
-    jsc
+  if (exists(".sparkRsession", envir = .sparkREnv)) {
+    return(get(".sparkRsession", envir = .sparkREnv))
   }
 
-  sqlContext <- callJStatic("org.apache.spark.sql.api.r.SQLUtils",
-                            "createSQLContext",
-                            sc)
-  assign(".sparkRSQLsc", sqlContext, envir = .sparkREnv)
-  sqlContext
+  # Default to without Hive support for backward compatibility.
+  sparkR.session(enableHiveSupport = FALSE)
 }
 
-#' Initialize a new HiveContext.
+#' (Deprecated) Initialize a new HiveContext.
 #'
 #' This function creates a HiveContext from an existing JavaSparkContext
 #'
+#' Starting SparkR 2.0, a SparkSession is initialized and returned instead.
+#' This API is deprecated and kept for backward compatibility only.
+#'
 #' @param jsc The existing JavaSparkContext created with SparkR.init()
+#' @seealso \link{sparkR.session}
+#' @rdname sparkRHive.init-deprecated
 #' @export
 #' @examples
 #'\dontrun{
@@ -270,27 +297,93 @@ sparkRSQL.init <- function(jsc = NULL) {
 #'}
 
 sparkRHive.init <- function(jsc = NULL) {
-  if (exists(".sparkRHivesc", envir = .sparkREnv)) {
-    return(get(".sparkRHivesc", envir = .sparkREnv))
+  .Deprecated("sparkR.session")
+
+  if (exists(".sparkRsession", envir = .sparkREnv)) {
+    return(get(".sparkRsession", envir = .sparkREnv))
   }
 
-  # If jsc is NULL, create a Spark Context
-  sc <- if (is.null(jsc)) {
-    sparkR.init()
-  } else {
-    jsc
+  # Default to without Hive support for backward compatibility.
+  sparkR.session(enableHiveSupport = TRUE)
+}
+
+#' Get the existing SparkSession or initialize a new SparkSession.
+#'
+#' Additional Spark properties can be set (...), and these named parameters 
take priority over
+#' over values in master, appName, named lists of sparkConfig.
+#'
+#' @param master The Spark master URL
+#' @param appName Application name to register with cluster manager
+#' @param sparkHome Spark Home directory
+#' @param sparkConfig Named list of Spark configuration to set on worker nodes
+#' @param sparkJars Character vector of jar files to pass to the worker nodes
+#' @param sparkPackages Character vector of packages from spark-packages.org
+#' @param enableHiveSupport Enable support for Hive, fallback if not built 
with Hive support; once
+#'        set, this cannot be turned off on an existing session
+#' @export
+#' @examples
+#'\dontrun{
+#' sparkR.session()
+#' df <- read.json(path)
+#'
+#' sparkR.session("local[2]", "SparkR", "/home/spark")
+#' sparkR.session("yarn-client", "SparkR", "/home/spark",
+#'                list(spark.executor.memory="4g"),
+#'                c("one.jar", "two.jar", "three.jar"),
+#'                c("com.databricks:spark-avro_2.10:2.0.1"))
+#' sparkR.session(spark.master = "yarn-client", spark.executor.memory = "4g")
+#'}
+#' @note since 2.0.0
+
+sparkR.session <- function(
+  master = "",
+  appName = "SparkR",
+  sparkHome = Sys.getenv("SPARK_HOME"),
+  sparkConfig = list(),
+  sparkJars = "",
+  sparkPackages = "",
+  enableHiveSupport = TRUE,
+  ...) {
+
+  sparkConfigMap <- convertNamedListToEnv(sparkConfig)
+  namedParams <- list(...)
+  if (length(namedParams) > 0) {
+    paramMap <- convertNamedListToEnv(namedParams)
+    # Override for certain named parameters
+    if (exists("spark.master", envir = paramMap)) {
+      master <- paramMap[["spark.master"]]
+    }
+    if (exists("spark.app.name", envir = paramMap)) {
+      appName <- paramMap[["spark.app.name"]]
+    }
+    overrideEnvs(sparkConfigMap, paramMap)
   }
 
-  ssc <- callJMethod(sc, "sc")
-  hiveCtx <- tryCatch({
-    newJObject("org.apache.spark.sql.hive.HiveContext", ssc)
-  },
-  error = function(err) {
-    stop("Spark SQL is not built with Hive support")
-  })
+  if (!exists(".sparkRjsc", envir = .sparkREnv)) {
+    sparkExecutorEnvMap <- new.env()
+    sparkR.sparkContext(master, appName, sparkHome, sparkConfigMap, 
sparkExecutorEnvMap,
+       sparkJars, sparkPackages)
+    stopifnot(exists(".sparkRjsc", envir = .sparkREnv))
+  }
 
-  assign(".sparkRHivesc", hiveCtx, envir = .sparkREnv)
-  hiveCtx
+  if (exists(".sparkRsession", envir = .sparkREnv)) {
+    sparkSession <- get(".sparkRsession", envir = .sparkREnv)
+    # Apply config to Spark Context and Spark Session if already there
+    # Cannot change enableHiveSupport
+    callJStatic("org.apache.spark.sql.api.r.SQLUtils",
+                "setSparkContextSessionConf",
+                sparkSession,
+                sparkConfigMap)
+  } else {
+    jsc <- get(".sparkRjsc", envir = .sparkREnv)
+    sparkSession <- callJStatic("org.apache.spark.sql.api.r.SQLUtils",
+                                "getOrCreateSparkSession",
+                                jsc,
+                                sparkConfigMap,
+                                enableHiveSupport)
+    assign(".sparkRsession", sparkSession, envir = .sparkREnv)
+  }
+  sparkSession
 }
 
 #' Assigns a group ID to all the jobs started by this thread until the group 
ID is set to a

http://git-wip-us.apache.org/repos/asf/spark/blob/8b7e5612/R/pkg/R/utils.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R
index b1b8ada..aafb344 100644
--- a/R/pkg/R/utils.R
+++ b/R/pkg/R/utils.R
@@ -317,6 +317,15 @@ convertEnvsToList <- function(keys, vals) {
          })
 }
 
+# Utility function to merge 2 environments with the second overriding values 
in the first
+# env1 is changed in place
+overrideEnvs <- function(env1, env2) {
+  lapply(ls(env2),
+         function(name) {
+           env1[[name]] <- env2[[name]]
+         })
+}
+
 # Utility function to capture the varargs into environment object
 varargsToEnv <- function(...) {
   # Based on http://stackoverflow.com/a/3057419/4577954

http://git-wip-us.apache.org/repos/asf/spark/blob/8b7e5612/R/pkg/inst/profile/shell.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/profile/shell.R b/R/pkg/inst/profile/shell.R
index 90a3761..8a8111a 100644
--- a/R/pkg/inst/profile/shell.R
+++ b/R/pkg/inst/profile/shell.R
@@ -18,17 +18,17 @@
 .First <- function() {
   home <- Sys.getenv("SPARK_HOME")
   .libPaths(c(file.path(home, "R", "lib"), .libPaths()))
-  Sys.setenv(NOAWT=1)
+  Sys.setenv(NOAWT = 1)
 
   # Make sure SparkR package is the last loaded one
   old <- getOption("defaultPackages")
   options(defaultPackages = c(old, "SparkR"))
 
-  sc <- SparkR::sparkR.init()
-  assign("sc", sc, envir=.GlobalEnv)
-  sqlContext <- SparkR::sparkRSQL.init(sc)
+  spark <- SparkR::sparkR.session()
+  assign("spark", spark, envir = .GlobalEnv)
+  sc <- SparkR:::callJStatic("org.apache.spark.sql.api.r.SQLUtils", 
"getJavaSparkContext", spark)
+  assign("sc", sc, envir = .GlobalEnv)
   sparkVer <- SparkR:::callJMethod(sc, "version")
-  assign("sqlContext", sqlContext, envir=.GlobalEnv)
   cat("\n Welcome to")
   cat("\n")
   cat("    ____              __", "\n")
@@ -43,5 +43,5 @@
   cat("    /_/", "\n")
   cat("\n")
 
-  cat("\n Spark context is available as sc, SQL context is available as 
sqlContext\n")
+  cat("\n SparkSession available as 'spark'.\n")
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/8b7e5612/R/pkg/inst/tests/testthat/jarTest.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/tests/testthat/jarTest.R 
b/R/pkg/inst/tests/testthat/jarTest.R
index d68bb20..84e4845 100644
--- a/R/pkg/inst/tests/testthat/jarTest.R
+++ b/R/pkg/inst/tests/testthat/jarTest.R
@@ -16,7 +16,7 @@
 #
 library(SparkR)
 
-sc <- sparkR.init()
+sparkSession <- sparkR.session()
 
 helloTest <- SparkR:::callJStatic("sparkR.test.hello",
                                   "helloWorld",
@@ -27,6 +27,6 @@ basicFunction <- 
SparkR:::callJStatic("sparkR.test.basicFunction",
                                       2L,
                                       2L)
 
-sparkR.stop()
+sparkR.session.stop()
 output <- c(helloTest, basicFunction)
 writeLines(output)

http://git-wip-us.apache.org/repos/asf/spark/blob/8b7e5612/R/pkg/inst/tests/testthat/packageInAJarTest.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/tests/testthat/packageInAJarTest.R 
b/R/pkg/inst/tests/testthat/packageInAJarTest.R
index c26b28b..940c91f 100644
--- a/R/pkg/inst/tests/testthat/packageInAJarTest.R
+++ b/R/pkg/inst/tests/testthat/packageInAJarTest.R
@@ -17,13 +17,13 @@
 library(SparkR)
 library(sparkPackageTest)
 
-sc <- sparkR.init()
+sparkSession <- sparkR.session()
 
 run1 <- myfunc(5L)
 
 run2 <- myfunc(-4L)
 
-sparkR.stop()
+sparkR.session.stop()
 
 if (run1 != 6) quit(save = "no", status = 1)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/8b7e5612/R/pkg/inst/tests/testthat/test_Serde.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/tests/testthat/test_Serde.R 
b/R/pkg/inst/tests/testthat/test_Serde.R
index dddce54..96fb6dd 100644
--- a/R/pkg/inst/tests/testthat/test_Serde.R
+++ b/R/pkg/inst/tests/testthat/test_Serde.R
@@ -17,7 +17,7 @@
 
 context("SerDe functionality")
 
-sc <- sparkR.init()
+sparkSession <- sparkR.session()
 
 test_that("SerDe of primitive types", {
   x <- callJStatic("SparkRHandler", "echo", 1L)

http://git-wip-us.apache.org/repos/asf/spark/blob/8b7e5612/R/pkg/inst/tests/testthat/test_binaryFile.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/tests/testthat/test_binaryFile.R 
b/R/pkg/inst/tests/testthat/test_binaryFile.R
index 976a755..b69f017 100644
--- a/R/pkg/inst/tests/testthat/test_binaryFile.R
+++ b/R/pkg/inst/tests/testthat/test_binaryFile.R
@@ -18,7 +18,8 @@
 context("functions on binary files")
 
 # JavaSparkContext handle
-sc <- sparkR.init()
+sparkSession <- sparkR.session()
+sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", 
"getJavaSparkContext", sparkSession)
 
 mockFile <- c("Spark is pretty.", "Spark is awesome.")
 

http://git-wip-us.apache.org/repos/asf/spark/blob/8b7e5612/R/pkg/inst/tests/testthat/test_binary_function.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/tests/testthat/test_binary_function.R 
b/R/pkg/inst/tests/testthat/test_binary_function.R
index 7bad4d2..6f51d20 100644
--- a/R/pkg/inst/tests/testthat/test_binary_function.R
+++ b/R/pkg/inst/tests/testthat/test_binary_function.R
@@ -18,7 +18,8 @@
 context("binary functions")
 
 # JavaSparkContext handle
-sc <- sparkR.init()
+sparkSession <- sparkR.session()
+sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", 
"getJavaSparkContext", sparkSession)
 
 # Data
 nums <- 1:10

http://git-wip-us.apache.org/repos/asf/spark/blob/8b7e5612/R/pkg/inst/tests/testthat/test_broadcast.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/tests/testthat/test_broadcast.R 
b/R/pkg/inst/tests/testthat/test_broadcast.R
index 8be6efc..cf1d432 100644
--- a/R/pkg/inst/tests/testthat/test_broadcast.R
+++ b/R/pkg/inst/tests/testthat/test_broadcast.R
@@ -18,7 +18,8 @@
 context("broadcast variables")
 
 # JavaSparkContext handle
-sc <- sparkR.init()
+sparkSession <- sparkR.session()
+sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", 
"getJavaSparkContext", sparkSession)
 
 # Partitioned data
 nums <- 1:2

http://git-wip-us.apache.org/repos/asf/spark/blob/8b7e5612/R/pkg/inst/tests/testthat/test_context.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/tests/testthat/test_context.R 
b/R/pkg/inst/tests/testthat/test_context.R
index 126484c..f123187 100644
--- a/R/pkg/inst/tests/testthat/test_context.R
+++ b/R/pkg/inst/tests/testthat/test_context.R
@@ -56,31 +56,33 @@ test_that("Check masked functions", {
 
 test_that("repeatedly starting and stopping SparkR", {
   for (i in 1:4) {
-    sc <- sparkR.init()
+    sc <- suppressWarnings(sparkR.init())
     rdd <- parallelize(sc, 1:20, 2L)
     expect_equal(count(rdd), 20)
-    sparkR.stop()
+    suppressWarnings(sparkR.stop())
   }
 })
 
-test_that("repeatedly starting and stopping SparkR SQL", {
-  for (i in 1:4) {
-    sc <- sparkR.init()
-    sqlContext <- sparkRSQL.init(sc)
-    df <- createDataFrame(data.frame(a = 1:20))
-    expect_equal(count(df), 20)
-    sparkR.stop()
-  }
-})
+# Does not work consistently even with Hive off
+# nolint start
+# test_that("repeatedly starting and stopping SparkR", {
+#   for (i in 1:4) {
+#     sparkR.session(enableHiveSupport = FALSE)
+#     df <- createDataFrame(data.frame(dummy=1:i))
+#     expect_equal(count(df), i)
+#     sparkR.session.stop()
+#     Sys.sleep(5) # Need more time to shutdown Hive metastore
+#   }
+# })
+# nolint end
 
 test_that("rdd GC across sparkR.stop", {
-  sparkR.stop()
-  sc <- sparkR.init() # sc should get id 0
+  sc <- sparkR.sparkContext() # sc should get id 0
   rdd1 <- parallelize(sc, 1:20, 2L) # rdd1 should get id 1
   rdd2 <- parallelize(sc, 1:10, 2L) # rdd2 should get id 2
-  sparkR.stop()
+  sparkR.session.stop()
 
-  sc <- sparkR.init() # sc should get id 0 again
+  sc <- sparkR.sparkContext() # sc should get id 0 again
 
   # GC rdd1 before creating rdd3 and rdd2 after
   rm(rdd1)
@@ -97,15 +99,17 @@ test_that("rdd GC across sparkR.stop", {
 })
 
 test_that("job group functions can be called", {
-  sc <- sparkR.init()
+  sc <- sparkR.sparkContext()
   setJobGroup(sc, "groupId", "job description", TRUE)
   cancelJobGroup(sc, "groupId")
   clearJobGroup(sc)
+  sparkR.session.stop()
 })
 
 test_that("utility function can be called", {
-  sc <- sparkR.init()
+  sc <- sparkR.sparkContext()
   setLogLevel(sc, "ERROR")
+  sparkR.session.stop()
 })
 
 test_that("getClientModeSparkSubmitOpts() returns spark-submit args from 
whitelist", {
@@ -156,7 +160,8 @@ test_that("sparkJars sparkPackages as comma-separated 
strings", {
 })
 
 test_that("spark.lapply should perform simple transforms", {
-  sc <- sparkR.init()
+  sc <- sparkR.sparkContext()
   doubled <- spark.lapply(sc, 1:10, function(x) { 2 * x })
   expect_equal(doubled, as.list(2 * 1:10))
+  sparkR.session.stop()
 })

http://git-wip-us.apache.org/repos/asf/spark/blob/8b7e5612/R/pkg/inst/tests/testthat/test_includePackage.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/tests/testthat/test_includePackage.R 
b/R/pkg/inst/tests/testthat/test_includePackage.R
index 8152b44..d6a3766 100644
--- a/R/pkg/inst/tests/testthat/test_includePackage.R
+++ b/R/pkg/inst/tests/testthat/test_includePackage.R
@@ -18,7 +18,8 @@
 context("include R packages")
 
 # JavaSparkContext handle
-sc <- sparkR.init()
+sparkSession <- sparkR.session()
+sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", 
"getJavaSparkContext", sparkSession)
 
 # Partitioned data
 nums <- 1:2

http://git-wip-us.apache.org/repos/asf/spark/blob/8b7e5612/R/pkg/inst/tests/testthat/test_mllib.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/tests/testthat/test_mllib.R 
b/R/pkg/inst/tests/testthat/test_mllib.R
index 59ef15c..c8c5ef2 100644
--- a/R/pkg/inst/tests/testthat/test_mllib.R
+++ b/R/pkg/inst/tests/testthat/test_mllib.R
@@ -20,10 +20,7 @@ library(testthat)
 context("MLlib functions")
 
 # Tests for MLlib functions in SparkR
-
-sc <- sparkR.init()
-
-sqlContext <- sparkRSQL.init(sc)
+sparkSession <- sparkR.session()
 
 test_that("formula of spark.glm", {
   training <- suppressWarnings(createDataFrame(iris))

http://git-wip-us.apache.org/repos/asf/spark/blob/8b7e5612/R/pkg/inst/tests/testthat/test_parallelize_collect.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/tests/testthat/test_parallelize_collect.R 
b/R/pkg/inst/tests/testthat/test_parallelize_collect.R
index 2552127..f79a8a7 100644
--- a/R/pkg/inst/tests/testthat/test_parallelize_collect.R
+++ b/R/pkg/inst/tests/testthat/test_parallelize_collect.R
@@ -33,7 +33,8 @@ numPairs <- list(list(1, 1), list(1, 2), list(2, 2), list(2, 
3))
 strPairs <- list(list(strList, strList), list(strList, strList))
 
 # JavaSparkContext handle
-jsc <- sparkR.init()
+sparkSession <- sparkR.session()
+jsc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", 
"getJavaSparkContext", sparkSession)
 
 # Tests
 

http://git-wip-us.apache.org/repos/asf/spark/blob/8b7e5612/R/pkg/inst/tests/testthat/test_rdd.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/tests/testthat/test_rdd.R 
b/R/pkg/inst/tests/testthat/test_rdd.R
index b6c8e1d..429311d 100644
--- a/R/pkg/inst/tests/testthat/test_rdd.R
+++ b/R/pkg/inst/tests/testthat/test_rdd.R
@@ -18,7 +18,8 @@
 context("basic RDD functions")
 
 # JavaSparkContext handle
-sc <- sparkR.init()
+sparkSession <- sparkR.session()
+sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", 
"getJavaSparkContext", sparkSession)
 
 # Data
 nums <- 1:10

http://git-wip-us.apache.org/repos/asf/spark/blob/8b7e5612/R/pkg/inst/tests/testthat/test_shuffle.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/tests/testthat/test_shuffle.R 
b/R/pkg/inst/tests/testthat/test_shuffle.R
index d3d0f8a..7d4f342 100644
--- a/R/pkg/inst/tests/testthat/test_shuffle.R
+++ b/R/pkg/inst/tests/testthat/test_shuffle.R
@@ -18,7 +18,8 @@
 context("partitionBy, groupByKey, reduceByKey etc.")
 
 # JavaSparkContext handle
-sc <- sparkR.init()
+sparkSession <- sparkR.session()
+sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", 
"getJavaSparkContext", sparkSession)
 
 # Data
 intPairs <- list(list(1L, -1), list(2L, 100), list(2L, 1), list(1L, 200))

http://git-wip-us.apache.org/repos/asf/spark/blob/8b7e5612/R/pkg/inst/tests/testthat/test_sparkSQL.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R 
b/R/pkg/inst/tests/testthat/test_sparkSQL.R
index 607bd9c..fcc2ab3 100644
--- a/R/pkg/inst/tests/testthat/test_sparkSQL.R
+++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R
@@ -33,26 +33,35 @@ markUtf8 <- function(s) {
 }
 
 setHiveContext <- function(sc) {
-  ssc <- callJMethod(sc, "sc")
-  hiveCtx <- tryCatch({
-    newJObject("org.apache.spark.sql.hive.test.TestHiveContext", ssc)
-  },
-  error = function(err) {
-    skip("Hive is not build with SparkSQL, skipped")
-  })
-  assign(".sparkRHivesc", hiveCtx, envir = .sparkREnv)
-  hiveCtx
+  if (exists(".testHiveSession", envir = .sparkREnv)) {
+    hiveSession <- get(".testHiveSession", envir = .sparkREnv)
+  } else {
+    # initialize once and reuse
+    ssc <- callJMethod(sc, "sc")
+    hiveCtx <- tryCatch({
+      newJObject("org.apache.spark.sql.hive.test.TestHiveContext", ssc)
+    },
+    error = function(err) {
+      skip("Hive is not build with SparkSQL, skipped")
+    })
+    hiveSession <- callJMethod(hiveCtx, "sparkSession")
+  }
+  previousSession <- get(".sparkRsession", envir = .sparkREnv)
+  assign(".sparkRsession", hiveSession, envir = .sparkREnv)
+  assign(".prevSparkRsession", previousSession, envir = .sparkREnv)
+  hiveSession
 }
 
 unsetHiveContext <- function() {
-  remove(".sparkRHivesc", envir = .sparkREnv)
+  previousSession <- get(".prevSparkRsession", envir = .sparkREnv)
+  assign(".sparkRsession", previousSession, envir = .sparkREnv)
+  remove(".prevSparkRsession", envir = .sparkREnv)
 }
 
 # Tests for SparkSQL functions in SparkR
 
-sc <- sparkR.init()
-
-sqlContext <- sparkRSQL.init(sc)
+sparkSession <- sparkR.session()
+sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", 
"getJavaSparkContext", sparkSession)
 
 mockLines <- c("{\"name\":\"Michael\"}",
                "{\"name\":\"Andy\", \"age\":30}",
@@ -79,7 +88,16 @@ complexTypeJsonPath <- tempfile(pattern = "sparkr-test", 
fileext = ".tmp")
 writeLines(mockLinesComplexType, complexTypeJsonPath)
 
 test_that("calling sparkRSQL.init returns existing SQL context", {
-  expect_equal(sparkRSQL.init(sc), sqlContext)
+  sqlContext <- suppressWarnings(sparkRSQL.init(sc))
+  expect_equal(suppressWarnings(sparkRSQL.init(sc)), sqlContext)
+})
+
+test_that("calling sparkRSQL.init returns existing SparkSession", {
+  expect_equal(suppressWarnings(sparkRSQL.init(sc)), sparkSession)
+})
+
+test_that("calling sparkR.session returns existing SparkSession", {
+  expect_equal(sparkR.session(), sparkSession)
 })
 
 test_that("infer types and check types", {
@@ -431,6 +449,7 @@ test_that("read/write json files", {
 })
 
 test_that("jsonRDD() on a RDD with json string", {
+  sqlContext <- suppressWarnings(sparkRSQL.init(sc))
   rdd <- parallelize(sc, mockLines)
   expect_equal(count(rdd), 3)
   df <- suppressWarnings(jsonRDD(sqlContext, rdd))
@@ -2228,7 +2247,6 @@ test_that("gapply() on a DataFrame", {
 })
 
 test_that("Window functions on a DataFrame", {
-  setHiveContext(sc)
   df <- createDataFrame(list(list(1L, "1"), list(2L, "2"), list(1L, "1"), 
list(2L, "2")),
                         schema = c("key", "value"))
   ws <- orderBy(window.partitionBy("key"), "value")
@@ -2253,10 +2271,10 @@ test_that("Window functions on a DataFrame", {
   result <- collect(select(df, over(lead("key", 1), ws), over(lead("value", 
1), ws)))
   names(result) <- c("key", "value")
   expect_equal(result, expected)
-  unsetHiveContext()
 })
 
 test_that("createDataFrame sqlContext parameter backward compatibility", {
+  sqlContext <- suppressWarnings(sparkRSQL.init(sc))
   a <- 1:3
   b <- c("a", "b", "c")
   ldf <- data.frame(a, b)
@@ -2283,7 +2301,6 @@ test_that("createDataFrame sqlContext parameter backward 
compatibility", {
 test_that("randomSplit", {
   num <- 4000
   df <- createDataFrame(data.frame(id = 1:num))
-
   weights <- c(2, 3, 5)
   df_list <- randomSplit(df, weights)
   expect_equal(length(weights), length(df_list))
@@ -2298,6 +2315,41 @@ test_that("randomSplit", {
   expect_true(all(sapply(abs(counts / num - weights / sum(weights)), 
function(e) { e < 0.05 })))
 })
 
+test_that("Change config on SparkSession", {
+  # first, set it to a random but known value
+  conf <- callJMethod(sparkSession, "conf")
+  property <- paste0("spark.testing.", as.character(runif(1)))
+  value1 <- as.character(runif(1))
+  callJMethod(conf, "set", property, value1)
+
+  # next, change the same property to the new value
+  value2 <- as.character(runif(1))
+  l <- list(value2)
+  names(l) <- property
+  sparkR.session(sparkConfig = l)
+
+  conf <- callJMethod(sparkSession, "conf")
+  newValue <- callJMethod(conf, "get", property, "")
+  expect_equal(value2, newValue)
+
+  value <- as.character(runif(1))
+  sparkR.session(spark.app.name = "sparkSession test", 
spark.testing.r.session.r = value)
+  conf <- callJMethod(sparkSession, "conf")
+  appNameValue <- callJMethod(conf, "get", "spark.app.name", "")
+  testValue <- callJMethod(conf, "get", "spark.testing.r.session.r", "")
+  expect_equal(appNameValue, "sparkSession test")
+  expect_equal(testValue, value)
+})
+
+test_that("enableHiveSupport on SparkSession", {
+  setHiveContext(sc)
+  unsetHiveContext()
+  # if we are still here, it must be built with hive
+  conf <- callJMethod(sparkSession, "conf")
+  value <- callJMethod(conf, "get", "spark.sql.catalogImplementation", "")
+  expect_equal(value, "hive")
+})
+
 unlink(parquetPath)
 unlink(jsonPath)
 unlink(jsonPathNa)

http://git-wip-us.apache.org/repos/asf/spark/blob/8b7e5612/R/pkg/inst/tests/testthat/test_take.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/tests/testthat/test_take.R 
b/R/pkg/inst/tests/testthat/test_take.R
index c2c724c..daf5e41 100644
--- a/R/pkg/inst/tests/testthat/test_take.R
+++ b/R/pkg/inst/tests/testthat/test_take.R
@@ -30,10 +30,11 @@ strList <- list("Dexter Morgan: Blood. Sometimes it sets my 
teeth on edge, ",
                 "raising me. But they're both dead now. I didn't kill them. 
Honest.")
 
 # JavaSparkContext handle
-jsc <- sparkR.init()
+sparkSession <- sparkR.session()
+sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", 
"getJavaSparkContext", sparkSession)
 
 test_that("take() gives back the original elements in correct count and 
order", {
-  numVectorRDD <- parallelize(jsc, numVector, 10)
+  numVectorRDD <- parallelize(sc, numVector, 10)
   # case: number of elements to take is less than the size of the first 
partition
   expect_equal(take(numVectorRDD, 1), as.list(head(numVector, n = 1)))
   # case: number of elements to take is the same as the size of the first 
partition
@@ -42,20 +43,20 @@ test_that("take() gives back the original elements in 
correct count and order",
   expect_equal(take(numVectorRDD, length(numVector)), as.list(numVector))
   expect_equal(take(numVectorRDD, length(numVector) + 1), as.list(numVector))
 
-  numListRDD <- parallelize(jsc, numList, 1)
-  numListRDD2 <- parallelize(jsc, numList, 4)
+  numListRDD <- parallelize(sc, numList, 1)
+  numListRDD2 <- parallelize(sc, numList, 4)
   expect_equal(take(numListRDD, 3), take(numListRDD2, 3))
   expect_equal(take(numListRDD, 5), take(numListRDD2, 5))
   expect_equal(take(numListRDD, 1), as.list(head(numList, n = 1)))
   expect_equal(take(numListRDD2, 999), numList)
 
-  strVectorRDD <- parallelize(jsc, strVector, 2)
-  strVectorRDD2 <- parallelize(jsc, strVector, 3)
+  strVectorRDD <- parallelize(sc, strVector, 2)
+  strVectorRDD2 <- parallelize(sc, strVector, 3)
   expect_equal(take(strVectorRDD, 4), as.list(strVector))
   expect_equal(take(strVectorRDD2, 2), as.list(head(strVector, n = 2)))
 
-  strListRDD <- parallelize(jsc, strList, 4)
-  strListRDD2 <- parallelize(jsc, strList, 1)
+  strListRDD <- parallelize(sc, strList, 4)
+  strListRDD2 <- parallelize(sc, strList, 1)
   expect_equal(take(strListRDD, 3), as.list(head(strList, n = 3)))
   expect_equal(take(strListRDD2, 1), as.list(head(strList, n = 1)))
 

http://git-wip-us.apache.org/repos/asf/spark/blob/8b7e5612/R/pkg/inst/tests/testthat/test_textFile.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/tests/testthat/test_textFile.R 
b/R/pkg/inst/tests/testthat/test_textFile.R
index e64ef1b..7b2cc74 100644
--- a/R/pkg/inst/tests/testthat/test_textFile.R
+++ b/R/pkg/inst/tests/testthat/test_textFile.R
@@ -18,7 +18,8 @@
 context("the textFile() function")
 
 # JavaSparkContext handle
-sc <- sparkR.init()
+sparkSession <- sparkR.session()
+sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", 
"getJavaSparkContext", sparkSession)
 
 mockFile <- c("Spark is pretty.", "Spark is awesome.")
 

http://git-wip-us.apache.org/repos/asf/spark/blob/8b7e5612/R/pkg/inst/tests/testthat/test_utils.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/tests/testthat/test_utils.R 
b/R/pkg/inst/tests/testthat/test_utils.R
index 54d2eca..21a119a 100644
--- a/R/pkg/inst/tests/testthat/test_utils.R
+++ b/R/pkg/inst/tests/testthat/test_utils.R
@@ -18,7 +18,8 @@
 context("functions in utils.R")
 
 # JavaSparkContext handle
-sc <- sparkR.init()
+sparkSession <- sparkR.session()
+sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", 
"getJavaSparkContext", sparkSession)
 
 test_that("convertJListToRList() gives back (deserializes) the original JLists
           of strings and integers", {
@@ -168,3 +169,16 @@ test_that("convertToJSaveMode", {
 test_that("hashCode", {
   expect_error(hashCode("bc53d3605e8a5b7de1e8e271c2317645"), NA)
 })
+
+test_that("overrideEnvs", {
+  config <- new.env()
+  config[["spark.master"]] <- "foo"
+  config[["config_only"]] <- "ok"
+  param <- new.env()
+  param[["spark.master"]] <- "local"
+  param[["param_only"]] <- "blah"
+  overrideEnvs(config, param)
+  expect_equal(config[["spark.master"]], "local")
+  expect_equal(config[["param_only"]], "blah")
+  expect_equal(config[["config_only"]], "ok")
+})

http://git-wip-us.apache.org/repos/asf/spark/blob/8b7e5612/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
index fe426fa..0a995d2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
@@ -18,27 +18,61 @@
 package org.apache.spark.sql.api.r
 
 import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, 
DataOutputStream}
+import java.util.{Map => JMap}
 
 import scala.collection.JavaConverters._
 import scala.util.matching.Regex
 
+import org.apache.spark.internal.Logging
+import org.apache.spark.SparkContext
 import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
 import org.apache.spark.api.r.SerDe
 import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.internal.config.CATALOG_IMPLEMENTATION
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, RelationalGroupedDataset, Row, 
SaveMode, SQLContext}
+import org.apache.spark.sql._
 import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
+import org.apache.spark.sql.execution.command.ShowTablesCommand
 import org.apache.spark.sql.types._
 
-private[sql] object SQLUtils {
+private[sql] object SQLUtils extends Logging {
   SerDe.registerSqlSerDe((readSqlObject, writeSqlObject))
 
-  def createSQLContext(jsc: JavaSparkContext): SQLContext = {
-    SQLContext.getOrCreate(jsc.sc)
+  private[this] def withHiveExternalCatalog(sc: SparkContext): SparkContext = {
+    sc.conf.set(CATALOG_IMPLEMENTATION.key, "hive")
+    sc
   }
 
-  def getJavaSparkContext(sqlCtx: SQLContext): JavaSparkContext = {
-    new JavaSparkContext(sqlCtx.sparkContext)
+  def getOrCreateSparkSession(
+      jsc: JavaSparkContext,
+      sparkConfigMap: JMap[Object, Object],
+      enableHiveSupport: Boolean): SparkSession = {
+    val spark = if (SparkSession.hiveClassesArePresent && enableHiveSupport) {
+      
SparkSession.builder().sparkContext(withHiveExternalCatalog(jsc.sc)).getOrCreate()
+    } else {
+      if (enableHiveSupport) {
+        logWarning("SparkR: enableHiveSupport is requested for SparkSession 
but " +
+          "Spark is not built with Hive; falling back to without Hive 
support.")
+      }
+      SparkSession.builder().sparkContext(jsc.sc).getOrCreate()
+    }
+    setSparkContextSessionConf(spark, sparkConfigMap)
+    spark
+  }
+
+  def setSparkContextSessionConf(
+      spark: SparkSession,
+      sparkConfigMap: JMap[Object, Object]): Unit = {
+    for ((name, value) <- sparkConfigMap.asScala) {
+      spark.conf.set(name.toString, value.toString)
+    }
+    for ((name, value) <- sparkConfigMap.asScala) {
+      spark.sparkContext.conf.set(name.toString, value.toString)
+    }
+  }
+
+  def getJavaSparkContext(spark: SparkSession): JavaSparkContext = {
+    new JavaSparkContext(spark.sparkContext)
   }
 
   def createStructType(fields : Seq[StructField]): StructType = {
@@ -95,10 +129,10 @@ private[sql] object SQLUtils {
     StructField(name, dtObj, nullable)
   }
 
-  def createDF(rdd: RDD[Array[Byte]], schema: StructType, sqlContext: 
SQLContext): DataFrame = {
+  def createDF(rdd: RDD[Array[Byte]], schema: StructType, sparkSession: 
SparkSession): DataFrame = {
     val num = schema.fields.length
     val rowRDD = rdd.map(bytesToRow(_, schema))
-    sqlContext.createDataFrame(rowRDD, schema)
+    sparkSession.createDataFrame(rowRDD, schema)
   }
 
   def dfToRowRDD(df: DataFrame): JavaRDD[Array[Byte]] = {
@@ -191,18 +225,18 @@ private[sql] object SQLUtils {
   }
 
   def loadDF(
-      sqlContext: SQLContext,
+      sparkSession: SparkSession,
       source: String,
       options: java.util.Map[String, String]): DataFrame = {
-    sqlContext.read.format(source).options(options).load()
+    sparkSession.read.format(source).options(options).load()
   }
 
   def loadDF(
-      sqlContext: SQLContext,
+      sparkSession: SparkSession,
       source: String,
       schema: StructType,
       options: java.util.Map[String, String]): DataFrame = {
-    sqlContext.read.format(source).schema(schema).options(options).load()
+    sparkSession.read.format(source).schema(schema).options(options).load()
   }
 
   def readSqlObject(dis: DataInputStream, dataType: Char): Object = {
@@ -227,4 +261,22 @@ private[sql] object SQLUtils {
         false
     }
   }
+
+  def getTables(sparkSession: SparkSession, databaseName: String): DataFrame = 
{
+    databaseName match {
+      case n: String if n != null && n.trim.nonEmpty =>
+        Dataset.ofRows(sparkSession, ShowTablesCommand(Some(n), None))
+      case _ =>
+        Dataset.ofRows(sparkSession, ShowTablesCommand(None, None))
+    }
+  }
+
+  def getTableNames(sparkSession: SparkSession, databaseName: String): 
Array[String] = {
+    databaseName match {
+      case n: String if n != null && n.trim.nonEmpty =>
+        sparkSession.catalog.listTables(n).collect().map(_.name)
+      case _ =>
+        sparkSession.catalog.listTables().collect().map(_.name)
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to