http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/ColumnDecisionTreeTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/ColumnDecisionTreeTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/ColumnDecisionTreeTrainer.java deleted file mode 100644 index fec0a83..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/ColumnDecisionTreeTrainer.java +++ /dev/null @@ -1,568 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.trees.trainers.columnbased; - -import com.zaxxer.sparsebits.SparseBitSet; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.Comparator; -import java.util.HashMap; -import java.util.LinkedList; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.Set; -import java.util.UUID; -import java.util.concurrent.ConcurrentHashMap; -import java.util.function.Consumer; -import java.util.stream.Collectors; -import java.util.stream.DoubleStream; -import java.util.stream.IntStream; -import java.util.stream.Stream; -import javax.cache.Cache; -import org.apache.ignite.Ignite; -import org.apache.ignite.IgniteCache; -import org.apache.ignite.IgniteLogger; -import org.apache.ignite.Ignition; -import org.apache.ignite.cache.CachePeekMode; -import org.apache.ignite.cache.affinity.Affinity; -import org.apache.ignite.cluster.ClusterNode; -import org.apache.ignite.internal.processors.cache.CacheEntryImpl; -import org.apache.ignite.lang.IgniteBiTuple; -import org.apache.ignite.ml.Trainer; -import org.apache.ignite.ml.math.Vector; -import org.apache.ignite.ml.math.distributed.CacheUtils; -import org.apache.ignite.ml.math.functions.Functions; -import org.apache.ignite.ml.math.functions.IgniteBiFunction; -import org.apache.ignite.ml.math.functions.IgniteCurriedBiFunction; -import org.apache.ignite.ml.math.functions.IgniteFunction; -import org.apache.ignite.ml.math.functions.IgniteSupplier; -import org.apache.ignite.ml.trees.ContinuousRegionInfo; -import org.apache.ignite.ml.trees.ContinuousSplitCalculator; -import org.apache.ignite.ml.trees.models.DecisionTreeModel; -import org.apache.ignite.ml.trees.nodes.DecisionTreeNode; -import org.apache.ignite.ml.trees.nodes.Leaf; -import org.apache.ignite.ml.trees.nodes.SplitNode; -import org.apache.ignite.ml.trees.trainers.columnbased.caches.ContextCache; -import org.apache.ignite.ml.trees.trainers.columnbased.caches.FeaturesCache; -import org.apache.ignite.ml.trees.trainers.columnbased.caches.FeaturesCache.FeatureKey; -import org.apache.ignite.ml.trees.trainers.columnbased.caches.ProjectionsCache; -import org.apache.ignite.ml.trees.trainers.columnbased.caches.ProjectionsCache.RegionKey; -import org.apache.ignite.ml.trees.trainers.columnbased.caches.SplitCache; -import org.apache.ignite.ml.trees.trainers.columnbased.caches.SplitCache.SplitKey; -import org.apache.ignite.ml.trees.trainers.columnbased.vectors.FeatureProcessor; -import org.apache.ignite.ml.trees.trainers.columnbased.vectors.SplitInfo; -import org.jetbrains.annotations.NotNull; - -import static org.apache.ignite.ml.trees.trainers.columnbased.caches.FeaturesCache.getFeatureCacheKey; - -/** - * This trainer stores observations as columns and features as rows. - * Ideas from https://github.com/fabuzaid21/yggdrasil are used here. - */ -public class ColumnDecisionTreeTrainer<D extends ContinuousRegionInfo> implements - Trainer<DecisionTreeModel, ColumnDecisionTreeTrainerInput> { - /** - * Function used to assign a value to a region. - */ - private final IgniteFunction<DoubleStream, Double> regCalc; - - /** - * Function used to calculate impurity in regions used by categorical features. - */ - private final IgniteFunction<ColumnDecisionTreeTrainerInput, ? extends ContinuousSplitCalculator<D>> continuousCalculatorProvider; - - /** - * Categorical calculator provider. - **/ - private final IgniteFunction<ColumnDecisionTreeTrainerInput, IgniteFunction<DoubleStream, Double>> categoricalCalculatorProvider; - - /** - * Cache used for storing data for training. - */ - private IgniteCache<RegionKey, List<RegionProjection>> prjsCache; - - /** - * Minimal information gain. - */ - private static final double MIN_INFO_GAIN = 1E-10; - - /** - * Maximal depth of the decision tree. - */ - private final int maxDepth; - - /** - * Size of block which is used for storing regions in cache. - */ - private static final int BLOCK_SIZE = 1 << 4; - - /** Ignite instance. */ - private final Ignite ignite; - - /** Logger */ - private final IgniteLogger log; - - /** - * Construct {@link ColumnDecisionTreeTrainer}. - * - * @param maxDepth Maximal depth of the decision tree. - * @param continuousCalculatorProvider Provider of calculator of splits for region projection on continuous - * features. - * @param categoricalCalculatorProvider Provider of calculator of splits for region projection on categorical - * features. - * @param regCalc Function used to assign a value to a region. - */ - public ColumnDecisionTreeTrainer(int maxDepth, - IgniteFunction<ColumnDecisionTreeTrainerInput, ? extends ContinuousSplitCalculator<D>> continuousCalculatorProvider, - IgniteFunction<ColumnDecisionTreeTrainerInput, IgniteFunction<DoubleStream, Double>> categoricalCalculatorProvider, - IgniteFunction<DoubleStream, Double> regCalc, - Ignite ignite) { - this.maxDepth = maxDepth; - this.continuousCalculatorProvider = continuousCalculatorProvider; - this.categoricalCalculatorProvider = categoricalCalculatorProvider; - this.regCalc = regCalc; - this.ignite = ignite; - this.log = ignite.log(); - } - - /** - * Utility class used to get index of feature by which split is done and split info. - */ - private static class IndexAndSplitInfo { - /** - * Index of feature by which split is done. - */ - private final int featureIdx; - - /** - * Split information. - */ - private final SplitInfo info; - - /** - * @param featureIdx Index of feature by which split is done. - * @param info Split information. - */ - IndexAndSplitInfo(int featureIdx, SplitInfo info) { - this.featureIdx = featureIdx; - this.info = info; - } - - /** {@inheritDoc} */ - @Override public String toString() { - return "IndexAndSplitInfo [featureIdx=" + featureIdx + ", info=" + info + ']'; - } - } - - /** - * Utility class used to build decision tree. Basically it is pointer to leaf node. - */ - private static class TreeTip { - /** */ - private Consumer<DecisionTreeNode> leafSetter; - - /** */ - private int depth; - - /** */ - TreeTip(Consumer<DecisionTreeNode> leafSetter, int depth) { - this.leafSetter = leafSetter; - this.depth = depth; - } - } - - /** - * Utility class used as decision tree root node. - */ - private static class RootNode implements DecisionTreeNode { - /** */ - private DecisionTreeNode s; - - /** - * {@inheritDoc} - */ - @Override public double process(Vector v) { - return s.process(v); - } - - /** */ - void setSplit(DecisionTreeNode s) { - this.s = s; - } - } - - /** - * {@inheritDoc} - */ - @Override public DecisionTreeModel train(ColumnDecisionTreeTrainerInput i) { - prjsCache = ProjectionsCache.getOrCreate(ignite); - IgniteCache<UUID, TrainingContext<D>> ctxtCache = ContextCache.getOrCreate(ignite); - SplitCache.getOrCreate(ignite); - - UUID trainingUUID = UUID.randomUUID(); - - TrainingContext<D> ct = new TrainingContext<>(i, continuousCalculatorProvider.apply(i), categoricalCalculatorProvider.apply(i), trainingUUID, ignite); - ctxtCache.put(trainingUUID, ct); - - CacheUtils.bcast(prjsCache.getName(), ignite, () -> { - Ignite ignite = Ignition.localIgnite(); - IgniteCache<RegionKey, List<RegionProjection>> projCache = ProjectionsCache.getOrCreate(ignite); - IgniteCache<FeatureKey, double[]> featuresCache = FeaturesCache.getOrCreate(ignite); - - Affinity<RegionKey> targetAffinity = ignite.affinity(ProjectionsCache.CACHE_NAME); - - ClusterNode locNode = ignite.cluster().localNode(); - - Map<FeatureKey, double[]> fm = new ConcurrentHashMap<>(); - Map<RegionKey, List<RegionProjection>> pm = new ConcurrentHashMap<>(); - - targetAffinity. - mapKeysToNodes(IntStream.range(0, i.featuresCount()). - mapToObj(idx -> ProjectionsCache.key(idx, 0, i.affinityKey(idx, ignite), trainingUUID)). - collect(Collectors.toSet())).getOrDefault(locNode, Collections.emptyList()). - forEach(k -> { - FeatureProcessor vec; - - int featureIdx = k.featureIdx(); - - IgniteCache<UUID, TrainingContext<D>> ctxCache = ContextCache.getOrCreate(ignite); - TrainingContext ctx = ctxCache.get(trainingUUID); - double[] vals = new double[ctx.labels().length]; - - vec = ctx.featureProcessor(featureIdx); - i.values(featureIdx).forEach(t -> vals[t.get1()] = t.get2()); - - fm.put(getFeatureCacheKey(featureIdx, trainingUUID, i.affinityKey(featureIdx, ignite)), vals); - - List<RegionProjection> newReg = new ArrayList<>(BLOCK_SIZE); - newReg.add(vec.createInitialRegion(getSamples(i.values(featureIdx), ctx.labels().length), vals, ctx.labels())); - pm.put(k, newReg); - }); - - featuresCache.putAll(fm); - projCache.putAll(pm); - - return null; - }); - - return doTrain(i, trainingUUID); - } - - /** - * Get samples array. - * - * @param values Stream of tuples in the form of (index, value). - * @param size size of stream. - * @return Samples array. - */ - private Integer[] getSamples(Stream<IgniteBiTuple<Integer, Double>> values, int size) { - Integer[] res = new Integer[size]; - - values.forEach(v -> res[v.get1()] = v.get1()); - - return res; - } - - /** */ - @NotNull - private DecisionTreeModel doTrain(ColumnDecisionTreeTrainerInput input, UUID uuid) { - RootNode root = new RootNode(); - - // List containing setters of leaves of the tree. - List<TreeTip> tips = new LinkedList<>(); - tips.add(new TreeTip(root::setSplit, 0)); - - int curDepth = 0; - int regsCnt = 1; - - int featuresCnt = input.featuresCount(); - IntStream.range(0, featuresCnt).mapToObj(fIdx -> SplitCache.key(fIdx, input.affinityKey(fIdx, ignite), uuid)). - forEach(k -> SplitCache.getOrCreate(ignite).put(k, new IgniteBiTuple<>(0, 0.0))); - updateSplitCache(0, regsCnt, featuresCnt, ig -> i -> input.affinityKey(i, ig), uuid); - - // TODO: IGNITE-5893 Currently if the best split makes tree deeper than max depth process will be terminated, but actually we should - // only stop when *any* improving split makes tree deeper than max depth. Can be fixed if we will store which - // regions cannot be split more and split only those that can. - while (true) { - long before = System.currentTimeMillis(); - - IgniteBiTuple<Integer, IgniteBiTuple<Integer, Double>> b = findBestSplitIndexForFeatures(featuresCnt, input::affinityKey, uuid); - - long findBestRegIdx = System.currentTimeMillis() - before; - - Integer bestFeatureIdx = b.get1(); - - Integer regIdx = b.get2().get1(); - Double bestInfoGain = b.get2().get2(); - - if (regIdx >= 0 && bestInfoGain > MIN_INFO_GAIN) { - before = System.currentTimeMillis(); - - SplitInfo bi = ignite.compute().affinityCall(ProjectionsCache.CACHE_NAME, - input.affinityKey(bestFeatureIdx, ignite), - () -> { - TrainingContext<ContinuousRegionInfo> ctx = ContextCache.getOrCreate(ignite).get(uuid); - Ignite ignite = Ignition.localIgnite(); - RegionKey key = ProjectionsCache.key(bestFeatureIdx, - regIdx / BLOCK_SIZE, - input.affinityKey(bestFeatureIdx, Ignition.localIgnite()), - uuid); - RegionProjection reg = ProjectionsCache.getOrCreate(ignite).localPeek(key).get(regIdx % BLOCK_SIZE); - return ctx.featureProcessor(bestFeatureIdx).findBestSplit(reg, ctx.values(bestFeatureIdx, ignite), ctx.labels(), regIdx); - }); - - long findBestSplit = System.currentTimeMillis() - before; - - IndexAndSplitInfo best = new IndexAndSplitInfo(bestFeatureIdx, bi); - - regsCnt++; - - if (log.isDebugEnabled()) - log.debug("Globally best: " + best.info + " idx time: " + findBestRegIdx + ", calculate best: " + findBestSplit + " fi: " + best.featureIdx + ", regs: " + regsCnt); - // Request bitset for split region. - int ind = best.info.regionIndex(); - - SparseBitSet bs = ignite.compute().affinityCall(ProjectionsCache.CACHE_NAME, - input.affinityKey(bestFeatureIdx, ignite), - () -> { - Ignite ignite = Ignition.localIgnite(); - IgniteCache<FeatureKey, double[]> featuresCache = FeaturesCache.getOrCreate(ignite); - IgniteCache<UUID, TrainingContext<D>> ctxCache = ContextCache.getOrCreate(ignite); - TrainingContext ctx = ctxCache.localPeek(uuid); - - double[] values = featuresCache.localPeek(getFeatureCacheKey(bestFeatureIdx, uuid, input.affinityKey(bestFeatureIdx, Ignition.localIgnite()))); - RegionKey key = ProjectionsCache.key(bestFeatureIdx, - regIdx / BLOCK_SIZE, - input.affinityKey(bestFeatureIdx, Ignition.localIgnite()), - uuid); - RegionProjection reg = ProjectionsCache.getOrCreate(ignite).localPeek(key).get(regIdx % BLOCK_SIZE); - return ctx.featureProcessor(bestFeatureIdx).calculateOwnershipBitSet(reg, values, best.info); - - }); - - SplitNode sn = best.info.createSplitNode(best.featureIdx); - - TreeTip tipToSplit = tips.get(ind); - tipToSplit.leafSetter.accept(sn); - tipToSplit.leafSetter = sn::setLeft; - int d = tipToSplit.depth++; - tips.add(new TreeTip(sn::setRight, d)); - - if (d > curDepth) { - curDepth = d; - if (log.isDebugEnabled()) { - log.debug("Depth: " + curDepth); - log.debug("Cache size: " + prjsCache.size(CachePeekMode.PRIMARY)); - } - } - - before = System.currentTimeMillis(); - // Perform split on all feature vectors. - IgniteSupplier<Set<RegionKey>> bestRegsKeys = () -> IntStream.range(0, featuresCnt). - mapToObj(fIdx -> ProjectionsCache.key(fIdx, ind / BLOCK_SIZE, input.affinityKey(fIdx, Ignition.localIgnite()), uuid)). - collect(Collectors.toSet()); - - int rc = regsCnt; - - // Perform split. - CacheUtils.update(prjsCache.getName(), ignite, - (Ignite ign, Cache.Entry<RegionKey, List<RegionProjection>> e) -> { - RegionKey k = e.getKey(); - - List<RegionProjection> leftBlock = e.getValue(); - - int fIdx = k.featureIdx(); - int idxInBlock = ind % BLOCK_SIZE; - - IgniteCache<UUID, TrainingContext<D>> ctxCache = ContextCache.getOrCreate(ign); - TrainingContext<D> ctx = ctxCache.get(uuid); - - RegionProjection targetRegProj = leftBlock.get(idxInBlock); - - IgniteBiTuple<RegionProjection, RegionProjection> regs = ctx. - performSplit(input, bs, fIdx, best.featureIdx, targetRegProj, best.info.leftData(), best.info.rightData(), ign); - - RegionProjection left = regs.get1(); - RegionProjection right = regs.get2(); - - leftBlock.set(idxInBlock, left); - RegionKey rightKey = ProjectionsCache.key(fIdx, (rc - 1) / BLOCK_SIZE, input.affinityKey(fIdx, ign), uuid); - - IgniteCache<RegionKey, List<RegionProjection>> c = ProjectionsCache.getOrCreate(ign); - - List<RegionProjection> rightBlock = rightKey.equals(k) ? leftBlock : c.localPeek(rightKey); - - if (rightBlock == null) { - List<RegionProjection> newBlock = new ArrayList<>(BLOCK_SIZE); - newBlock.add(right); - return Stream.of(new CacheEntryImpl<>(k, leftBlock), new CacheEntryImpl<>(rightKey, newBlock)); - } - else { - rightBlock.add(right); - return rightBlock.equals(k) ? - Stream.of(new CacheEntryImpl<>(k, leftBlock)) : - Stream.of(new CacheEntryImpl<>(k, leftBlock), new CacheEntryImpl<>(rightKey, rightBlock)); - } - }, - bestRegsKeys); - - if (log.isDebugEnabled()) - log.debug("Update of projections cache time: " + (System.currentTimeMillis() - before)); - - before = System.currentTimeMillis(); - - updateSplitCache(ind, rc, featuresCnt, ig -> i -> input.affinityKey(i, ig), uuid); - - if (log.isDebugEnabled()) - log.debug("Update of split cache time: " + (System.currentTimeMillis() - before)); - } - else { - if (log.isDebugEnabled()) - log.debug("Best split [bestFeatureIdx=" + bestFeatureIdx + ", bestInfoGain=" + bestInfoGain + "]"); - break; - } - } - - int rc = regsCnt; - - IgniteSupplier<Iterable<Cache.Entry<RegionKey, List<RegionProjection>>>> featZeroRegs = () -> { - IgniteCache<RegionKey, List<RegionProjection>> projsCache = ProjectionsCache.getOrCreate(Ignition.localIgnite()); - - return () -> IntStream.range(0, (rc - 1) / BLOCK_SIZE + 1). - mapToObj(rBIdx -> ProjectionsCache.key(0, rBIdx, input.affinityKey(0, Ignition.localIgnite()), uuid)). - map(k -> (Cache.Entry<RegionKey, List<RegionProjection>>)new CacheEntryImpl<>(k, projsCache.localPeek(k))).iterator(); - }; - - Map<Integer, Double> vals = CacheUtils.reduce(prjsCache.getName(), ignite, - (TrainingContext ctx, Cache.Entry<RegionKey, List<RegionProjection>> e, Map<Integer, Double> m) -> { - int regBlockIdx = e.getKey().regionBlockIndex(); - - if (e.getValue() != null) { - for (int i = 0; i < e.getValue().size(); i++) { - int regIdx = regBlockIdx * BLOCK_SIZE + i; - RegionProjection reg = e.getValue().get(i); - - Double res = regCalc.apply(Arrays.stream(reg.sampleIndexes()).mapToDouble(s -> ctx.labels()[s])); - m.put(regIdx, res); - } - } - - return m; - }, - () -> ContextCache.getOrCreate(Ignition.localIgnite()).get(uuid), - featZeroRegs, - (infos, infos2) -> { - Map<Integer, Double> res = new HashMap<>(); - res.putAll(infos); - res.putAll(infos2); - return res; - }, - HashMap::new - ); - - int i = 0; - for (TreeTip tip : tips) { - tip.leafSetter.accept(new Leaf(vals.get(i))); - i++; - } - - ProjectionsCache.clear(featuresCnt, rc, input::affinityKey, uuid, ignite); - ContextCache.getOrCreate(ignite).remove(uuid); - FeaturesCache.clear(featuresCnt, input::affinityKey, uuid, ignite); - SplitCache.clear(featuresCnt, input::affinityKey, uuid, ignite); - - return new DecisionTreeModel(root.s); - } - - /** - * Find the best split in the form (feature index, (index of region with the best split, impurity of region with the - * best split)). - * - * @param featuresCnt Count of features. - * @param affinity Affinity function. - * @param trainingUUID UUID of training. - * @return Best split in the form (feature index, (index of region with the best split, impurity of region with the - * best split)). - */ - private IgniteBiTuple<Integer, IgniteBiTuple<Integer, Double>> findBestSplitIndexForFeatures(int featuresCnt, - IgniteBiFunction<Integer, Ignite, Object> affinity, - UUID trainingUUID) { - Set<Integer> featureIndexes = IntStream.range(0, featuresCnt).boxed().collect(Collectors.toSet()); - - return CacheUtils.reduce(SplitCache.CACHE_NAME, ignite, - (Object ctx, Cache.Entry<SplitKey, IgniteBiTuple<Integer, Double>> e, IgniteBiTuple<Integer, IgniteBiTuple<Integer, Double>> r) -> - Functions.MAX_GENERIC(new IgniteBiTuple<>(e.getKey().featureIdx(), e.getValue()), r, comparator()), - () -> null, - () -> SplitCache.localEntries(featureIndexes, affinity, trainingUUID), - (i1, i2) -> Functions.MAX_GENERIC(i1, i2, Comparator.comparingDouble(bt -> bt.get2().get2())), - () -> new IgniteBiTuple<>(-1, new IgniteBiTuple<>(-1, Double.NEGATIVE_INFINITY)) - ); - } - - /** */ - private static Comparator<IgniteBiTuple<Integer, IgniteBiTuple<Integer, Double>>> comparator() { - return Comparator.comparingDouble(bt -> bt != null && bt.get2() != null ? bt.get2().get2() : Double.NEGATIVE_INFINITY); - } - - /** - * Update split cache. - * - * @param lastSplitRegionIdx Index of region which had last best split. - * @param regsCnt Count of regions. - * @param featuresCnt Count of features. - * @param affinity Affinity function. - * @param trainingUUID UUID of current training. - */ - private void updateSplitCache(int lastSplitRegionIdx, int regsCnt, int featuresCnt, - IgniteCurriedBiFunction<Ignite, Integer, Object> affinity, - UUID trainingUUID) { - CacheUtils.update(SplitCache.CACHE_NAME, ignite, - (Ignite ign, Cache.Entry<SplitKey, IgniteBiTuple<Integer, Double>> e) -> { - Integer bestRegIdx = e.getValue().get1(); - int fIdx = e.getKey().featureIdx(); - TrainingContext ctx = ContextCache.getOrCreate(ign).get(trainingUUID); - - Map<Integer, RegionProjection> toCompare; - - // Fully recalculate best. - if (bestRegIdx == lastSplitRegionIdx) - toCompare = ProjectionsCache.projectionsOfFeature(fIdx, maxDepth, regsCnt, BLOCK_SIZE, affinity.apply(ign), trainingUUID, ign); - // Just compare previous best and two regions which are produced by split. - else - toCompare = ProjectionsCache.projectionsOfRegions(fIdx, maxDepth, - IntStream.of(bestRegIdx, lastSplitRegionIdx, regsCnt - 1), BLOCK_SIZE, affinity.apply(ign), trainingUUID, ign); - - double[] values = ctx.values(fIdx, ign); - double[] labels = ctx.labels(); - - Optional<IgniteBiTuple<Integer, Double>> max = toCompare.entrySet().stream(). - map(ent -> { - SplitInfo bestSplit = ctx.featureProcessor(fIdx).findBestSplit(ent.getValue(), values, labels, ent.getKey()); - return new IgniteBiTuple<>(ent.getKey(), bestSplit != null ? bestSplit.infoGain() : Double.NEGATIVE_INFINITY); - }). - max(Comparator.comparingDouble(IgniteBiTuple::get2)); - - return max.<Stream<Cache.Entry<SplitKey, IgniteBiTuple<Integer, Double>>>> - map(objects -> Stream.of(new CacheEntryImpl<>(e.getKey(), objects))).orElseGet(Stream::empty); - }, - () -> IntStream.range(0, featuresCnt).mapToObj(fIdx -> SplitCache.key(fIdx, affinity.apply(ignite).apply(fIdx), trainingUUID)).collect(Collectors.toSet()) - ); - } -}
http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/ColumnDecisionTreeTrainerInput.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/ColumnDecisionTreeTrainerInput.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/ColumnDecisionTreeTrainerInput.java deleted file mode 100644 index bf8790b..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/ColumnDecisionTreeTrainerInput.java +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.trees.trainers.columnbased; - -import java.util.Map; -import java.util.stream.Stream; -import org.apache.ignite.Ignite; -import org.apache.ignite.lang.IgniteBiTuple; - -/** - * Input for {@link ColumnDecisionTreeTrainer}. - */ -public interface ColumnDecisionTreeTrainerInput { - /** - * Projection of data on feature with the given index. - * - * @param idx Feature index. - * @return Projection of data on feature with the given index. - */ - Stream<IgniteBiTuple<Integer, Double>> values(int idx); - - /** - * Labels. - * - * @param ignite Ignite instance. - */ - double[] labels(Ignite ignite); - - /** Information about which features are categorical in the form of feature index -> number of categories. */ - Map<Integer, Integer> catFeaturesInfo(); - - /** Number of features. */ - int featuresCount(); - - /** - * Get affinity key for the given column index. - * Affinity key should be pure-functionally dependent from idx. - */ - Object affinityKey(int idx, Ignite ignite); -} http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/MatrixColumnDecisionTreeTrainerInput.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/MatrixColumnDecisionTreeTrainerInput.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/MatrixColumnDecisionTreeTrainerInput.java deleted file mode 100644 index 3da6bad..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/MatrixColumnDecisionTreeTrainerInput.java +++ /dev/null @@ -1,83 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.trees.trainers.columnbased; - -import java.util.HashMap; -import java.util.Map; -import java.util.stream.DoubleStream; -import java.util.stream.IntStream; -import java.util.stream.Stream; -import javax.cache.Cache; -import org.apache.ignite.Ignite; -import org.apache.ignite.lang.IgniteBiTuple; -import org.apache.ignite.ml.math.distributed.keys.RowColMatrixKey; -import org.apache.ignite.ml.math.distributed.keys.impl.SparseMatrixKey; -import org.apache.ignite.ml.math.functions.IgniteFunction; -import org.apache.ignite.ml.math.impls.matrix.SparseDistributedMatrix; -import org.apache.ignite.ml.math.impls.storage.matrix.SparseDistributedMatrixStorage; -import org.apache.ignite.ml.math.StorageConstants; -import org.jetbrains.annotations.NotNull; - -/** - * Adapter of SparseDistributedMatrix to ColumnDecisionTreeTrainerInput. - * Sparse SparseDistributedMatrix should be in {@link StorageConstants#COLUMN_STORAGE_MODE} and - * should contain samples in rows last position in row being label of this sample. - */ -public class MatrixColumnDecisionTreeTrainerInput extends CacheColumnDecisionTreeTrainerInput<RowColMatrixKey, Map<Integer, Double>> { - /** - * @param m Sparse SparseDistributedMatrix should be in {@link StorageConstants#COLUMN_STORAGE_MODE} - * containing samples in rows last position in row being label of this sample. - * @param catFeaturesInfo Information about which features are categorical in form of feature index -> number of - * categories. - */ - public MatrixColumnDecisionTreeTrainerInput(SparseDistributedMatrix m, Map<Integer, Integer> catFeaturesInfo) { - super(((SparseDistributedMatrixStorage)m.getStorage()).cache(), - () -> Stream.of(new SparseMatrixKey(m.columnSize() - 1, m.getUUID(), m.columnSize() - 1)), - valuesMapper(m), - labels(m), - keyMapper(m), - catFeaturesInfo, - m.columnSize() - 1, - m.rowSize()); - } - - /** Values mapper. See {@link CacheColumnDecisionTreeTrainerInput#valuesMapper} */ - @NotNull - private static IgniteFunction<Cache.Entry<RowColMatrixKey, Map<Integer, Double>>, Stream<IgniteBiTuple<Integer, Double>>> valuesMapper( - SparseDistributedMatrix m) { - return ent -> { - Map<Integer, Double> map = ent.getValue() != null ? ent.getValue() : new HashMap<>(); - return IntStream.range(0, m.rowSize()).mapToObj(k -> new IgniteBiTuple<>(k, map.getOrDefault(k, 0.0))); - }; - } - - /** Key mapper. See {@link CacheColumnDecisionTreeTrainerInput#keyMapper} */ - @NotNull private static IgniteFunction<Integer, Stream<RowColMatrixKey>> keyMapper(SparseDistributedMatrix m) { - return i -> Stream.of(new SparseMatrixKey(i, ((SparseDistributedMatrixStorage)m.getStorage()).getUUID(), i)); - } - - /** Labels mapper. See {@link CacheColumnDecisionTreeTrainerInput#labelsMapper} */ - @NotNull private static IgniteFunction<Map<Integer, Double>, DoubleStream> labels(SparseDistributedMatrix m) { - return mp -> IntStream.range(0, m.rowSize()).mapToDouble(k -> mp.getOrDefault(k, 0.0)); - } - - /** {@inheritDoc} */ - @Override public Object affinityKey(int idx, Ignite ignite) { - return idx; - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/RegionProjection.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/RegionProjection.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/RegionProjection.java deleted file mode 100644 index e95f57b..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/RegionProjection.java +++ /dev/null @@ -1,109 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.trees.trainers.columnbased; - -import java.io.Externalizable; -import java.io.IOException; -import java.io.ObjectInput; -import java.io.ObjectOutput; -import org.apache.ignite.ml.trees.RegionInfo; - -/** - * Projection of region on given feature. - * - * @param <D> Data of region. - */ -public class RegionProjection<D extends RegionInfo> implements Externalizable { - /** Samples projections. */ - protected Integer[] sampleIndexes; - - /** Region data */ - protected D data; - - /** Depth of this region. */ - protected int depth; - - /** - * @param sampleIndexes Samples indexes. - * @param data Region data. - * @param depth Depth of this region. - */ - public RegionProjection(Integer[] sampleIndexes, D data, int depth) { - this.data = data; - this.depth = depth; - this.sampleIndexes = sampleIndexes; - } - - /** - * No-op constructor used for serialization/deserialization. - */ - public RegionProjection() { - // No-op. - } - - /** - * Get samples indexes. - * - * @return Samples indexes. - */ - public Integer[] sampleIndexes() { - return sampleIndexes; - } - - /** - * Get region data. - * - * @return Region data. - */ - public D data() { - return data; - } - - /** - * Get region depth. - * - * @return Region depth. - */ - public int depth() { - return depth; - } - - /** {@inheritDoc} */ - @Override public void writeExternal(ObjectOutput out) throws IOException { - out.writeInt(sampleIndexes.length); - - for (Integer sampleIndex : sampleIndexes) - out.writeInt(sampleIndex); - - out.writeObject(data); - out.writeInt(depth); - } - - /** {@inheritDoc} */ - @Override public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { - int size = in.readInt(); - - sampleIndexes = new Integer[size]; - - for (int i = 0; i < size; i++) - sampleIndexes[i] = in.readInt(); - - data = (D)in.readObject(); - depth = in.readInt(); - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/TrainingContext.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/TrainingContext.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/TrainingContext.java deleted file mode 100644 index 6415dab..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/TrainingContext.java +++ /dev/null @@ -1,166 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.trees.trainers.columnbased; - -import com.zaxxer.sparsebits.SparseBitSet; -import java.util.Map; -import java.util.UUID; -import java.util.stream.DoubleStream; -import org.apache.ignite.Ignite; -import org.apache.ignite.IgniteCache; -import org.apache.ignite.Ignition; -import org.apache.ignite.lang.IgniteBiTuple; -import org.apache.ignite.ml.math.functions.IgniteFunction; -import org.apache.ignite.ml.trees.ContinuousRegionInfo; -import org.apache.ignite.ml.trees.ContinuousSplitCalculator; -import org.apache.ignite.ml.trees.RegionInfo; -import org.apache.ignite.ml.trees.trainers.columnbased.caches.FeaturesCache; -import org.apache.ignite.ml.trees.trainers.columnbased.vectors.CategoricalFeatureProcessor; -import org.apache.ignite.ml.trees.trainers.columnbased.vectors.ContinuousFeatureProcessor; -import org.apache.ignite.ml.trees.trainers.columnbased.vectors.FeatureProcessor; - -import static org.apache.ignite.ml.trees.trainers.columnbased.caches.FeaturesCache.COLUMN_DECISION_TREE_TRAINER_FEATURES_CACHE_NAME; - -/** - * Context of training with {@link ColumnDecisionTreeTrainer}. - * - * @param <D> Class for storing of information used in calculation of impurity of continuous feature region. - */ -public class TrainingContext<D extends ContinuousRegionInfo> { - /** Input for training with {@link ColumnDecisionTreeTrainer}. */ - private final ColumnDecisionTreeTrainerInput input; - - /** Labels. */ - private final double[] labels; - - /** Calculator used for finding splits of region of continuous features. */ - private final ContinuousSplitCalculator<D> continuousSplitCalculator; - - /** Calculator used for finding splits of region of categorical feature. */ - private final IgniteFunction<DoubleStream, Double> categoricalSplitCalculator; - - /** UUID of current training. */ - private final UUID trainingUUID; - - /** - * Construct context for training with {@link ColumnDecisionTreeTrainer}. - * - * @param input Input for training. - * @param continuousSplitCalculator Calculator used for calculations of splits of continuous features regions. - * @param categoricalSplitCalculator Calculator used for calculations of splits of categorical features regions. - * @param trainingUUID UUID of the current training. - * @param ignite Ignite instance. - */ - public TrainingContext(ColumnDecisionTreeTrainerInput input, - ContinuousSplitCalculator<D> continuousSplitCalculator, - IgniteFunction<DoubleStream, Double> categoricalSplitCalculator, - UUID trainingUUID, - Ignite ignite) { - this.input = input; - this.labels = input.labels(ignite); - this.continuousSplitCalculator = continuousSplitCalculator; - this.categoricalSplitCalculator = categoricalSplitCalculator; - this.trainingUUID = trainingUUID; - } - - /** - * Get processor used for calculating splits of categorical features. - * - * @param catsCnt Count of categories. - * @return Processor used for calculating splits of categorical features. - */ - public CategoricalFeatureProcessor categoricalFeatureProcessor(int catsCnt) { - return new CategoricalFeatureProcessor(categoricalSplitCalculator, catsCnt); - } - - /** - * Get processor used for calculating splits of continuous features. - * - * @return Processor used for calculating splits of continuous features. - */ - public ContinuousFeatureProcessor<D> continuousFeatureProcessor() { - return new ContinuousFeatureProcessor<>(continuousSplitCalculator); - } - - /** - * Get labels. - * - * @return Labels. - */ - public double[] labels() { - return labels; - } - - /** - * Get values of feature with given index. - * - * @param featIdx Feature index. - * @param ignite Ignite instance. - * @return Values of feature with given index. - */ - public double[] values(int featIdx, Ignite ignite) { - IgniteCache<FeaturesCache.FeatureKey, double[]> featuresCache = ignite.getOrCreateCache(COLUMN_DECISION_TREE_TRAINER_FEATURES_CACHE_NAME); - return featuresCache.localPeek(FeaturesCache.getFeatureCacheKey(featIdx, trainingUUID, input.affinityKey(featIdx, ignite))); - } - - /** - * Perform best split on the given region projection. - * - * @param input Input of {@link ColumnDecisionTreeTrainer} performing split. - * @param bitSet Bit set specifying split. - * @param targetFeatIdx Index of feature for performing split. - * @param bestFeatIdx Index of feature with best split. - * @param targetRegionPrj Projection of region to split on feature with index {@code featureIdx}. - * @param leftData Data of left region of split. - * @param rightData Data of right region of split. - * @param ignite Ignite instance. - * @return Perform best split on the given region projection. - */ - public IgniteBiTuple<RegionProjection, RegionProjection> performSplit(ColumnDecisionTreeTrainerInput input, - SparseBitSet bitSet, int targetFeatIdx, int bestFeatIdx, RegionProjection targetRegionPrj, RegionInfo leftData, - RegionInfo rightData, Ignite ignite) { - - Map<Integer, Integer> catFeaturesInfo = input.catFeaturesInfo(); - - if (!catFeaturesInfo.containsKey(targetFeatIdx) && !catFeaturesInfo.containsKey(bestFeatIdx)) - return continuousFeatureProcessor().performSplit(bitSet, targetRegionPrj, (D)leftData, (D)rightData); - else if (catFeaturesInfo.containsKey(targetFeatIdx)) - return categoricalFeatureProcessor(catFeaturesInfo.get(targetFeatIdx)).performSplitGeneric(bitSet, values(targetFeatIdx, ignite), targetRegionPrj, leftData, rightData); - return continuousFeatureProcessor().performSplitGeneric(bitSet, labels, targetRegionPrj, leftData, rightData); - } - - /** - * Processor used for calculating splits for feature with the given index. - * - * @param featureIdx Index of feature to process. - * @return Processor used for calculating splits for feature with the given index. - */ - public FeatureProcessor featureProcessor(int featureIdx) { - return input.catFeaturesInfo().containsKey(featureIdx) ? categoricalFeatureProcessor(input.catFeaturesInfo().get(featureIdx)) : continuousFeatureProcessor(); - } - - /** - * Shortcut for affinity key. - * - * @param idx Feature index. - * @return Affinity key. - */ - public Object affinityKey(int idx) { - return input.affinityKey(idx, Ignition.localIgnite()); - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/caches/ContextCache.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/caches/ContextCache.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/caches/ContextCache.java deleted file mode 100644 index 51ea359..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/caches/ContextCache.java +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.trees.trainers.columnbased.caches; - -import java.util.UUID; -import org.apache.ignite.Ignite; -import org.apache.ignite.IgniteCache; -import org.apache.ignite.cache.CacheAtomicityMode; -import org.apache.ignite.cache.CacheMode; -import org.apache.ignite.cache.CacheWriteSynchronizationMode; -import org.apache.ignite.configuration.CacheConfiguration; -import org.apache.ignite.ml.trees.ContinuousRegionInfo; -import org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainer; -import org.apache.ignite.ml.trees.trainers.columnbased.TrainingContext; - -/** - * Class for operations related to cache containing training context for {@link ColumnDecisionTreeTrainer}. - */ -public class ContextCache { - /** - * Name of cache containing training context for {@link ColumnDecisionTreeTrainer}. - */ - public static final String COLUMN_DECISION_TREE_TRAINER_CONTEXT_CACHE_NAME = "COLUMN_DECISION_TREE_TRAINER_CONTEXT_CACHE_NAME"; - - /** - * Get or create cache for training context. - * - * @param ignite Ignite instance. - * @param <D> Class storing information about continuous regions. - * @return Cache for training context. - */ - public static <D extends ContinuousRegionInfo> IgniteCache<UUID, TrainingContext<D>> getOrCreate(Ignite ignite) { - CacheConfiguration<UUID, TrainingContext<D>> cfg = new CacheConfiguration<>(); - - cfg.setWriteSynchronizationMode(CacheWriteSynchronizationMode.FULL_SYNC); - - cfg.setAtomicityMode(CacheAtomicityMode.ATOMIC); - - cfg.setEvictionPolicy(null); - - cfg.setCopyOnRead(false); - - cfg.setCacheMode(CacheMode.REPLICATED); - - cfg.setOnheapCacheEnabled(true); - - cfg.setReadFromBackup(true); - - cfg.setName(COLUMN_DECISION_TREE_TRAINER_CONTEXT_CACHE_NAME); - - return ignite.getOrCreateCache(cfg); - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/caches/FeaturesCache.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/caches/FeaturesCache.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/caches/FeaturesCache.java deleted file mode 100644 index fcc1f16..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/caches/FeaturesCache.java +++ /dev/null @@ -1,151 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.trees.trainers.columnbased.caches; - -import java.util.Set; -import java.util.UUID; -import java.util.stream.Collectors; -import java.util.stream.IntStream; -import org.apache.ignite.Ignite; -import org.apache.ignite.IgniteCache; -import org.apache.ignite.cache.CacheAtomicityMode; -import org.apache.ignite.cache.CacheMode; -import org.apache.ignite.cache.CacheWriteSynchronizationMode; -import org.apache.ignite.cache.affinity.AffinityKeyMapped; -import org.apache.ignite.configuration.CacheConfiguration; -import org.apache.ignite.ml.math.functions.IgniteBiFunction; -import org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainer; - -/** - * Cache storing features for {@link ColumnDecisionTreeTrainer}. - */ -public class FeaturesCache { - /** - * Name of cache which is used for storing features for {@link ColumnDecisionTreeTrainer}. - */ - public static final String COLUMN_DECISION_TREE_TRAINER_FEATURES_CACHE_NAME = "COLUMN_DECISION_TREE_TRAINER_FEATURES_CACHE_NAME"; - - /** - * Key of features cache. - */ - public static class FeatureKey { - /** Column key of cache used as input for {@link ColumnDecisionTreeTrainer}. */ - @AffinityKeyMapped - private Object parentColKey; - - /** Index of feature. */ - private final int featureIdx; - - /** UUID of training. */ - private final UUID trainingUUID; - - /** - * Construct FeatureKey. - * - * @param featureIdx Feature index. - * @param trainingUUID UUID of training. - * @param parentColKey Column key of cache used as input. - */ - public FeatureKey(int featureIdx, UUID trainingUUID, Object parentColKey) { - this.parentColKey = parentColKey; - this.featureIdx = featureIdx; - this.trainingUUID = trainingUUID; - this.parentColKey = parentColKey; - } - - /** {@inheritDoc} */ - @Override public boolean equals(Object o) { - if (this == o) - return true; - if (o == null || getClass() != o.getClass()) - return false; - - FeatureKey key = (FeatureKey)o; - - if (featureIdx != key.featureIdx) - return false; - return trainingUUID != null ? trainingUUID.equals(key.trainingUUID) : key.trainingUUID == null; - } - - /** {@inheritDoc} */ - @Override public int hashCode() { - int res = trainingUUID != null ? trainingUUID.hashCode() : 0; - res = 31 * res + featureIdx; - return res; - } - } - - /** - * Create new projections cache for ColumnDecisionTreeTrainer if needed. - * - * @param ignite Ignite instance. - */ - public static IgniteCache<FeatureKey, double[]> getOrCreate(Ignite ignite) { - CacheConfiguration<FeatureKey, double[]> cfg = new CacheConfiguration<>(); - - // Write to primary. - cfg.setWriteSynchronizationMode(CacheWriteSynchronizationMode.PRIMARY_SYNC); - - // Atomic transactions only. - cfg.setAtomicityMode(CacheAtomicityMode.ATOMIC); - - // No eviction. - cfg.setEvictionPolicy(null); - - // No copying of values. - cfg.setCopyOnRead(false); - - // Cache is partitioned. - cfg.setCacheMode(CacheMode.PARTITIONED); - - cfg.setOnheapCacheEnabled(true); - - cfg.setBackups(0); - - cfg.setName(COLUMN_DECISION_TREE_TRAINER_FEATURES_CACHE_NAME); - - return ignite.getOrCreateCache(cfg); - } - - /** - * Construct FeatureKey from index, uuid and affinity key. - * - * @param idx Feature index. - * @param uuid UUID of training. - * @param aff Affinity key. - * @return FeatureKey. - */ - public static FeatureKey getFeatureCacheKey(int idx, UUID uuid, Object aff) { - return new FeatureKey(idx, uuid, aff); - } - - /** - * Clear all data from features cache related to given training. - * - * @param featuresCnt Count of features. - * @param affinity Affinity function. - * @param uuid Training uuid. - * @param ignite Ignite instance. - */ - public static void clear(int featuresCnt, IgniteBiFunction<Integer, Ignite, Object> affinity, UUID uuid, - Ignite ignite) { - Set<FeatureKey> toRmv = IntStream.range(0, featuresCnt).boxed().map(fIdx -> getFeatureCacheKey(fIdx, uuid, affinity.apply(fIdx, ignite))).collect(Collectors.toSet()); - - getOrCreate(ignite).removeAll(toRmv); - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/caches/ProjectionsCache.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/caches/ProjectionsCache.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/caches/ProjectionsCache.java deleted file mode 100644 index 080cb66..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/caches/ProjectionsCache.java +++ /dev/null @@ -1,286 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.trees.trainers.columnbased.caches; - -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.PrimitiveIterator; -import java.util.Set; -import java.util.UUID; -import java.util.stream.Collectors; -import java.util.stream.IntStream; -import org.apache.ignite.Ignite; -import org.apache.ignite.IgniteCache; -import org.apache.ignite.Ignition; -import org.apache.ignite.cache.CacheAtomicityMode; -import org.apache.ignite.cache.CacheMode; -import org.apache.ignite.cache.CacheWriteSynchronizationMode; -import org.apache.ignite.cache.affinity.Affinity; -import org.apache.ignite.cache.affinity.AffinityKeyMapped; -import org.apache.ignite.configuration.CacheConfiguration; -import org.apache.ignite.lang.IgniteBiTuple; -import org.apache.ignite.ml.math.functions.IgniteBiFunction; -import org.apache.ignite.ml.math.functions.IgniteFunction; -import org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainer; -import org.apache.ignite.ml.trees.trainers.columnbased.RegionProjection; - -/** - * Cache used for storing data of region projections on features. - */ -public class ProjectionsCache { - /** - * Name of cache which is used for storing data of region projections on features of {@link - * ColumnDecisionTreeTrainer}. - */ - public static final String CACHE_NAME = "COLUMN_DECISION_TREE_TRAINER_PROJECTIONS_CACHE_NAME"; - - /** - * Key of region projections cache. - */ - public static class RegionKey { - /** Column key of cache used as input for {@link ColumnDecisionTreeTrainer}. */ - @AffinityKeyMapped - private final Object parentColKey; - - /** Feature index. */ - private final int featureIdx; - - /** Region index. */ - private final int regBlockIdx; - - /** Training UUID. */ - private final UUID trainingUUID; - - /** - * Construct a RegionKey from feature index, index of block, key of column in input cache and UUID of training. - * - * @param featureIdx Feature index. - * @param regBlockIdx Index of block. - * @param parentColKey Key of column in input cache. - * @param trainingUUID UUID of training. - */ - public RegionKey(int featureIdx, int regBlockIdx, Object parentColKey, UUID trainingUUID) { - this.featureIdx = featureIdx; - this.regBlockIdx = regBlockIdx; - this.trainingUUID = trainingUUID; - this.parentColKey = parentColKey; - } - - /** - * Feature index. - * - * @return Feature index. - */ - public int featureIdx() { - return featureIdx; - } - - /** - * Region block index. - * - * @return Region block index. - */ - public int regionBlockIndex() { - return regBlockIdx; - } - - /** - * UUID of training. - * - * @return UUID of training. - */ - public UUID trainingUUID() { - return trainingUUID; - } - - /** {@inheritDoc} */ - @Override public boolean equals(Object o) { - if (this == o) - return true; - if (o == null || getClass() != o.getClass()) - return false; - - RegionKey key = (RegionKey)o; - - if (featureIdx != key.featureIdx) - return false; - if (regBlockIdx != key.regBlockIdx) - return false; - return trainingUUID != null ? trainingUUID.equals(key.trainingUUID) : key.trainingUUID == null; - } - - /** {@inheritDoc} */ - @Override public int hashCode() { - int res = trainingUUID != null ? trainingUUID.hashCode() : 0; - res = 31 * res + featureIdx; - res = 31 * res + regBlockIdx; - return res; - } - - /** {@inheritDoc} */ - @Override public String toString() { - return "RegionKey [" + - "parentColKey=" + parentColKey + - ", featureIdx=" + featureIdx + - ", regBlockIdx=" + regBlockIdx + - ", trainingUUID=" + trainingUUID + - ']'; - } - } - - /** - * Affinity service for region projections cache. - * - * @return Affinity service for region projections cache. - */ - public static Affinity<RegionKey> affinity() { - return Ignition.localIgnite().affinity(CACHE_NAME); - } - - /** - * Get or create region projections cache. - * - * @param ignite Ignite instance. - * @return Region projections cache. - */ - public static IgniteCache<RegionKey, List<RegionProjection>> getOrCreate(Ignite ignite) { - CacheConfiguration<RegionKey, List<RegionProjection>> cfg = new CacheConfiguration<>(); - - // Write to primary. - cfg.setWriteSynchronizationMode(CacheWriteSynchronizationMode.PRIMARY_SYNC); - - // Atomic transactions only. - cfg.setAtomicityMode(CacheAtomicityMode.ATOMIC); - - // No eviction. - cfg.setEvictionPolicy(null); - - // No copying of values. - cfg.setCopyOnRead(false); - - // Cache is partitioned. - cfg.setCacheMode(CacheMode.PARTITIONED); - - cfg.setBackups(0); - - cfg.setOnheapCacheEnabled(true); - - cfg.setName(CACHE_NAME); - - return ignite.getOrCreateCache(cfg); - } - - /** - * Get region projections in the form of map (regionIndex -> regionProjections). - * - * @param featureIdx Feature index. - * @param maxDepth Max depth of decision tree. - * @param regionIndexes Indexes of regions for which we want get projections. - * @param blockSize Size of regions block. - * @param affinity Affinity function. - * @param trainingUUID UUID of training. - * @param ignite Ignite instance. - * @return Region projections in the form of map (regionIndex -> regionProjections). - */ - public static Map<Integer, RegionProjection> projectionsOfRegions(int featureIdx, int maxDepth, - IntStream regionIndexes, int blockSize, IgniteFunction<Integer, Object> affinity, UUID trainingUUID, - Ignite ignite) { - HashMap<Integer, RegionProjection> regsForSearch = new HashMap<>(); - IgniteCache<RegionKey, List<RegionProjection>> cache = getOrCreate(ignite); - - PrimitiveIterator.OfInt itr = regionIndexes.iterator(); - - int curBlockIdx = -1; - List<RegionProjection> block = null; - - Object affinityKey = affinity.apply(featureIdx); - - while (itr.hasNext()) { - int i = itr.nextInt(); - - int blockIdx = i / blockSize; - - if (blockIdx != curBlockIdx) { - block = cache.localPeek(key(featureIdx, blockIdx, affinityKey, trainingUUID)); - curBlockIdx = blockIdx; - } - - if (block == null) - throw new IllegalStateException("Unexpected null block at index " + i); - - RegionProjection reg = block.get(i % blockSize); - - if (reg.depth() < maxDepth) - regsForSearch.put(i, reg); - } - - return regsForSearch; - } - - /** - * Returns projections of regions on given feature filtered by maximal depth in the form of (region index -> region - * projection). - * - * @param featureIdx Feature index. - * @param maxDepth Maximal depth of the tree. - * @param regsCnt Count of regions. - * @param blockSize Size of regions blocks. - * @param affinity Affinity function. - * @param trainingUUID UUID of training. - * @param ignite Ignite instance. - * @return Projections of regions on given feature filtered by maximal depth in the form of (region index -> region - * projection). - */ - public static Map<Integer, RegionProjection> projectionsOfFeature(int featureIdx, int maxDepth, int regsCnt, - int blockSize, IgniteFunction<Integer, Object> affinity, UUID trainingUUID, Ignite ignite) { - return projectionsOfRegions(featureIdx, maxDepth, IntStream.range(0, regsCnt), blockSize, affinity, trainingUUID, ignite); - } - - /** - * Construct key for projections cache. - * - * @param featureIdx Feature index. - * @param regBlockIdx Region block index. - * @param parentColKey Column key of cache used as input for {@link ColumnDecisionTreeTrainer}. - * @param uuid UUID of training. - * @return Key for projections cache. - */ - public static RegionKey key(int featureIdx, int regBlockIdx, Object parentColKey, UUID uuid) { - return new RegionKey(featureIdx, regBlockIdx, parentColKey, uuid); - } - - /** - * Clear data from projections cache related to given training. - * - * @param featuresCnt Features count. - * @param regs Regions count. - * @param aff Affinity function. - * @param uuid UUID of training. - * @param ignite Ignite instance. - */ - public static void clear(int featuresCnt, int regs, IgniteBiFunction<Integer, Ignite, Object> aff, UUID uuid, - Ignite ignite) { - Set<RegionKey> toRmv = IntStream.range(0, featuresCnt).boxed(). - flatMap(fIdx -> IntStream.range(0, regs).boxed().map(reg -> new IgniteBiTuple<>(fIdx, reg))). - map(t -> key(t.get1(), t.get2(), aff.apply(t.get1(), ignite), uuid)). - collect(Collectors.toSet()); - - getOrCreate(ignite).removeAll(toRmv); - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/caches/SplitCache.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/caches/SplitCache.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/caches/SplitCache.java deleted file mode 100644 index ecbc861..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/caches/SplitCache.java +++ /dev/null @@ -1,206 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.trees.trainers.columnbased.caches; - -import java.util.Collection; -import java.util.Collections; -import java.util.Set; -import java.util.UUID; -import java.util.function.Function; -import java.util.stream.Collectors; -import java.util.stream.IntStream; -import javax.cache.Cache; -import org.apache.ignite.Ignite; -import org.apache.ignite.IgniteCache; -import org.apache.ignite.Ignition; -import org.apache.ignite.cache.CacheAtomicityMode; -import org.apache.ignite.cache.CacheMode; -import org.apache.ignite.cache.CacheWriteSynchronizationMode; -import org.apache.ignite.cache.affinity.Affinity; -import org.apache.ignite.cache.affinity.AffinityKeyMapped; -import org.apache.ignite.configuration.CacheConfiguration; -import org.apache.ignite.internal.processors.cache.CacheEntryImpl; -import org.apache.ignite.lang.IgniteBiTuple; -import org.apache.ignite.ml.math.functions.IgniteBiFunction; -import org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainer; - -/** - * Class for working with cache used for storing of best splits during training with {@link ColumnDecisionTreeTrainer}. - */ -public class SplitCache { - /** Name of splits cache. */ - public static final String CACHE_NAME = "COLUMN_DECISION_TREE_TRAINER_SPLIT_CACHE_NAME"; - - /** - * Class used for keys in the splits cache. - */ - public static class SplitKey { - /** UUID of current training. */ - private final UUID trainingUUID; - - /** Affinity key of input data. */ - @AffinityKeyMapped - private final Object parentColKey; - - /** Index of feature by which the split is made. */ - private final int featureIdx; - - /** - * Construct SplitKey. - * - * @param trainingUUID UUID of the training. - * @param parentColKey Affinity key used to ensure that cache entry for given feature will be on the same node - * as column with that feature in input. - * @param featureIdx Feature index. - */ - public SplitKey(UUID trainingUUID, Object parentColKey, int featureIdx) { - this.trainingUUID = trainingUUID; - this.featureIdx = featureIdx; - this.parentColKey = parentColKey; - } - - /** Get UUID of current training. */ - public UUID trainingUUID() { - return trainingUUID; - } - - /** - * Get feature index. - * - * @return Feature index. - */ - public int featureIdx() { - return featureIdx; - } - - /** {@inheritDoc} */ - @Override public boolean equals(Object o) { - if (this == o) - return true; - if (o == null || getClass() != o.getClass()) - return false; - - SplitKey splitKey = (SplitKey)o; - - if (featureIdx != splitKey.featureIdx) - return false; - return trainingUUID != null ? trainingUUID.equals(splitKey.trainingUUID) : splitKey.trainingUUID == null; - - } - - /** {@inheritDoc} */ - @Override public int hashCode() { - int res = trainingUUID != null ? trainingUUID.hashCode() : 0; - res = 31 * res + featureIdx; - return res; - } - } - - /** - * Construct the key for splits cache. - * - * @param featureIdx Feature index. - * @param parentColKey Affinity key used to ensure that cache entry for given feature will be on the same node as - * column with that feature in input. - * @param uuid UUID of current training. - * @return Key for splits cache. - */ - public static SplitKey key(int featureIdx, Object parentColKey, UUID uuid) { - return new SplitKey(uuid, parentColKey, featureIdx); - } - - /** - * Get or create splits cache. - * - * @param ignite Ignite instance. - * @return Splits cache. - */ - public static IgniteCache<SplitKey, IgniteBiTuple<Integer, Double>> getOrCreate(Ignite ignite) { - CacheConfiguration<SplitKey, IgniteBiTuple<Integer, Double>> cfg = new CacheConfiguration<>(); - - // Write to primary. - cfg.setWriteSynchronizationMode(CacheWriteSynchronizationMode.PRIMARY_SYNC); - - // Atomic transactions only. - cfg.setAtomicityMode(CacheAtomicityMode.ATOMIC); - - // No eviction. - cfg.setEvictionPolicy(null); - - // No copying of values. - cfg.setCopyOnRead(false); - - // Cache is partitioned. - cfg.setCacheMode(CacheMode.PARTITIONED); - - cfg.setBackups(0); - - cfg.setOnheapCacheEnabled(true); - - cfg.setName(CACHE_NAME); - - return ignite.getOrCreateCache(cfg); - } - - /** - * Affinity function used in splits cache. - * - * @return Affinity function used in splits cache. - */ - public static Affinity<SplitKey> affinity() { - return Ignition.localIgnite().affinity(CACHE_NAME); - } - - /** - * Returns local entries for keys corresponding to {@code featureIndexes}. - * - * @param featureIndexes Index of features. - * @param affinity Affinity function. - * @param trainingUUID UUID of training. - * @return local entries for keys corresponding to {@code featureIndexes}. - */ - public static Iterable<Cache.Entry<SplitKey, IgniteBiTuple<Integer, Double>>> localEntries( - Set<Integer> featureIndexes, - IgniteBiFunction<Integer, Ignite, Object> affinity, - UUID trainingUUID) { - Ignite ignite = Ignition.localIgnite(); - Set<SplitKey> keys = featureIndexes.stream().map(fIdx -> new SplitKey(trainingUUID, affinity.apply(fIdx, ignite), fIdx)).collect(Collectors.toSet()); - - Collection<SplitKey> locKeys = affinity().mapKeysToNodes(keys).getOrDefault(ignite.cluster().localNode(), Collections.emptyList()); - - return () -> { - Function<SplitKey, Cache.Entry<SplitKey, IgniteBiTuple<Integer, Double>>> f = k -> (new CacheEntryImpl<>(k, getOrCreate(ignite).localPeek(k))); - return locKeys.stream().map(f).iterator(); - }; - } - - /** - * Clears data related to current training from splits cache related to given training. - * - * @param featuresCnt Count of features. - * @param affinity Affinity function. - * @param uuid UUID of the given training. - * @param ignite Ignite instance. - */ - public static void clear(int featuresCnt, IgniteBiFunction<Integer, Ignite, Object> affinity, UUID uuid, - Ignite ignite) { - Set<SplitKey> toRmv = IntStream.range(0, featuresCnt).boxed().map(fIdx -> new SplitKey(uuid, affinity.apply(fIdx, ignite), fIdx)).collect(Collectors.toSet()); - - getOrCreate(ignite).removeAll(toRmv); - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/caches/package-info.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/caches/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/caches/package-info.java deleted file mode 100644 index 0a488ab..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/caches/package-info.java +++ /dev/null @@ -1,22 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/** - * <!-- Package description. --> - * Contains cache configurations for columnbased decision tree trainer with some related logic. - */ -package org.apache.ignite.ml.trees.trainers.columnbased.caches; \ No newline at end of file http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/contsplitcalcs/ContinuousSplitCalculators.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/contsplitcalcs/ContinuousSplitCalculators.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/contsplitcalcs/ContinuousSplitCalculators.java deleted file mode 100644 index 9fd4c66..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/contsplitcalcs/ContinuousSplitCalculators.java +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.trees.trainers.columnbased.contsplitcalcs; - -import org.apache.ignite.Ignite; -import org.apache.ignite.ml.math.functions.IgniteCurriedBiFunction; -import org.apache.ignite.ml.math.functions.IgniteFunction; -import org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainerInput; - -/** Continuous Split Calculators. */ -public class ContinuousSplitCalculators { - /** Variance split calculator. */ - public static IgniteFunction<ColumnDecisionTreeTrainerInput, VarianceSplitCalculator> VARIANCE = input -> - new VarianceSplitCalculator(); - - /** Gini split calculator. */ - public static IgniteCurriedBiFunction<Ignite, ColumnDecisionTreeTrainerInput, GiniSplitCalculator> GINI = ignite -> - input -> new GiniSplitCalculator(input.labels(ignite)); -} http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/contsplitcalcs/GiniSplitCalculator.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/contsplitcalcs/GiniSplitCalculator.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/contsplitcalcs/GiniSplitCalculator.java deleted file mode 100644 index 259c84c..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/contsplitcalcs/GiniSplitCalculator.java +++ /dev/null @@ -1,234 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.trees.trainers.columnbased.contsplitcalcs; - -import it.unimi.dsi.fastutil.doubles.Double2IntArrayMap; -import java.io.IOException; -import java.io.ObjectInput; -import java.io.ObjectOutput; -import java.util.HashMap; -import java.util.Map; -import java.util.PrimitiveIterator; -import java.util.stream.DoubleStream; -import org.apache.ignite.ml.trees.ContinuousRegionInfo; -import org.apache.ignite.ml.trees.ContinuousSplitCalculator; -import org.apache.ignite.ml.trees.trainers.columnbased.vectors.ContinuousSplitInfo; -import org.apache.ignite.ml.trees.trainers.columnbased.vectors.SplitInfo; - -/** - * Calculator for Gini impurity. - */ -public class GiniSplitCalculator implements ContinuousSplitCalculator<GiniSplitCalculator.GiniData> { - /** Mapping assigning index to each member value */ - private final Map<Double, Integer> mapping = new Double2IntArrayMap(); - - /** - * Create Gini split calculator from labels. - * - * @param labels Labels. - */ - public GiniSplitCalculator(double[] labels) { - int i = 0; - - for (double label : labels) { - if (!mapping.containsKey(label)) { - mapping.put(label, i); - i++; - } - } - } - - /** {@inheritDoc} */ - @Override public GiniData calculateRegionInfo(DoubleStream s, int l) { - PrimitiveIterator.OfDouble itr = s.iterator(); - - Map<Double, Integer> m = new HashMap<>(); - - int size = 0; - - while (itr.hasNext()) { - size++; - m.compute(itr.next(), (a, i) -> i != null ? i + 1 : 1); - } - - double c2 = m.values().stream().mapToDouble(v -> v * v).sum(); - - int[] cnts = new int[mapping.size()]; - - m.forEach((key, value) -> cnts[mapping.get(key)] = value); - - return new GiniData(size != 0 ? 1 - c2 / (size * size) : 0.0, size, cnts, c2); - } - - /** {@inheritDoc} */ - @Override public SplitInfo<GiniData> splitRegion(Integer[] s, double[] values, double[] labels, int regionIdx, - GiniData d) { - int size = d.getSize(); - - double lg = 0.0; - double rg = d.impurity(); - - double lc2 = 0.0; - double rc2 = d.c2; - int lSize = 0; - - double minImpurity = d.impurity() * size; - double curThreshold; - double curImpurity; - double threshold = Double.NEGATIVE_INFINITY; - - int i = 0; - int nextIdx = s[0]; - i++; - double[] lrImps = new double[] {0.0, d.impurity(), lc2, rc2}; - - int[] lMapCur = new int[d.counts().length]; - int[] rMapCur = new int[d.counts().length]; - - System.arraycopy(d.counts(), 0, rMapCur, 0, d.counts().length); - - int[] lMap = new int[d.counts().length]; - int[] rMap = new int[d.counts().length]; - - System.arraycopy(d.counts(), 0, rMap, 0, d.counts().length); - - do { - // Process all values equal to prev. - while (i < s.length) { - moveLeft(labels[nextIdx], i, size - i, lMapCur, rMapCur, lrImps); - curImpurity = (i * lrImps[0] + (size - i) * lrImps[1]); - curThreshold = values[nextIdx]; - - if (values[nextIdx] != values[(nextIdx = s[i++])]) { - if (curImpurity < minImpurity) { - lSize = i - 1; - - lg = lrImps[0]; - rg = lrImps[1]; - - lc2 = lrImps[2]; - rc2 = lrImps[3]; - - System.arraycopy(lMapCur, 0, lMap, 0, lMapCur.length); - System.arraycopy(rMapCur, 0, rMap, 0, rMapCur.length); - - minImpurity = curImpurity; - threshold = curThreshold; - } - - break; - } - } - } - while (i < s.length - 1); - - if (lSize == size || lSize == 0) - return null; - - GiniData lData = new GiniData(lg, lSize, lMap, lc2); - int rSize = size - lSize; - GiniData rData = new GiniData(rg, rSize, rMap, rc2); - - return new ContinuousSplitInfo<>(regionIdx, threshold, lData, rData); - } - - /** - * Add point to the left interval and remove it from the right interval and calculate necessary statistics on - * intervals with new bounds. - */ - private void moveLeft(double x, int lSize, int rSize, int[] lMap, int[] rMap, double[] data) { - double lc2 = data[2]; - double rc2 = data[3]; - - Integer idx = mapping.get(x); - - int cxl = lMap[idx]; - int cxr = rMap[idx]; - - lc2 += 2 * cxl + 1; - rc2 -= 2 * cxr - 1; - - lMap[idx] += 1; - rMap[idx] -= 1; - - data[0] = 1 - lc2 / (lSize * lSize); - data[1] = 1 - rc2 / (rSize * rSize); - - data[2] = lc2; - data[3] = rc2; - } - - /** - * Data used for gini impurity calculations. - */ - public static class GiniData extends ContinuousRegionInfo { - /** Sum of squares of counts of each label. */ - private double c2; - - /** Counts of each label. On i-th position there is count of label which is mapped to index i. */ - private int[] m; - - /** - * Create Gini data. - * - * @param impurity Impurity (i.e. Gini impurity). - * @param size Count of samples. - * @param m Counts of each label. - * @param c2 Sum of squares of counts of each label. - */ - public GiniData(double impurity, int size, int[] m, double c2) { - super(impurity, size); - this.m = m; - this.c2 = c2; - } - - /** - * No-op constructor for serialization/deserialization.. - */ - public GiniData() { - // No-op. - } - - /** Get counts of each label. */ - public int[] counts() { - return m; - } - - /** {@inheritDoc} */ - @Override public void writeExternal(ObjectOutput out) throws IOException { - super.writeExternal(out); - out.writeDouble(c2); - out.writeInt(m.length); - for (int i : m) - out.writeInt(i); - - } - - /** {@inheritDoc} */ - @Override public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { - super.readExternal(in); - - c2 = in.readDouble(); - int size = in.readInt(); - m = new int[size]; - - for (int i = 0; i < size; i++) - m[i] = in.readInt(); - } - } -}