spark: Expose socketReadTimeoutMs to Spark connector This patch exposes socketReadTimeoutMs in the KuduContext and the DefaultSource.
This patch also performs a bit of cleanup by renaming the KuduConnection object to KuduClientCache, which seems like a more appropriate name. Because socketReadTimeout is a KuduClient configuration parameter related to connection handling, socketReadTimeout was incorporated into the client cache key. Manually tested in spark-shell using spark-on-yarn. Added a basic test to ensure that the parameter is properly parsed by the DefaultSource and configured in the KuduRelation instance. Change-Id: I0ab0ff0b242790caffb7e2848958148ffe547c4d Reviewed-on: http://gerrit.cloudera.org:8080/10839 Tested-by: Kudu Jenkins Reviewed-by: Dan Burkert <[email protected]> Project: http://git-wip-us.apache.org/repos/asf/kudu/repo Commit: http://git-wip-us.apache.org/repos/asf/kudu/commit/eee82d90 Tree: http://git-wip-us.apache.org/repos/asf/kudu/tree/eee82d90 Diff: http://git-wip-us.apache.org/repos/asf/kudu/diff/eee82d90 Branch: refs/heads/master Commit: eee82d90a54108f2d7e18e84ec0bbd391fcc129a Parents: eaed285 Author: Mike Percy <[email protected]> Authored: Thu Jun 28 00:35:55 2018 -0700 Committer: Mike Percy <[email protected]> Committed: Mon Jul 16 19:26:47 2018 +0000 ---------------------------------------------------------------------- .../apache/kudu/spark/kudu/DefaultSource.scala | 21 ++++++--- .../apache/kudu/spark/kudu/KuduContext.scala | 47 ++++++++++++-------- .../org/apache/kudu/spark/kudu/KuduRDD.scala | 3 +- .../kudu/spark/kudu/DefaultSourceTest.scala | 15 +++++++ .../kudu/spark/kudu/KuduContextTest.scala | 2 +- 5 files changed, 61 insertions(+), 27 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/kudu/blob/eee82d90/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/DefaultSource.scala ---------------------------------------------------------------------- diff --git a/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/DefaultSource.scala b/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/DefaultSource.scala index dd5b824..090e5fb 100644 --- a/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/DefaultSource.scala +++ b/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/DefaultSource.scala @@ -52,6 +52,7 @@ class DefaultSource extends RelationProvider with CreatableRelationProvider val IGNORE_NULL = "kudu.ignoreNull" val IGNORE_DUPLICATE_ROW_ERRORS = "kudu.ignoreDuplicateRowErrors" val SCAN_REQUEST_TIMEOUT_MS = "kudu.scanRequestTimeoutMs" + val SOCKET_READ_TIMEOUT_MS = "kudu.socketReadTimeoutMs" def defaultMasterAddrs: String = InetAddress.getLocalHost.getCanonicalHostName @@ -59,6 +60,10 @@ class DefaultSource extends RelationProvider with CreatableRelationProvider parameters.get(SCAN_REQUEST_TIMEOUT_MS).map(_.toLong) } + def getSocketReadTimeoutMs(parameters: Map[String, String]): Option[Long] = { + parameters.get(SOCKET_READ_TIMEOUT_MS).map(_.toLong) + } + /** * Construct a BaseRelation using the provided context and parameters. * @@ -82,8 +87,8 @@ class DefaultSource extends RelationProvider with CreatableRelationProvider val writeOptions = new KuduWriteOptions(ignoreDuplicateRowErrors, ignoreNull) new KuduRelation(tableName, kuduMaster, faultTolerantScanner, - scanLocality, getScanRequestTimeoutMs(parameters), operationType, None, - writeOptions)(sqlContext) + scanLocality, getScanRequestTimeoutMs(parameters), getSocketReadTimeoutMs(parameters), + operationType, None, writeOptions)(sqlContext) } /** @@ -119,7 +124,8 @@ class DefaultSource extends RelationProvider with CreatableRelationProvider val scanLocality = getScanLocalityType(parameters.getOrElse(SCAN_LOCALITY, "closest_replica")) new KuduRelation(tableName, kuduMaster, faultTolerantScanner, - scanLocality, getScanRequestTimeoutMs(parameters), operationType, Some(schema))(sqlContext) + scanLocality, getScanRequestTimeoutMs(parameters), getSocketReadTimeoutMs(parameters), + operationType, Some(schema))(sqlContext) } private def getOperationType(opParam: String): OperationType = { @@ -163,6 +169,7 @@ class KuduRelation(private val tableName: String, private val faultTolerantScanner: Boolean, private val scanLocality: ReplicaSelection, private[kudu] val scanRequestTimeoutMs: Option[Long], + private[kudu] val socketReadTimeoutMs: Option[Long], private val operationType: OperationType, private val userSchema: Option[StructType], private val writeOptions: KuduWriteOptions = new KuduWriteOptions)( @@ -171,13 +178,13 @@ class KuduRelation(private val tableName: String, with PrunedFilteredScan with InsertableRelation { - import KuduRelation._ + private val context: KuduContext = new KuduContext(masterAddrs, sqlContext.sparkContext, + socketReadTimeoutMs) - private val context: KuduContext = new KuduContext(masterAddrs, sqlContext.sparkContext) private val table: KuduTable = context.syncClient.openTable(tableName) override def unhandledFilters(filters: Array[Filter]): Array[Filter] = - filters.filterNot(supportsFilter) + filters.filterNot(KuduRelation.supportsFilter) /** * Generates a SparkSQL schema object so SparkSQL knows what is being @@ -200,7 +207,7 @@ class KuduRelation(private val tableName: String, val predicates = filters.flatMap(filterToPredicate) new KuduRDD(context, 1024 * 1024 * 20, requiredColumns, predicates, table, faultTolerantScanner, scanLocality, scanRequestTimeoutMs, - sqlContext.sparkContext) + socketReadTimeoutMs, sqlContext.sparkContext) } /** http://git-wip-us.apache.org/repos/asf/kudu/blob/eee82d90/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/KuduContext.scala ---------------------------------------------------------------------- diff --git a/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/KuduContext.scala b/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/KuduContext.scala index a27e395..5f7d1f3 100644 --- a/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/KuduContext.scala +++ b/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/KuduContext.scala @@ -18,24 +18,22 @@ package org.apache.kudu.spark.kudu import java.security.{AccessController, PrivilegedAction} + import javax.security.auth.Subject import javax.security.auth.login.{AppConfigurationEntry, Configuration, LoginContext} import scala.collection.JavaConverters._ import scala.collection.mutable - import org.apache.hadoop.util.ShutdownHookManager import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.types.{DataType, DataTypes, DecimalType, StructType} import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.util.AccumulatorV2 -import org.apache.yetus.audience.InterfaceStability +import org.apache.yetus.audience.{InterfaceAudience, InterfaceStability} import org.slf4j.{Logger, LoggerFactory} - import org.apache.kudu.client.SessionConfiguration.FlushMode import org.apache.kudu.client._ -import org.apache.kudu.spark.kudu import org.apache.kudu.spark.kudu.SparkUtil._ import org.apache.kudu.{Schema, Type} @@ -48,7 +46,10 @@ import org.apache.kudu.{Schema, Type} */ @InterfaceStability.Unstable class KuduContext(val kuduMaster: String, - sc: SparkContext) extends Serializable { + sc: SparkContext, + val socketReadTimeoutMs: Option[Long]) extends Serializable { + + def this(kuduMaster: String, sc: SparkContext) = this(kuduMaster, sc, None) /** * TimestampAccumulator accumulates the maximum value of client's @@ -88,8 +89,6 @@ class KuduContext(val kuduMaster: String, val timestampAccumulator = new TimestampAccumulator() sc.register(timestampAccumulator) - import kudu.KuduContext._ - @Deprecated() def this(kuduMaster: String) { this(kuduMaster, new SparkContext()) @@ -98,7 +97,7 @@ class KuduContext(val kuduMaster: String, @transient lazy val syncClient: KuduClient = asyncClient.syncClient() @transient lazy val asyncClient: AsyncKuduClient = { - val c = KuduConnection.getAsyncClient(kuduMaster) + val c = KuduClientCache.getAsyncClient(kuduMaster, socketReadTimeoutMs) if (authnCredentials != null) { c.importAuthenticationCredentials(authnCredentials) } @@ -107,7 +106,7 @@ class KuduContext(val kuduMaster: String, // Visible for testing. private[kudu] val authnCredentials : Array[Byte] = { - Subject.doAs(getSubject(sc), new PrivilegedAction[Array[Byte]] { + Subject.doAs(KuduContext.getSubject(sc), new PrivilegedAction[Array[Byte]] { override def run(): Array[Byte] = syncClient.exportAuthenticationCredentials() }) } @@ -128,7 +127,7 @@ class KuduContext(val kuduMaster: String, // TODO: localityScan, etc) to KuduRDD new KuduRDD(this, 1024*1024*20, columnProjection.toArray, Array(), syncClient.openTable(tableName), false, ReplicaSelection.LEADER_ONLY, - None, sc) + None, None, sc) } /** @@ -391,8 +390,8 @@ private object KuduContext { } } -private object KuduConnection { - private[kudu] val asyncCache = new mutable.HashMap[String, AsyncKuduClient]() +private object KuduClientCache { + private case class CacheKey(kuduMaster: String, socketReadTimeoutMs: Option[Long]) /** * Set to @@ -403,17 +402,29 @@ private object KuduConnection { */ private val ShutdownHookPriority = 100 - def getAsyncClient(kuduMaster: String): AsyncKuduClient = { - asyncCache.synchronized { - if (!asyncCache.contains(kuduMaster)) { - val asyncClient = new AsyncKuduClient.AsyncKuduClientBuilder(kuduMaster).build() + private val clientCache = new mutable.HashMap[CacheKey, AsyncKuduClient]() + + // Visible for testing. + private[kudu] def clearCacheForTests() = clientCache.clear() + + def getAsyncClient(kuduMaster: String, socketReadTimeoutMs: Option[Long]): AsyncKuduClient = { + val cacheKey = CacheKey(kuduMaster, socketReadTimeoutMs) + clientCache.synchronized { + if (!clientCache.contains(cacheKey)) { + val builder = new AsyncKuduClient.AsyncKuduClientBuilder(kuduMaster) + socketReadTimeoutMs match { + case Some(timeout) => builder.defaultSocketReadTimeoutMs(timeout) + case None => + } + + val asyncClient = builder.build() ShutdownHookManager.get().addShutdownHook( new Runnable { override def run(): Unit = asyncClient.close() }, ShutdownHookPriority) - asyncCache.put(kuduMaster, asyncClient) + clientCache.put(cacheKey, asyncClient) } - return asyncCache(kuduMaster) + return clientCache(cacheKey) } } } http://git-wip-us.apache.org/repos/asf/kudu/blob/eee82d90/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/KuduRDD.scala ---------------------------------------------------------------------- diff --git a/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/KuduRDD.scala b/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/KuduRDD.scala index 7117983..4817da6 100644 --- a/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/KuduRDD.scala +++ b/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/KuduRDD.scala @@ -28,7 +28,7 @@ import org.apache.kudu.{Type, client} /** * A Resilient Distributed Dataset backed by a Kudu table. * - * To construct a KuduRDD, use {@link KuduContext#kuduRdd} or a Kudu DataSource. + * To construct a KuduRDD, use [[KuduContext#kuduRDD]] or a Kudu DataSource. */ class KuduRDD private[kudu] (val kuduContext: KuduContext, @transient val batchSize: Integer, @@ -38,6 +38,7 @@ class KuduRDD private[kudu] (val kuduContext: KuduContext, @transient val isFaultTolerant: Boolean, @transient val scanLocality: ReplicaSelection, @transient val scanRequestTimeoutMs: Option[Long], + @transient val socketReadTimeoutMs: Option[Long], @transient val sc: SparkContext) extends RDD[Row](sc, Nil) { override protected def getPartitions: Array[Partition] = { http://git-wip-us.apache.org/repos/asf/kudu/blob/eee82d90/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/DefaultSourceTest.scala ---------------------------------------------------------------------- diff --git a/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/DefaultSourceTest.scala b/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/DefaultSourceTest.scala index 701e398..f01a211 100644 --- a/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/DefaultSourceTest.scala +++ b/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/DefaultSourceTest.scala @@ -734,4 +734,19 @@ class DefaultSourceTest extends FunSuite with TestContext with BeforeAndAfterEac val kuduRelation = kuduRelationFromDataFrame(dataFrame) assert(kuduRelation.scanRequestTimeoutMs == Some(1)) } + + /** + * Verify that the kudu.socketReadTimeoutMs parameter is parsed by the + * DefaultSource and makes it into the KuduRelation as a configuration + * parameter. + */ + test("socket read timeout propagation") { + kuduOptions = Map( + "kudu.table" -> tableName, + "kudu.master" -> miniCluster.getMasterAddresses, + "kudu.socketReadTimeoutMs" -> "1") + val dataFrame = sqlContext.read.options(kuduOptions).kudu + val kuduRelation = kuduRelationFromDataFrame(dataFrame) + assert(kuduRelation.socketReadTimeoutMs == Some(1)) + } } http://git-wip-us.apache.org/repos/asf/kudu/blob/eee82d90/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/KuduContextTest.scala ---------------------------------------------------------------------- diff --git a/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/KuduContextTest.scala b/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/KuduContextTest.scala index 4915002..47d4519 100644 --- a/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/KuduContextTest.scala +++ b/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/KuduContextTest.scala @@ -52,7 +52,7 @@ class KuduContextTest extends FunSuite with TestContext with Matchers { test("Test KuduContext serialization") { val serialized = serialize(kuduContext) - KuduConnection.asyncCache.clear() + KuduClientCache.clearCacheForTests() val deserialized = deserialize(serialized).asInstanceOf[KuduContext] assert(deserialized.authnCredentials != null) // Make a nonsense call just to make sure the re-hydrated client works.
