Github user WeichenXu123 commented on a diff in the pull request: https://github.com/apache/spark/pull/19433#discussion_r144790384 --- Diff: mllib/src/main/scala/org/apache/spark/ml/tree/impl/AggUpdateUtils.scala --- @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import org.apache.spark.ml.tree.Split + +/** + * Helpers for updating DTStatsAggregators during collection of sufficient stats for tree training. + */ +private[impl] object AggUpdateUtils { + + /** + * Updates the parent node stats of the passed-in impurity aggregator with the labels + * corresponding to the feature values at indices [from, to). + */ + private[impl] def updateParentImpurity( + statsAggregator: DTStatsAggregator, + col: FeatureVector, + from: Int, + to: Int, + instanceWeights: Array[Double], + labels: Array[Double]): Unit = { + from.until(to).foreach { idx => + val rowIndex = col.indices(idx) + val label = labels(rowIndex) + statsAggregator.updateParent(label, instanceWeights(rowIndex)) + } + } + + /** + * Update aggregator for an (unordered feature, label) pair + * @param splits Array of arrays of splits for each feature; splits(i) = splits for feature i. + */ + private[impl] def updateUnorderedFeature( + agg: DTStatsAggregator, + featureValue: Int, + label: Double, + featureIndex: Int, + featureIndexIdx: Int, + splits: Array[Array[Split]], --- End diff -- You only need to pass in the `featureSplit: Array[Split]`, don't pass all splits for all features.
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org