add partition options for source df

Project: http://git-wip-us.apache.org/repos/asf/incubator-s2graph/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-s2graph/commit/488a56d7
Tree: http://git-wip-us.apache.org/repos/asf/incubator-s2graph/tree/488a56d7
Diff: http://git-wip-us.apache.org/repos/asf/incubator-s2graph/diff/488a56d7

Branch: refs/heads/master
Commit: 488a56d7685ce916e0e37fb65b77c43d5e4583fc
Parents: d93e5de
Author: Chul Kang <[email protected]>
Authored: Fri Jun 8 00:13:37 2018 +0900
Committer: Chul Kang <[email protected]>
Committed: Fri Jun 8 00:13:37 2018 +0900

----------------------------------------------------------------------
 .../org/apache/s2graph/s2jobs/task/Source.scala    | 17 ++++++++++++++---
 1 file changed, 14 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-s2graph/blob/488a56d7/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/task/Source.scala
----------------------------------------------------------------------
diff --git a/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/task/Source.scala 
b/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/task/Source.scala
index 9c80e58..259cfc0 100644
--- a/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/task/Source.scala
+++ b/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/task/Source.scala
@@ -40,19 +40,30 @@ class KafkaSource(conf:TaskConf) extends Source(conf) {
   val DEFAULT_FORMAT = "raw"
   override def mandatoryOptions: Set[String] = Set("kafka.bootstrap.servers", 
"subscribe")
 
+  def repartition(df: DataFrame, defaultParallelism: Int) = {
+    conf.options.get("numPartitions").map(n => Integer.parseInt(n)) match {
+      case Some(numOfPartitions: Int) =>
+        logger.info(s"[repartitition] $numOfPartitions ($defaultParallelism)")
+        if (numOfPartitions >= defaultParallelism) 
df.repartition(numOfPartitions)
+        else df.coalesce(numOfPartitions)
+      case None => df
+    }
+  }
+
   override def toDF(ss:SparkSession):DataFrame = {
     logger.info(s"${LOG_PREFIX} options: ${conf.options}")
 
     val format = conf.options.getOrElse("format", "raw")
     val df = ss.readStream.format("kafka").options(conf.options).load()
 
+    val partitionedDF = repartition(df, 
df.sparkSession.sparkContext.defaultParallelism)
     format match {
-      case "raw" => df
-      case "json" => parseJsonSchema(ss, df)
+      case "raw" => partitionedDF
+      case "json" => parseJsonSchema(ss, partitionedDF)
 //      case "custom" => parseCustomSchema(df)
       case _ =>
         logger.warn(s"${LOG_PREFIX} unsupported format '$format'.. use default 
schema ")
-        df
+        partitionedDF
     }
   }
 

Reply via email to