EBernhardson has uploaded a new change for review. ( https://gerrit.wikimedia.org/r/394741 )
Change subject: [WIP] Bad ideas for improved DBN performance ...................................................................... [WIP] Bad ideas for improved DBN performance I'm not sure this is a particularly great idea, but I wanted to explore the performance limits of the JVM based DBN implementation. This brings the original benchmark (90s in java, 3-4s in prior patch) to ~900ms. To get a better idea on performance i increased the size of the benchmark: * python: 616s - only ran once * orig jvm: min: 21.7, max: 24.1 mean: 23.5s - 5 runs - 25x- 28x faster than python * optimized jvm: min: 5.0s max: 5.3s mean: 5.2s - 5 runs - 116x - 123x faster than python - 4x - 5x faster than orig jvm The improvements made were guided by profiling in visualvm and arn't all that numerous: * We were thrashing memory pretty hard at >1GB/sec. To reduce this add caches of our intermediate arrays. We are still thrashing memory pretty hard but not as bad. * The caches of the intermediate arrays in scala Maps brought those maps up high in the profiler. Replace with arrays of queues. The backing linked list still shows up in profiling, but not as bad. * DefaultMap.apply gets hit *alot* and was showing up in profiling. Replacing inner scala maps with java maps helped some. Further replacing java maps with trove4j primitive maps helped significantly. * Find places where we were repeatedly hitting an array for the same item (for example getting something by s.queryId in a loop on the urls) and fetch it into a local var. Not sure this made much difference visualvm now reports 80% of cpu time is spent in our own functions, whereas before it was significantly lower. Mostly I just kept looking for places where the supporting machinery was taking up cpu instead of our calculations and kept replacing them until it was better. Change-Id: I08b72b98f515a820675e1ef9b45dd8724cbd070e --- M jvm/pom.xml M jvm/src/main/scala/org/wikimedia/search/mjolnir/DBN.scala M jvm/src/test/scala/org/wikimedia/search/mjolnir/DBNSuite.scala 3 files changed, 246 insertions(+), 58 deletions(-) git pull ssh://gerrit.wikimedia.org:29418/search/MjoLniR refs/changes/41/394741/1 diff --git a/jvm/pom.xml b/jvm/pom.xml index b2a7f71..f405975 100644 --- a/jvm/pom.xml +++ b/jvm/pom.xml @@ -141,6 +141,11 @@ <version>3.0.1</version> <scope>test</scope> </dependency> + <dependency> + <groupId>net.sf.trove4j</groupId> + <artifactId>trove4j</artifactId> + <version>3.0.3</version> + </dependency> </dependencies> <repositories> <repository> diff --git a/jvm/src/main/scala/org/wikimedia/search/mjolnir/DBN.scala b/jvm/src/main/scala/org/wikimedia/search/mjolnir/DBN.scala index 12ef975..cda6778 100644 --- a/jvm/src/main/scala/org/wikimedia/search/mjolnir/DBN.scala +++ b/jvm/src/main/scala/org/wikimedia/search/mjolnir/DBN.scala @@ -9,6 +9,9 @@ * A Dynamic Bayesian Network Click Model for Web Search Ranking - Olivier Chapelle and * Ya Zang - http://olivier.chapelle.cc/pub/DBN_www2009.pdf */ +import gnu.trove.iterator.TIntObjectIterator +import gnu.trove.map.hash.TIntObjectHashMap + import scala.collection.mutable import scala.util.parsing.json.JSON @@ -19,15 +22,22 @@ // This bit maps input queryies/results to array indexes to be used while calculating private var currentUrlId: Int = 0 // TODO: Why is first returned value 1 instead of 0? + private val urlToIdMap: mutable.Map[String, Int] = mutable.Map() + def urlToId(key: String): Int = { + urlToIdMap.getOrElseUpdate(key, { + currentUrlId += 1 + currentUrlId + }) + } + private var currentQueryId: Int = -1 - private val urlToId: DefaultMap[String, Int] = new DefaultMap({ _ => - currentUrlId += 1 - currentUrlId - }) - private val queryToId: DefaultMap[(String, String), Int] = new DefaultMap({ _ => - currentQueryId += 1 - currentQueryId - }) + private val queryToIdMap: mutable.Map[(String, String), Int] = mutable.Map() + def queryToId(key: (String, String)): Int = { + queryToIdMap.getOrElseUpdate(key, { + currentQueryId += 1 + currentQueryId + }) + } def maxQueryId: Int = currentQueryId + 2 @@ -91,8 +101,8 @@ } def toRelevances(urlRelevances: Array[Map[Int, UrlRel]]): Seq[RelevanceResult] = { - val idToUrl = urlToId.asMap.map(_.swap) - val idToQuery = queryToId.asMap.map(_.swap) + val idToUrl = urlToIdMap.map(_.swap) + val idToQuery = queryToIdMap.map(_.swap) urlRelevances.zipWithIndex.flatMap { case (d, queryId) => val (query, region) = idToQuery(queryId) @@ -101,6 +111,127 @@ RelevanceResult(query, region, url, urlRel.a * urlRel.s) } } + } +} + +class ArrayCache { + val QUEUE_1D_MAX = 20 + private val queueMap1d: Array[mutable.Queue[Array[Double]]] = Array.fill(QUEUE_1D_MAX + 1){ mutable.Queue() } + + def get1d(n: Int): Array[Double] = { + if (n > QUEUE_1D_MAX) { + new Array[Double](n) + } else { + val queue = queueMap1d(n) + if (queue.isEmpty) { + new Array[Double](n) + } else { + queue.dequeue() + } + } + } + + def get1d(n: Int, default: Double): Array[Double] = { + if (n > QUEUE_1D_MAX) { + Array.fill(n)(default) + } else { + val queue = queueMap1d(n) + if (queue.isEmpty) { + Array.fill(n)(default) + } else { + val arr = queue.dequeue() + var i = 0 + while (i < n) { + arr(i) = default + i += 1 + } + arr + } + } + } + + def put1d(arr: Array[Double]): Unit = { + if (arr.length <= QUEUE_1D_MAX) { + queueMap1d(arr.length) += arr + } + } + + private val ALPHA_BETA_MAX = 21 + private val alphaBetaMap: Array[mutable.Queue[Array[Array[Double]]]] = Array.fill(ALPHA_BETA_MAX + 1){ + mutable.Queue() + } + + def getAlphaBeta(n: Int): Array[Array[Double]] = { + if (n > ALPHA_BETA_MAX) { + Array.ofDim(n, 2) + } else { + val queue = alphaBetaMap(n) + if (queue.isEmpty) { + println(s"Allocate alpha/beta (n=$n)") + Array.ofDim(n, 2) + } else { + queue.dequeue() + } + } + } + + def putAlphaBeta(arr: Array[Array[Double]]): Unit = { + if (arr.length <= ALPHA_BETA_MAX) { + alphaBetaMap(arr.length) += arr + } + } + + private val POSITION_REL_MAX = 20 + private val positionRelMap: Array[mutable.Queue[PositionRel]] = Array.fill(POSITION_REL_MAX + 1){ mutable.Queue() } + + def getPositionRel(n: Int): PositionRel = { + if (n > POSITION_REL_MAX) { + new PositionRel(new Array[Double](n), new Array[Double](n)) + } else { + val queue = positionRelMap(n) + if (queue.isEmpty) { + println(s"Allocate position rel (n=$n)") + new PositionRel(new Array[Double](n), new Array[Double](n)) + } else { + queue.dequeue() + } + } + } + + def putPositionRel(rel: PositionRel): Unit = { + if (rel.a.length <= POSITION_REL_MAX) { + positionRelMap(rel.a.length) += rel + } + } + + private val UPDATE_MATRIX_MAX = 20 + private val updateMatrixMap: Array[mutable.Queue[Array[Array[Array[Double]]]]] = Array.fill(UPDATE_MATRIX_MAX + 1){ mutable.Queue() } + + def getUpateMatrix(n: Int): Array[Array[Array[Double]]] = { + if (n > UPDATE_MATRIX_MAX) { + Array.ofDim(n, 2, 2) + } else { + val queue = updateMatrixMap(n) + if (queue.isEmpty) { + println(s"Allocate update matrix (n=$n)") + Array.ofDim(n, 2, 2) + } else { + queue.dequeue() + } + } + } + + def putUpdateMatrix(arr: Array[Array[Array[Double]]]): Unit = { + if (arr.length <= UPDATE_MATRIX_MAX) { + updateMatrixMap(arr.length) += arr + } + } + + def clear(): Unit = { + queueMap1d.foreach(_.clear) + alphaBetaMap.foreach(_.clear) + positionRelMap.foreach(_.clear) + updateMatrixMap.foreach(_.clear) } } @@ -155,26 +286,36 @@ // it so requesting an item not in the map gets set to a default // value and then returned. This differs from withDefault which // expects to return an immutable value so doesn't set it into the map. -class DefaultMap[K, V](default: K => V) extends Iterable[(K, V)] { - private val map = mutable.Map[K,V]() +class DefaultMap[V](default: => V) { + private val map = new TIntObjectHashMap[V]() - def apply(key: K): V = { - map.get(key) match { - case Some(value) => value - case None => - val value = default(key) - map.update(key, value) - value + def apply(key: Int): V = { + if (map.containsKey(key)) { + map.get(key) + } else { + val value: V = default + map.put(key, value) + value } } - override def iterator: Iterator[(K, V)] = map.iterator - // converts to immutable scala Map - def asMap: Map[K, V] = map.toMap + def asMap: Map[Int, V] = { + val x = mutable.Map[Int, V]() + val iter = map.iterator() + while (iter.hasNext) { + iter.advance() + x.put(iter.key(), iter.value()) + } + x.toMap + } + + def iterator(): TIntObjectIterator[V] = map.iterator() } object DbnModel { + val doubleArrCache = new ArrayCache() + /** * The forward-backward algorithm is used to to compute the posterior probabilities of the hidden variables. * @@ -241,26 +382,30 @@ def getForwardBackwardEstimates(rel: PositionRel, gamma: Double, clicks: Array[Boolean]): (Array[Array[Double]], Array[Array[Double]]) = { val N = clicks.length // alpha(k)(e) = P(C_1,...C_k-1,E_i=e|a_u, s_u, G) calculated forwards for C_1, then C_1,C_2, ... - val alpha = Array.ofDim[Double](N + 1, 2) + val alpha = doubleArrCache.getAlphaBeta(N + 1) // beta(k)(e) = P(C_k,...C_N|E_i=e, a_u, s_u, G) calculated backwards for C_10, then C_9, C_10, ... - val beta = Array.ofDim[Double](N + 1, 2) + val beta = doubleArrCache.getAlphaBeta(N + 1) + alpha(0)(0) = 0D alpha(0)(1) = 1D beta(N)(0) = 1D beta(N)(1) = 1D // Forwards (alpha) and backwards (beta) need the same probabilities as inputs so pre-calculate them. var k = 0 - val updateMatrix: Array[Array[Array[Double]]] = Array.ofDim[Double](clicks.length, 2, 2) + val updateMatrix = doubleArrCache.getUpateMatrix(clicks.length) while (k < N) { val a_u = rel.a(k) val s_u = rel.s(k) if (clicks(k)) { + updateMatrix(k)(0)(0) = 0D updateMatrix(k)(0)(1) = (s_u + (1 - gamma) * (1 - s_u)) * a_u + updateMatrix(k)(1)(0) = 0D updateMatrix(k)(1)(1) = gamma * (1 - s_u) * a_u } else { updateMatrix(k)(0)(0) = 1D updateMatrix(k)(0)(1) = (1D - gamma) * (1D - a_u) + updateMatrix(k)(1)(0) = 0D updateMatrix(k)(1)(1) = gamma * (1D - a_u) } k += 1 @@ -286,6 +431,8 @@ k += 1 } + doubleArrCache.putUpdateMatrix(updateMatrix) + (alpha, beta) } @@ -301,7 +448,7 @@ // varphi is the smoothing of the forwards and backwards. I think, based on wiki page on forward/backwards // algorithm, that varphi is then P(E_k|C_1,...,C_N,a_u,s_u,G) but not 100% sure... var k = 0 - val varphi: Array[Double] = new Array(alpha.length) + val varphi: Array[Double] = doubleArrCache.get1d(alpha.length) while (k < alpha.length) { val a = alpha(k) val b = beta(k) @@ -311,7 +458,10 @@ k += 1 } - val sessionEstimate = new PositionRel(new Array[Double](N), new Array[Double](N)) + doubleArrCache.putAlphaBeta(alpha) + doubleArrCache.putAlphaBeta(beta) + + val sessionEstimate = doubleArrCache.getPositionRel(N) k = 0 while (k < N) { val a_u = rel.a(k) @@ -331,6 +481,9 @@ } k += 1 } + + doubleArrCache.put1d(varphi) + sessionEstimate } } @@ -342,59 +495,89 @@ // dimension and urlId in the second dimension. Because queries only reference // a subset of the known urls we use a map at the second level instead of // creating the entire matrix. - val urlRelevances: Array[DefaultMap[Int, UrlRel]] = Array.fill(config.maxQueryId) { - new DefaultMap[Int, UrlRel]({ - _ => new UrlRel(config.defaultRel, config.defaultRel) + val urlRelevances: Array[DefaultMap[UrlRel]] = Array.fill(config.maxQueryId) { + new DefaultMap[UrlRel]({ + new UrlRel(config.defaultRel, config.defaultRel) }) } for (_ <- 0 until config.maxIterations) { - for ((d, queryId) <- eStep(urlRelevances, sessions).view.zipWithIndex) { + val urlRelFractions = eStep(urlRelevances, sessions) + var queryId = 0 + while (queryId < urlRelFractions.length) { + val d = urlRelFractions(queryId) // M step - for ((urlId, relFractions) <- d) { - val rel = urlRelevances(queryId)(urlId) + val queryUrlRelevances = urlRelevances(queryId) + val iter = d.iterator() + while (iter.hasNext) { + iter.advance() + val urlId = iter.key() + val relFractions = iter.value() + val rel = queryUrlRelevances(urlId) // Convert our sums of per-session a_u and s_u into probabilities (domain of [0,1]) // attracted / (attracted + not-attracted) rel.a = relFractions.a(1) / (relFractions.a(1) + relFractions.a(0)) // satisfied / (satisfied + not-satisfied) rel.s = relFractions.s(1) / (relFractions.s(1) + relFractions.s(0)) + + DbnModel.doubleArrCache.put1d(relFractions.a) + DbnModel.doubleArrCache.put1d(relFractions.s) } + queryId += 1 } } + DbnModel.doubleArrCache.clear() urlRelevances.map(_.asMap) } // E step - private def eStep(urlRelevances: Array[DefaultMap[Int, UrlRel]], sessions: Seq[SessionItem]) - : Array[DefaultMap[Int, UrlRelFrac]] = { + private def eStep(urlRelevances: Array[DefaultMap[UrlRel]], sessions: Seq[SessionItem]) + : Array[DefaultMap[UrlRelFrac]] = { // urlRelFraction(queryId)(urlId) - val urlRelFractions: Array[DefaultMap[Int, UrlRelFrac]] = Array.fill(config.maxQueryId) { - new DefaultMap[Int, UrlRelFrac]({ - _ => new UrlRelFrac(Array.fill(2)(1D), Array.fill(2)(1D)) + val urlRelFractions: Array[DefaultMap[UrlRelFrac]] = Array.fill(config.maxQueryId) { + new DefaultMap[UrlRelFrac]({ + new UrlRelFrac( + DbnModel.doubleArrCache.get1d(2, 1D), + DbnModel.doubleArrCache.get1d(2, 1D) + ) }) } - for (s <- sessions) { - val positionRelevances = new PositionRel( - s.urlIds.map(urlRelevances(s.queryId)(_).a), - s.urlIds.map(urlRelevances(s.queryId)(_).s) - ) + var sidx = 0 + while (sidx < sessions.length) { + val s = sessions(sidx) + val positionRelevances = DbnModel.doubleArrCache.getPositionRel(s.urlIds.length) + var i = 0 + val urlRelQuery = urlRelevances(s.queryId) + while (i < s.urlIds.length) { + val urlRel = urlRelQuery(s.urlIds(i)) + positionRelevances.a(i) = urlRel.a + positionRelevances.s(i) = urlRel.s + i += 1 + } val sessionEstimate = DbnModel.getSessionEstimate(positionRelevances, gamma, s.clicks) - for ((urlId, k) <- s.urlIds.view.zipWithIndex) { + DbnModel.doubleArrCache.putPositionRel(positionRelevances) + val queryUrlRelFrac = urlRelFractions(s.queryId) + var k = 0 + while (k < s.urlIds.length) { + var urlId = s.urlIds(k) // update attraction - val rel = urlRelFractions(s.queryId)(urlId) + val rel = queryUrlRelFrac(urlId) val estA = sessionEstimate.a(k) - rel.a(0) += (1 - estA) + rel.a(0) += 1 - estA rel.a(1) += estA if (s.clicks(k)) { // update satisfaction val estS = sessionEstimate.s(k) - rel.s(0) += (1 - estS) + rel.s(0) += 1 - estS rel.s(1) += estS } + k += 1 } + DbnModel.doubleArrCache.putPositionRel(sessionEstimate) + sidx += 1 } urlRelFractions } diff --git a/jvm/src/test/scala/org/wikimedia/search/mjolnir/DBNSuite.scala b/jvm/src/test/scala/org/wikimedia/search/mjolnir/DBNSuite.scala index 9d883a2..efaf8d8 100644 --- a/jvm/src/test/scala/org/wikimedia/search/mjolnir/DBNSuite.scala +++ b/jvm/src/test/scala/org/wikimedia/search/mjolnir/DBNSuite.scala @@ -129,17 +129,17 @@ println(s"Took ${took/1000000}ms") // Create a datafile that python clickmodels can read in to have fair comparison - //import java.io.File - //import java.io.PrintWriter + import java.io.File + import java.io.PrintWriter - //val writer = new PrintWriter(new File("/tmp/dbn.clickmodels")) - //for ( s <- sessions) { - // // poor mans json serialization - // val layout = Array.fill(s.urlIds.length)("false").mkString("[", ",", "]") - // val clicks = s.clicks.map(_.toString).mkString("[", ",", "]") - // val urls = s.urlIds.map(_.toString).mkString("[\"", "\",\"", "\"]") - // writer.write(s"0\t${s.queryId}\tregion\t0\t$urls\t$layout\t$clicks\n") - //} - //writer.close() + val writer = new PrintWriter(new File("/tmp/dbn.clickmodels")) + for ( s <- sessions) { + // poor mans json serialization + val layout = Array.fill(s.urlIds.length)("false").mkString("[", ",", "]") + val clicks = s.clicks.map(_.toString).mkString("[", ",", "]") + val urls = s.urlIds.map(_.toString).mkString("[\"", "\",\"", "\"]") + writer.write(s"0\t${s.queryId}\tregion\t0\t$urls\t$layout\t$clicks\n") + } + writer.close() } } -- To view, visit https://gerrit.wikimedia.org/r/394741 To unsubscribe, visit https://gerrit.wikimedia.org/r/settings Gerrit-MessageType: newchange Gerrit-Change-Id: I08b72b98f515a820675e1ef9b45dd8724cbd070e Gerrit-PatchSet: 1 Gerrit-Project: search/MjoLniR Gerrit-Branch: master Gerrit-Owner: EBernhardson <ebernhard...@wikimedia.org> _______________________________________________ MediaWiki-commits mailing list MediaWiki-commits@lists.wikimedia.org https://lists.wikimedia.org/mailman/listinfo/mediawiki-commits