http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/caches/FeaturesCache.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/caches/FeaturesCache.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/caches/FeaturesCache.java new file mode 100644 index 0000000..fcc1f16 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/caches/FeaturesCache.java @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.trees.trainers.columnbased.caches; + +import java.util.Set; +import java.util.UUID; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import org.apache.ignite.Ignite; +import org.apache.ignite.IgniteCache; +import org.apache.ignite.cache.CacheAtomicityMode; +import org.apache.ignite.cache.CacheMode; +import org.apache.ignite.cache.CacheWriteSynchronizationMode; +import org.apache.ignite.cache.affinity.AffinityKeyMapped; +import org.apache.ignite.configuration.CacheConfiguration; +import org.apache.ignite.ml.math.functions.IgniteBiFunction; +import org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainer; + +/** + * Cache storing features for {@link ColumnDecisionTreeTrainer}. + */ +public class FeaturesCache { + /** + * Name of cache which is used for storing features for {@link ColumnDecisionTreeTrainer}. + */ + public static final String COLUMN_DECISION_TREE_TRAINER_FEATURES_CACHE_NAME = "COLUMN_DECISION_TREE_TRAINER_FEATURES_CACHE_NAME"; + + /** + * Key of features cache. + */ + public static class FeatureKey { + /** Column key of cache used as input for {@link ColumnDecisionTreeTrainer}. */ + @AffinityKeyMapped + private Object parentColKey; + + /** Index of feature. */ + private final int featureIdx; + + /** UUID of training. */ + private final UUID trainingUUID; + + /** + * Construct FeatureKey. + * + * @param featureIdx Feature index. + * @param trainingUUID UUID of training. + * @param parentColKey Column key of cache used as input. + */ + public FeatureKey(int featureIdx, UUID trainingUUID, Object parentColKey) { + this.parentColKey = parentColKey; + this.featureIdx = featureIdx; + this.trainingUUID = trainingUUID; + this.parentColKey = parentColKey; + } + + /** {@inheritDoc} */ + @Override public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + + FeatureKey key = (FeatureKey)o; + + if (featureIdx != key.featureIdx) + return false; + return trainingUUID != null ? trainingUUID.equals(key.trainingUUID) : key.trainingUUID == null; + } + + /** {@inheritDoc} */ + @Override public int hashCode() { + int res = trainingUUID != null ? trainingUUID.hashCode() : 0; + res = 31 * res + featureIdx; + return res; + } + } + + /** + * Create new projections cache for ColumnDecisionTreeTrainer if needed. + * + * @param ignite Ignite instance. + */ + public static IgniteCache<FeatureKey, double[]> getOrCreate(Ignite ignite) { + CacheConfiguration<FeatureKey, double[]> cfg = new CacheConfiguration<>(); + + // Write to primary. + cfg.setWriteSynchronizationMode(CacheWriteSynchronizationMode.PRIMARY_SYNC); + + // Atomic transactions only. + cfg.setAtomicityMode(CacheAtomicityMode.ATOMIC); + + // No eviction. + cfg.setEvictionPolicy(null); + + // No copying of values. + cfg.setCopyOnRead(false); + + // Cache is partitioned. + cfg.setCacheMode(CacheMode.PARTITIONED); + + cfg.setOnheapCacheEnabled(true); + + cfg.setBackups(0); + + cfg.setName(COLUMN_DECISION_TREE_TRAINER_FEATURES_CACHE_NAME); + + return ignite.getOrCreateCache(cfg); + } + + /** + * Construct FeatureKey from index, uuid and affinity key. + * + * @param idx Feature index. + * @param uuid UUID of training. + * @param aff Affinity key. + * @return FeatureKey. + */ + public static FeatureKey getFeatureCacheKey(int idx, UUID uuid, Object aff) { + return new FeatureKey(idx, uuid, aff); + } + + /** + * Clear all data from features cache related to given training. + * + * @param featuresCnt Count of features. + * @param affinity Affinity function. + * @param uuid Training uuid. + * @param ignite Ignite instance. + */ + public static void clear(int featuresCnt, IgniteBiFunction<Integer, Ignite, Object> affinity, UUID uuid, + Ignite ignite) { + Set<FeatureKey> toRmv = IntStream.range(0, featuresCnt).boxed().map(fIdx -> getFeatureCacheKey(fIdx, uuid, affinity.apply(fIdx, ignite))).collect(Collectors.toSet()); + + getOrCreate(ignite).removeAll(toRmv); + } +}
http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/caches/ProjectionsCache.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/caches/ProjectionsCache.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/caches/ProjectionsCache.java new file mode 100644 index 0000000..29cf6b4 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/caches/ProjectionsCache.java @@ -0,0 +1,284 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.trees.trainers.columnbased.caches; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.PrimitiveIterator; +import java.util.Set; +import java.util.UUID; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import org.apache.ignite.Ignite; +import org.apache.ignite.IgniteCache; +import org.apache.ignite.Ignition; +import org.apache.ignite.cache.CacheAtomicityMode; +import org.apache.ignite.cache.CacheMode; +import org.apache.ignite.cache.CacheWriteSynchronizationMode; +import org.apache.ignite.cache.affinity.Affinity; +import org.apache.ignite.cache.affinity.AffinityKeyMapped; +import org.apache.ignite.configuration.CacheConfiguration; +import org.apache.ignite.lang.IgniteBiTuple; +import org.apache.ignite.ml.math.functions.IgniteBiFunction; +import org.apache.ignite.ml.math.functions.IgniteFunction; +import org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainer; +import org.apache.ignite.ml.trees.trainers.columnbased.RegionProjection; + +/** + * Cache used for storing data of region projections on features. + */ +public class ProjectionsCache { + /** + * Name of cache which is used for storing data of region projections on features of {@link + * ColumnDecisionTreeTrainer}. + */ + public static final String CACHE_NAME = "COLUMN_DECISION_TREE_TRAINER_PROJECTIONS_CACHE_NAME"; + + /** + * Key of region projections cache. + */ + public static class RegionKey { + /** Column key of cache used as input for {@link ColumnDecisionTreeTrainer}. */ + @AffinityKeyMapped + private final Object parentColKey; + + /** Feature index. */ + private final int featureIdx; + + /** Region index. */ + private final int regBlockIdx; + + /** Training UUID. */ + private final UUID trainingUUID; + + /** + * Construct a RegionKey from feature index, index of block, key of column in input cache and UUID of training. + * + * @param featureIdx Feature index. + * @param regBlockIdx Index of block. + * @param parentColKey Key of column in input cache. + * @param trainingUUID UUID of training. + */ + public RegionKey(int featureIdx, int regBlockIdx, Object parentColKey, UUID trainingUUID) { + this.featureIdx = featureIdx; + this.regBlockIdx = regBlockIdx; + this.trainingUUID = trainingUUID; + this.parentColKey = parentColKey; + } + + /** + * Feature index. + * + * @return Feature index. + */ + public int featureIdx() { + return featureIdx; + } + + /** + * Region block index. + * + * @return Region block index. + */ + public int regionBlockIndex() { + return regBlockIdx; + } + + /** + * UUID of training. + * + * @return UUID of training. + */ + public UUID trainingUUID() { + return trainingUUID; + } + + /** {@inheritDoc} */ + @Override public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + + RegionKey key = (RegionKey)o; + + if (featureIdx != key.featureIdx) + return false; + if (regBlockIdx != key.regBlockIdx) + return false; + return trainingUUID != null ? trainingUUID.equals(key.trainingUUID) : key.trainingUUID == null; + } + + /** {@inheritDoc} */ + @Override public int hashCode() { + int res = trainingUUID != null ? trainingUUID.hashCode() : 0; + res = 31 * res + featureIdx; + res = 31 * res + regBlockIdx; + return res; + } + + /** {@inheritDoc} */ + @Override public String toString() { + return "RegionKey [" + + "parentColKey=" + parentColKey + + ", featureIdx=" + featureIdx + + ", regBlockIdx=" + regBlockIdx + + ", trainingUUID=" + trainingUUID + + ']'; + } + } + + /** + * Affinity service for region projections cache. + * + * @return Affinity service for region projections cache. + */ + public static Affinity<RegionKey> affinity() { + return Ignition.localIgnite().affinity(CACHE_NAME); + } + + /** + * Get or create region projections cache. + * + * @param ignite Ignite instance. + * @return Region projections cache. + */ + public static IgniteCache<RegionKey, List<RegionProjection>> getOrCreate(Ignite ignite) { + CacheConfiguration<RegionKey, List<RegionProjection>> cfg = new CacheConfiguration<>(); + + // Write to primary. + cfg.setWriteSynchronizationMode(CacheWriteSynchronizationMode.PRIMARY_SYNC); + + // Atomic transactions only. + cfg.setAtomicityMode(CacheAtomicityMode.ATOMIC); + + // No eviction. + cfg.setEvictionPolicy(null); + + // No copying of values. + cfg.setCopyOnRead(false); + + // Cache is partitioned. + cfg.setCacheMode(CacheMode.PARTITIONED); + + cfg.setBackups(0); + + cfg.setOnheapCacheEnabled(true); + + cfg.setName(CACHE_NAME); + + return ignite.getOrCreateCache(cfg); + } + + /** + * Get region projections in the form of map (regionIndex -> regionProjections). + * + * @param featureIdx Feature index. + * @param maxDepth Max depth of decision tree. + * @param regionIndexes Indexes of regions for which we want get projections. + * @param blockSize Size of regions block. + * @param affinity Affinity function. + * @param trainingUUID UUID of training. + * @param ignite Ignite instance. + * @return Region projections in the form of map (regionIndex -> regionProjections). + */ + public static Map<Integer, RegionProjection> projectionsOfRegions(int featureIdx, int maxDepth, + IntStream regionIndexes, int blockSize, IgniteFunction<Integer, Object> affinity, UUID trainingUUID, + Ignite ignite) { + HashMap<Integer, RegionProjection> regsForSearch = new HashMap<>(); + IgniteCache<RegionKey, List<RegionProjection>> cache = getOrCreate(ignite); + + PrimitiveIterator.OfInt itr = regionIndexes.iterator(); + + int curBlockIdx = -1; + List<RegionProjection> block = null; + + Object affinityKey = affinity.apply(featureIdx); + + while (itr.hasNext()) { + int i = itr.nextInt(); + + int blockIdx = i / blockSize; + + if (blockIdx != curBlockIdx) { + block = cache.localPeek(key(featureIdx, blockIdx, affinityKey, trainingUUID)); + curBlockIdx = blockIdx; + } + + if (block == null) + throw new IllegalStateException("Unexpected null block at index " + i); + + RegionProjection reg = block.get(i % blockSize); + + if (reg.depth() < maxDepth) + regsForSearch.put(i, reg); + } + + return regsForSearch; + } + + /** + * Returns projections of regions on given feature filtered by maximal depth in the form of (region index -> region projection). + * + * @param featureIdx Feature index. + * @param maxDepth Maximal depth of the tree. + * @param regsCnt Count of regions. + * @param blockSize Size of regions blocks. + * @param affinity Affinity function. + * @param trainingUUID UUID of training. + * @param ignite Ignite instance. + * @return Projections of regions on given feature filtered by maximal depth in the form of (region index -> region projection). + */ + public static Map<Integer, RegionProjection> projectionsOfFeature(int featureIdx, int maxDepth, int regsCnt, + int blockSize, IgniteFunction<Integer, Object> affinity, UUID trainingUUID, Ignite ignite) { + return projectionsOfRegions(featureIdx, maxDepth, IntStream.range(0, regsCnt), blockSize, affinity, trainingUUID, ignite); + } + + /** + * Construct key for projections cache. + * + * @param featureIdx Feature index. + * @param regBlockIdx Region block index. + * @param parentColKey Column key of cache used as input for {@link ColumnDecisionTreeTrainer}. + * @param uuid UUID of training. + * @return Key for projections cache. + */ + public static RegionKey key(int featureIdx, int regBlockIdx, Object parentColKey, UUID uuid) { + return new RegionKey(featureIdx, regBlockIdx, parentColKey, uuid); + } + + /** + * Clear data from projections cache related to given training. + * + * @param featuresCnt Features count. + * @param regs Regions count. + * @param aff Affinity function. + * @param uuid UUID of training. + * @param ignite Ignite instance. + */ + public static void clear(int featuresCnt, int regs, IgniteBiFunction<Integer, Ignite, Object> aff, UUID uuid, + Ignite ignite) { + Set<RegionKey> toRmv = IntStream.range(0, featuresCnt).boxed(). + flatMap(fIdx -> IntStream.range(0, regs).boxed().map(reg -> new IgniteBiTuple<>(fIdx, reg))). + map(t -> key(t.get1(), t.get2(), aff.apply(t.get1(), ignite), uuid)). + collect(Collectors.toSet()); + + getOrCreate(ignite).removeAll(toRmv); + } +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/caches/SplitCache.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/caches/SplitCache.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/caches/SplitCache.java new file mode 100644 index 0000000..ecbc861 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/caches/SplitCache.java @@ -0,0 +1,206 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.trees.trainers.columnbased.caches; + +import java.util.Collection; +import java.util.Collections; +import java.util.Set; +import java.util.UUID; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import javax.cache.Cache; +import org.apache.ignite.Ignite; +import org.apache.ignite.IgniteCache; +import org.apache.ignite.Ignition; +import org.apache.ignite.cache.CacheAtomicityMode; +import org.apache.ignite.cache.CacheMode; +import org.apache.ignite.cache.CacheWriteSynchronizationMode; +import org.apache.ignite.cache.affinity.Affinity; +import org.apache.ignite.cache.affinity.AffinityKeyMapped; +import org.apache.ignite.configuration.CacheConfiguration; +import org.apache.ignite.internal.processors.cache.CacheEntryImpl; +import org.apache.ignite.lang.IgniteBiTuple; +import org.apache.ignite.ml.math.functions.IgniteBiFunction; +import org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainer; + +/** + * Class for working with cache used for storing of best splits during training with {@link ColumnDecisionTreeTrainer}. + */ +public class SplitCache { + /** Name of splits cache. */ + public static final String CACHE_NAME = "COLUMN_DECISION_TREE_TRAINER_SPLIT_CACHE_NAME"; + + /** + * Class used for keys in the splits cache. + */ + public static class SplitKey { + /** UUID of current training. */ + private final UUID trainingUUID; + + /** Affinity key of input data. */ + @AffinityKeyMapped + private final Object parentColKey; + + /** Index of feature by which the split is made. */ + private final int featureIdx; + + /** + * Construct SplitKey. + * + * @param trainingUUID UUID of the training. + * @param parentColKey Affinity key used to ensure that cache entry for given feature will be on the same node + * as column with that feature in input. + * @param featureIdx Feature index. + */ + public SplitKey(UUID trainingUUID, Object parentColKey, int featureIdx) { + this.trainingUUID = trainingUUID; + this.featureIdx = featureIdx; + this.parentColKey = parentColKey; + } + + /** Get UUID of current training. */ + public UUID trainingUUID() { + return trainingUUID; + } + + /** + * Get feature index. + * + * @return Feature index. + */ + public int featureIdx() { + return featureIdx; + } + + /** {@inheritDoc} */ + @Override public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + + SplitKey splitKey = (SplitKey)o; + + if (featureIdx != splitKey.featureIdx) + return false; + return trainingUUID != null ? trainingUUID.equals(splitKey.trainingUUID) : splitKey.trainingUUID == null; + + } + + /** {@inheritDoc} */ + @Override public int hashCode() { + int res = trainingUUID != null ? trainingUUID.hashCode() : 0; + res = 31 * res + featureIdx; + return res; + } + } + + /** + * Construct the key for splits cache. + * + * @param featureIdx Feature index. + * @param parentColKey Affinity key used to ensure that cache entry for given feature will be on the same node as + * column with that feature in input. + * @param uuid UUID of current training. + * @return Key for splits cache. + */ + public static SplitKey key(int featureIdx, Object parentColKey, UUID uuid) { + return new SplitKey(uuid, parentColKey, featureIdx); + } + + /** + * Get or create splits cache. + * + * @param ignite Ignite instance. + * @return Splits cache. + */ + public static IgniteCache<SplitKey, IgniteBiTuple<Integer, Double>> getOrCreate(Ignite ignite) { + CacheConfiguration<SplitKey, IgniteBiTuple<Integer, Double>> cfg = new CacheConfiguration<>(); + + // Write to primary. + cfg.setWriteSynchronizationMode(CacheWriteSynchronizationMode.PRIMARY_SYNC); + + // Atomic transactions only. + cfg.setAtomicityMode(CacheAtomicityMode.ATOMIC); + + // No eviction. + cfg.setEvictionPolicy(null); + + // No copying of values. + cfg.setCopyOnRead(false); + + // Cache is partitioned. + cfg.setCacheMode(CacheMode.PARTITIONED); + + cfg.setBackups(0); + + cfg.setOnheapCacheEnabled(true); + + cfg.setName(CACHE_NAME); + + return ignite.getOrCreateCache(cfg); + } + + /** + * Affinity function used in splits cache. + * + * @return Affinity function used in splits cache. + */ + public static Affinity<SplitKey> affinity() { + return Ignition.localIgnite().affinity(CACHE_NAME); + } + + /** + * Returns local entries for keys corresponding to {@code featureIndexes}. + * + * @param featureIndexes Index of features. + * @param affinity Affinity function. + * @param trainingUUID UUID of training. + * @return local entries for keys corresponding to {@code featureIndexes}. + */ + public static Iterable<Cache.Entry<SplitKey, IgniteBiTuple<Integer, Double>>> localEntries( + Set<Integer> featureIndexes, + IgniteBiFunction<Integer, Ignite, Object> affinity, + UUID trainingUUID) { + Ignite ignite = Ignition.localIgnite(); + Set<SplitKey> keys = featureIndexes.stream().map(fIdx -> new SplitKey(trainingUUID, affinity.apply(fIdx, ignite), fIdx)).collect(Collectors.toSet()); + + Collection<SplitKey> locKeys = affinity().mapKeysToNodes(keys).getOrDefault(ignite.cluster().localNode(), Collections.emptyList()); + + return () -> { + Function<SplitKey, Cache.Entry<SplitKey, IgniteBiTuple<Integer, Double>>> f = k -> (new CacheEntryImpl<>(k, getOrCreate(ignite).localPeek(k))); + return locKeys.stream().map(f).iterator(); + }; + } + + /** + * Clears data related to current training from splits cache related to given training. + * + * @param featuresCnt Count of features. + * @param affinity Affinity function. + * @param uuid UUID of the given training. + * @param ignite Ignite instance. + */ + public static void clear(int featuresCnt, IgniteBiFunction<Integer, Ignite, Object> affinity, UUID uuid, + Ignite ignite) { + Set<SplitKey> toRmv = IntStream.range(0, featuresCnt).boxed().map(fIdx -> new SplitKey(uuid, affinity.apply(fIdx, ignite), fIdx)).collect(Collectors.toSet()); + + getOrCreate(ignite).removeAll(toRmv); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/contsplitcalcs/ContinuousSplitCalculators.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/contsplitcalcs/ContinuousSplitCalculators.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/contsplitcalcs/ContinuousSplitCalculators.java new file mode 100644 index 0000000..9fd4c66 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/contsplitcalcs/ContinuousSplitCalculators.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.trees.trainers.columnbased.contsplitcalcs; + +import org.apache.ignite.Ignite; +import org.apache.ignite.ml.math.functions.IgniteCurriedBiFunction; +import org.apache.ignite.ml.math.functions.IgniteFunction; +import org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainerInput; + +/** Continuous Split Calculators. */ +public class ContinuousSplitCalculators { + /** Variance split calculator. */ + public static IgniteFunction<ColumnDecisionTreeTrainerInput, VarianceSplitCalculator> VARIANCE = input -> + new VarianceSplitCalculator(); + + /** Gini split calculator. */ + public static IgniteCurriedBiFunction<Ignite, ColumnDecisionTreeTrainerInput, GiniSplitCalculator> GINI = ignite -> + input -> new GiniSplitCalculator(input.labels(ignite)); +} http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/contsplitcalcs/GiniSplitCalculator.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/contsplitcalcs/GiniSplitCalculator.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/contsplitcalcs/GiniSplitCalculator.java new file mode 100644 index 0000000..259c84c --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/contsplitcalcs/GiniSplitCalculator.java @@ -0,0 +1,234 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.trees.trainers.columnbased.contsplitcalcs; + +import it.unimi.dsi.fastutil.doubles.Double2IntArrayMap; +import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; +import java.util.HashMap; +import java.util.Map; +import java.util.PrimitiveIterator; +import java.util.stream.DoubleStream; +import org.apache.ignite.ml.trees.ContinuousRegionInfo; +import org.apache.ignite.ml.trees.ContinuousSplitCalculator; +import org.apache.ignite.ml.trees.trainers.columnbased.vectors.ContinuousSplitInfo; +import org.apache.ignite.ml.trees.trainers.columnbased.vectors.SplitInfo; + +/** + * Calculator for Gini impurity. + */ +public class GiniSplitCalculator implements ContinuousSplitCalculator<GiniSplitCalculator.GiniData> { + /** Mapping assigning index to each member value */ + private final Map<Double, Integer> mapping = new Double2IntArrayMap(); + + /** + * Create Gini split calculator from labels. + * + * @param labels Labels. + */ + public GiniSplitCalculator(double[] labels) { + int i = 0; + + for (double label : labels) { + if (!mapping.containsKey(label)) { + mapping.put(label, i); + i++; + } + } + } + + /** {@inheritDoc} */ + @Override public GiniData calculateRegionInfo(DoubleStream s, int l) { + PrimitiveIterator.OfDouble itr = s.iterator(); + + Map<Double, Integer> m = new HashMap<>(); + + int size = 0; + + while (itr.hasNext()) { + size++; + m.compute(itr.next(), (a, i) -> i != null ? i + 1 : 1); + } + + double c2 = m.values().stream().mapToDouble(v -> v * v).sum(); + + int[] cnts = new int[mapping.size()]; + + m.forEach((key, value) -> cnts[mapping.get(key)] = value); + + return new GiniData(size != 0 ? 1 - c2 / (size * size) : 0.0, size, cnts, c2); + } + + /** {@inheritDoc} */ + @Override public SplitInfo<GiniData> splitRegion(Integer[] s, double[] values, double[] labels, int regionIdx, + GiniData d) { + int size = d.getSize(); + + double lg = 0.0; + double rg = d.impurity(); + + double lc2 = 0.0; + double rc2 = d.c2; + int lSize = 0; + + double minImpurity = d.impurity() * size; + double curThreshold; + double curImpurity; + double threshold = Double.NEGATIVE_INFINITY; + + int i = 0; + int nextIdx = s[0]; + i++; + double[] lrImps = new double[] {0.0, d.impurity(), lc2, rc2}; + + int[] lMapCur = new int[d.counts().length]; + int[] rMapCur = new int[d.counts().length]; + + System.arraycopy(d.counts(), 0, rMapCur, 0, d.counts().length); + + int[] lMap = new int[d.counts().length]; + int[] rMap = new int[d.counts().length]; + + System.arraycopy(d.counts(), 0, rMap, 0, d.counts().length); + + do { + // Process all values equal to prev. + while (i < s.length) { + moveLeft(labels[nextIdx], i, size - i, lMapCur, rMapCur, lrImps); + curImpurity = (i * lrImps[0] + (size - i) * lrImps[1]); + curThreshold = values[nextIdx]; + + if (values[nextIdx] != values[(nextIdx = s[i++])]) { + if (curImpurity < minImpurity) { + lSize = i - 1; + + lg = lrImps[0]; + rg = lrImps[1]; + + lc2 = lrImps[2]; + rc2 = lrImps[3]; + + System.arraycopy(lMapCur, 0, lMap, 0, lMapCur.length); + System.arraycopy(rMapCur, 0, rMap, 0, rMapCur.length); + + minImpurity = curImpurity; + threshold = curThreshold; + } + + break; + } + } + } + while (i < s.length - 1); + + if (lSize == size || lSize == 0) + return null; + + GiniData lData = new GiniData(lg, lSize, lMap, lc2); + int rSize = size - lSize; + GiniData rData = new GiniData(rg, rSize, rMap, rc2); + + return new ContinuousSplitInfo<>(regionIdx, threshold, lData, rData); + } + + /** + * Add point to the left interval and remove it from the right interval and calculate necessary statistics on + * intervals with new bounds. + */ + private void moveLeft(double x, int lSize, int rSize, int[] lMap, int[] rMap, double[] data) { + double lc2 = data[2]; + double rc2 = data[3]; + + Integer idx = mapping.get(x); + + int cxl = lMap[idx]; + int cxr = rMap[idx]; + + lc2 += 2 * cxl + 1; + rc2 -= 2 * cxr - 1; + + lMap[idx] += 1; + rMap[idx] -= 1; + + data[0] = 1 - lc2 / (lSize * lSize); + data[1] = 1 - rc2 / (rSize * rSize); + + data[2] = lc2; + data[3] = rc2; + } + + /** + * Data used for gini impurity calculations. + */ + public static class GiniData extends ContinuousRegionInfo { + /** Sum of squares of counts of each label. */ + private double c2; + + /** Counts of each label. On i-th position there is count of label which is mapped to index i. */ + private int[] m; + + /** + * Create Gini data. + * + * @param impurity Impurity (i.e. Gini impurity). + * @param size Count of samples. + * @param m Counts of each label. + * @param c2 Sum of squares of counts of each label. + */ + public GiniData(double impurity, int size, int[] m, double c2) { + super(impurity, size); + this.m = m; + this.c2 = c2; + } + + /** + * No-op constructor for serialization/deserialization.. + */ + public GiniData() { + // No-op. + } + + /** Get counts of each label. */ + public int[] counts() { + return m; + } + + /** {@inheritDoc} */ + @Override public void writeExternal(ObjectOutput out) throws IOException { + super.writeExternal(out); + out.writeDouble(c2); + out.writeInt(m.length); + for (int i : m) + out.writeInt(i); + + } + + /** {@inheritDoc} */ + @Override public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + super.readExternal(in); + + c2 = in.readDouble(); + int size = in.readInt(); + m = new int[size]; + + for (int i = 0; i < size; i++) + m[i] = in.readInt(); + } + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/contsplitcalcs/VarianceSplitCalculator.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/contsplitcalcs/VarianceSplitCalculator.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/contsplitcalcs/VarianceSplitCalculator.java new file mode 100644 index 0000000..66c54f2 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/contsplitcalcs/VarianceSplitCalculator.java @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.trees.trainers.columnbased.contsplitcalcs; + +import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; +import java.util.PrimitiveIterator; +import java.util.stream.DoubleStream; +import org.apache.ignite.ml.trees.ContinuousRegionInfo; +import org.apache.ignite.ml.trees.ContinuousSplitCalculator; +import org.apache.ignite.ml.trees.trainers.columnbased.vectors.ContinuousSplitInfo; +import org.apache.ignite.ml.trees.trainers.columnbased.vectors.SplitInfo; + +/** + * Calculator of variance in a given region. + */ +public class VarianceSplitCalculator implements ContinuousSplitCalculator<VarianceSplitCalculator.VarianceData> { + /** + * Data used in variance calculations. + */ + public static class VarianceData extends ContinuousRegionInfo { + /** Mean value in a given region. */ + double mean; + + /** + * @param var Variance in this region. + * @param size Size of data for which variance is calculated. + * @param mean Mean value in this region. + */ + public VarianceData(double var, int size, double mean) { + super(var, size); + this.mean = mean; + } + + /** + * No-op constructor. For serialization/deserialization. + */ + public VarianceData() { + // No-op. + } + + /** {@inheritDoc} */ + @Override public void writeExternal(ObjectOutput out) throws IOException { + super.writeExternal(out); + out.writeDouble(mean); + } + + /** {@inheritDoc} */ + @Override public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + super.readExternal(in); + mean = in.readDouble(); + } + + /** + * Returns mean. + */ + public double mean() { + return mean; + } + } + + /** {@inheritDoc} */ + @Override public VarianceData calculateRegionInfo(DoubleStream s, int size) { + PrimitiveIterator.OfDouble itr = s.iterator(); + int i = 0; + + double mean = 0.0; + double m2 = 0.0; + + // Here we calculate variance and mean by incremental computation. + while (itr.hasNext()) { + i++; + double x = itr.next(); + double delta = x - mean; + mean += delta / i; + double delta2 = x - mean; + m2 += delta * delta2; + } + + return new VarianceData(m2 / i, size, mean); + } + + /** {@inheritDoc} */ + @Override public SplitInfo<VarianceData> splitRegion(Integer[] s, double[] values, double[] labels, int regionIdx, + VarianceData d) { + int size = d.getSize(); + + double lm2 = 0.0; + double rm2 = d.impurity() * size; + int lSize = size; + + double lMean = 0.0; + double rMean = d.mean; + + double minImpurity = d.impurity() * size; + double curThreshold; + double curImpurity; + double threshold = Double.NEGATIVE_INFINITY; + + int i = 0; + int nextIdx = s[0]; + i++; + double[] lrImps = new double[] {lm2, rm2, lMean, rMean}; + + do { + // Process all values equal to prev. + while (i < s.length) { + moveLeft(labels[nextIdx], lrImps[2], i, lrImps[0], lrImps[3], size - i, lrImps[1], lrImps); + curImpurity = (lrImps[0] + lrImps[1]); + curThreshold = values[nextIdx]; + + if (values[nextIdx] != values[(nextIdx = s[i++])]) { + if (curImpurity < minImpurity) { + lSize = i - 1; + + lm2 = lrImps[0]; + rm2 = lrImps[1]; + + lMean = lrImps[2]; + rMean = lrImps[3]; + + minImpurity = curImpurity; + threshold = curThreshold; + } + + break; + } + } + } + while (i < s.length - 1); + + if (lSize == size) + return null; + + VarianceData lData = new VarianceData(lm2 / (lSize != 0 ? lSize : 1), lSize, lMean); + int rSize = size - lSize; + VarianceData rData = new VarianceData(rm2 / (rSize != 0 ? rSize : 1), rSize, rMean); + + return new ContinuousSplitInfo<>(regionIdx, threshold, lData, rData); + } + + /** + * Add point to the left interval and remove it from the right interval and calculate necessary statistics on + * intervals with new bounds. + */ + private void moveLeft(double x, double lMean, int lSize, double lm2, double rMean, int rSize, double rm2, + double[] data) { + // We add point to the left interval. + double lDelta = x - lMean; + double lMeanNew = lMean + lDelta / lSize; + double lm2New = lm2 + lDelta * (x - lMeanNew); + + // We remove point from the right interval. lSize + 1 is the size of right interval before removal. + double rMeanNew = (rMean * (rSize + 1) - x) / rSize; + double rm2New = rm2 - (x - rMean) * (x - rMeanNew); + + data[0] = lm2New; + data[1] = rm2New; + + data[2] = lMeanNew; + data[3] = rMeanNew; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/contsplitcalcs/package-info.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/contsplitcalcs/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/contsplitcalcs/package-info.java new file mode 100644 index 0000000..08c8a75 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/contsplitcalcs/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * <!-- Package description. --> + * Calculators of splits by continuous features. + */ +package org.apache.ignite.ml.trees.trainers.columnbased.contsplitcalcs; \ No newline at end of file http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/package-info.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/package-info.java new file mode 100644 index 0000000..8523914 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * <!-- Package description. --> + * Contains column based decision tree algorithms. + */ +package org.apache.ignite.ml.trees.trainers.columnbased; \ No newline at end of file http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/regcalcs/RegionCalculators.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/regcalcs/RegionCalculators.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/regcalcs/RegionCalculators.java new file mode 100644 index 0000000..5c4b354 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/regcalcs/RegionCalculators.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.trees.trainers.columnbased.regcalcs; + +import it.unimi.dsi.fastutil.doubles.Double2IntOpenHashMap; +import java.util.Comparator; +import java.util.HashMap; +import java.util.Map; +import java.util.PrimitiveIterator; +import java.util.stream.DoubleStream; +import org.apache.ignite.ml.math.functions.IgniteFunction; +import org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainerInput; + +/** Some commonly used functions for calculations of regions of space which correspond to decision tree leaf nodes. */ +public class RegionCalculators { + /** Mean value in the region. */ + public static final IgniteFunction<DoubleStream, Double> MEAN = s -> s.average().orElse(0.0); + + /** Most common value in the region. */ + public static final IgniteFunction<DoubleStream, Double> MOST_COMMON = + s -> { + PrimitiveIterator.OfDouble itr = s.iterator(); + Map<Double, Integer> voc = new HashMap<>(); + + while (itr.hasNext()) + voc.compute(itr.next(), (d, i) -> i != null ? i + 1 : 0); + + return voc.entrySet().stream().max(Comparator.comparing(Map.Entry::getValue)).map(Map.Entry::getKey).orElse(0.0); + }; + + /** Variance of a region. */ + public static final IgniteFunction<ColumnDecisionTreeTrainerInput, IgniteFunction<DoubleStream, Double>> VARIANCE = input -> + s -> { + PrimitiveIterator.OfDouble itr = s.iterator(); + int i = 0; + + double mean = 0.0; + double m2 = 0.0; + + while (itr.hasNext()) { + i++; + double x = itr.next(); + double delta = x - mean; + mean += delta / i; + double delta2 = x - mean; + m2 += delta * delta2; + } + + return i > 0 ? m2 / i : 0.0; + }; + + /** Gini impurity of a region. */ + public static final IgniteFunction<ColumnDecisionTreeTrainerInput, IgniteFunction<DoubleStream, Double>> GINI = input -> + s -> { + PrimitiveIterator.OfDouble itr = s.iterator(); + + Double2IntOpenHashMap m = new Double2IntOpenHashMap(); + + int size = 0; + + while (itr.hasNext()) { + size++; + m.compute(itr.next(), (a, i) -> i != null ? i + 1 : 1); + } + + double c2 = m.values().stream().mapToDouble(v -> v * v).sum(); + + return size != 0 ? 1 - c2 / (size * size) : 0.0; + }; +} http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/regcalcs/package-info.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/regcalcs/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/regcalcs/package-info.java new file mode 100644 index 0000000..e8edd8f --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/regcalcs/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * <!-- Package description. --> + * Region calculators. + */ +package org.apache.ignite.ml.trees.trainers.columnbased.regcalcs; \ No newline at end of file http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/CategoricalFeatureProcessor.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/CategoricalFeatureProcessor.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/CategoricalFeatureProcessor.java new file mode 100644 index 0000000..9469768 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/CategoricalFeatureProcessor.java @@ -0,0 +1,211 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.trees.trainers.columnbased.vectors; + +import com.zaxxer.sparsebits.SparseBitSet; +import java.util.Arrays; +import java.util.BitSet; +import java.util.Comparator; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.DoubleStream; +import java.util.stream.Stream; +import java.util.stream.StreamSupport; +import org.apache.ignite.lang.IgniteBiTuple; +import org.apache.ignite.ml.math.functions.IgniteFunction; +import org.apache.ignite.ml.trees.CategoricalRegionInfo; +import org.apache.ignite.ml.trees.CategoricalSplitInfo; +import org.apache.ignite.ml.trees.RegionInfo; +import org.apache.ignite.ml.trees.trainers.columnbased.RegionProjection; + +import static org.apache.ignite.ml.trees.trainers.columnbased.vectors.FeatureVectorProcessorUtils.splitByBitSet; + +/** + * Categorical feature vector processor implementation used by {@see ColumnDecisionTreeTrainer}. + */ +public class CategoricalFeatureProcessor + implements FeatureProcessor<CategoricalRegionInfo, CategoricalSplitInfo<CategoricalRegionInfo>> { + /** Count of categories for this feature. */ + private final int catsCnt; + + /** Function for calculating impurity of a given region of points. */ + private final IgniteFunction<DoubleStream, Double> calc; + + /** + * @param calc Function for calculating impurity of a given region of points. + * @param catsCnt Number of categories. + */ + public CategoricalFeatureProcessor(IgniteFunction<DoubleStream, Double> calc, int catsCnt) { + this.calc = calc; + this.catsCnt = catsCnt; + } + + /** */ + private SplitInfo<CategoricalRegionInfo> split(BitSet leftCats, int intervalIdx, Map<Integer, Integer> mapping, + Integer[] sampleIndexes, double[] values, double[] labels, double impurity) { + Map<Boolean, List<Integer>> leftRight = Arrays.stream(sampleIndexes). + collect(Collectors.partitioningBy((smpl) -> leftCats.get(mapping.get((int)values[smpl])))); + + List<Integer> left = leftRight.get(true); + int leftSize = left.size(); + double leftImpurity = calc.apply(left.stream().mapToDouble(s -> labels[s])); + + List<Integer> right = leftRight.get(false); + int rightSize = right.size(); + double rightImpurity = calc.apply(right.stream().mapToDouble(s -> labels[s])); + + int totalSize = leftSize + rightSize; + + // Result of this call will be sent back to trainer node, we do not need vectors inside of sent data. + CategoricalSplitInfo<CategoricalRegionInfo> res = new CategoricalSplitInfo<>(intervalIdx, + new CategoricalRegionInfo(leftImpurity, null), // cats can be computed on the last step. + new CategoricalRegionInfo(rightImpurity, null), + leftCats); + + res.setInfoGain(impurity - (double)leftSize / totalSize * leftImpurity - (double)rightSize / totalSize * rightImpurity); + return res; + } + + /** + * Get a stream of subsets given categories count. + * + * @param catsCnt categories count. + * @return Stream of subsets given categories count. + */ + private Stream<BitSet> powerSet(int catsCnt) { + Iterable<BitSet> iterable = () -> new PSI(catsCnt); + return StreamSupport.stream(iterable.spliterator(), false); + } + + /** {@inheritDoc} */ + @Override public SplitInfo findBestSplit(RegionProjection<CategoricalRegionInfo> regionPrj, double[] values, + double[] labels, int regIdx) { + Map<Integer, Integer> mapping = mapping(regionPrj.data().cats()); + + return powerSet(regionPrj.data().cats().length()). + map(s -> split(s, regIdx, mapping, regionPrj.sampleIndexes(), values, labels, regionPrj.data().impurity())). + max(Comparator.comparingDouble(SplitInfo::infoGain)). + orElse(null); + } + + /** {@inheritDoc} */ + @Override public RegionProjection<CategoricalRegionInfo> createInitialRegion(Integer[] sampleIndexes, + double[] values, double[] labels) { + BitSet set = new BitSet(); + set.set(0, catsCnt); + + Double impurity = calc.apply(Arrays.stream(labels)); + + return new RegionProjection<>(sampleIndexes, new CategoricalRegionInfo(impurity, set), 0); + } + + /** {@inheritDoc} */ + @Override public SparseBitSet calculateOwnershipBitSet(RegionProjection<CategoricalRegionInfo> regionPrj, + double[] values, + CategoricalSplitInfo<CategoricalRegionInfo> s) { + SparseBitSet res = new SparseBitSet(); + Arrays.stream(regionPrj.sampleIndexes()).forEach(smpl -> res.set(smpl, s.bitSet().get((int)values[smpl]))); + return res; + } + + /** {@inheritDoc} */ + @Override public IgniteBiTuple<RegionProjection, RegionProjection> performSplit(SparseBitSet bs, + RegionProjection<CategoricalRegionInfo> reg, CategoricalRegionInfo leftData, CategoricalRegionInfo rightData) { + return performSplitGeneric(bs, null, reg, leftData, rightData); + } + + /** {@inheritDoc} */ + @Override public IgniteBiTuple<RegionProjection, RegionProjection> performSplitGeneric( + SparseBitSet bs, double[] values, RegionProjection<CategoricalRegionInfo> reg, RegionInfo leftData, + RegionInfo rightData) { + int depth = reg.depth(); + + int lSize = bs.cardinality(); + int rSize = reg.sampleIndexes().length - lSize; + IgniteBiTuple<Integer[], Integer[]> lrSamples = splitByBitSet(lSize, rSize, reg.sampleIndexes(), bs); + BitSet leftCats = calculateCats(lrSamples.get1(), values); + CategoricalRegionInfo lInfo = new CategoricalRegionInfo(leftData.impurity(), leftCats); + + // TODO: IGNITE-5892 Check how it will work with sparse data. + BitSet rightCats = calculateCats(lrSamples.get2(), values); + CategoricalRegionInfo rInfo = new CategoricalRegionInfo(rightData.impurity(), rightCats); + + RegionProjection<CategoricalRegionInfo> rPrj = new RegionProjection<>(lrSamples.get2(), rInfo, depth + 1); + RegionProjection<CategoricalRegionInfo> lPrj = new RegionProjection<>(lrSamples.get1(), lInfo, depth + 1); + return new IgniteBiTuple<>(lPrj, rPrj); + } + + /** + * Powerset iterator. Iterates not over the whole powerset, but on half of it. + */ + private static class PSI implements Iterator<BitSet> { + + /** Current subset number. */ + private int i = 1; // We are not interested in {emptyset, set} split and therefore start from 1. + + /** Size of set, subsets of which we iterate over. */ + final int size; + + /** + * @param bitCnt Size of set, subsets of which we iterate over. + */ + PSI(int bitCnt) { + this.size = 1 << (bitCnt - 1); + } + + /** {@inheritDoc} */ + @Override public boolean hasNext() { + return i < size; + } + + /** {@inheritDoc} */ + @Override public BitSet next() { + BitSet res = BitSet.valueOf(new long[] {i}); + i++; + return res; + } + } + + /** */ + private Map<Integer, Integer> mapping(BitSet bs) { + int bn = 0; + Map<Integer, Integer> res = new HashMap<>(); + + int i = 0; + while ((bn = bs.nextSetBit(bn)) != -1) { + res.put(bn, i); + i++; + bn++; + } + + return res; + } + + /** Get set of categories of given samples */ + private BitSet calculateCats(Integer[] sampleIndexes, double[] values) { + BitSet res = new BitSet(); + + for (int smpl : sampleIndexes) + res.set((int)values[smpl]); + + return res; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/ContinuousFeatureProcessor.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/ContinuousFeatureProcessor.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/ContinuousFeatureProcessor.java new file mode 100644 index 0000000..4117993 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/ContinuousFeatureProcessor.java @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.trees.trainers.columnbased.vectors; + +import com.zaxxer.sparsebits.SparseBitSet; +import java.util.Arrays; +import java.util.Comparator; +import org.apache.ignite.lang.IgniteBiTuple; +import org.apache.ignite.ml.trees.ContinuousRegionInfo; +import org.apache.ignite.ml.trees.ContinuousSplitCalculator; +import org.apache.ignite.ml.trees.RegionInfo; +import org.apache.ignite.ml.trees.trainers.columnbased.RegionProjection; + +import static org.apache.ignite.ml.trees.trainers.columnbased.vectors.FeatureVectorProcessorUtils.splitByBitSet; + +/** + * Container of projection of samples on continuous feature. + * + * @param <D> Information about regions. Designed to contain information which will make computations of impurity + * optimal. + */ +public class ContinuousFeatureProcessor<D extends ContinuousRegionInfo> implements + FeatureProcessor<D, ContinuousSplitInfo<D>> { + /** ContinuousSplitCalculator used for calculating of best split of each region. */ + private final ContinuousSplitCalculator<D> calc; + + /** + * @param splitCalc Calculator used for calculating splits. + */ + public ContinuousFeatureProcessor(ContinuousSplitCalculator<D> splitCalc) { + this.calc = splitCalc; + } + + /** {@inheritDoc} */ + @Override public SplitInfo<D> findBestSplit(RegionProjection<D> ri, double[] values, double[] labels, int regIdx) { + SplitInfo<D> res = calc.splitRegion(ri.sampleIndexes(), values, labels, regIdx, ri.data()); + + if (res == null) + return null; + + double lWeight = (double)res.leftData.getSize() / ri.sampleIndexes().length; + double rWeight = (double)res.rightData.getSize() / ri.sampleIndexes().length; + + double infoGain = ri.data().impurity() - lWeight * res.leftData().impurity() - rWeight * res.rightData().impurity(); + res.setInfoGain(infoGain); + + return res; + } + + /** {@inheritDoc} */ + @Override public RegionProjection<D> createInitialRegion(Integer[] samples, double[] values, double[] labels) { + Arrays.sort(samples, Comparator.comparingDouble(s -> values[s])); + return new RegionProjection<>(samples, calc.calculateRegionInfo(Arrays.stream(labels), samples.length), 0); + } + + /** {@inheritDoc} */ + @Override public SparseBitSet calculateOwnershipBitSet(RegionProjection<D> reg, double[] values, + ContinuousSplitInfo<D> s) { + SparseBitSet res = new SparseBitSet(); + + for (int i = 0; i < s.leftData().getSize(); i++) + res.set(reg.sampleIndexes()[i]); + + return res; + } + + /** {@inheritDoc} */ + @Override public IgniteBiTuple<RegionProjection, RegionProjection> performSplit(SparseBitSet bs, + RegionProjection<D> reg, D leftData, D rightData) { + int lSize = leftData.getSize(); + int rSize = rightData.getSize(); + int depth = reg.depth(); + + IgniteBiTuple<Integer[], Integer[]> lrSamples = splitByBitSet(lSize, rSize, reg.sampleIndexes(), bs); + + RegionProjection<D> left = new RegionProjection<>(lrSamples.get1(), leftData, depth + 1); + RegionProjection<D> right = new RegionProjection<>(lrSamples.get2(), rightData, depth + 1); + + return new IgniteBiTuple<>(left, right); + } + + /** {@inheritDoc} */ + @Override public IgniteBiTuple<RegionProjection, RegionProjection> performSplitGeneric(SparseBitSet bs, + double[] labels, RegionProjection<D> reg, RegionInfo leftData, RegionInfo rightData) { + int lSize = bs.cardinality(); + int rSize = reg.sampleIndexes().length - lSize; + int depth = reg.depth(); + + IgniteBiTuple<Integer[], Integer[]> lrSamples = splitByBitSet(lSize, rSize, reg.sampleIndexes(), bs); + + D ld = calc.calculateRegionInfo(Arrays.stream(lrSamples.get1()).mapToDouble(s -> labels[s]), lSize); + D rd = calc.calculateRegionInfo(Arrays.stream(lrSamples.get2()).mapToDouble(s -> labels[s]), rSize); + + return new IgniteBiTuple<>(new RegionProjection<>(lrSamples.get1(), ld, depth + 1), new RegionProjection<>(lrSamples.get2(), rd, depth + 1)); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/ContinuousSplitInfo.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/ContinuousSplitInfo.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/ContinuousSplitInfo.java new file mode 100644 index 0000000..d6f2847 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/ContinuousSplitInfo.java @@ -0,0 +1,54 @@ +package org.apache.ignite.ml.trees.trainers.columnbased.vectors; + +import org.apache.ignite.ml.trees.RegionInfo; +import org.apache.ignite.ml.trees.nodes.ContinuousSplitNode; +import org.apache.ignite.ml.trees.nodes.SplitNode; + +/** + * Information about split of continuous region. + * + * @param <D> Class encapsulating information about the region. + */ +public class ContinuousSplitInfo<D extends RegionInfo> extends SplitInfo<D> { + /** + * Threshold used for split. + * Samples with values less or equal than this go to left region, others go to the right region. + */ + private final double threshold; + + /** + * @param regionIdx Index of region being split. + * @param threshold Threshold used for split. Samples with values less or equal than this go to left region, others + * go to the right region. + * @param leftData Information about left subregion. + * @param rightData Information about right subregion. + */ + public ContinuousSplitInfo(int regionIdx, double threshold, D leftData, D rightData) { + super(regionIdx, leftData, rightData); + this.threshold = threshold; + } + + /** {@inheritDoc} */ + @Override public SplitNode createSplitNode(int featureIdx) { + return new ContinuousSplitNode(threshold, featureIdx); + } + + /** + * Threshold used for splits. + * Samples with values less or equal than this go to left region, others go to the right region. + */ + public double threshold() { + return threshold; + } + + /** {@inheritDoc} */ + @Override public String toString() { + return "ContinuousSplitInfo [" + + "threshold=" + threshold + + ", infoGain=" + infoGain + + ", regionIdx=" + regionIdx + + ", leftData=" + leftData + + ", rightData=" + rightData + + ']'; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/FeatureProcessor.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/FeatureProcessor.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/FeatureProcessor.java new file mode 100644 index 0000000..cb8f5c2 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/FeatureProcessor.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.trees.trainers.columnbased.vectors; + +import com.zaxxer.sparsebits.SparseBitSet; +import org.apache.ignite.lang.IgniteBiTuple; +import org.apache.ignite.ml.trees.RegionInfo; +import org.apache.ignite.ml.trees.trainers.columnbased.RegionProjection; + +/** + * Base interface for feature processors used in {@see org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainer} + * + * @param <D> Class representing data of regions resulted from split. + * @param <S> Class representing data of split. + */ +public interface FeatureProcessor<D extends RegionInfo, S extends SplitInfo<D>> { + /** + * Finds best split by this feature among all splits of all regions. + * + * @return best split by this feature among all splits of all regions. + */ + SplitInfo findBestSplit(RegionProjection<D> regionPrj, double[] values, double[] labels, int regIdx); + + /** + * Creates initial region from samples. + * + * @param samples samples. + * @return region. + */ + RegionProjection<D> createInitialRegion(Integer[] samples, double[] values, double[] labels); + + /** + * Calculates the bitset mapping each data point to left (corresponding bit is set) or right subregion. + * + * @param s data used for calculating the split. + * @return Bitset mapping each data point to left (corresponding bit is set) or right subregion. + */ + SparseBitSet calculateOwnershipBitSet(RegionProjection<D> regionPrj, double[] values, S s); + + /** + * Splits given region using bitset which maps data point to left or right subregion. + * This method is present for the vectors of the same type to be able to pass between them information about regions + * and therefore used iff the optimal split is received on feature of the same type. + * + * @param bs Bitset which maps data point to left or right subregion. + * @param leftData Data of the left subregion. + * @param rightData Data of the right subregion. + * @return This feature vector. + */ + IgniteBiTuple<RegionProjection, RegionProjection> performSplit(SparseBitSet bs, RegionProjection<D> reg, D leftData, + D rightData); + + /** + * Splits given region using bitset which maps data point to left or right subregion. This method is used iff the + * optimal split is received on feature of different type, therefore information about regions is limited to the + * {@see RegionInfo} class which is base for all classes used to represent region data. + * + * @param bs Bitset which maps data point to left or right subregion. + * @param leftData Data of the left subregion. + * @param rightData Data of the right subregion. + * @return This feature vector. + */ + IgniteBiTuple<RegionProjection, RegionProjection> performSplitGeneric(SparseBitSet bs, double[] values, + RegionProjection<D> reg, RegionInfo leftData, + RegionInfo rightData); +} http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/FeatureVectorProcessorUtils.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/FeatureVectorProcessorUtils.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/FeatureVectorProcessorUtils.java new file mode 100644 index 0000000..69ff019 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/FeatureVectorProcessorUtils.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.trees.trainers.columnbased.vectors; + +import com.zaxxer.sparsebits.SparseBitSet; +import org.apache.ignite.lang.IgniteBiTuple; + +/** Utility class for feature vector processors. */ +public class FeatureVectorProcessorUtils { + /** + * Split target array into two (left and right) arrays by bitset. + * + * @param lSize Left array size; + * @param rSize Right array size. + * @param samples Arrays to split size. + * @param bs Bitset specifying split. + * @return BiTuple containing result of split. + */ + public static IgniteBiTuple<Integer[], Integer[]> splitByBitSet(int lSize, int rSize, Integer[] samples, + SparseBitSet bs) { + Integer[] lArr = new Integer[lSize]; + Integer[] rArr = new Integer[rSize]; + + int lc = 0; + int rc = 0; + + for (int i = 0; i < lSize + rSize; i++) { + int si = samples[i]; + + if (bs.get(si)) { + lArr[lc] = si; + lc++; + } + else { + rArr[rc] = si; + rc++; + } + } + + return new IgniteBiTuple<>(lArr, rArr); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/SampleInfo.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/SampleInfo.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/SampleInfo.java new file mode 100644 index 0000000..8aa4f79 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/SampleInfo.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.trees.trainers.columnbased.vectors; + +import java.io.Externalizable; +import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; + +/** + * Information about given sample within given fixed feature. + */ +public class SampleInfo implements Externalizable { + /** Value of projection of this sample on given fixed feature. */ + private double val; + + /** Sample index. */ + private int sampleIdx; + + /** + * @param val Value of projection of this sample on given fixed feature. + * @param sampleIdx Sample index. + */ + public SampleInfo(double val, int sampleIdx) { + this.val = val; + this.sampleIdx = sampleIdx; + } + + /** + * No-op constructor used for serialization/deserialization. + */ + public SampleInfo() { + // No-op. + } + + /** + * Get the value of projection of this sample on given fixed feature. + * + * @return Value of projection of this sample on given fixed feature. + */ + public double val() { + return val; + } + + /** + * Get the sample index. + * + * @return Sample index. + */ + public int sampleInd() { + return sampleIdx; + } + + /** {@inheritDoc} */ + @Override public void writeExternal(ObjectOutput out) throws IOException { + out.writeDouble(val); + out.writeInt(sampleIdx); + } + + /** {@inheritDoc} */ + @Override public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + val = in.readDouble(); + sampleIdx = in.readInt(); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/SplitInfo.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/SplitInfo.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/SplitInfo.java new file mode 100644 index 0000000..124e82f --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/SplitInfo.java @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.trees.trainers.columnbased.vectors; + +import org.apache.ignite.ml.trees.RegionInfo; +import org.apache.ignite.ml.trees.nodes.SplitNode; + +/** + * Class encapsulating information about the split. + * + * @param <D> Class representing information of left and right subregions. + */ +public abstract class SplitInfo<D extends RegionInfo> { + /** Information gain of this split. */ + protected double infoGain; + + /** Index of the region to split. */ + protected final int regionIdx; + + /** Data of left subregion. */ + protected final D leftData; + + /** Data of right subregion. */ + protected final D rightData; + + /** + * Construct the split info. + * + * @param regionIdx Index of the region to split. + * @param leftData Data of left subregion. + * @param rightData Data of right subregion. + */ + public SplitInfo(int regionIdx, D leftData, D rightData) { + this.regionIdx = regionIdx; + this.leftData = leftData; + this.rightData = rightData; + } + + /** + * Index of region to split. + * + * @return Index of region to split. + */ + public int regionIndex() { + return regionIdx; + } + + /** + * Information gain of the split. + * + * @return Information gain of the split. + */ + public double infoGain() { + return infoGain; + } + + /** + * Data of right subregion. + * + * @return Data of right subregion. + */ + public D rightData() { + return rightData; + } + + /** + * Data of left subregion. + * + * @return Data of left subregion. + */ + public D leftData() { + return leftData; + } + + /** + * Create SplitNode from this split info. + * + * @param featureIdx Index of feature by which goes split. + * @return SplitNode from this split info. + */ + public abstract SplitNode createSplitNode(int featureIdx); + + /** + * Set information gain. + * + * @param infoGain Information gain. + */ + public void setInfoGain(double infoGain) { + this.infoGain = infoGain; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/package-info.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/package-info.java new file mode 100644 index 0000000..0dea204 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * <!-- Package description. --> + * Contains feature containers needed by column based decision tree trainers. + */ +package org.apache.ignite.ml.trees.trainers.columnbased.vectors; \ No newline at end of file
