Repository: incubator-predictionio Updated Branches: refs/heads/release/0.11.0 070f1794d -> 9deca1a47
Properly combine spark-submit arguments Project: http://git-wip-us.apache.org/repos/asf/incubator-predictionio/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-predictionio/commit/9deca1a4 Tree: http://git-wip-us.apache.org/repos/asf/incubator-predictionio/tree/9deca1a4 Diff: http://git-wip-us.apache.org/repos/asf/incubator-predictionio/diff/9deca1a4 Branch: refs/heads/release/0.11.0 Commit: 9deca1a47fccb6c58fba366d6dd1c8bf751b1cf9 Parents: 070f179 Author: Donald Szeto <[email protected]> Authored: Sat Apr 8 22:48:00 2017 -0700 Committer: Donald Szeto <[email protected]> Committed: Sat Apr 8 22:48:00 2017 -0700 ---------------------------------------------------------------------- .../org/apache/predictionio/tools/Runner.scala | 101 ++++++++++++++++++- .../apache/predictionio/tools/RunnerSpec.scala | 73 ++++++++++++++ 2 files changed, 170 insertions(+), 4 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/9deca1a4/tools/src/main/scala/org/apache/predictionio/tools/Runner.scala ---------------------------------------------------------------------- diff --git a/tools/src/main/scala/org/apache/predictionio/tools/Runner.scala b/tools/src/main/scala/org/apache/predictionio/tools/Runner.scala index 70e3837..4a721be 100644 --- a/tools/src/main/scala/org/apache/predictionio/tools/Runner.scala +++ b/tools/src/main/scala/org/apache/predictionio/tools/Runner.scala @@ -25,6 +25,7 @@ import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.predictionio.tools.ReturnTypes._ import org.apache.predictionio.workflow.WorkflowUtils +import scala.collection.mutable import scala.sys.process._ case class SparkArgs( @@ -95,6 +96,92 @@ object Runner extends EitherLogging { } } + /** Group argument values by argument names + * + * This only works with long argument names immediately followed by a value + * + * Input: + * Seq("--foo", "bar", "--flag", "--dead", "beef baz", "n00b", "--foo", "jeez") + * + * Output: + * Map("--foo" -> Seq("bar", "jeez"), "--dead"- > "beef baz") + * + * @param arguments Sequence of argument names and values + * @return A map with argument values keyed by the same argument name + */ + def groupByArgumentName(arguments: Seq[String]): Map[String, Seq[String]] = { + val argumentMap = mutable.HashMap.empty[String, Seq[String]] + arguments.foldLeft("") { (prev, current) => + if (prev.startsWith("--") && !current.startsWith("--")) { + if (argumentMap.contains(prev)) { + argumentMap(prev) = argumentMap(prev) :+ current + } else { + argumentMap(prev) = Seq(current) + } + } + current + } + argumentMap.toMap + } + + /** Remove argument names and values + * + * This only works with long argument names immediately followed by a value + * + * Input: + * Seq("--foo", "bar", "--flag", "--dead", "beef baz", "n00b", "--foo", "jeez") + * Set("--flag", "--foo") + * + * Output: + * Seq("--flag", "--dead", "beef baz", "n00b") + * + * @param arguments Sequence of argument names and values + * @param remove Name of argument and associated values to remove + * @return Sequence of argument names and values with targets removed + */ + def removeArguments(arguments: Seq[String], remove: Set[String]): Seq[String] = { + if (remove.isEmpty) { + arguments + } else { + arguments.foldLeft(Seq.empty[String]) { (ongoing, current) => + if (ongoing.isEmpty) { + Seq(current) + } else { + if (remove.contains(ongoing.last) && !current.startsWith("--")) { + ongoing.take(ongoing.length - 1) + } else { + ongoing :+ current + } + } + } + } + } + + /** Combine repeated arguments together + * + * Input: + * Seq("--foo", "bar", "--flag", "--dead", "beef baz", "n00b", "--foo", "jeez") + * Map("--foo", (_ + _)) + * + * Output: + * Seq("--flag", "--dead", "beef baz", "n00b", "--foo", "bar jeez") + * + * @param arguments Sequence of argument names and values + * @param combinators Map of argument name to combinator function + * @return Sequence of argument names and values with specific argument values combined + */ + def combineArguments( + arguments: Seq[String], + combinators: Map[String, (String, String) => String]): Seq[String] = { + val argumentsToCombine: Map[String, Seq[String]] = + groupByArgumentName(arguments).filterKeys(combinators.keySet.contains(_)) + val argumentsMinusToCombine = removeArguments(arguments, combinators.keySet) + val combinedArguments = argumentsToCombine flatMap { kv => + Seq(kv._1, kv._2.reduce(combinators(kv._1))) + } + argumentsMinusToCombine ++ combinedArguments + } + def runOnSpark( className: String, classArgs: Seq[String], @@ -189,17 +276,23 @@ object Runner extends EitherLogging { } val verboseArg = if (verbose) Seq("--verbose") else Nil - val pioLogDir = Option(System.getProperty("pio.log.dir")).getOrElse(s"${pioHome}/log") + val pioLogDir = Option(System.getProperty("pio.log.dir")).getOrElse(s"$pioHome/log") - val sparkSubmit = Seq( - sparkSubmitCommand, + val sparkSubmitArgs = Seq( sa.sparkPassThrough, Seq("--class", className), sparkSubmitJars, sparkSubmitFiles, sparkSubmitExtraClasspaths, sparkSubmitKryo, - Seq("--driver-java-options", s"-Dpio.log.dir=${pioLogDir}"), + Seq("--driver-java-options", s"-Dpio.log.dir=$pioLogDir")).flatten + + val whitespaceCombinator = (a: String, b: String) => s"$a $b" + val combinators = Map("--driver-java-options" -> whitespaceCombinator) + + val sparkSubmit = Seq( + sparkSubmitCommand, + combineArguments(sparkSubmitArgs, combinators), Seq(mainJar), detectFilePaths(fs, sa.scratchUri, classArgs), Seq("--env", pioEnvVars), http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/9deca1a4/tools/src/test/scala/org/apache/predictionio/tools/RunnerSpec.scala ---------------------------------------------------------------------- diff --git a/tools/src/test/scala/org/apache/predictionio/tools/RunnerSpec.scala b/tools/src/test/scala/org/apache/predictionio/tools/RunnerSpec.scala new file mode 100644 index 0000000..92317f8 --- /dev/null +++ b/tools/src/test/scala/org/apache/predictionio/tools/RunnerSpec.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.predictionio.tools + +import org.specs2.mutable.Specification + +class RunnerSpec extends Specification { + "groupByArgumentName" >> { + "test1" >> { + val test = Seq("--foo", "bar", "--flag", "--dead", "beef baz", "n00b", "--foo", "jeez") + Runner.groupByArgumentName(test) must havePairs( + "--foo" -> Seq("bar", "jeez"), + "--dead" -> Seq("beef baz")) + } + + "test2" >> { + val test = + Seq("--foo", "--bar", "flag", "--dead", "beef baz", "n00b", "--foo", "jeez", "--flag") + Runner.groupByArgumentName(test) must havePairs( + "--foo" -> Seq("jeez"), + "--bar" -> Seq("flag"), + "--dead" -> Seq("beef baz")) + } + } + + "removeArguments" >> { + "test1" >> { + val test = Seq("--foo", "bar", "--flag", "--dead", "beef baz", "n00b", "--foo", "jeez") + val remove = Set("--flag", "--foo") + Runner.removeArguments(test, remove) === Seq("--flag", "--dead", "beef baz", "n00b") + } + + "test2" >> { + val test = + Seq("--foo", "--bar", "flag", "--dead", "beef baz", "n00b", "--foo", "jeez", "--flag") + val remove = Set("--flag", "--foo") + Runner.removeArguments(test, remove) === + Seq("--foo", "--bar", "flag", "--dead", "beef baz", "n00b", "--flag") + } + } + + "combineArguments" >> { + "test1" >> { + val test = Seq("--foo", "bar", "--flag", "--dead", "beef baz", "n00b", "--foo", "jeez") + val combinators = Map("--foo" -> ((a: String, b: String) => s"$a $b")) + Runner.combineArguments(test, combinators) === + Seq("--flag", "--dead", "beef baz", "n00b", "--foo", "bar jeez") + } + + "test2" >> { + val test = + Seq("--foo", "--bar", "flag", "--dead", "beef baz", "n00b", "--foo", "jeez", "--flag") + val combinators = Map("--foo" -> ((a: String, b: String) => s"$a $b")) + Runner.combineArguments(test, combinators) === + Seq("--foo", "--bar", "flag", "--dead", "beef baz", "n00b", "--flag", "--foo", "jeez") + } + } +}
