gaoyunhaii commented on a change in pull request #20: URL: https://github.com/apache/flink-ml/pull/20#discussion_r772116079
########## File path: flink-ml-core/src/main/java/org/apache/flink/ml/builder/GraphExecutionHelper.java ########## @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.builder; + +import org.apache.flink.table.api.Table; +import org.apache.flink.util.Preconditions; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Deque; +import java.util.HashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; + +/** + * A container class that maintains the execution state of the graph (e.g. which nodes are ready to + * run). + */ +class GraphExecutionHelper { + // A map from tableId to the list of nodes which take this table as input. + private final Map<TableId, List<GraphNode>> consumerNodes = new HashMap<>(); + // A map from tableId to the corresponding table. The table value is null if it has not Review comment: HashMap should not support null values, does this mean we do not store the tables not constructed yet? ########## File path: flink-ml-core/src/main/java/org/apache/flink/ml/builder/Graph.java ########## @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.builder; + +import org.apache.flink.annotation.PublicEvolving; +import org.apache.flink.ml.api.AlgoOperator; +import org.apache.flink.ml.api.Estimator; +import org.apache.flink.ml.api.Model; +import org.apache.flink.ml.api.Stage; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.apache.flink.ml.builder.GraphNode.StageType; + +/** + * A Graph acts as an Estimator. A Graph consists of a DAG of stages, each of which could be an + * Estimator, Model, Transformer or AlgoOperator. When `Graph::fit` is called, the stages are + * executed in a topologically-sorted order. If a stage is an Estimator, its `Estimator::fit` method + * will be called on the input tables (from the input edges) to fit a Model. Then the Model will be + * used to transform the input tables and produce output tables to the output edges. If a stage is + * an AlgoOperator, its `AlgoOperator::transform` method will be called on the input tables and + * produce output tables to the output edges. The GraphModel fitted from a Graph consists of the + * fitted Models and AlgoOperators, corresponding to the Graph's stages. + */ +@PublicEvolving +public final class Graph implements Estimator<Graph, GraphModel> { + private static final long serialVersionUID = 6354253958813529308L; + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + private final List<GraphNode> nodes; + private final TableId[] estimatorInputIds; + private final TableId[] modelInputIds; + private final TableId[] outputIds; + private final TableId[] inputModelDataIds; + private final TableId[] outputModelDataIds; + + public Graph( + List<GraphNode> nodes, + TableId[] estimatorInputIds, + TableId[] modelInputs, + TableId[] outputs, + TableId[] inputModelDataIds, + TableId[] outputModelDataIds) { + this.nodes = nodes; Review comment: checkNotNull ? ########## File path: flink-ml-core/src/main/java/org/apache/flink/ml/builder/GraphExecutionHelper.java ########## @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.builder; + +import org.apache.flink.table.api.Table; +import org.apache.flink.util.Preconditions; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Deque; +import java.util.HashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; + +/** + * A container class that maintains the execution state of the graph (e.g. which nodes are ready to + * run). + */ +class GraphExecutionHelper { + // A map from tableId to the list of nodes which take this table as input. + private final Map<TableId, List<GraphNode>> consumerNodes = new HashMap<>(); + // A map from tableId to the corresponding table. The table value is null if it has not + // been constructed yet. + private final Map<TableId, Table> constructedTables = new HashMap<>(); + // A map that maintains the number of input tables that have not been constructed + // for each node in the graph. + private final Map<GraphNode, Integer> numUnConstructedInputTables = new HashMap<>(); + // An ordered list of nodes whose input tables have all been constructed AND who has not + // been fetch via pollNextReadyNode. + private final Deque<GraphNode> unFetchedReadyNodes = new LinkedList<>(); + + public GraphExecutionHelper(List<GraphNode> nodes) { + // Initializes dependentNodes and numUnConstructedInputs. + for (GraphNode node : nodes) { + List<TableId> inputs = new ArrayList<>(); + inputs.addAll(Arrays.asList(node.algoOpInputIds)); + if (node.estimatorInputIds != null) { + inputs.addAll(Arrays.asList(node.estimatorInputIds)); + } + if (node.inputModelDataIds != null) { + inputs.addAll(Arrays.asList(node.inputModelDataIds)); + } + numUnConstructedInputTables.put(node, inputs.size()); + for (TableId tableId : inputs) { + consumerNodes.putIfAbsent(tableId, new ArrayList<>()); Review comment: Should be chained as `consumerNodes.putIfAbsent(tableId, new ArrayList<>()).add(node)` ########## File path: flink-ml-core/src/main/java/org/apache/flink/ml/builder/GraphExecutionHelper.java ########## @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.builder; + +import org.apache.flink.table.api.Table; +import org.apache.flink.util.Preconditions; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Deque; +import java.util.HashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; + +/** + * A container class that maintains the execution state of the graph (e.g. which nodes are ready to + * run). + */ +class GraphExecutionHelper { + // A map from tableId to the list of nodes which take this table as input. + private final Map<TableId, List<GraphNode>> consumerNodes = new HashMap<>(); + // A map from tableId to the corresponding table. The table value is null if it has not + // been constructed yet. + private final Map<TableId, Table> constructedTables = new HashMap<>(); + // A map that maintains the number of input tables that have not been constructed + // for each node in the graph. + private final Map<GraphNode, Integer> numUnConstructedInputTables = new HashMap<>(); + // An ordered list of nodes whose input tables have all been constructed AND who has not + // been fetch via pollNextReadyNode. + private final Deque<GraphNode> unFetchedReadyNodes = new LinkedList<>(); + + public GraphExecutionHelper(List<GraphNode> nodes) { + // Initializes dependentNodes and numUnConstructedInputs. + for (GraphNode node : nodes) { + List<TableId> inputs = new ArrayList<>(); + inputs.addAll(Arrays.asList(node.algoOpInputIds)); + if (node.estimatorInputIds != null) { + inputs.addAll(Arrays.asList(node.estimatorInputIds)); + } + if (node.inputModelDataIds != null) { + inputs.addAll(Arrays.asList(node.inputModelDataIds)); + } + numUnConstructedInputTables.put(node, inputs.size()); + for (TableId tableId : inputs) { + consumerNodes.putIfAbsent(tableId, new ArrayList<>()); + consumerNodes.get(tableId).add(node); + } + } + } + + public void setTables(TableId[] tableIds, Table[] tables) { + Preconditions.checkArgument( + tableIds.length >= tables.length, Review comment: When would `tableIds.length > tables.length` ? ########## File path: flink-ml-core/src/main/java/org/apache/flink/ml/builder/GraphExecutionHelper.java ########## @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.builder; + +import org.apache.flink.table.api.Table; +import org.apache.flink.util.Preconditions; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Deque; +import java.util.HashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; + +/** + * A container class that maintains the execution state of the graph (e.g. which nodes are ready to + * run). + */ +class GraphExecutionHelper { + // A map from tableId to the list of nodes which take this table as input. + private final Map<TableId, List<GraphNode>> consumerNodes = new HashMap<>(); + // A map from tableId to the corresponding table. The table value is null if it has not + // been constructed yet. + private final Map<TableId, Table> constructedTables = new HashMap<>(); + // A map that maintains the number of input tables that have not been constructed + // for each node in the graph. + private final Map<GraphNode, Integer> numUnConstructedInputTables = new HashMap<>(); + // An ordered list of nodes whose input tables have all been constructed AND who has not + // been fetch via pollNextReadyNode. + private final Deque<GraphNode> unFetchedReadyNodes = new LinkedList<>(); + + public GraphExecutionHelper(List<GraphNode> nodes) { + // Initializes dependentNodes and numUnConstructedInputs. + for (GraphNode node : nodes) { + List<TableId> inputs = new ArrayList<>(); + inputs.addAll(Arrays.asList(node.algoOpInputIds)); + if (node.estimatorInputIds != null) { + inputs.addAll(Arrays.asList(node.estimatorInputIds)); + } + if (node.inputModelDataIds != null) { + inputs.addAll(Arrays.asList(node.inputModelDataIds)); + } + numUnConstructedInputTables.put(node, inputs.size()); + for (TableId tableId : inputs) { + consumerNodes.putIfAbsent(tableId, new ArrayList<>()); + consumerNodes.get(tableId).add(node); + } + } + } + + public void setTables(TableId[] tableIds, Table[] tables) { + Preconditions.checkArgument( + tableIds.length >= tables.length, + "the length of tablesIds %s is less than the length of tables %s", + tableIds.length, + tables.length); + for (int i = 0; i < tables.length; i++) { + setTable(tableIds[i], tables[i]); + } + } + + private void setTable(TableId tableId, Table table) { + Preconditions.checkArgument( + !constructedTables.containsKey(tableId), + "the table with id=%s has already been constructed", + tableId.toString()); + constructedTables.put(tableId, table); + + for (GraphNode node : consumerNodes.getOrDefault(tableId, new ArrayList<>())) { + int prevNum = numUnConstructedInputTables.get(node); + numUnConstructedInputTables.put(node, prevNum - 1); Review comment: Change to ``` if (prevNum == 1) { unFetchedReadyNodes.addLast(node); numUnConstructedInputTables.remove(node); } else { numUnConstructedInputTables.put(node, prevNum - 1); } ``` ? ########## File path: flink-ml-core/src/main/java/org/apache/flink/ml/builder/GraphBuilder.java ########## @@ -0,0 +1,431 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.builder; + +import org.apache.flink.annotation.PublicEvolving; +import org.apache.flink.ml.api.AlgoOperator; +import org.apache.flink.ml.api.Estimator; +import org.apache.flink.ml.api.Model; +import org.apache.flink.ml.api.Stage; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.apache.flink.ml.builder.GraphNode.StageType; + +/** + * A GraphBuilder provides APIs to build Estimator/Model/AlgoOperator from a DAG of stages, each of + * which could be an Estimator, Model, Transformer or AlgoOperator. + */ +@PublicEvolving +public final class GraphBuilder { + + private int maxOutputLength = 20; Review comment: Could you elaborate me a bit why we need `maxOutputLength = 20` here ? Why don't we let users to specify the number of outputs when adding nodes? ########## File path: flink-ml-core/src/main/java/org/apache/flink/ml/builder/GraphBuilder.java ########## @@ -0,0 +1,431 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.builder; + +import org.apache.flink.annotation.PublicEvolving; +import org.apache.flink.ml.api.AlgoOperator; +import org.apache.flink.ml.api.Estimator; +import org.apache.flink.ml.api.Model; +import org.apache.flink.ml.api.Stage; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.apache.flink.ml.builder.GraphNode.StageType; + +/** + * A GraphBuilder provides APIs to build Estimator/Model/AlgoOperator from a DAG of stages, each of + * which could be an Estimator, Model, Transformer or AlgoOperator. + */ +@PublicEvolving +public final class GraphBuilder { + + private int maxOutputLength = 20; + private int nextTableId = 0; + private int nextNodeId = 0; + // An ordered list of nodes in the graph. + private final List<GraphNode> nodes = new ArrayList<>(); + // A map from stage instance to the corresponding node in the graph. + private final Map<Stage<?>, GraphNode> existingNodes = new HashMap<>(); + + public GraphBuilder() {} + + /** + * Specifies the loose upper bound of the number of output tables that can be returned by the + * Model::getModelData() and AlgoOperator::transform() methods, for any stage involved in this + * Graph. + * + * <p>The default upper bound is 20. + */ + public GraphBuilder setMaxOutputTableNum(int maxOutputLength) { + this.maxOutputLength = maxOutputLength; + return this; + } + + /** + * Creates a TableId associated with this GraphBuilder. It can be used to specify the passing of + * tables between stages, as well as the input/output tables of the Graph/GraphModel generated + * by this builder. + * + * @return A TableId. + */ + public TableId createTableId() { + return new TableId(nextTableId++); + } + + /** + * Adds an AlgoOperator in the graph. + * + * <p>When the graph runs as Estimator, the transform() of the given AlgoOperator would be + * invoked with the given inputs. Then when the GraphModel fitted by this graph runs, the + * transform() of the given AlgoOperator would be invoked with the given inputs. + * + * <p>When the graph runs as AlgoOperator or Model, the transform() of the given AlgoOperator + * would be invoked with the given inputs. + * + * <p>NOTE: the number of the returned TableIds does not represent the actual number of Tables + * outputted by transform(). This number could be configured using {@link + * #setMaxOutputTableNum(int)}. Users should make sure that this number >= the actual number of + * Tables outputted by transform(). + * + * @param algoOp An AlgoOperator instance. + * @param inputs A list of TableIds which represents inputs to transform() of the given + * AlgoOperator. + * @return A list of TableIds which represents the outputs of transform() of the given + * AlgoOperator. + */ + public TableId[] addAlgoOperator(AlgoOperator<?> algoOp, TableId... inputs) { + return addStage(algoOp, StageType.ALGO_OPERATOR, null, inputs); + } + + /** + * Adds an Estimator in the graph. + * + * <p>When the graph runs as Estimator, the fit() of the given Estimator would be invoked with + * the given inputs. Then when the GraphModel fitted by this graph runs, the transform() of the + * Model fitted by the given Estimator would be invoked with the given inputs. + * + * <p>When the graph runs as AlgoOperator or Model, the fit() of the given Estimator would be + * invoked with the given inputs, then the transform() of the Model fitted by the given + * Estimator would be invoked with the given inputs. + * + * <p>NOTE: the number of the returned TableIds does not represent the actual number of Tables + * outputted by transform(). This number could be configured using {@link + * #setMaxOutputTableNum(int)}. Users should make sure that this number >= the actual number of + * Tables outputted by transform(). + * + * @param estimator An Estimator instance. + * @param inputs A list of TableIds which represents inputs to fit() of the given Estimator as + * well as inputs to transform() of the Model fitted by the given Estimator. + * @return A list of TableIds which represents the outputs of transform() of the Model fitted by + * the given Estimator. + */ + public TableId[] addEstimator(Estimator<?, ?> estimator, TableId... inputs) { + return addEstimator(estimator, inputs, inputs); + } + + /** + * Adds an Estimator in the graph. + * + * <p>When the graph runs as Estimator, the fit() of the given Estimator would be invoked with + * estimatorInputs. Then when the GraphModel fitted by this graph runs, the transform() of the + * Model fitted by the given Estimator would be invoked with modelInputs. + * + * <p>When the graph runs as AlgoOperator or Model, the fit() of the given Estimator would be + * invoked with estimatorInputs, then the transform() of the Model fitted by the given Estimator + * would be invoked with modelInputs. + * + * <p>NOTE: the number of the returned TableIds does not represent the actual number of Tables + * outputted by transform(). This number could be configured using {@link + * #setMaxOutputTableNum(int)}. Users should make sure that this number >= the actual number of + * Tables outputted by transform(). + * + * @param estimator An Estimator instance. + * @param estimatorInputs A list of TableIds which represents inputs to fit() of the given + * Estimator. + * @param modelInputs A list of TableIds which represents inputs to transform() of the Model + * fitted by the given Estimator. + * @return A list of TableIds which represents the outputs of transform() of the Model fitted by + * the given Estimator. + */ + public TableId[] addEstimator( + Estimator<?, ?> estimator, TableId[] estimatorInputs, TableId[] modelInputs) { + return addStage(estimator, StageType.ESTIMATOR, estimatorInputs, modelInputs); + } + + /** + * When the graph runs as Estimator, it first generates a GraphModel that contains the Model + * fitted by the given Estimator. Then when this GraphModel runs, the setModelData() of the + * fitted Model would be invoked with the given inputs before its transform() is invoked. + * + * <p>When the graph runs as AlgoOperator or Model, the setModelData() of the Model fitted by + * the given Estimator would be invoked with the given inputs before its transform() is invoked. + * + * @param estimator An Estimator instance. + * @param inputs A list of TableIds which represents inputs to setModelData() of the Model + * fitted by the given Estimator. + */ + public void setModelDataOnEstimator(Estimator<?, ?> estimator, TableId... inputs) { + GraphNode node = existingNodes.get(estimator); + if (node == null) { Review comment: Could be simplified with `checkState(boolean, "");`, similar to other checks ########## File path: flink-ml-core/src/main/java/org/apache/flink/ml/builder/GraphExecutionHelper.java ########## @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.builder; + +import org.apache.flink.table.api.Table; +import org.apache.flink.util.Preconditions; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Deque; +import java.util.HashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; + +/** + * A container class that maintains the execution state of the graph (e.g. which nodes are ready to + * run). + */ +class GraphExecutionHelper { + // A map from tableId to the list of nodes which take this table as input. Review comment: Change the comment to `/** */` ? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
