http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/trees/ContinuousRegionInfo.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/ContinuousRegionInfo.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/ContinuousRegionInfo.java new file mode 100644 index 0000000..e98bb72 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/trees/ContinuousRegionInfo.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.trees; + +import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; + +/** + * Information about region used by continuous features. + */ +public class ContinuousRegionInfo extends RegionInfo { + /** + * Count of samples in this region. + */ + private int size; + + /** + * @param impurity Impurity of the region. + * @param size Size of this region + */ + public ContinuousRegionInfo(double impurity, int size) { + super(impurity); + this.size = size; + } + + /** + * No-op constructor for serialization/deserialization. + */ + public ContinuousRegionInfo() { + // No-op + } + + /** + * Get the size of region. + */ + public int getSize() { + return size; + } + + /** {@inheritDoc} */ + @Override public String toString() { + return "ContinuousRegionInfo [" + + "size=" + size + + ']'; + } + + /** {@inheritDoc} */ + @Override public void writeExternal(ObjectOutput out) throws IOException { + super.writeExternal(out); + out.writeInt(size); + } + + /** {@inheritDoc} */ + @Override public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + super.readExternal(in); + size = in.readInt(); + } +} \ No newline at end of file
http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/trees/ContinuousSplitCalculator.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/ContinuousSplitCalculator.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/ContinuousSplitCalculator.java new file mode 100644 index 0000000..f9b81d0 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/trees/ContinuousSplitCalculator.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.trees; + +import java.util.stream.DoubleStream; +import org.apache.ignite.ml.trees.trainers.columnbased.vectors.SplitInfo; + +/** + * This class is used for calculation of best split by continuous feature. + * + * @param <C> Class in which information about region will be stored. + */ +public interface ContinuousSplitCalculator<C extends ContinuousRegionInfo> { + /** + * Calculate region info 'from scratch'. + * + * @param s Stream of labels in this region. + * @param l Index of sample projection on this feature in array sorted by this projection value and intervals + * bitsets. ({@see org.apache.ignite.ml.trees.trainers.columnbased.vectors.ContinuousFeatureProcessor}). + * @return Region info. + */ + C calculateRegionInfo(DoubleStream s, int l); + + /** + * Calculate split info of best split of region given information about this region. + * + * @param sampleIndexes Indexes of samples of this region. + * @param values All values of this feature. + * @param labels All labels of this feature. + * @param regionIdx Index of region being split. + * @param data Information about region being split which can be used for computations. + * @return Information about best split of region with index given by regionIdx. + */ + SplitInfo<C> splitRegion(Integer[] sampleIndexes, double[] values, double[] labels, int regionIdx, C data); +} http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/trees/RegionInfo.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/RegionInfo.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/RegionInfo.java new file mode 100644 index 0000000..8ec7db3 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/trees/RegionInfo.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.trees; + +import java.io.Externalizable; +import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; + +/** Class containing information about region. */ +public class RegionInfo implements Externalizable { + /** Impurity in this region. */ + private double impurity; + + /** + * @param impurity Impurity of this region. + */ + public RegionInfo(double impurity) { + this.impurity = impurity; + } + + /** + * No-op constructor for serialization/deserialization. + */ + public RegionInfo() { + // No-op + } + + /** + * Get impurity in this region. + * + * @return Impurity of this region. + */ + public double impurity() { + return impurity; + } + + /** {@inheritDoc} */ + @Override public void writeExternal(ObjectOutput out) throws IOException { + out.writeDouble(impurity); + } + + /** {@inheritDoc} */ + @Override public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + impurity = in.readDouble(); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/trees/models/DecisionTreeModel.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/models/DecisionTreeModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/models/DecisionTreeModel.java new file mode 100644 index 0000000..86e9326 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/trees/models/DecisionTreeModel.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.trees.models; + +import org.apache.ignite.ml.Model; +import org.apache.ignite.ml.math.Vector; +import org.apache.ignite.ml.trees.nodes.DecisionTreeNode; + +/** + * Model for decision tree. + */ +public class DecisionTreeModel implements Model<Vector, Double> { + /** Root node of the decision tree. */ + private final DecisionTreeNode root; + + /** + * Construct decision tree model. + * + * @param root Root of decision tree. + */ + public DecisionTreeModel(DecisionTreeNode root) { + this.root = root; + } + + /** {@inheritDoc} */ + @Override public Double predict(Vector val) { + return root.process(val); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/trees/models/package-info.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/models/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/models/package-info.java new file mode 100644 index 0000000..ce8418e --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/trees/models/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * <!-- Package description. --> + * Contains decision tree models. + */ +package org.apache.ignite.ml.trees.models; \ No newline at end of file http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/trees/nodes/CategoricalSplitNode.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/nodes/CategoricalSplitNode.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/nodes/CategoricalSplitNode.java new file mode 100644 index 0000000..cae6d4a --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/trees/nodes/CategoricalSplitNode.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.trees.nodes; + +import java.util.BitSet; +import org.apache.ignite.ml.math.Vector; + +/** + * Split node by categorical feature. + */ +public class CategoricalSplitNode extends SplitNode { + /** Bitset specifying which categories belong to left subregion. */ + private final BitSet bs; + + /** + * Construct categorical split node. + * + * @param featureIdx Index of feature by which split is done. + * @param bs Bitset specifying which categories go to the left subtree. + */ + public CategoricalSplitNode(int featureIdx, BitSet bs) { + super(featureIdx); + this.bs = bs; + } + + /** {@inheritDoc} */ + @Override public boolean goLeft(Vector v) { + return bs.get((int)v.getX(featureIdx)); + } + + /** {@inheritDoc} */ + @Override public String toString() { + return "CategoricalSplitNode [bs=" + bs + ']'; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/trees/nodes/ContinuousSplitNode.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/nodes/ContinuousSplitNode.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/nodes/ContinuousSplitNode.java new file mode 100644 index 0000000..285cfcd --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/trees/nodes/ContinuousSplitNode.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.trees.nodes; + +import org.apache.ignite.ml.math.Vector; + +/** + * Split node representing split of continuous feature. + */ +public class ContinuousSplitNode extends SplitNode { + /** Threshold. Values which are less or equal then threshold are assigned to the left subregion. */ + private final double threshold; + + /** + * Construct ContinuousSplitNode by threshold and feature index. + * + * @param threshold Threshold. + * @param featureIdx Feature index. + */ + public ContinuousSplitNode(double threshold, int featureIdx) { + super(featureIdx); + this.threshold = threshold; + } + + /** {@inheritDoc} */ + @Override public boolean goLeft(Vector v) { + return v.getX(featureIdx) <= threshold; + } + + /** Threshold. Values which are less or equal then threshold are assigned to the left subregion. */ + public double threshold() { + return threshold; + } + + /** {@inheritDoc} */ + @Override public String toString() { + return "ContinuousSplitNode [" + + "threshold=" + threshold + + ']'; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/trees/nodes/DecisionTreeNode.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/nodes/DecisionTreeNode.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/nodes/DecisionTreeNode.java new file mode 100644 index 0000000..d31623d --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/trees/nodes/DecisionTreeNode.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.trees.nodes; + +import org.apache.ignite.ml.math.Vector; + +/** + * Node of decision tree. + */ +public interface DecisionTreeNode { + /** + * Assign the double value to the given vector. + * + * @param v Vector. + * @return Value assigned to the given vector. + */ + double process(Vector v); +} http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/trees/nodes/Leaf.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/nodes/Leaf.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/nodes/Leaf.java new file mode 100644 index 0000000..79b441f --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/trees/nodes/Leaf.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.trees.nodes; + +import org.apache.ignite.ml.math.Vector; + +/** + * Terminal node of the decision tree. + */ +public class Leaf implements DecisionTreeNode { + /** + * Value in subregion represented by this node. + */ + private final double val; + + /** + * Construct the leaf of decision tree. + * + * @param val Value in subregion represented by this node. + */ + public Leaf(double val) { + this.val = val; + } + + /** + * Return value in subregion represented by this node. + * + * @param v Vector. + * @return Value in subregion represented by this node. + */ + @Override public double process(Vector v) { + return val; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/trees/nodes/SplitNode.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/nodes/SplitNode.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/nodes/SplitNode.java new file mode 100644 index 0000000..4c258d1 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/trees/nodes/SplitNode.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.trees.nodes; + +import org.apache.ignite.ml.math.Vector; + +/** + * Node in decision tree representing a split. + */ +public abstract class SplitNode implements DecisionTreeNode { + /** Left subtree. */ + protected DecisionTreeNode l; + + /** Right subtree. */ + protected DecisionTreeNode r; + + /** Feature index. */ + protected final int featureIdx; + + /** + * Constructs SplitNode with a given feature index. + * + * @param featureIdx Feature index. + */ + public SplitNode(int featureIdx) { + this.featureIdx = featureIdx; + } + + /** + * Indicates if the given vector is in left subtree. + * + * @param v Vector + * @return Status of given vector being left subtree. + */ + abstract boolean goLeft(Vector v); + + /** + * Left subtree. + * + * @return Left subtree. + */ + public DecisionTreeNode left() { + return l; + } + + /** + * Right subtree. + * + * @return Right subtree. + */ + public DecisionTreeNode right() { + return r; + } + + /** + * Set the left subtree. + * + * @param n left subtree. + */ + public void setLeft(DecisionTreeNode n) { + l = n; + } + + /** + * Set the right subtree. + * + * @param n right subtree. + */ + public void setRight(DecisionTreeNode n) { + r = n; + } + + /** + * Delegates processing to subtrees. + * + * @param v Vector. + * @return Value assigned to the given vector. + */ + @Override public double process(Vector v) { + if (left() != null && goLeft(v)) + return left().process(v); + else + return right().process(v); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/trees/nodes/package-info.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/nodes/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/nodes/package-info.java new file mode 100644 index 0000000..d6deb9d --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/trees/nodes/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * <!-- Package description. --> + * Contains classes representing decision tree nodes. + */ +package org.apache.ignite.ml.trees.nodes; \ No newline at end of file http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/trees/package-info.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/package-info.java new file mode 100644 index 0000000..b07ba4a --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/trees/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * <!-- Package description. --> + * Contains decision tree algorithms. + */ +package org.apache.ignite.ml.trees; \ No newline at end of file http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/BiIndex.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/BiIndex.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/BiIndex.java new file mode 100644 index 0000000..0d27c8a --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/BiIndex.java @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.trees.trainers.columnbased; + +import java.io.Externalizable; +import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; +import org.apache.ignite.cache.affinity.AffinityKeyMapped; + +/** + * Class representing a simple index in 2d matrix in the form (row, col). + */ +public class BiIndex implements Externalizable { + /** Row. */ + private int row; + + /** Column. */ + @AffinityKeyMapped + private int col; + + /** + * No-op constructor for serialization/deserialization. + */ + public BiIndex() { + // No-op. + } + + /** + * Construct BiIndex from row and column. + * + * @param row Row. + * @param col Column. + */ + public BiIndex(int row, int col) { + this.row = row; + this.col = col; + } + + /** + * Returns row. + * + * @return Row. + */ + public int row() { + return row; + } + + /** + * Returns column. + * + * @return Column. + */ + public int col() { + return col; + } + + /** {@inheritDoc} */ + @Override public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + + BiIndex idx = (BiIndex)o; + + if (row != idx.row) + return false; + return col == idx.col; + } + + /** {@inheritDoc} */ + @Override public int hashCode() { + int res = row; + res = 31 * res + col; + return res; + } + + /** {@inheritDoc} */ + @Override public String toString() { + return "BiIndex [" + + "row=" + row + + ", col=" + col + + ']'; + } + + /** {@inheritDoc} */ + @Override public void writeExternal(ObjectOutput out) throws IOException { + out.writeInt(row); + out.writeInt(col); + } + + /** {@inheritDoc} */ + @Override public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + row = in.readInt(); + col = in.readInt(); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/BiIndexedCacheColumnDecisionTreeTrainerInput.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/BiIndexedCacheColumnDecisionTreeTrainerInput.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/BiIndexedCacheColumnDecisionTreeTrainerInput.java new file mode 100644 index 0000000..04281fb --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/BiIndexedCacheColumnDecisionTreeTrainerInput.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.trees.trainers.columnbased; + +import java.util.Map; +import java.util.stream.DoubleStream; +import java.util.stream.IntStream; +import java.util.stream.Stream; +import org.apache.ignite.Ignite; +import org.apache.ignite.IgniteCache; +import org.apache.ignite.lang.IgniteBiTuple; + +/** + * Adapter for column decision tree trainer for bi-indexed cache. + */ +public class BiIndexedCacheColumnDecisionTreeTrainerInput extends CacheColumnDecisionTreeTrainerInput<BiIndex, Double> { + /** + * Construct an input for {@link ColumnDecisionTreeTrainer}. + * + * @param cache Bi-indexed cache. + * @param catFeaturesInfo Information about categorical feature in the form (feature index -> number of + * categories). + * @param samplesCnt Count of samples. + * @param featuresCnt Count of features. + */ + public BiIndexedCacheColumnDecisionTreeTrainerInput(IgniteCache<BiIndex, Double> cache, + Map<Integer, Integer> catFeaturesInfo, int samplesCnt, int featuresCnt) { + super(cache, + () -> IntStream.range(0, samplesCnt).mapToObj(s -> new BiIndex(s, featuresCnt)), + e -> Stream.of(new IgniteBiTuple<>(e.getKey().row(), e.getValue())), + DoubleStream::of, + fIdx -> IntStream.range(0, samplesCnt).mapToObj(s -> new BiIndex(s, fIdx)), + catFeaturesInfo, + featuresCnt, + samplesCnt); + } + + /** {@inheritDoc} */ + @Override public Object affinityKey(int idx, Ignite ignite) { + return idx; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/CacheColumnDecisionTreeTrainerInput.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/CacheColumnDecisionTreeTrainerInput.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/CacheColumnDecisionTreeTrainerInput.java new file mode 100644 index 0000000..9518caf --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/CacheColumnDecisionTreeTrainerInput.java @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.trees.trainers.columnbased; + +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.DoubleStream; +import java.util.stream.Stream; +import javax.cache.Cache; +import org.apache.ignite.Ignite; +import org.apache.ignite.IgniteCache; +import org.apache.ignite.Ignition; +import org.apache.ignite.internal.processors.cache.CacheEntryImpl; +import org.apache.ignite.lang.IgniteBiTuple; +import org.apache.ignite.ml.math.functions.IgniteFunction; +import org.apache.ignite.ml.math.functions.IgniteSupplier; + +/** + * Adapter of a given cache to {@see CacheColumnDecisionTreeTrainerInput} + * + * @param <K> Class of keys of the cache. + * @param <V> Class of values of the cache. + */ +public abstract class CacheColumnDecisionTreeTrainerInput<K, V> implements ColumnDecisionTreeTrainerInput { + /** Supplier of labels key. */ + private final IgniteSupplier<Stream<K>> labelsKeys; + + /** Count of features. */ + private final int featuresCnt; + + /** Function which maps feature index to Stream of keys corresponding to this feature index. */ + private final IgniteFunction<Integer, Stream<K>> keyMapper; + + /** Information about which features are categorical in form of feature index -> number of categories. */ + private final Map<Integer, Integer> catFeaturesInfo; + + /** Cache name. */ + private final String cacheName; + + /** Count of samples. */ + private final int samplesCnt; + + /** Function used for mapping cache values to stream of tuples. */ + private final IgniteFunction<Cache.Entry<K, V>, Stream<IgniteBiTuple<Integer, Double>>> valuesMapper; + + /** + * Function which map value of entry with label key to DoubleStream. + * Look at {@code CacheColumnDecisionTreeTrainerInput::labels} for understanding how {@code labelsKeys} and + * {@code labelsMapper} interact. + */ + private final IgniteFunction<V, DoubleStream> labelsMapper; + + /** + * Constructs input for {@see org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainer}. + * + * @param c Cache. + * @param valuesMapper Function for mapping cache entry to stream used by {@link + * org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainer}. + * @param labelsMapper Function used for mapping cache value to labels array. + * @param keyMapper Function used for mapping feature index to the cache key. + * @param catFeaturesInfo Information about which features are categorical in form of feature index -> number of + * categories. + * @param featuresCnt Count of features. + * @param samplesCnt Count of samples. + */ + // TODO: IGNITE-5724 think about boxing/unboxing + public CacheColumnDecisionTreeTrainerInput(IgniteCache<K, V> c, + IgniteSupplier<Stream<K>> labelsKeys, + IgniteFunction<Cache.Entry<K, V>, Stream<IgniteBiTuple<Integer, Double>>> valuesMapper, + IgniteFunction<V, DoubleStream> labelsMapper, + IgniteFunction<Integer, Stream<K>> keyMapper, + Map<Integer, Integer> catFeaturesInfo, + int featuresCnt, int samplesCnt) { + + cacheName = c.getName(); + this.labelsKeys = labelsKeys; + this.valuesMapper = valuesMapper; + this.labelsMapper = labelsMapper; + this.keyMapper = keyMapper; + this.catFeaturesInfo = catFeaturesInfo; + this.samplesCnt = samplesCnt; + this.featuresCnt = featuresCnt; + } + + /** {@inheritDoc} */ + @Override public Stream<IgniteBiTuple<Integer, Double>> values(int idx) { + return cache(Ignition.localIgnite()).getAll(keyMapper.apply(idx).collect(Collectors.toSet())). + entrySet(). + stream(). + flatMap(ent -> valuesMapper.apply(new CacheEntryImpl<>(ent.getKey(), ent.getValue()))); + } + + /** {@inheritDoc} */ + @Override public double[] labels(Ignite ignite) { + return labelsKeys.get().map(k -> get(k, ignite)).flatMapToDouble(labelsMapper).toArray(); + } + + /** {@inheritDoc} */ + @Override public Map<Integer, Integer> catFeaturesInfo() { + return catFeaturesInfo; + } + + /** {@inheritDoc} */ + @Override public int featuresCount() { + return featuresCnt; + } + + /** {@inheritDoc} */ + @Override public Object affinityKey(int idx, Ignite ignite) { + return ignite.affinity(cacheName).affinityKey(keyMapper.apply(idx)); + } + + /** */ + private V get(K k, Ignite ignite) { + V res = cache(ignite).localPeek(k); + + if (res == null) + res = cache(ignite).get(k); + + return res; + } + + /** */ + private IgniteCache<K, V> cache(Ignite ignite) { + return ignite.getOrCreateCache(cacheName); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/ColumnDecisionTreeTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/ColumnDecisionTreeTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/ColumnDecisionTreeTrainer.java new file mode 100644 index 0000000..32e33f3 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/ColumnDecisionTreeTrainer.java @@ -0,0 +1,557 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.trees.trainers.columnbased; + +import com.zaxxer.sparsebits.SparseBitSet; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Consumer; +import java.util.stream.Collectors; +import java.util.stream.DoubleStream; +import java.util.stream.IntStream; +import java.util.stream.Stream; +import javax.cache.Cache; +import org.apache.ignite.Ignite; +import org.apache.ignite.IgniteCache; +import org.apache.ignite.Ignition; +import org.apache.ignite.cache.CachePeekMode; +import org.apache.ignite.cache.affinity.Affinity; +import org.apache.ignite.cluster.ClusterNode; +import org.apache.ignite.internal.processors.cache.CacheEntryImpl; +import org.apache.ignite.internal.util.typedef.X; +import org.apache.ignite.lang.IgniteBiTuple; +import org.apache.ignite.ml.Trainer; +import org.apache.ignite.ml.math.Vector; +import org.apache.ignite.ml.math.distributed.CacheUtils; +import org.apache.ignite.ml.math.functions.Functions; +import org.apache.ignite.ml.math.functions.IgniteBiFunction; +import org.apache.ignite.ml.math.functions.IgniteCurriedBiFunction; +import org.apache.ignite.ml.math.functions.IgniteFunction; +import org.apache.ignite.ml.math.functions.IgniteSupplier; +import org.apache.ignite.ml.trees.ContinuousRegionInfo; +import org.apache.ignite.ml.trees.ContinuousSplitCalculator; +import org.apache.ignite.ml.trees.models.DecisionTreeModel; +import org.apache.ignite.ml.trees.nodes.DecisionTreeNode; +import org.apache.ignite.ml.trees.nodes.Leaf; +import org.apache.ignite.ml.trees.nodes.SplitNode; +import org.apache.ignite.ml.trees.trainers.columnbased.caches.ContextCache; +import org.apache.ignite.ml.trees.trainers.columnbased.caches.FeaturesCache; +import org.apache.ignite.ml.trees.trainers.columnbased.caches.FeaturesCache.FeatureKey; +import org.apache.ignite.ml.trees.trainers.columnbased.caches.ProjectionsCache; +import org.apache.ignite.ml.trees.trainers.columnbased.caches.ProjectionsCache.RegionKey; +import org.apache.ignite.ml.trees.trainers.columnbased.caches.SplitCache; +import org.apache.ignite.ml.trees.trainers.columnbased.caches.SplitCache.SplitKey; +import org.apache.ignite.ml.trees.trainers.columnbased.vectors.FeatureProcessor; +import org.apache.ignite.ml.trees.trainers.columnbased.vectors.SplitInfo; +import org.jetbrains.annotations.NotNull; + +import static org.apache.ignite.ml.trees.trainers.columnbased.caches.FeaturesCache.getFeatureCacheKey; + +/** + * This trainer stores observations as columns and features as rows. + * Ideas from https://github.com/fabuzaid21/yggdrasil are used here. + */ +public class ColumnDecisionTreeTrainer<D extends ContinuousRegionInfo> implements + Trainer<DecisionTreeModel, ColumnDecisionTreeTrainerInput> { + /** + * Function used to assign a value to a region. + */ + private final IgniteFunction<DoubleStream, Double> regCalc; + + /** + * Function used to calculate impurity in regions used by categorical features. + */ + private final IgniteFunction<ColumnDecisionTreeTrainerInput, ? extends ContinuousSplitCalculator<D>> continuousCalculatorProvider; + + /** + * Categorical calculator provider. + **/ + private final IgniteFunction<ColumnDecisionTreeTrainerInput, IgniteFunction<DoubleStream, Double>> categoricalCalculatorProvider; + + /** + * Cache used for storing data for training. + */ + private IgniteCache<RegionKey, List<RegionProjection>> prjsCache; + + /** + * Minimal information gain. + */ + private static final double MIN_INFO_GAIN = 1E-10; + + /** + * Maximal depth of the decision tree. + */ + private final int maxDepth; + + /** + * Size of block which is used for storing regions in cache. + */ + private static final int BLOCK_SIZE = 1 << 4; + + /** Ignite instance. */ + private final Ignite ignite; + + /** + * Construct {@link ColumnDecisionTreeTrainer}. + * + * @param maxDepth Maximal depth of the decision tree. + * @param continuousCalculatorProvider Provider of calculator of splits for region projection on continuous + * features. + * @param categoricalCalculatorProvider Provider of calculator of splits for region projection on categorical + * features. + * @param regCalc Function used to assign a value to a region. + */ + public ColumnDecisionTreeTrainer(int maxDepth, + IgniteFunction<ColumnDecisionTreeTrainerInput, ? extends ContinuousSplitCalculator<D>> continuousCalculatorProvider, + IgniteFunction<ColumnDecisionTreeTrainerInput, IgniteFunction<DoubleStream, Double>> categoricalCalculatorProvider, + IgniteFunction<DoubleStream, Double> regCalc, + Ignite ignite) { + this.maxDepth = maxDepth; + this.continuousCalculatorProvider = continuousCalculatorProvider; + this.categoricalCalculatorProvider = categoricalCalculatorProvider; + this.regCalc = regCalc; + this.ignite = ignite; + } + + /** + * Utility class used to get index of feature by which split is done and split info. + */ + private static class IndexAndSplitInfo { + /** + * Index of feature by which split is done. + */ + private final int featureIdx; + + /** + * Split information. + */ + private final SplitInfo info; + + /** + * @param featureIdx Index of feature by which split is done. + * @param info Split information. + */ + IndexAndSplitInfo(int featureIdx, SplitInfo info) { + this.featureIdx = featureIdx; + this.info = info; + } + + /** {@inheritDoc} */ + @Override public String toString() { + return "IndexAndSplitInfo [featureIdx=" + featureIdx + ", info=" + info + ']'; + } + } + + /** + * Utility class used to build decision tree. Basically it is pointer to leaf node. + */ + private static class TreeTip { + /** */ + private Consumer<DecisionTreeNode> leafSetter; + + /** */ + private int depth; + + /** */ + TreeTip(Consumer<DecisionTreeNode> leafSetter, int depth) { + this.leafSetter = leafSetter; + this.depth = depth; + } + } + + /** + * Utility class used as decision tree root node. + */ + private static class RootNode implements DecisionTreeNode { + /** */ + private DecisionTreeNode s; + + /** + * {@inheritDoc} + */ + @Override public double process(Vector v) { + return s.process(v); + } + + /** */ + void setSplit(DecisionTreeNode s) { + this.s = s; + } + } + + /** + * {@inheritDoc} + */ + @Override public DecisionTreeModel train(ColumnDecisionTreeTrainerInput i) { + prjsCache = ProjectionsCache.getOrCreate(ignite); + IgniteCache<UUID, TrainingContext<D>> ctxtCache = ContextCache.getOrCreate(ignite); + SplitCache.getOrCreate(ignite); + + UUID trainingUUID = UUID.randomUUID(); + + TrainingContext<D> ct = new TrainingContext<>(i, continuousCalculatorProvider.apply(i), categoricalCalculatorProvider.apply(i), trainingUUID, ignite); + ctxtCache.put(trainingUUID, ct); + + CacheUtils.bcast(prjsCache.getName(), ignite, () -> { + Ignite ignite = Ignition.localIgnite(); + IgniteCache<RegionKey, List<RegionProjection>> projCache = ProjectionsCache.getOrCreate(ignite); + IgniteCache<FeatureKey, double[]> featuresCache = FeaturesCache.getOrCreate(ignite); + + Affinity<RegionKey> targetAffinity = ignite.affinity(ProjectionsCache.CACHE_NAME); + + ClusterNode locNode = ignite.cluster().localNode(); + + Map<FeatureKey, double[]> fm = new ConcurrentHashMap<>(); + Map<RegionKey, List<RegionProjection>> pm = new ConcurrentHashMap<>(); + + targetAffinity. + mapKeysToNodes(IntStream.range(0, i.featuresCount()). + mapToObj(idx -> ProjectionsCache.key(idx, 0, i.affinityKey(idx, ignite), trainingUUID)). + collect(Collectors.toSet())).getOrDefault(locNode, Collections.emptyList()). + forEach(k -> { + FeatureProcessor vec; + + int featureIdx = k.featureIdx(); + + IgniteCache<UUID, TrainingContext<D>> ctxCache = ContextCache.getOrCreate(ignite); + TrainingContext ctx = ctxCache.get(trainingUUID); + double[] vals = new double[ctx.labels().length]; + + vec = ctx.featureProcessor(featureIdx); + i.values(featureIdx).forEach(t -> vals[t.get1()] = t.get2()); + + fm.put(getFeatureCacheKey(featureIdx, trainingUUID, i.affinityKey(featureIdx, ignite)), vals); + + List<RegionProjection> newReg = new ArrayList<>(BLOCK_SIZE); + newReg.add(vec.createInitialRegion(getSamples(i.values(featureIdx), ctx.labels().length), vals, ctx.labels())); + pm.put(k, newReg); + }); + + featuresCache.putAll(fm); + projCache.putAll(pm); + + return null; + }); + + return doTrain(i, trainingUUID); + } + + /** + * Get samples array. + * + * @param values Stream of tuples in the form of (index, value). + * @param size size of stream. + * @return Samples array. + */ + private Integer[] getSamples(Stream<IgniteBiTuple<Integer, Double>> values, int size) { + Integer[] res = new Integer[size]; + + values.forEach(v -> res[v.get1()] = v.get1()); + + return res; + } + + /** */ + @NotNull + private DecisionTreeModel doTrain(ColumnDecisionTreeTrainerInput input, UUID uuid) { + RootNode root = new RootNode(); + + // List containing setters of leaves of the tree. + List<TreeTip> tips = new LinkedList<>(); + tips.add(new TreeTip(root::setSplit, 0)); + + int curDepth = 0; + int regsCnt = 1; + + int featuresCnt = input.featuresCount(); + IntStream.range(0, featuresCnt).mapToObj(fIdx -> SplitCache.key(fIdx, input.affinityKey(fIdx, ignite), uuid)). + forEach(k -> SplitCache.getOrCreate(ignite).put(k, new IgniteBiTuple<>(0, 0.0))); + updateSplitCache(0, regsCnt, featuresCnt, ig -> i -> input.affinityKey(i, ig), uuid); + + // TODO: IGNITE-5893 Currently if the best split makes tree deeper than max depth process will be terminated, but actually we should + // only stop when *any* improving split makes tree deeper than max depth. Can be fixed if we will store which + // regions cannot be split more and split only those that can. + while (true) { + long before = System.currentTimeMillis(); + + IgniteBiTuple<Integer, IgniteBiTuple<Integer, Double>> b = findBestSplitIndexForFeatures(featuresCnt, input::affinityKey, uuid); + + long findBestRegIdx = System.currentTimeMillis() - before; + + Integer bestFeatureIdx = b.get1(); + + Integer regIdx = b.get2().get1(); + Double bestInfoGain = b.get2().get2(); + + if (regIdx >= 0 && bestInfoGain > MIN_INFO_GAIN) { + before = System.currentTimeMillis(); + + SplitInfo bi = ignite.compute().affinityCall(ProjectionsCache.CACHE_NAME, + input.affinityKey(bestFeatureIdx, ignite), + () -> { + TrainingContext<ContinuousRegionInfo> ctx = ContextCache.getOrCreate(ignite).get(uuid); + Ignite ignite = Ignition.localIgnite(); + RegionKey key = ProjectionsCache.key(bestFeatureIdx, + regIdx / BLOCK_SIZE, + input.affinityKey(bestFeatureIdx, Ignition.localIgnite()), + uuid); + RegionProjection reg = ProjectionsCache.getOrCreate(ignite).localPeek(key).get(regIdx % BLOCK_SIZE); + return ctx.featureProcessor(bestFeatureIdx).findBestSplit(reg, ctx.values(bestFeatureIdx, ignite), ctx.labels(), regIdx); + }); + + long findBestSplit = System.currentTimeMillis() - before; + + IndexAndSplitInfo best = new IndexAndSplitInfo(bestFeatureIdx, bi); + + regsCnt++; + + X.println(">>> Globally best: " + best.info + " idx time: " + findBestRegIdx + ", calculate best: " + findBestSplit + " fi: " + best.featureIdx + ", regs: " + regsCnt); + // Request bitset for split region. + int ind = best.info.regionIndex(); + + SparseBitSet bs = ignite.compute().affinityCall(ProjectionsCache.CACHE_NAME, + input.affinityKey(bestFeatureIdx, ignite), + () -> { + Ignite ignite = Ignition.localIgnite(); + IgniteCache<FeatureKey, double[]> featuresCache = FeaturesCache.getOrCreate(ignite); + IgniteCache<UUID, TrainingContext<D>> ctxCache = ContextCache.getOrCreate(ignite); + TrainingContext ctx = ctxCache.localPeek(uuid); + + double[] values = featuresCache.localPeek(getFeatureCacheKey(bestFeatureIdx, uuid, input.affinityKey(bestFeatureIdx, Ignition.localIgnite()))); + RegionKey key = ProjectionsCache.key(bestFeatureIdx, + regIdx / BLOCK_SIZE, + input.affinityKey(bestFeatureIdx, Ignition.localIgnite()), + uuid); + RegionProjection reg = ProjectionsCache.getOrCreate(ignite).localPeek(key).get(regIdx % BLOCK_SIZE); + return ctx.featureProcessor(bestFeatureIdx).calculateOwnershipBitSet(reg, values, best.info); + + }); + + SplitNode sn = best.info.createSplitNode(best.featureIdx); + + TreeTip tipToSplit = tips.get(ind); + tipToSplit.leafSetter.accept(sn); + tipToSplit.leafSetter = sn::setLeft; + int d = tipToSplit.depth++; + tips.add(new TreeTip(sn::setRight, d)); + + if (d > curDepth) { + curDepth = d; + X.println(">>> Depth: " + curDepth); + X.println(">>> Cache size: " + prjsCache.size(CachePeekMode.PRIMARY)); + } + + before = System.currentTimeMillis(); + // Perform split on all feature vectors. + IgniteSupplier<Set<RegionKey>> bestRegsKeys = () -> IntStream.range(0, featuresCnt). + mapToObj(fIdx -> ProjectionsCache.key(fIdx, ind / BLOCK_SIZE, input.affinityKey(fIdx, Ignition.localIgnite()), uuid)). + collect(Collectors.toSet()); + + int rc = regsCnt; + + // Perform split. + CacheUtils.update(prjsCache.getName(), ignite, + (Ignite ign, Cache.Entry<RegionKey, List<RegionProjection>> e) -> { + RegionKey k = e.getKey(); + + List<RegionProjection> leftBlock = e.getValue(); + + int fIdx = k.featureIdx(); + int idxInBlock = ind % BLOCK_SIZE; + + IgniteCache<UUID, TrainingContext<D>> ctxCache = ContextCache.getOrCreate(ign); + TrainingContext<D> ctx = ctxCache.get(uuid); + + RegionProjection targetRegProj = leftBlock.get(idxInBlock); + + IgniteBiTuple<RegionProjection, RegionProjection> regs = ctx. + performSplit(input, bs, fIdx, best.featureIdx, targetRegProj, best.info.leftData(), best.info.rightData(), ign); + + RegionProjection left = regs.get1(); + RegionProjection right = regs.get2(); + + leftBlock.set(idxInBlock, left); + RegionKey rightKey = ProjectionsCache.key(fIdx, (rc - 1) / BLOCK_SIZE, input.affinityKey(fIdx, ign), uuid); + + IgniteCache<RegionKey, List<RegionProjection>> c = ProjectionsCache.getOrCreate(ign); + + List<RegionProjection> rightBlock = rightKey.equals(k) ? leftBlock : c.localPeek(rightKey); + + if (rightBlock == null) { + List<RegionProjection> newBlock = new ArrayList<>(BLOCK_SIZE); + newBlock.add(right); + return Stream.of(new CacheEntryImpl<>(k, leftBlock), new CacheEntryImpl<>(rightKey, newBlock)); + } + else { + rightBlock.add(right); + return rightBlock.equals(k) ? + Stream.of(new CacheEntryImpl<>(k, leftBlock)) : + Stream.of(new CacheEntryImpl<>(k, leftBlock), new CacheEntryImpl<>(rightKey, rightBlock)); + } + }, + bestRegsKeys); + + X.println(">>> Update of projs cache took " + (System.currentTimeMillis() - before)); + + before = System.currentTimeMillis(); + + updateSplitCache(ind, rc, featuresCnt, ig -> i -> input.affinityKey(i, ig), uuid); + + X.println(">>> Update of split cache took " + (System.currentTimeMillis() - before)); + } + else { + X.println(">>> Best feature index: " + bestFeatureIdx + ", best infoGain " + bestInfoGain); + break; + } + } + + int rc = regsCnt; + + IgniteSupplier<Iterable<Cache.Entry<RegionKey, List<RegionProjection>>>> featZeroRegs = () -> { + IgniteCache<RegionKey, List<RegionProjection>> projsCache = ProjectionsCache.getOrCreate(Ignition.localIgnite()); + + return () -> IntStream.range(0, (rc - 1) / BLOCK_SIZE + 1). + mapToObj(rBIdx -> ProjectionsCache.key(0, rBIdx, input.affinityKey(0, Ignition.localIgnite()), uuid)). + map(k -> (Cache.Entry<RegionKey, List<RegionProjection>>)new CacheEntryImpl<>(k, projsCache.localPeek(k))).iterator(); + }; + + Map<Integer, Double> vals = CacheUtils.reduce(prjsCache.getName(), ignite, + (TrainingContext ctx, Cache.Entry<RegionKey, List<RegionProjection>> e, Map<Integer, Double> m) -> { + int regBlockIdx = e.getKey().regionBlockIndex(); + + if (e.getValue() != null) { + for (int i = 0; i < e.getValue().size(); i++) { + int regIdx = regBlockIdx * BLOCK_SIZE + i; + RegionProjection reg = e.getValue().get(i); + + Double res = regCalc.apply(Arrays.stream(reg.sampleIndexes()).mapToDouble(s -> ctx.labels()[s])); + m.put(regIdx, res); + } + } + + return m; + }, + () -> ContextCache.getOrCreate(Ignition.localIgnite()).get(uuid), + featZeroRegs, + (infos, infos2) -> { + Map<Integer, Double> res = new HashMap<>(); + res.putAll(infos); + res.putAll(infos2); + return res; + }, + HashMap::new + ); + + int i = 0; + for (TreeTip tip : tips) { + tip.leafSetter.accept(new Leaf(vals.get(i))); + i++; + } + + ProjectionsCache.clear(featuresCnt, rc, input::affinityKey, uuid, ignite); + ContextCache.getOrCreate(ignite).remove(uuid); + FeaturesCache.clear(featuresCnt, input::affinityKey, uuid, ignite); + SplitCache.clear(featuresCnt, input::affinityKey, uuid, ignite); + + return new DecisionTreeModel(root.s); + } + + /** + * Find the best split in the form (feature index, (index of region with the best split, impurity of region with the + * best split)). + * + * @param featuresCnt Count of features. + * @param affinity Affinity function. + * @param trainingUUID UUID of training. + * @return Best split in the form (feature index, (index of region with the best split, impurity of region with the + * best split)). + */ + private IgniteBiTuple<Integer, IgniteBiTuple<Integer, Double>> findBestSplitIndexForFeatures(int featuresCnt, + IgniteBiFunction<Integer, Ignite, Object> affinity, + UUID trainingUUID) { + Set<Integer> featureIndexes = IntStream.range(0, featuresCnt).boxed().collect(Collectors.toSet()); + + return CacheUtils.reduce(SplitCache.CACHE_NAME, ignite, + (Object ctx, Cache.Entry<SplitKey, IgniteBiTuple<Integer, Double>> e, IgniteBiTuple<Integer, IgniteBiTuple<Integer, Double>> r) -> + Functions.MAX_GENERIC(new IgniteBiTuple<>(e.getKey().featureIdx(), e.getValue()), r, comparator()), + () -> null, + () -> SplitCache.localEntries(featureIndexes, affinity, trainingUUID), + (i1, i2) -> Functions.MAX_GENERIC(i1, i2, Comparator.comparingDouble(bt -> bt.get2().get2())), + () -> new IgniteBiTuple<>(-1, new IgniteBiTuple<>(-1, Double.NEGATIVE_INFINITY)) + ); + } + + /** */ + private static Comparator<IgniteBiTuple<Integer, IgniteBiTuple<Integer, Double>>> comparator() { + return Comparator.comparingDouble(bt -> bt != null && bt.get2() != null ? bt.get2().get2() : Double.NEGATIVE_INFINITY); + } + + /** + * Update split cache. + * + * @param lastSplitRegionIdx Index of region which had last best split. + * @param regsCnt Count of regions. + * @param featuresCnt Count of features. + * @param affinity Affinity function. + * @param trainingUUID UUID of current training. + */ + private void updateSplitCache(int lastSplitRegionIdx, int regsCnt, int featuresCnt, + IgniteCurriedBiFunction<Ignite, Integer, Object> affinity, + UUID trainingUUID) { + CacheUtils.update(SplitCache.CACHE_NAME, ignite, + (Ignite ign, Cache.Entry<SplitKey, IgniteBiTuple<Integer, Double>> e) -> { + Integer bestRegIdx = e.getValue().get1(); + int fIdx = e.getKey().featureIdx(); + TrainingContext ctx = ContextCache.getOrCreate(ign).get(trainingUUID); + + Map<Integer, RegionProjection> toCompare; + + // Fully recalculate best. + if (bestRegIdx == lastSplitRegionIdx) + toCompare = ProjectionsCache.projectionsOfFeature(fIdx, maxDepth, regsCnt, BLOCK_SIZE, affinity.apply(ign), trainingUUID, ign); + // Just compare previous best and two regions which are produced by split. + else + toCompare = ProjectionsCache.projectionsOfRegions(fIdx, maxDepth, + IntStream.of(bestRegIdx, lastSplitRegionIdx, regsCnt - 1), BLOCK_SIZE, affinity.apply(ign), trainingUUID, ign); + + double[] values = ctx.values(fIdx, ign); + double[] labels = ctx.labels(); + + IgniteBiTuple<Integer, Double> max = toCompare.entrySet().stream(). + map(ent -> { + SplitInfo bestSplit = ctx.featureProcessor(fIdx).findBestSplit(ent.getValue(), values, labels, ent.getKey()); + return new IgniteBiTuple<>(ent.getKey(), bestSplit != null ? bestSplit.infoGain() : Double.NEGATIVE_INFINITY); + }). + max(Comparator.comparingDouble(IgniteBiTuple::get2)). + get(); + + return Stream.of(new CacheEntryImpl<>(e.getKey(), max)); + }, + () -> IntStream.range(0, featuresCnt).mapToObj(fIdx -> SplitCache.key(fIdx, affinity.apply(ignite).apply(fIdx), trainingUUID)).collect(Collectors.toSet()) + ); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/ColumnDecisionTreeTrainerInput.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/ColumnDecisionTreeTrainerInput.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/ColumnDecisionTreeTrainerInput.java new file mode 100644 index 0000000..94331f7 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/ColumnDecisionTreeTrainerInput.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.trees.trainers.columnbased; + +import java.util.Map; +import java.util.stream.Stream; +import org.apache.ignite.Ignite; +import org.apache.ignite.lang.IgniteBiTuple; + +/** + * Input for {@see ColumnDecisionTreeTrainer}. + */ +public interface ColumnDecisionTreeTrainerInput { + /** + * Projection of data on feature with the given index. + * + * @param idx Feature index. + * @return Projection of data on feature with the given index. + */ + Stream<IgniteBiTuple<Integer, Double>> values(int idx); + + /** + * Labels. + * + * @param ignite Ignite instance. + */ + double[] labels(Ignite ignite); + + /** Information about which features are categorical in the form of feature index -> number of categories. */ + Map<Integer, Integer> catFeaturesInfo(); + + /** Number of features. */ + int featuresCount(); + + /** + * Get affinity key for the given column index. + * Affinity key should be pure-functionally dependent from idx. + */ + Object affinityKey(int idx, Ignite ignite); +} http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/MatrixColumnDecisionTreeTrainerInput.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/MatrixColumnDecisionTreeTrainerInput.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/MatrixColumnDecisionTreeTrainerInput.java new file mode 100644 index 0000000..9a11902 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/MatrixColumnDecisionTreeTrainerInput.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.trees.trainers.columnbased; + +import java.util.HashMap; +import java.util.Map; +import java.util.stream.DoubleStream; +import java.util.stream.IntStream; +import java.util.stream.Stream; +import javax.cache.Cache; +import org.apache.ignite.Ignite; +import org.apache.ignite.lang.IgniteBiTuple; +import org.apache.ignite.ml.math.distributed.keys.RowColMatrixKey; +import org.apache.ignite.ml.math.distributed.keys.impl.SparseMatrixKey; +import org.apache.ignite.ml.math.functions.IgniteFunction; +import org.apache.ignite.ml.math.impls.matrix.SparseDistributedMatrix; +import org.apache.ignite.ml.math.impls.storage.matrix.SparseDistributedMatrixStorage; +import org.jetbrains.annotations.NotNull; + +/** + * Adapter of SparseDistributedMatrix to ColumnDecisionTreeTrainerInput. + * Sparse SparseDistributedMatrix should be in {@see org.apache.ignite.ml.math.StorageConstants#COLUMN_STORAGE_MODE} and + * should contain samples in rows last position in row being label of this sample. + */ +public class MatrixColumnDecisionTreeTrainerInput extends CacheColumnDecisionTreeTrainerInput<RowColMatrixKey, Map<Integer, Double>> { + /** + * @param m Sparse SparseDistributedMatrix should be in {@see org.apache.ignite.ml.math.StorageConstants#COLUMN_STORAGE_MODE} + * containing samples in rows last position in row being label of this sample. + * @param catFeaturesInfo Information about which features are categorical in form of feature index -> number of + * categories. + */ + public MatrixColumnDecisionTreeTrainerInput(SparseDistributedMatrix m, Map<Integer, Integer> catFeaturesInfo) { + super(((SparseDistributedMatrixStorage)m.getStorage()).cache(), + () -> Stream.of(new SparseMatrixKey(m.columnSize() - 1, m.getUUID(), m.columnSize() - 1)), + valuesMapper(m), + labels(m), + keyMapper(m), + catFeaturesInfo, + m.columnSize() - 1, + m.rowSize()); + } + + /** Values mapper. See {@link CacheColumnDecisionTreeTrainerInput#valuesMapper} */ + @NotNull + private static IgniteFunction<Cache.Entry<RowColMatrixKey, Map<Integer, Double>>, Stream<IgniteBiTuple<Integer, Double>>> valuesMapper( + SparseDistributedMatrix m) { + return ent -> { + Map<Integer, Double> map = ent.getValue() != null ? ent.getValue() : new HashMap<>(); + return IntStream.range(0, m.rowSize()).mapToObj(k -> new IgniteBiTuple<>(k, map.getOrDefault(k, 0.0))); + }; + } + + /** Key mapper. See {@link CacheColumnDecisionTreeTrainerInput#keyMapper} */ + @NotNull private static IgniteFunction<Integer, Stream<RowColMatrixKey>> keyMapper(SparseDistributedMatrix m) { + return i -> Stream.of(new SparseMatrixKey(i, ((SparseDistributedMatrixStorage)m.getStorage()).getUUID(), i)); + } + + /** Labels mapper. See {@link CacheColumnDecisionTreeTrainerInput#labelsMapper} */ + @NotNull private static IgniteFunction<Map<Integer, Double>, DoubleStream> labels(SparseDistributedMatrix m) { + return mp -> IntStream.range(0, m.rowSize()).mapToDouble(k -> mp.getOrDefault(k, 0.0)); + } + + /** {@inheritDoc} */ + @Override public Object affinityKey(int idx, Ignite ignite) { + return idx; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/RegionProjection.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/RegionProjection.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/RegionProjection.java new file mode 100644 index 0000000..e95f57b --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/RegionProjection.java @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.trees.trainers.columnbased; + +import java.io.Externalizable; +import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; +import org.apache.ignite.ml.trees.RegionInfo; + +/** + * Projection of region on given feature. + * + * @param <D> Data of region. + */ +public class RegionProjection<D extends RegionInfo> implements Externalizable { + /** Samples projections. */ + protected Integer[] sampleIndexes; + + /** Region data */ + protected D data; + + /** Depth of this region. */ + protected int depth; + + /** + * @param sampleIndexes Samples indexes. + * @param data Region data. + * @param depth Depth of this region. + */ + public RegionProjection(Integer[] sampleIndexes, D data, int depth) { + this.data = data; + this.depth = depth; + this.sampleIndexes = sampleIndexes; + } + + /** + * No-op constructor used for serialization/deserialization. + */ + public RegionProjection() { + // No-op. + } + + /** + * Get samples indexes. + * + * @return Samples indexes. + */ + public Integer[] sampleIndexes() { + return sampleIndexes; + } + + /** + * Get region data. + * + * @return Region data. + */ + public D data() { + return data; + } + + /** + * Get region depth. + * + * @return Region depth. + */ + public int depth() { + return depth; + } + + /** {@inheritDoc} */ + @Override public void writeExternal(ObjectOutput out) throws IOException { + out.writeInt(sampleIndexes.length); + + for (Integer sampleIndex : sampleIndexes) + out.writeInt(sampleIndex); + + out.writeObject(data); + out.writeInt(depth); + } + + /** {@inheritDoc} */ + @Override public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + int size = in.readInt(); + + sampleIndexes = new Integer[size]; + + for (int i = 0; i < size; i++) + sampleIndexes[i] = in.readInt(); + + data = (D)in.readObject(); + depth = in.readInt(); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/TrainingContext.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/TrainingContext.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/TrainingContext.java new file mode 100644 index 0000000..6415dab --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/TrainingContext.java @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.trees.trainers.columnbased; + +import com.zaxxer.sparsebits.SparseBitSet; +import java.util.Map; +import java.util.UUID; +import java.util.stream.DoubleStream; +import org.apache.ignite.Ignite; +import org.apache.ignite.IgniteCache; +import org.apache.ignite.Ignition; +import org.apache.ignite.lang.IgniteBiTuple; +import org.apache.ignite.ml.math.functions.IgniteFunction; +import org.apache.ignite.ml.trees.ContinuousRegionInfo; +import org.apache.ignite.ml.trees.ContinuousSplitCalculator; +import org.apache.ignite.ml.trees.RegionInfo; +import org.apache.ignite.ml.trees.trainers.columnbased.caches.FeaturesCache; +import org.apache.ignite.ml.trees.trainers.columnbased.vectors.CategoricalFeatureProcessor; +import org.apache.ignite.ml.trees.trainers.columnbased.vectors.ContinuousFeatureProcessor; +import org.apache.ignite.ml.trees.trainers.columnbased.vectors.FeatureProcessor; + +import static org.apache.ignite.ml.trees.trainers.columnbased.caches.FeaturesCache.COLUMN_DECISION_TREE_TRAINER_FEATURES_CACHE_NAME; + +/** + * Context of training with {@link ColumnDecisionTreeTrainer}. + * + * @param <D> Class for storing of information used in calculation of impurity of continuous feature region. + */ +public class TrainingContext<D extends ContinuousRegionInfo> { + /** Input for training with {@link ColumnDecisionTreeTrainer}. */ + private final ColumnDecisionTreeTrainerInput input; + + /** Labels. */ + private final double[] labels; + + /** Calculator used for finding splits of region of continuous features. */ + private final ContinuousSplitCalculator<D> continuousSplitCalculator; + + /** Calculator used for finding splits of region of categorical feature. */ + private final IgniteFunction<DoubleStream, Double> categoricalSplitCalculator; + + /** UUID of current training. */ + private final UUID trainingUUID; + + /** + * Construct context for training with {@link ColumnDecisionTreeTrainer}. + * + * @param input Input for training. + * @param continuousSplitCalculator Calculator used for calculations of splits of continuous features regions. + * @param categoricalSplitCalculator Calculator used for calculations of splits of categorical features regions. + * @param trainingUUID UUID of the current training. + * @param ignite Ignite instance. + */ + public TrainingContext(ColumnDecisionTreeTrainerInput input, + ContinuousSplitCalculator<D> continuousSplitCalculator, + IgniteFunction<DoubleStream, Double> categoricalSplitCalculator, + UUID trainingUUID, + Ignite ignite) { + this.input = input; + this.labels = input.labels(ignite); + this.continuousSplitCalculator = continuousSplitCalculator; + this.categoricalSplitCalculator = categoricalSplitCalculator; + this.trainingUUID = trainingUUID; + } + + /** + * Get processor used for calculating splits of categorical features. + * + * @param catsCnt Count of categories. + * @return Processor used for calculating splits of categorical features. + */ + public CategoricalFeatureProcessor categoricalFeatureProcessor(int catsCnt) { + return new CategoricalFeatureProcessor(categoricalSplitCalculator, catsCnt); + } + + /** + * Get processor used for calculating splits of continuous features. + * + * @return Processor used for calculating splits of continuous features. + */ + public ContinuousFeatureProcessor<D> continuousFeatureProcessor() { + return new ContinuousFeatureProcessor<>(continuousSplitCalculator); + } + + /** + * Get labels. + * + * @return Labels. + */ + public double[] labels() { + return labels; + } + + /** + * Get values of feature with given index. + * + * @param featIdx Feature index. + * @param ignite Ignite instance. + * @return Values of feature with given index. + */ + public double[] values(int featIdx, Ignite ignite) { + IgniteCache<FeaturesCache.FeatureKey, double[]> featuresCache = ignite.getOrCreateCache(COLUMN_DECISION_TREE_TRAINER_FEATURES_CACHE_NAME); + return featuresCache.localPeek(FeaturesCache.getFeatureCacheKey(featIdx, trainingUUID, input.affinityKey(featIdx, ignite))); + } + + /** + * Perform best split on the given region projection. + * + * @param input Input of {@link ColumnDecisionTreeTrainer} performing split. + * @param bitSet Bit set specifying split. + * @param targetFeatIdx Index of feature for performing split. + * @param bestFeatIdx Index of feature with best split. + * @param targetRegionPrj Projection of region to split on feature with index {@code featureIdx}. + * @param leftData Data of left region of split. + * @param rightData Data of right region of split. + * @param ignite Ignite instance. + * @return Perform best split on the given region projection. + */ + public IgniteBiTuple<RegionProjection, RegionProjection> performSplit(ColumnDecisionTreeTrainerInput input, + SparseBitSet bitSet, int targetFeatIdx, int bestFeatIdx, RegionProjection targetRegionPrj, RegionInfo leftData, + RegionInfo rightData, Ignite ignite) { + + Map<Integer, Integer> catFeaturesInfo = input.catFeaturesInfo(); + + if (!catFeaturesInfo.containsKey(targetFeatIdx) && !catFeaturesInfo.containsKey(bestFeatIdx)) + return continuousFeatureProcessor().performSplit(bitSet, targetRegionPrj, (D)leftData, (D)rightData); + else if (catFeaturesInfo.containsKey(targetFeatIdx)) + return categoricalFeatureProcessor(catFeaturesInfo.get(targetFeatIdx)).performSplitGeneric(bitSet, values(targetFeatIdx, ignite), targetRegionPrj, leftData, rightData); + return continuousFeatureProcessor().performSplitGeneric(bitSet, labels, targetRegionPrj, leftData, rightData); + } + + /** + * Processor used for calculating splits for feature with the given index. + * + * @param featureIdx Index of feature to process. + * @return Processor used for calculating splits for feature with the given index. + */ + public FeatureProcessor featureProcessor(int featureIdx) { + return input.catFeaturesInfo().containsKey(featureIdx) ? categoricalFeatureProcessor(input.catFeaturesInfo().get(featureIdx)) : continuousFeatureProcessor(); + } + + /** + * Shortcut for affinity key. + * + * @param idx Feature index. + * @return Affinity key. + */ + public Object affinityKey(int idx) { + return input.affinityKey(idx, Ignition.localIgnite()); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/caches/ContextCache.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/caches/ContextCache.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/caches/ContextCache.java new file mode 100644 index 0000000..51ea359 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/caches/ContextCache.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.trees.trainers.columnbased.caches; + +import java.util.UUID; +import org.apache.ignite.Ignite; +import org.apache.ignite.IgniteCache; +import org.apache.ignite.cache.CacheAtomicityMode; +import org.apache.ignite.cache.CacheMode; +import org.apache.ignite.cache.CacheWriteSynchronizationMode; +import org.apache.ignite.configuration.CacheConfiguration; +import org.apache.ignite.ml.trees.ContinuousRegionInfo; +import org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainer; +import org.apache.ignite.ml.trees.trainers.columnbased.TrainingContext; + +/** + * Class for operations related to cache containing training context for {@link ColumnDecisionTreeTrainer}. + */ +public class ContextCache { + /** + * Name of cache containing training context for {@link ColumnDecisionTreeTrainer}. + */ + public static final String COLUMN_DECISION_TREE_TRAINER_CONTEXT_CACHE_NAME = "COLUMN_DECISION_TREE_TRAINER_CONTEXT_CACHE_NAME"; + + /** + * Get or create cache for training context. + * + * @param ignite Ignite instance. + * @param <D> Class storing information about continuous regions. + * @return Cache for training context. + */ + public static <D extends ContinuousRegionInfo> IgniteCache<UUID, TrainingContext<D>> getOrCreate(Ignite ignite) { + CacheConfiguration<UUID, TrainingContext<D>> cfg = new CacheConfiguration<>(); + + cfg.setWriteSynchronizationMode(CacheWriteSynchronizationMode.FULL_SYNC); + + cfg.setAtomicityMode(CacheAtomicityMode.ATOMIC); + + cfg.setEvictionPolicy(null); + + cfg.setCopyOnRead(false); + + cfg.setCacheMode(CacheMode.REPLICATED); + + cfg.setOnheapCacheEnabled(true); + + cfg.setReadFromBackup(true); + + cfg.setName(COLUMN_DECISION_TREE_TRAINER_CONTEXT_CACHE_NAME); + + return ignite.getOrCreateCache(cfg); + } +}
