http://git-wip-us.apache.org/repos/asf/flink/blob/e9bf13d8/flink-contrib/flink-tez/src/main/java/org/apache/flink/tez/dag/FlinkVertex.java ---------------------------------------------------------------------- diff --git a/flink-contrib/flink-tez/src/main/java/org/apache/flink/tez/dag/FlinkVertex.java b/flink-contrib/flink-tez/src/main/java/org/apache/flink/tez/dag/FlinkVertex.java new file mode 100644 index 0000000..883acc6 --- /dev/null +++ b/flink-contrib/flink-tez/src/main/java/org/apache/flink/tez/dag/FlinkVertex.java @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.tez.dag; + + +import org.apache.flink.tez.runtime.TezTaskConfig; +import org.apache.tez.dag.api.TezConfiguration; +import org.apache.tez.dag.api.Vertex; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.UUID; + +public abstract class FlinkVertex { + + protected Vertex cached; + private String taskName; + private int parallelism; + protected TezTaskConfig taskConfig; + + // Tez-specific bookkeeping + protected String uniqueName; //Unique name in DAG + private Map<FlinkVertex,ArrayList<Integer>> inputPositions; + private ArrayList<Integer> numberOfSubTasksInOutputs; + + public TezTaskConfig getConfig() { + return taskConfig; + } + + public FlinkVertex(String taskName, int parallelism, TezTaskConfig taskConfig) { + this.cached = null; + this.taskName = taskName; + this.parallelism = parallelism; + this.taskConfig = taskConfig; + this.uniqueName = taskName + UUID.randomUUID().toString(); + this.inputPositions = new HashMap<FlinkVertex, ArrayList<Integer>>(); + this.numberOfSubTasksInOutputs = new ArrayList<Integer>(); + } + + public int getParallelism () { + return parallelism; + } + + public void setParallelism (int parallelism) { + this.parallelism = parallelism; + } + + public abstract Vertex createVertex (TezConfiguration conf); + + public Vertex getVertex () { + return cached; + } + + protected String getUniqueName () { + return uniqueName; + } + + public void addInput (FlinkVertex vertex, int position) { + if (inputPositions.containsKey(vertex)) { + inputPositions.get(vertex).add(position); + } + else { + ArrayList<Integer> lst = new ArrayList<Integer>(); + lst.add(position); + inputPositions.put(vertex,lst); + } + } + + public void addNumberOfSubTasksInOutput (int subTasks, int position) { + if (numberOfSubTasksInOutputs.isEmpty()) { + numberOfSubTasksInOutputs.add(-1); + } + int currSize = numberOfSubTasksInOutputs.size(); + for (int i = currSize; i <= position; i++) { + numberOfSubTasksInOutputs.add(i, -1); + } + numberOfSubTasksInOutputs.set(position, subTasks); + } + + // Must be called before taskConfig is written to Tez configuration + protected void writeInputPositionsToConfig () { + HashMap<String,ArrayList<Integer>> toWrite = new HashMap<String, ArrayList<Integer>>(); + for (FlinkVertex v: inputPositions.keySet()) { + String name = v.getUniqueName(); + List<Integer> positions = inputPositions.get(v); + toWrite.put(name, new ArrayList<Integer>(positions)); + } + this.taskConfig.setInputPositions(toWrite); + } + + // Must be called before taskConfig is written to Tez configuration + protected void writeSubTasksInOutputToConfig () { + this.taskConfig.setNumberSubtasksInOutput(this.numberOfSubTasksInOutputs); + } + +}
http://git-wip-us.apache.org/repos/asf/flink/blob/e9bf13d8/flink-contrib/flink-tez/src/main/java/org/apache/flink/tez/dag/TezDAGGenerator.java ---------------------------------------------------------------------- diff --git a/flink-contrib/flink-tez/src/main/java/org/apache/flink/tez/dag/TezDAGGenerator.java b/flink-contrib/flink-tez/src/main/java/org/apache/flink/tez/dag/TezDAGGenerator.java new file mode 100644 index 0000000..52f39be --- /dev/null +++ b/flink-contrib/flink-tez/src/main/java/org/apache/flink/tez/dag/TezDAGGenerator.java @@ -0,0 +1,460 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.tez.dag; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.flink.api.common.distributions.DataDistribution; +import org.apache.flink.api.common.io.InputFormat; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.common.typeutils.TypeSerializerFactory; +import org.apache.flink.configuration.ConfigConstants; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.core.io.InputSplit; +import org.apache.flink.optimizer.CompilerException; +import org.apache.flink.optimizer.dag.TempMode; +import org.apache.flink.optimizer.plan.BulkIterationPlanNode; +import org.apache.flink.optimizer.plan.Channel; +import org.apache.flink.optimizer.plan.DualInputPlanNode; +import org.apache.flink.optimizer.plan.NAryUnionPlanNode; +import org.apache.flink.optimizer.plan.OptimizedPlan; +import org.apache.flink.optimizer.plan.PlanNode; +import org.apache.flink.optimizer.plan.SingleInputPlanNode; +import org.apache.flink.optimizer.plan.SinkPlanNode; +import org.apache.flink.optimizer.plan.SourcePlanNode; +import org.apache.flink.optimizer.plan.WorksetIterationPlanNode; +import org.apache.flink.runtime.operators.DriverStrategy; +import org.apache.flink.runtime.operators.shipping.ShipStrategyType; +import org.apache.flink.runtime.operators.util.LocalStrategy; +import org.apache.flink.runtime.operators.util.TaskConfig; +import org.apache.flink.tez.runtime.TezTaskConfig; +import org.apache.flink.util.Visitor; +import org.apache.tez.dag.api.DAG; +import org.apache.tez.dag.api.TezConfiguration; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; + + +public class TezDAGGenerator implements Visitor<PlanNode> { + + private static final Log LOG = LogFactory.getLog(TezDAGGenerator.class); + + private Map<PlanNode, FlinkVertex> vertices; // a map from optimizer nodes to Tez vertices + private List<FlinkEdge> edges; + private final int defaultMaxFan; + private final TezConfiguration tezConf; + + private final float defaultSortSpillingThreshold; + + public TezDAGGenerator (TezConfiguration tezConf, Configuration config) { + this.defaultMaxFan = config.getInteger(ConfigConstants.DEFAULT_SPILLING_MAX_FAN_KEY, + ConfigConstants.DEFAULT_SPILLING_MAX_FAN); + this.defaultSortSpillingThreshold = config.getFloat(ConfigConstants.DEFAULT_SORT_SPILLING_THRESHOLD_KEY, + ConfigConstants.DEFAULT_SORT_SPILLING_THRESHOLD); + this.tezConf = tezConf; + } + + public DAG createDAG (OptimizedPlan program) throws Exception { + LOG.info ("Creating Tez DAG"); + this.vertices = new HashMap<PlanNode, FlinkVertex>(); + this.edges = new ArrayList<FlinkEdge>(); + program.accept(this); + + DAG dag = DAG.create(program.getJobName()); + for (FlinkVertex v : vertices.values()) { + dag.addVertex(v.createVertex(new TezConfiguration(tezConf))); + } + for (FlinkEdge e: edges) { + dag.addEdge(e.createEdge(new TezConfiguration(tezConf))); + } + + /* + * Temporarily throw an error until TEZ-1190 has been fixed or a workaround has been created + */ + if (containsSelfJoins()) { + throw new CompilerException("Dual-input operators with the same input (self-joins) are not yet supported"); + } + + this.vertices = null; + this.edges = null; + + LOG.info ("Tez DAG created"); + return dag; + } + + + @Override + public boolean preVisit(PlanNode node) { + if (this.vertices.containsKey(node)) { + // return false to prevent further descend + return false; + } + + if ((node instanceof BulkIterationPlanNode) || (node instanceof WorksetIterationPlanNode)) { + throw new CompilerException("Iterations are not yet supported by the Tez execution environment"); + } + + if ( (node.getBroadcastInputs() != null) && (!node.getBroadcastInputs().isEmpty())) { + throw new CompilerException("Broadcast inputs are not yet supported by the Tez execution environment"); + } + + FlinkVertex vertex = null; + + try { + if (node instanceof SourcePlanNode) { + vertex = createDataSourceVertex ((SourcePlanNode) node); + } + else if (node instanceof SinkPlanNode) { + vertex = createDataSinkVertex ((SinkPlanNode) node); + } + else if ((node instanceof SingleInputPlanNode)) { + vertex = createSingleInputVertex((SingleInputPlanNode) node); + } + else if (node instanceof DualInputPlanNode) { + vertex = createDualInputVertex((DualInputPlanNode) node); + } + else if (node instanceof NAryUnionPlanNode) { + vertex = createUnionVertex ((NAryUnionPlanNode) node); + } + else { + throw new CompilerException("Unrecognized node type: " + node.getClass().getName()); + } + + } + catch (Exception e) { + throw new CompilerException("Error translating node '" + node + "': " + e.getMessage(), e); + } + + if (vertex != null) { + this.vertices.put(node, vertex); + } + return true; + } + + @Override + public void postVisit (PlanNode node) { + try { + if (node instanceof SourcePlanNode) { + return; + } + final Iterator<Channel> inConns = node.getInputs().iterator(); + if (!inConns.hasNext()) { + throw new CompilerException("Bug: Found a non-source task with no input."); + } + int inputIndex = 0; + + FlinkVertex targetVertex = this.vertices.get(node); + TezTaskConfig targetVertexConfig = targetVertex.getConfig(); + + + while (inConns.hasNext()) { + Channel input = inConns.next(); + inputIndex += translateChannel(input, inputIndex, targetVertex, targetVertexConfig, false); + } + } + catch (Exception e) { + e.printStackTrace(); + throw new CompilerException( + "An error occurred while translating the optimized plan to a Tez DAG: " + e.getMessage(), e); + } + } + + private FlinkVertex createSingleInputVertex(SingleInputPlanNode node) throws CompilerException, IOException { + + final String taskName = node.getNodeName(); + final DriverStrategy ds = node.getDriverStrategy(); + final int dop = node.getParallelism(); + + final TezTaskConfig config= new TezTaskConfig(new Configuration()); + + config.setDriver(ds.getDriverClass()); + config.setDriverStrategy(ds); + config.setStubWrapper(node.getProgramOperator().getUserCodeWrapper()); + config.setStubParameters(node.getProgramOperator().getParameters()); + + for(int i=0;i<ds.getNumRequiredComparators();i++) { + config.setDriverComparator(node.getComparator(i), i); + } + assignDriverResources(node, config); + + return new FlinkProcessorVertex(taskName, dop, config); + } + + private FlinkVertex createDualInputVertex(DualInputPlanNode node) throws CompilerException, IOException { + final String taskName = node.getNodeName(); + final DriverStrategy ds = node.getDriverStrategy(); + final int dop = node.getParallelism(); + + final TezTaskConfig config= new TezTaskConfig(new Configuration()); + + config.setDriver(ds.getDriverClass()); + config.setDriverStrategy(ds); + config.setStubWrapper(node.getProgramOperator().getUserCodeWrapper()); + config.setStubParameters(node.getProgramOperator().getParameters()); + + if (node.getComparator1() != null) { + config.setDriverComparator(node.getComparator1(), 0); + } + if (node.getComparator2() != null) { + config.setDriverComparator(node.getComparator2(), 1); + } + if (node.getPairComparator() != null) { + config.setDriverPairComparator(node.getPairComparator()); + } + + assignDriverResources(node, config); + + LOG.info("Creating processor vertex " + taskName + " with parallelism " + dop); + + return new FlinkProcessorVertex(taskName, dop, config); + } + + private FlinkVertex createDataSinkVertex(SinkPlanNode node) throws CompilerException, IOException { + final String taskName = node.getNodeName(); + final int dop = node.getParallelism(); + + final TezTaskConfig config = new TezTaskConfig(new Configuration()); + + // set user code + config.setStubWrapper(node.getProgramOperator().getUserCodeWrapper()); + config.setStubParameters(node.getProgramOperator().getParameters()); + + LOG.info("Creating data sink vertex " + taskName + " with parallelism " + dop); + + return new FlinkDataSinkVertex(taskName, dop, config); + } + + private FlinkVertex createDataSourceVertex(SourcePlanNode node) throws CompilerException, IOException { + final String taskName = node.getNodeName(); + int dop = node.getParallelism(); + + final TezTaskConfig config= new TezTaskConfig(new Configuration()); + + config.setStubWrapper(node.getProgramOperator().getUserCodeWrapper()); + config.setStubParameters(node.getProgramOperator().getParameters()); + + InputFormat format = node.getDataSourceNode().getOperator().getFormatWrapper().getUserCodeObject(); + + config.setInputFormat(format); + + // Create as many data sources as input splits + InputSplit[] splits = format.createInputSplits((dop > 0) ? dop : 1); + dop = splits.length; + + LOG.info("Creating data source vertex " + taskName + " with parallelism " + dop); + + return new FlinkDataSourceVertex(taskName, dop, config); + } + + private FlinkVertex createUnionVertex(NAryUnionPlanNode node) throws CompilerException, IOException { + final String taskName = node.getNodeName(); + final int dop = node.getParallelism(); + final TezTaskConfig config= new TezTaskConfig(new Configuration()); + + LOG.info("Creating union vertex " + taskName + " with parallelism " + dop); + + return new FlinkUnionVertex (taskName, dop, config); + } + + + private void assignDriverResources(PlanNode node, TaskConfig config) { + final double relativeMem = node.getRelativeMemoryPerSubTask(); + if (relativeMem > 0) { + config.setRelativeMemoryDriver(relativeMem); + config.setFilehandlesDriver(this.defaultMaxFan); + config.setSpillingThresholdDriver(this.defaultSortSpillingThreshold); + } + } + + private void assignLocalStrategyResources(Channel c, TaskConfig config, int inputNum) { + if (c.getRelativeMemoryLocalStrategy() > 0) { + config.setRelativeMemoryInput(inputNum, c.getRelativeMemoryLocalStrategy()); + config.setFilehandlesInput(inputNum, this.defaultMaxFan); + config.setSpillingThresholdInput(inputNum, this.defaultSortSpillingThreshold); + } + } + + private int translateChannel(Channel input, int inputIndex, FlinkVertex targetVertex, + TezTaskConfig targetVertexConfig, boolean isBroadcast) throws Exception + { + final PlanNode inputPlanNode = input.getSource(); + final Iterator<Channel> allInChannels; + + + allInChannels = Collections.singletonList(input).iterator(); + + + // check that the type serializer is consistent + TypeSerializerFactory<?> typeSerFact = null; + + while (allInChannels.hasNext()) { + final Channel inConn = allInChannels.next(); + + if (typeSerFact == null) { + typeSerFact = inConn.getSerializer(); + } else if (!typeSerFact.equals(inConn.getSerializer())) { + throw new CompilerException("Conflicting types in union operator."); + } + + final PlanNode sourceNode = inConn.getSource(); + FlinkVertex sourceVertex = this.vertices.get(sourceNode); + TezTaskConfig sourceVertexConfig = sourceVertex.getConfig(); //TODO ??? need to create a new TezConfig ??? + + connectJobVertices( + inConn, inputIndex, sourceVertex, sourceVertexConfig, targetVertex, targetVertexConfig, isBroadcast); + } + + // the local strategy is added only once. in non-union case that is the actual edge, + // in the union case, it is the edge between union and the target node + addLocalInfoFromChannelToConfig(input, targetVertexConfig, inputIndex, isBroadcast); + return 1; + } + + private void connectJobVertices(Channel channel, int inputNumber, + final FlinkVertex sourceVertex, final TezTaskConfig sourceConfig, + final FlinkVertex targetVertex, final TezTaskConfig targetConfig, boolean isBroadcast) + throws CompilerException { + + // -------------- configure the source task's ship strategy strategies in task config -------------- + final int outputIndex = sourceConfig.getNumOutputs(); + sourceConfig.addOutputShipStrategy(channel.getShipStrategy()); + if (outputIndex == 0) { + sourceConfig.setOutputSerializer(channel.getSerializer()); + } + if (channel.getShipStrategyComparator() != null) { + sourceConfig.setOutputComparator(channel.getShipStrategyComparator(), outputIndex); + } + + if (channel.getShipStrategy() == ShipStrategyType.PARTITION_RANGE) { + + final DataDistribution dataDistribution = channel.getDataDistribution(); + if(dataDistribution != null) { + sourceConfig.setOutputDataDistribution(dataDistribution, outputIndex); + } else { + throw new RuntimeException("Range partitioning requires data distribution"); + // TODO: inject code and configuration for automatic histogram generation + } + } + + // ---------------- configure the receiver ------------------- + if (isBroadcast) { + targetConfig.addBroadcastInputToGroup(inputNumber); + } else { + targetConfig.addInputToGroup(inputNumber); + } + + //----------------- connect source and target with edge ------------------------------ + + FlinkEdge edge; + ShipStrategyType shipStrategy = channel.getShipStrategy(); + TypeSerializer<?> serializer = channel.getSerializer().getSerializer(); + if ((shipStrategy == ShipStrategyType.FORWARD) || (shipStrategy == ShipStrategyType.NONE)) { + edge = new FlinkForwardEdge(sourceVertex, targetVertex, serializer); + // For forward edges, create as many tasks in upstream operator as in source operator + targetVertex.setParallelism(sourceVertex.getParallelism()); + } + else if (shipStrategy == ShipStrategyType.BROADCAST) { + edge = new FlinkBroadcastEdge(sourceVertex, targetVertex, serializer); + } + else if (shipStrategy == ShipStrategyType.PARTITION_HASH) { + edge = new FlinkPartitionEdge(sourceVertex, targetVertex, serializer); + } + else { + throw new CompilerException("Ship strategy between nodes " + sourceVertex.getVertex().getName() + " and " + targetVertex.getVertex().getName() + " currently not supported"); + } + + // Tez-specific bookkeeping + // TODO: This probably will not work for vertices with multiple outputs + sourceVertex.addNumberOfSubTasksInOutput(targetVertex.getParallelism(), outputIndex); + targetVertex.addInput(sourceVertex, inputNumber); + + + edges.add(edge); + } + + private void addLocalInfoFromChannelToConfig(Channel channel, TaskConfig config, int inputNum, boolean isBroadcastChannel) { + // serializer + if (isBroadcastChannel) { + config.setBroadcastInputSerializer(channel.getSerializer(), inputNum); + + if (channel.getLocalStrategy() != LocalStrategy.NONE || (channel.getTempMode() != null && channel.getTempMode() != TempMode.NONE)) { + throw new CompilerException("Found local strategy or temp mode on a broadcast variable channel."); + } else { + return; + } + } else { + config.setInputSerializer(channel.getSerializer(), inputNum); + } + + // local strategy + if (channel.getLocalStrategy() != LocalStrategy.NONE) { + config.setInputLocalStrategy(inputNum, channel.getLocalStrategy()); + if (channel.getLocalStrategyComparator() != null) { + config.setInputComparator(channel.getLocalStrategyComparator(), inputNum); + } + } + + assignLocalStrategyResources(channel, config, inputNum); + + // materialization / caching + if (channel.getTempMode() != null) { + final TempMode tm = channel.getTempMode(); + + boolean needsMemory = false; + if (tm.breaksPipeline()) { + config.setInputAsynchronouslyMaterialized(inputNum, true); + needsMemory = true; + } + if (tm.isCached()) { + config.setInputCached(inputNum, true); + needsMemory = true; + } + + if (needsMemory) { + // sanity check + if (tm == null || tm == TempMode.NONE || channel.getRelativeTempMemory() <= 0) { + throw new CompilerException("Bug in compiler: Inconsistent description of input materialization."); + } + config.setRelativeInputMaterializationMemory(inputNum, channel.getRelativeTempMemory()); + } + } + } + + private boolean containsSelfJoins () { + for (FlinkVertex v : vertices.values()) { + ArrayList<FlinkVertex> predecessors = new ArrayList<FlinkVertex>(); + for (FlinkEdge e : edges) { + if (e.target == v) { + if (predecessors.contains(e.source)) { + return true; + } + predecessors.add(e.source); + } + } + } + return false; + } + +} http://git-wip-us.apache.org/repos/asf/flink/blob/e9bf13d8/flink-contrib/flink-tez/src/main/java/org/apache/flink/tez/examples/ConnectedComponentsStep.java ---------------------------------------------------------------------- diff --git a/flink-contrib/flink-tez/src/main/java/org/apache/flink/tez/examples/ConnectedComponentsStep.java b/flink-contrib/flink-tez/src/main/java/org/apache/flink/tez/examples/ConnectedComponentsStep.java new file mode 100644 index 0000000..707fd47 --- /dev/null +++ b/flink-contrib/flink-tez/src/main/java/org/apache/flink/tez/examples/ConnectedComponentsStep.java @@ -0,0 +1,203 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.tez.examples; + +import org.apache.flink.api.common.ProgramDescription; +import org.apache.flink.api.common.functions.FlatJoinFunction; +import org.apache.flink.api.common.functions.FlatMapFunction; +import org.apache.flink.api.common.functions.JoinFunction; +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.ExecutionEnvironment; +import org.apache.flink.api.java.aggregation.Aggregations; +import org.apache.flink.api.java.functions.FunctionAnnotation.ForwardedFields; +import org.apache.flink.api.java.functions.FunctionAnnotation.ForwardedFieldsFirst; +import org.apache.flink.api.java.functions.FunctionAnnotation.ForwardedFieldsSecond; +import org.apache.flink.api.java.tuple.Tuple1; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.examples.java.graph.util.ConnectedComponentsData; +import org.apache.flink.util.Collector; + + +public class ConnectedComponentsStep implements ProgramDescription { + + // ************************************************************************* + // PROGRAM + // ************************************************************************* + + public static void main(String... args) throws Exception { + + if(!parseParameters(args)) { + return; + } + + // set up execution environment + ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + + // read vertex and edge data + DataSet<Long> vertices = getVertexDataSet(env); + DataSet<Tuple2<Long, Long>> edges = getEdgeDataSet(env).flatMap(new UndirectEdge()); + + // assign the initial components (equal to the vertex id) + DataSet<Tuple2<Long, Long>> verticesWithInitialId = vertices.map(new DuplicateValue<Long>()); + + DataSet<Tuple2<Long,Long>> nextComponenets = verticesWithInitialId + .join(edges) + .where(0).equalTo(0) + .with(new NeighborWithComponentIDJoin()) + .groupBy(0).aggregate(Aggregations.MIN, 1) + .join(verticesWithInitialId) + .where(0).equalTo(0) + .with(new ComponentIdFilter()); + + + // emit result + if(fileOutput) { + nextComponenets.writeAsCsv(outputPath, "\n", " "); + } else { + nextComponenets.print(); + } + + // execute program + env.execute("Connected Components Example"); + } + + // ************************************************************************* + // USER FUNCTIONS + // ************************************************************************* + + /** + * Function that turns a value into a 2-tuple where both fields are that value. + */ + @ForwardedFields("*->f0") + public static final class DuplicateValue<T> implements MapFunction<T, Tuple2<T, T>> { + + @Override + public Tuple2<T, T> map(T vertex) { + return new Tuple2<T, T>(vertex, vertex); + } + } + + /** + * Undirected edges by emitting for each input edge the input edges itself and an inverted version. + */ + public static final class UndirectEdge implements FlatMapFunction<Tuple2<Long, Long>, Tuple2<Long, Long>> { + Tuple2<Long, Long> invertedEdge = new Tuple2<Long, Long>(); + + @Override + public void flatMap(Tuple2<Long, Long> edge, Collector<Tuple2<Long, Long>> out) { + invertedEdge.f0 = edge.f1; + invertedEdge.f1 = edge.f0; + out.collect(edge); + out.collect(invertedEdge); + } + } + + /** + * UDF that joins a (Vertex-ID, Component-ID) pair that represents the current component that + * a vertex is associated with, with a (Source-Vertex-ID, Target-VertexID) edge. The function + * produces a (Target-vertex-ID, Component-ID) pair. + */ + @ForwardedFieldsFirst("f1->f1") + @ForwardedFieldsSecond("f1->f0") + public static final class NeighborWithComponentIDJoin implements JoinFunction<Tuple2<Long, Long>, Tuple2<Long, Long>, Tuple2<Long, Long>> { + + @Override + public Tuple2<Long, Long> join(Tuple2<Long, Long> vertexWithComponent, Tuple2<Long, Long> edge) { + return new Tuple2<Long, Long>(edge.f1, vertexWithComponent.f1); + } + } + + + + @ForwardedFieldsFirst("*") + public static final class ComponentIdFilter implements FlatJoinFunction<Tuple2<Long, Long>, Tuple2<Long, Long>, Tuple2<Long, Long>> { + + @Override + public void join(Tuple2<Long, Long> candidate, Tuple2<Long, Long> old, Collector<Tuple2<Long, Long>> out) { + if (candidate.f1 < old.f1) { + out.collect(candidate); + } + } + } + + + + @Override + public String getDescription() { + return "Parameters: <vertices-path> <edges-path> <result-path> <max-number-of-iterations>"; + } + + // ************************************************************************* + // UTIL METHODS + // ************************************************************************* + + private static boolean fileOutput = false; + private static String verticesPath = null; + private static String edgesPath = null; + private static String outputPath = null; + private static int maxIterations = 10; + + private static boolean parseParameters(String[] programArguments) { + + if(programArguments.length > 0) { + // parse input arguments + fileOutput = true; + if(programArguments.length == 4) { + verticesPath = programArguments[0]; + edgesPath = programArguments[1]; + outputPath = programArguments[2]; + maxIterations = Integer.parseInt(programArguments[3]); + } else { + System.err.println("Usage: ConnectedComponents <vertices path> <edges path> <result path> <max number of iterations>"); + return false; + } + } else { + System.out.println("Executing Connected Components example with default parameters and built-in default data."); + System.out.println(" Provide parameters to read input data from files."); + System.out.println(" See the documentation for the correct format of input files."); + System.out.println(" Usage: ConnectedComponents <vertices path> <edges path> <result path> <max number of iterations>"); + } + return true; + } + + private static DataSet<Long> getVertexDataSet(ExecutionEnvironment env) { + + if(fileOutput) { + return env.readCsvFile(verticesPath).types(Long.class) + .map( + new MapFunction<Tuple1<Long>, Long>() { + public Long map(Tuple1<Long> value) { return value.f0; } + }); + } else { + return ConnectedComponentsData.getDefaultVertexDataSet(env); + } + } + + private static DataSet<Tuple2<Long, Long>> getEdgeDataSet(ExecutionEnvironment env) { + + if(fileOutput) { + return env.readCsvFile(edgesPath).fieldDelimiter(' ').types(Long.class, Long.class); + } else { + return ConnectedComponentsData.getDefaultEdgeDataSet(env); + } + } + + +} http://git-wip-us.apache.org/repos/asf/flink/blob/e9bf13d8/flink-contrib/flink-tez/src/main/java/org/apache/flink/tez/examples/ExampleDriver.java ---------------------------------------------------------------------- diff --git a/flink-contrib/flink-tez/src/main/java/org/apache/flink/tez/examples/ExampleDriver.java b/flink-contrib/flink-tez/src/main/java/org/apache/flink/tez/examples/ExampleDriver.java new file mode 100644 index 0000000..c65fb69 --- /dev/null +++ b/flink-contrib/flink-tez/src/main/java/org/apache/flink/tez/examples/ExampleDriver.java @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.tez.examples; + +import org.apache.hadoop.util.ProgramDriver; +import org.apache.tez.common.counters.TezCounters; +import org.apache.tez.dag.api.TezException; +import org.apache.tez.dag.api.client.DAGClient; +import org.apache.tez.dag.api.client.DAGStatus; +import org.apache.tez.dag.api.client.Progress; +import org.apache.tez.dag.api.client.StatusGetOpts; +import org.apache.tez.dag.api.client.VertexStatus; + +import java.io.IOException; +import java.text.DecimalFormat; +import java.util.EnumSet; +import java.util.Set; + +public class ExampleDriver { + + private static final DecimalFormat formatter = new DecimalFormat("###.##%"); + + public static void main(String [] args){ + int exitCode = -1; + ProgramDriver pgd = new ProgramDriver(); + try { + pgd.addClass("wc", WordCount.class, + "Wordcount"); + pgd.addClass("tpch3", TPCHQuery3.class, + "Modified TPC-H 3 query"); + pgd.addClass("tc", TransitiveClosureNaiveStep.class, + "One step of transitive closure"); + pgd.addClass("pr", PageRankBasicStep.class, + "One step of PageRank"); + pgd.addClass("cc", ConnectedComponentsStep.class, + "One step of connected components"); + exitCode = pgd.run(args); + } catch(Throwable e){ + e.printStackTrace(); + } + System.exit(exitCode); + } + + public static void printDAGStatus(DAGClient dagClient, String[] vertexNames) + throws IOException, TezException { + printDAGStatus(dagClient, vertexNames, false, false); + } + + public static void printDAGStatus(DAGClient dagClient, String[] vertexNames, boolean displayDAGCounters, boolean displayVertexCounters) + throws IOException, TezException { + Set<StatusGetOpts> opts = EnumSet.of(StatusGetOpts.GET_COUNTERS); + DAGStatus dagStatus = dagClient.getDAGStatus( + (displayDAGCounters ? opts : null)); + Progress progress = dagStatus.getDAGProgress(); + double vProgressFloat = 0.0f; + if (progress != null) { + System.out.println(""); + System.out.println("DAG: State: " + + dagStatus.getState() + + " Progress: " + + (progress.getTotalTaskCount() < 0 ? formatter.format(0.0f) : + formatter.format((double)(progress.getSucceededTaskCount()) + /progress.getTotalTaskCount()))); + for (String vertexName : vertexNames) { + VertexStatus vStatus = dagClient.getVertexStatus(vertexName, + (displayVertexCounters ? opts : null)); + if (vStatus == null) { + System.out.println("Could not retrieve status for vertex: " + + vertexName); + continue; + } + Progress vProgress = vStatus.getProgress(); + if (vProgress != null) { + vProgressFloat = 0.0f; + if (vProgress.getTotalTaskCount() == 0) { + vProgressFloat = 1.0f; + } else if (vProgress.getTotalTaskCount() > 0) { + vProgressFloat = (double)vProgress.getSucceededTaskCount() + /vProgress.getTotalTaskCount(); + } + System.out.println("VertexStatus:" + + " VertexName: " + + (vertexName.equals("ivertex1") ? "intermediate-reducer" + : vertexName) + + " Progress: " + formatter.format(vProgressFloat)); + } + if (displayVertexCounters) { + TezCounters counters = vStatus.getVertexCounters(); + if (counters != null) { + System.out.println("Vertex Counters for " + vertexName + ": " + + counters); + } + } + } + } + if (displayDAGCounters) { + TezCounters counters = dagStatus.getDAGCounters(); + if (counters != null) { + System.out.println("DAG Counters: " + counters); + } + } + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/e9bf13d8/flink-contrib/flink-tez/src/main/java/org/apache/flink/tez/examples/PageRankBasicStep.java ---------------------------------------------------------------------- diff --git a/flink-contrib/flink-tez/src/main/java/org/apache/flink/tez/examples/PageRankBasicStep.java b/flink-contrib/flink-tez/src/main/java/org/apache/flink/tez/examples/PageRankBasicStep.java new file mode 100644 index 0000000..031893d --- /dev/null +++ b/flink-contrib/flink-tez/src/main/java/org/apache/flink/tez/examples/PageRankBasicStep.java @@ -0,0 +1,241 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.tez.examples; + + +import org.apache.flink.api.common.functions.FilterFunction; +import org.apache.flink.api.common.functions.FlatMapFunction; +import org.apache.flink.api.common.functions.GroupReduceFunction; +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.ExecutionEnvironment; +import org.apache.flink.api.java.functions.FunctionAnnotation.ForwardedFields; +import org.apache.flink.api.java.tuple.Tuple1; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.examples.java.graph.util.PageRankData; +import org.apache.flink.util.Collector; + +import java.util.ArrayList; + +import static org.apache.flink.api.java.aggregation.Aggregations.SUM; + +public class PageRankBasicStep { + + private static final double DAMPENING_FACTOR = 0.85; + private static final double EPSILON = 0.0001; + + // ************************************************************************* + // PROGRAM + // ************************************************************************* + + public static void main(String[] args) throws Exception { + + if(!parseParameters(args)) { + return; + } + + // set up execution environment + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + + // get input data + DataSet<Long> pagesInput = getPagesDataSet(env); + DataSet<Tuple2<Long, Long>> linksInput = getLinksDataSet(env); + + // assign initial rank to pages + DataSet<Tuple2<Long, Double>> pagesWithRanks = pagesInput. + map(new RankAssigner((1.0d / numPages))); + + // build adjacency list from link input + DataSet<Tuple2<Long, Long[]>> adjacencyListInput = + linksInput.groupBy(0).reduceGroup(new BuildOutgoingEdgeList()); + + DataSet<Tuple2<Long, Double>> newRanks = pagesWithRanks + .join(adjacencyListInput).where(0).equalTo(0) + .flatMap(new JoinVertexWithEdgesMatch()) + .groupBy(0).aggregate(SUM, 1) + .map(new Dampener(DAMPENING_FACTOR, numPages)); + + + // emit result + if(fileOutput) { + newRanks.writeAsCsv(outputPath, "\n", " "); + } else { + newRanks.print(); + } + + // execute program + env.execute("Basic Page Rank Example"); + + } + + // ************************************************************************* + // USER FUNCTIONS + // ************************************************************************* + + /** + * A map function that assigns an initial rank to all pages. + */ + public static final class RankAssigner implements MapFunction<Long, Tuple2<Long, Double>> { + Tuple2<Long, Double> outPageWithRank; + + public RankAssigner(double rank) { + this.outPageWithRank = new Tuple2<Long, Double>(-1l, rank); + } + + @Override + public Tuple2<Long, Double> map(Long page) { + outPageWithRank.f0 = page; + return outPageWithRank; + } + } + + /** + * A reduce function that takes a sequence of edges and builds the adjacency list for the vertex where the edges + * originate. Run as a pre-processing step. + */ + @ForwardedFields("0") + public static final class BuildOutgoingEdgeList implements GroupReduceFunction<Tuple2<Long, Long>, Tuple2<Long, Long[]>> { + + private final ArrayList<Long> neighbors = new ArrayList<Long>(); + + @Override + public void reduce(Iterable<Tuple2<Long, Long>> values, Collector<Tuple2<Long, Long[]>> out) { + neighbors.clear(); + Long id = 0L; + + for (Tuple2<Long, Long> n : values) { + id = n.f0; + neighbors.add(n.f1); + } + out.collect(new Tuple2<Long, Long[]>(id, neighbors.toArray(new Long[neighbors.size()]))); + } + } + + /** + * Join function that distributes a fraction of a vertex's rank to all neighbors. + */ + public static final class JoinVertexWithEdgesMatch implements FlatMapFunction<Tuple2<Tuple2<Long, Double>, Tuple2<Long, Long[]>>, Tuple2<Long, Double>> { + + @Override + public void flatMap(Tuple2<Tuple2<Long, Double>, Tuple2<Long, Long[]>> value, Collector<Tuple2<Long, Double>> out){ + Long[] neigbors = value.f1.f1; + double rank = value.f0.f1; + double rankToDistribute = rank / ((double) neigbors.length); + + for (int i = 0; i < neigbors.length; i++) { + out.collect(new Tuple2<Long, Double>(neigbors[i], rankToDistribute)); + } + } + } + + /** + * The function that applies the page rank dampening formula + */ + @ForwardedFields("0") + public static final class Dampener implements MapFunction<Tuple2<Long,Double>, Tuple2<Long,Double>> { + + private final double dampening; + private final double randomJump; + + public Dampener(double dampening, double numVertices) { + this.dampening = dampening; + this.randomJump = (1 - dampening) / numVertices; + } + + @Override + public Tuple2<Long, Double> map(Tuple2<Long, Double> value) { + value.f1 = (value.f1 * dampening) + randomJump; + return value; + } + } + + /** + * Filter that filters vertices where the rank difference is below a threshold. + */ + public static final class EpsilonFilter implements FilterFunction<Tuple2<Tuple2<Long, Double>, Tuple2<Long, Double>>> { + + @Override + public boolean filter(Tuple2<Tuple2<Long, Double>, Tuple2<Long, Double>> value) { + return Math.abs(value.f0.f1 - value.f1.f1) > EPSILON; + } + } + + // ************************************************************************* + // UTIL METHODS + // ************************************************************************* + + private static boolean fileOutput = false; + private static String pagesInputPath = null; + private static String linksInputPath = null; + private static String outputPath = null; + private static long numPages = 0; + private static int maxIterations = 10; + + private static boolean parseParameters(String[] args) { + + if(args.length > 0) { + if(args.length == 5) { + fileOutput = true; + pagesInputPath = args[0]; + linksInputPath = args[1]; + outputPath = args[2]; + numPages = Integer.parseInt(args[3]); + maxIterations = Integer.parseInt(args[4]); + } else { + System.err.println("Usage: PageRankBasic <pages path> <links path> <output path> <num pages> <num iterations>"); + return false; + } + } else { + System.out.println("Executing PageRank Basic example with default parameters and built-in default data."); + System.out.println(" Provide parameters to read input data from files."); + System.out.println(" See the documentation for the correct format of input files."); + System.out.println(" Usage: PageRankBasic <pages path> <links path> <output path> <num pages> <num iterations>"); + + numPages = PageRankData.getNumberOfPages(); + } + return true; + } + + private static DataSet<Long> getPagesDataSet(ExecutionEnvironment env) { + if(fileOutput) { + return env + .readCsvFile(pagesInputPath) + .fieldDelimiter(' ') + .lineDelimiter("\n") + .types(Long.class) + .map(new MapFunction<Tuple1<Long>, Long>() { + @Override + public Long map(Tuple1<Long> v) { return v.f0; } + }); + } else { + return PageRankData.getDefaultPagesDataSet(env); + } + } + + private static DataSet<Tuple2<Long, Long>> getLinksDataSet(ExecutionEnvironment env) { + if(fileOutput) { + return env.readCsvFile(linksInputPath) + .fieldDelimiter(' ') + .lineDelimiter("\n") + .types(Long.class, Long.class); + } else { + return PageRankData.getDefaultEdgeDataSet(env); + } + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/e9bf13d8/flink-contrib/flink-tez/src/main/java/org/apache/flink/tez/examples/TPCHQuery3.java ---------------------------------------------------------------------- diff --git a/flink-contrib/flink-tez/src/main/java/org/apache/flink/tez/examples/TPCHQuery3.java b/flink-contrib/flink-tez/src/main/java/org/apache/flink/tez/examples/TPCHQuery3.java new file mode 100644 index 0000000..d61f80e --- /dev/null +++ b/flink-contrib/flink-tez/src/main/java/org/apache/flink/tez/examples/TPCHQuery3.java @@ -0,0 +1,224 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.tez.examples; + +import org.apache.flink.api.common.functions.FilterFunction; +import org.apache.flink.api.common.functions.JoinFunction; +import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.ExecutionEnvironment; +import org.apache.flink.api.java.aggregation.Aggregations; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple4; +import org.apache.flink.tez.client.RemoteTezEnvironment; + +import java.text.DateFormat; +import java.text.ParseException; +import java.text.SimpleDateFormat; +import java.util.Date; + +public class TPCHQuery3 { + + // ************************************************************************* + // PROGRAM + // ************************************************************************* + + public static void main(String[] args) throws Exception { + + if(!parseParameters(args)) { + return; + } + + final RemoteTezEnvironment env = RemoteTezEnvironment.create(); + env.setParallelism(400); + + + // get input data + DataSet<Lineitem> lineitems = getLineitemDataSet(env); + DataSet<Order> orders = getOrdersDataSet(env); + DataSet<Customer> customers = getCustomerDataSet(env); + + // Filter market segment "AUTOMOBILE" + customers = customers.filter( + new FilterFunction<Customer>() { + @Override + public boolean filter(Customer c) { + return c.getMktsegment().equals("AUTOMOBILE"); + } + }); + + // Filter all Orders with o_orderdate < 12.03.1995 + orders = orders.filter( + new FilterFunction<Order>() { + private final DateFormat format = new SimpleDateFormat("yyyy-MM-dd"); + private final Date date = format.parse("1995-03-12"); + + @Override + public boolean filter(Order o) throws ParseException { + return format.parse(o.getOrderdate()).before(date); + } + }); + + // Filter all Lineitems with l_shipdate > 12.03.1995 + lineitems = lineitems.filter( + new FilterFunction<Lineitem>() { + private final DateFormat format = new SimpleDateFormat("yyyy-MM-dd"); + private final Date date = format.parse("1995-03-12"); + + @Override + public boolean filter(Lineitem l) throws ParseException { + return format.parse(l.getShipdate()).after(date); + } + }); + + // Join customers with orders and package them into a ShippingPriorityItem + DataSet<ShippingPriorityItem> customerWithOrders = + customers.join(orders).where(0).equalTo(1) + .with( + new JoinFunction<Customer, Order, ShippingPriorityItem>() { + @Override + public ShippingPriorityItem join(Customer c, Order o) { + return new ShippingPriorityItem(o.getOrderKey(), 0.0, o.getOrderdate(), + o.getShippriority()); + } + }); + + // Join the last join result with Lineitems + DataSet<ShippingPriorityItem> result = + customerWithOrders.join(lineitems).where(0).equalTo(0) + .with( + new JoinFunction<ShippingPriorityItem, Lineitem, ShippingPriorityItem>() { + @Override + public ShippingPriorityItem join(ShippingPriorityItem i, Lineitem l) { + i.setRevenue(l.getExtendedprice() * (1 - l.getDiscount())); + return i; + } + }) + // Group by l_orderkey, o_orderdate and o_shippriority and compute revenue sum + .groupBy(0, 2, 3) + .aggregate(Aggregations.SUM, 1); + + // emit result + result.writeAsCsv(outputPath, "\n", "|"); + + // execute program + env.registerMainClass(TPCHQuery3.class); + env.execute("TPCH Query 3 Example"); + + } + + // ************************************************************************* + // DATA TYPES + // ************************************************************************* + + public static class Lineitem extends Tuple4<Integer, Double, Double, String> { + + public Integer getOrderkey() { return this.f0; } + public Double getDiscount() { return this.f2; } + public Double getExtendedprice() { return this.f1; } + public String getShipdate() { return this.f3; } + } + + public static class Customer extends Tuple2<Integer, String> { + + public Integer getCustKey() { return this.f0; } + public String getMktsegment() { return this.f1; } + } + + public static class Order extends Tuple4<Integer, Integer, String, Integer> { + + public Integer getOrderKey() { return this.f0; } + public Integer getCustKey() { return this.f1; } + public String getOrderdate() { return this.f2; } + public Integer getShippriority() { return this.f3; } + } + + public static class ShippingPriorityItem extends Tuple4<Integer, Double, String, Integer> { + + public ShippingPriorityItem() { } + + public ShippingPriorityItem(Integer o_orderkey, Double revenue, + String o_orderdate, Integer o_shippriority) { + this.f0 = o_orderkey; + this.f1 = revenue; + this.f2 = o_orderdate; + this.f3 = o_shippriority; + } + + public Integer getOrderkey() { return this.f0; } + public void setOrderkey(Integer orderkey) { this.f0 = orderkey; } + public Double getRevenue() { return this.f1; } + public void setRevenue(Double revenue) { this.f1 = revenue; } + + public String getOrderdate() { return this.f2; } + public Integer getShippriority() { return this.f3; } + } + + // ************************************************************************* + // UTIL METHODS + // ************************************************************************* + + private static String lineitemPath; + private static String customerPath; + private static String ordersPath; + private static String outputPath; + + private static boolean parseParameters(String[] programArguments) { + + if(programArguments.length > 0) { + if(programArguments.length == 4) { + lineitemPath = programArguments[0]; + customerPath = programArguments[1]; + ordersPath = programArguments[2]; + outputPath = programArguments[3]; + } else { + System.err.println("Usage: TPCHQuery3 <lineitem-csv path> <customer-csv path> <orders-csv path> <result path>"); + return false; + } + } else { + System.err.println("This program expects data from the TPC-H benchmark as input data.\n" + + " Due to legal restrictions, we can not ship generated data.\n" + + " You can find the TPC-H data generator at http://www.tpc.org/tpch/.\n" + + " Usage: TPCHQuery3 <lineitem-csv path> <customer-csv path> <orders-csv path> <result path>"); + return false; + } + return true; + } + + private static DataSet<Lineitem> getLineitemDataSet(ExecutionEnvironment env) { + return env.readCsvFile(lineitemPath) + .fieldDelimiter('|') + .includeFields("1000011000100000") + .tupleType(Lineitem.class); + } + + private static DataSet<Customer> getCustomerDataSet(ExecutionEnvironment env) { + return env.readCsvFile(customerPath) + .fieldDelimiter('|') + .includeFields("10000010") + .tupleType(Customer.class); + } + + private static DataSet<Order> getOrdersDataSet(ExecutionEnvironment env) { + return env.readCsvFile(ordersPath) + .fieldDelimiter('|') + .includeFields("110010010") + .tupleType(Order.class); + } + +} http://git-wip-us.apache.org/repos/asf/flink/blob/e9bf13d8/flink-contrib/flink-tez/src/main/java/org/apache/flink/tez/examples/TransitiveClosureNaiveStep.java ---------------------------------------------------------------------- diff --git a/flink-contrib/flink-tez/src/main/java/org/apache/flink/tez/examples/TransitiveClosureNaiveStep.java b/flink-contrib/flink-tez/src/main/java/org/apache/flink/tez/examples/TransitiveClosureNaiveStep.java new file mode 100644 index 0000000..b014c3e --- /dev/null +++ b/flink-contrib/flink-tez/src/main/java/org/apache/flink/tez/examples/TransitiveClosureNaiveStep.java @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.tez.examples; + +import org.apache.flink.api.common.ProgramDescription; +import org.apache.flink.api.common.functions.GroupReduceFunction; +import org.apache.flink.api.common.functions.JoinFunction; +import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.ExecutionEnvironment; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.examples.java.graph.util.ConnectedComponentsData; +import org.apache.flink.util.Collector; + + +/* + * NOTE: + * This program is currently supposed to throw a Compiler Exception due to TEZ-1190 + */ + +public class TransitiveClosureNaiveStep implements ProgramDescription { + + + public static void main (String... args) throws Exception{ + + if (!parseParameters(args)) { + return; + } + + // set up execution environment + ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + + DataSet<Tuple2<Long, Long>> edges = getEdgeDataSet(env); + + DataSet<Tuple2<Long,Long>> nextPaths = edges + .join(edges) + .where(1) + .equalTo(0) + .with(new JoinFunction<Tuple2<Long, Long>, Tuple2<Long, Long>, Tuple2<Long, Long>>() { + @Override + /** + left: Path (z,x) - x is reachable by z + right: Edge (x,y) - edge x-->y exists + out: Path (z,y) - y is reachable by z + */ + public Tuple2<Long, Long> join(Tuple2<Long, Long> left, Tuple2<Long, Long> right) throws Exception { + return new Tuple2<Long, Long>( + new Long(left.f0), + new Long(right.f1)); + } + }) + .union(edges) + .groupBy(0, 1) + .reduceGroup(new GroupReduceFunction<Tuple2<Long, Long>, Tuple2<Long, Long>>() { + @Override + public void reduce(Iterable<Tuple2<Long, Long>> values, Collector<Tuple2<Long, Long>> out) throws Exception { + out.collect(values.iterator().next()); + } + }); + + // emit result + if (fileOutput) { + nextPaths.writeAsCsv(outputPath, "\n", " "); + } else { + nextPaths.print(); + } + + // execute program + env.execute("Transitive Closure Example"); + + } + + @Override + public String getDescription() { + return "Parameters: <edges-path> <result-path> <max-number-of-iterations>"; + } + + // ************************************************************************* + // UTIL METHODS + // ************************************************************************* + + private static boolean fileOutput = false; + private static String edgesPath = null; + private static String outputPath = null; + private static int maxIterations = 10; + + private static boolean parseParameters(String[] programArguments) { + + if (programArguments.length > 0) { + // parse input arguments + fileOutput = true; + if (programArguments.length == 3) { + edgesPath = programArguments[0]; + outputPath = programArguments[1]; + maxIterations = Integer.parseInt(programArguments[2]); + } else { + System.err.println("Usage: TransitiveClosure <edges path> <result path> <max number of iterations>"); + return false; + } + } else { + System.out.println("Executing TransitiveClosure example with default parameters and built-in default data."); + System.out.println(" Provide parameters to read input data from files."); + System.out.println(" See the documentation for the correct format of input files."); + System.out.println(" Usage: TransitiveClosure <edges path> <result path> <max number of iterations>"); + } + return true; + } + + + private static DataSet<Tuple2<Long, Long>> getEdgeDataSet(ExecutionEnvironment env) { + + if(fileOutput) { + return env.readCsvFile(edgesPath).fieldDelimiter(' ').types(Long.class, Long.class); + } else { + return ConnectedComponentsData.getDefaultEdgeDataSet(env); + } + } + +} + http://git-wip-us.apache.org/repos/asf/flink/blob/e9bf13d8/flink-contrib/flink-tez/src/main/java/org/apache/flink/tez/examples/WordCount.java ---------------------------------------------------------------------- diff --git a/flink-contrib/flink-tez/src/main/java/org/apache/flink/tez/examples/WordCount.java b/flink-contrib/flink-tez/src/main/java/org/apache/flink/tez/examples/WordCount.java new file mode 100644 index 0000000..e758156 --- /dev/null +++ b/flink-contrib/flink-tez/src/main/java/org/apache/flink/tez/examples/WordCount.java @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.tez.examples; + +import org.apache.flink.api.common.functions.FlatMapFunction; +import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.ExecutionEnvironment; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.examples.java.wordcount.util.WordCountData; +import org.apache.flink.tez.client.RemoteTezEnvironment; +import org.apache.flink.util.Collector; + +public class WordCount { + + // ************************************************************************* + // PROGRAM + // ************************************************************************* + + public static void main(String[] args) throws Exception { + + if(!parseParameters(args)) { + return; + } + + // set up the execution environment + final RemoteTezEnvironment env = RemoteTezEnvironment.create(); + env.setParallelism(8); + + // get input data + DataSet<String> text = getTextDataSet(env); + + DataSet<Tuple2<String, Integer>> counts = + // split up the lines in pairs (2-tuples) containing: (word,1) + text.flatMap(new Tokenizer()) + // group by the tuple field "0" and sum up tuple field "1" + .groupBy(0) + .sum(1); + + // emit result + if(fileOutput) { + counts.writeAsCsv(outputPath, "\n", " "); + } else { + counts.print(); + } + + // execute program + env.registerMainClass(WordCount.class); + env.execute("WordCount Example"); + } + + // ************************************************************************* + // USER FUNCTIONS + // ************************************************************************* + + /** + * Implements the string tokenizer that splits sentences into words as a user-defined + * FlatMapFunction. The function takes a line (String) and splits it into + * multiple pairs in the form of "(word,1)" (Tuple2<String, Integer>). + */ + public static final class Tokenizer implements FlatMapFunction<String, Tuple2<String, Integer>> { + + @Override + public void flatMap(String value, Collector<Tuple2<String, Integer>> out) { + // normalize and split the line + String[] tokens = value.toLowerCase().split("\\W+"); + + // emit the pairs + for (String token : tokens) { + if (token.length() > 0) { + out.collect(new Tuple2<String, Integer>(token, 1)); + } + } + } + } + + // ************************************************************************* + // UTIL METHODS + // ************************************************************************* + + private static boolean fileOutput = false; + private static String textPath; + private static String outputPath; + + private static boolean parseParameters(String[] args) { + + if(args.length > 0) { + // parse input arguments + fileOutput = true; + if(args.length == 2) { + textPath = args[0]; + outputPath = args[1]; + } else { + System.err.println("Usage: WordCount <text path> <result path>"); + return false; + } + } else { + System.out.println("Executing WordCount example with built-in default data."); + System.out.println(" Provide parameters to read input data from a file."); + System.out.println(" Usage: WordCount <text path> <result path>"); + } + return true; + } + + private static DataSet<String> getTextDataSet(ExecutionEnvironment env) { + if(fileOutput) { + // read the text file from given input path + return env.readTextFile(textPath); + } else { + // get default test text data + return WordCountData.getDefaultTextLineDataSet(env); + } + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/e9bf13d8/flink-contrib/flink-tez/src/main/java/org/apache/flink/tez/runtime/DataSinkProcessor.java ---------------------------------------------------------------------- diff --git a/flink-contrib/flink-tez/src/main/java/org/apache/flink/tez/runtime/DataSinkProcessor.java b/flink-contrib/flink-tez/src/main/java/org/apache/flink/tez/runtime/DataSinkProcessor.java new file mode 100644 index 0000000..8011d21 --- /dev/null +++ b/flink-contrib/flink-tez/src/main/java/org/apache/flink/tez/runtime/DataSinkProcessor.java @@ -0,0 +1,228 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.tez.runtime; + + +import com.google.common.base.Preconditions; +import org.apache.flink.api.common.io.OutputFormat; +import org.apache.flink.api.common.typeutils.TypeComparatorFactory; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.common.typeutils.TypeSerializerFactory; +import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable; +import org.apache.flink.runtime.operators.sort.UnilateralSortMerger; +import org.apache.flink.runtime.operators.util.CloseableInputProvider; +import org.apache.flink.tez.runtime.input.TezReaderIterator; +import org.apache.flink.tez.util.DummyInvokable; +import org.apache.flink.tez.util.EncodingUtils; +import org.apache.flink.util.MutableObjectIterator; +import org.apache.hadoop.conf.Configuration; +import org.apache.tez.common.TezUtils; +import org.apache.tez.dag.api.UserPayload; +import org.apache.tez.runtime.api.AbstractLogicalIOProcessor; +import org.apache.tez.runtime.api.Event; +import org.apache.tez.runtime.api.LogicalInput; +import org.apache.tez.runtime.api.LogicalOutput; +import org.apache.tez.runtime.api.ProcessorContext; +import org.apache.tez.runtime.library.api.KeyValueReader; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +public class DataSinkProcessor<IT> extends AbstractLogicalIOProcessor { + + // Tez stuff + private TezTaskConfig config; + protected Map<String, LogicalInput> inputs; + private List<KeyValueReader> readers; + private int numInputs; + private TezRuntimeEnvironment runtimeEnvironment; + AbstractInvokable invokable = new DummyInvokable(); + + // Flink stuff + private OutputFormat<IT> format; + private ClassLoader userCodeClassLoader = this.getClass().getClassLoader(); + private CloseableInputProvider<IT> localStrategy; + // input reader + private MutableObjectIterator<IT> reader; + // input iterator + private MutableObjectIterator<IT> input; + private TypeSerializerFactory<IT> inputTypeSerializerFactory; + + + + + public DataSinkProcessor(ProcessorContext context) { + super(context); + } + + @Override + public void initialize() throws Exception { + UserPayload payload = getContext().getUserPayload(); + Configuration conf = TezUtils.createConfFromUserPayload(payload); + + this.config = (TezTaskConfig) EncodingUtils.decodeObjectFromString(conf.get(TezTaskConfig.TEZ_TASK_CONFIG), getClass().getClassLoader()); + config.setTaskName(getContext().getTaskVertexName()); + + this.runtimeEnvironment = new TezRuntimeEnvironment((long) (0.7 * this.getContext().getTotalMemoryAvailableToTask())); + + this.inputTypeSerializerFactory = this.config.getInputSerializer(0, this.userCodeClassLoader); + + initOutputFormat(); + } + + @Override + public void run(Map<String, LogicalInput> inputs, Map<String, LogicalOutput> outputs) throws Exception { + + Preconditions.checkArgument((outputs == null) || (outputs.size() == 0)); + Preconditions.checkArgument(inputs.size() == 1); + + this.inputs = inputs; + this.numInputs = inputs.size(); + this.readers = new ArrayList<KeyValueReader>(numInputs); + if (this.inputs != null) { + for (LogicalInput input: this.inputs.values()) { + //if (input instanceof AbstractLogicalInput) { + // ((AbstractLogicalInput) input).initialize(); + //} + input.start(); + readers.add((KeyValueReader) input.getReader()); + } + } + + this.reader = new TezReaderIterator<IT>(readers.get(0)); + + this.invoke(); + } + + @Override + public void handleEvents(List<Event> processorEvents) { + + } + + @Override + public void close() throws Exception { + this.runtimeEnvironment.getIOManager().shutdown(); + } + + private void invoke () { + try { + // initialize local strategies + switch (this.config.getInputLocalStrategy(0)) { + case NONE: + // nothing to do + localStrategy = null; + input = reader; + break; + case SORT: + // initialize sort local strategy + try { + // get type comparator + TypeComparatorFactory<IT> compFact = this.config.getInputComparator(0, this.userCodeClassLoader); + if (compFact == null) { + throw new Exception("Missing comparator factory for local strategy on input " + 0); + } + + // initialize sorter + UnilateralSortMerger<IT> sorter = new UnilateralSortMerger<IT>( + this.runtimeEnvironment.getMemoryManager(), + this.runtimeEnvironment.getIOManager(), + this.reader, this.invokable, this.inputTypeSerializerFactory, compFact.createComparator(), + this.config.getRelativeMemoryInput(0), this.config.getFilehandlesInput(0), + this.config.getSpillingThresholdInput(0), false); + + this.localStrategy = sorter; + this.input = sorter.getIterator(); + } catch (Exception e) { + throw new RuntimeException("Initializing the input processing failed" + + e.getMessage() == null ? "." : ": " + e.getMessage(), e); + } + break; + default: + throw new RuntimeException("Invalid local strategy for DataSinkTask"); + } + + final TypeSerializer<IT> serializer = this.inputTypeSerializerFactory.getSerializer(); + final MutableObjectIterator<IT> input = this.input; + final OutputFormat<IT> format = this.format; + + + IT record = serializer.createInstance(); + format.open (this.getContext().getTaskIndex(), this.getContext().getVertexParallelism()); + + // work! + while (((record = input.next(record)) != null)) { + format.writeRecord(record); + } + + this.format.close(); + this.format = null; + } + catch (IOException e) { + e.printStackTrace(); + throw new RuntimeException(); + } + finally { + if (this.format != null) { + // close format, if it has not been closed, yet. + // This should only be the case if we had a previous error, or were canceled. + try { + this.format.close(); + } + catch (Throwable t) { + //TODO log warning message + } + } + // close local strategy if necessary + if (localStrategy != null) { + try { + this.localStrategy.close(); + } catch (Throwable t) { + //TODO log warning message + } + } + } + } + + private void initOutputFormat () { + try { + this.format = this.config.<OutputFormat<IT>>getStubWrapper(this.userCodeClassLoader).getUserCodeObject(OutputFormat.class, this.userCodeClassLoader); + + // check if the class is a subclass, if the check is required + if (!OutputFormat.class.isAssignableFrom(this.format.getClass())) { + throw new RuntimeException("The class '" + this.format.getClass().getName() + "' is not a subclass of '" + + OutputFormat.class.getName() + "' as is required."); + } + } + catch (ClassCastException ccex) { + throw new RuntimeException("The stub class is not a proper subclass of " + OutputFormat.class.getName(), ccex); + } + + // configure the stub. catch exceptions here extra, to report them as originating from the user code + try { + this.format.configure(this.config.getStubParameters()); + } + catch (Throwable t) { + throw new RuntimeException("The user defined 'configure()' method in the Output Format caused an error: " + + t.getMessage(), t); + } + } + +} http://git-wip-us.apache.org/repos/asf/flink/blob/e9bf13d8/flink-contrib/flink-tez/src/main/java/org/apache/flink/tez/runtime/DataSourceProcessor.java ---------------------------------------------------------------------- diff --git a/flink-contrib/flink-tez/src/main/java/org/apache/flink/tez/runtime/DataSourceProcessor.java b/flink-contrib/flink-tez/src/main/java/org/apache/flink/tez/runtime/DataSourceProcessor.java new file mode 100644 index 0000000..dd3f843 --- /dev/null +++ b/flink-contrib/flink-tez/src/main/java/org/apache/flink/tez/runtime/DataSourceProcessor.java @@ -0,0 +1,190 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.tez.runtime; + +import com.google.common.base.Preconditions; +import org.apache.flink.api.common.distributions.DataDistribution; +import org.apache.flink.api.common.io.InputFormat; +import org.apache.flink.api.common.typeutils.TypeComparator; +import org.apache.flink.api.common.typeutils.TypeComparatorFactory; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.common.typeutils.TypeSerializerFactory; +import org.apache.flink.core.io.InputSplit; +import org.apache.flink.runtime.operators.shipping.ShipStrategyType; +import org.apache.flink.tez.runtime.input.FlinkInput; +import org.apache.flink.tez.runtime.output.TezChannelSelector; +import org.apache.flink.tez.runtime.output.TezOutputCollector; +import org.apache.flink.tez.runtime.output.TezOutputEmitter; +import org.apache.flink.tez.util.EncodingUtils; +import org.apache.flink.util.Collector; +import org.apache.hadoop.conf.Configuration; +import org.apache.tez.common.TezUtils; +import org.apache.tez.dag.api.UserPayload; +import org.apache.tez.runtime.api.AbstractLogicalIOProcessor; +import org.apache.tez.runtime.api.Event; +import org.apache.tez.runtime.api.LogicalInput; +import org.apache.tez.runtime.api.LogicalOutput; +import org.apache.tez.runtime.api.ProcessorContext; +import org.apache.tez.runtime.library.api.KeyValueWriter; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + + +public class DataSourceProcessor<OT> extends AbstractLogicalIOProcessor { + + private TezTaskConfig config; + protected Map<String, LogicalOutput> outputs; + private List<KeyValueWriter> writers; + private int numOutputs; + private Collector<OT> collector; + + private InputFormat<OT, InputSplit> format; + private TypeSerializerFactory<OT> serializerFactory; + private FlinkInput input; + private ClassLoader userCodeClassLoader = getClass().getClassLoader(); + + + public DataSourceProcessor(ProcessorContext context) { + super(context); + } + + @Override + public void initialize() throws Exception { + UserPayload payload = getContext().getUserPayload(); + Configuration conf = TezUtils.createConfFromUserPayload(payload); + + this.config = (TezTaskConfig) EncodingUtils.decodeObjectFromString(conf.get(TezTaskConfig.TEZ_TASK_CONFIG), getClass().getClassLoader()); + config.setTaskName(getContext().getTaskVertexName()); + + this.serializerFactory = config.getOutputSerializer(this.userCodeClassLoader); + + initInputFormat(); + } + + @Override + public void handleEvents(List<Event> processorEvents) { + int i = 0; + } + + @Override + public void close() throws Exception { + + } + + @Override + public void run(Map<String, LogicalInput> inputs, Map<String, LogicalOutput> outputs) throws Exception { + + Preconditions.checkArgument(inputs.size() == 1); + LogicalInput logicalInput = inputs.values().iterator().next(); + if (!(logicalInput instanceof FlinkInput)) { + throw new RuntimeException("Input to Flink Data Source Processor should be of type FlinkInput"); + } + this.input = (FlinkInput) logicalInput; + //this.reader = (KeyValueReader) input.getReader(); + + // Initialize inputs, get readers and writers + this.outputs = outputs; + this.numOutputs = outputs.size(); + this.writers = new ArrayList<KeyValueWriter>(numOutputs); + if (this.outputs != null) { + for (LogicalOutput output : this.outputs.values()) { + output.start(); + writers.add((KeyValueWriter) output.getWriter()); + } + } + this.invoke(); + } + + + private void invoke () { + final TypeSerializer<OT> serializer = this.serializerFactory.getSerializer(); + try { + InputSplit split = input.getSplit(); + + OT record = serializer.createInstance(); + final InputFormat<OT, InputSplit> format = this.format; + format.open(split); + + int numOutputs = outputs.size(); + ArrayList<TezChannelSelector<OT>> channelSelectors = new ArrayList<TezChannelSelector<OT>>(numOutputs); + ArrayList<Integer> numStreamsInOutputs = this.config.getNumberSubtasksInOutput(); + for (int i = 0; i < numOutputs; i++) { + final ShipStrategyType strategy = config.getOutputShipStrategy(i); + final TypeComparatorFactory<OT> compFactory = config.getOutputComparator(i, this.userCodeClassLoader); + final DataDistribution dataDist = config.getOutputDataDistribution(i, this.userCodeClassLoader); + if (compFactory == null) { + channelSelectors.add(i, new TezOutputEmitter<OT>(strategy)); + } else if (dataDist == null){ + final TypeComparator<OT> comparator = compFactory.createComparator(); + channelSelectors.add(i, new TezOutputEmitter<OT>(strategy, comparator)); + } else { + final TypeComparator<OT> comparator = compFactory.createComparator(); + channelSelectors.add(i,new TezOutputEmitter<OT>(strategy, comparator, dataDist)); + } + } + collector = new TezOutputCollector<OT>(writers, channelSelectors, serializerFactory.getSerializer(), numStreamsInOutputs); + + while (!format.reachedEnd()) { + // build next pair and ship pair if it is valid + if ((record = format.nextRecord(record)) != null) { + collector.collect(record); + } + } + format.close(); + + collector.close(); + + } + catch (Exception ex) { + // close the input, but do not report any exceptions, since we already have another root cause + try { + this.format.close(); + } catch (Throwable t) {} + } + } + + + private void initInputFormat() { + try { + this.format = config.<InputFormat<OT, InputSplit>>getStubWrapper(this.userCodeClassLoader) + .getUserCodeObject(InputFormat.class, this.userCodeClassLoader); + + // check if the class is a subclass, if the check is required + if (!InputFormat.class.isAssignableFrom(this.format.getClass())) { + throw new RuntimeException("The class '" + this.format.getClass().getName() + "' is not a subclass of '" + + InputFormat.class.getName() + "' as is required."); + } + } + catch (ClassCastException ccex) { + throw new RuntimeException("The stub class is not a proper subclass of " + InputFormat.class.getName(), + ccex); + } + // configure the stub. catch exceptions here extra, to report them as originating from the user code + try { + this.format.configure(this.config.getStubParameters()); + } + catch (Throwable t) { + throw new RuntimeException("The user defined 'configure()' method caused an error: " + t.getMessage(), t); + } + } + + +} http://git-wip-us.apache.org/repos/asf/flink/blob/e9bf13d8/flink-contrib/flink-tez/src/main/java/org/apache/flink/tez/runtime/RegularProcessor.java ---------------------------------------------------------------------- diff --git a/flink-contrib/flink-tez/src/main/java/org/apache/flink/tez/runtime/RegularProcessor.java b/flink-contrib/flink-tez/src/main/java/org/apache/flink/tez/runtime/RegularProcessor.java new file mode 100644 index 0000000..14d9cde --- /dev/null +++ b/flink-contrib/flink-tez/src/main/java/org/apache/flink/tez/runtime/RegularProcessor.java @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.tez.runtime; + +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.TaskInfo; +import org.apache.flink.api.common.accumulators.Accumulator; +import org.apache.flink.api.common.functions.Function; +import org.apache.flink.core.fs.Path; +import org.apache.flink.runtime.operators.Driver; +import org.apache.flink.api.common.functions.util.RuntimeUDFContext; +import org.apache.flink.tez.util.EncodingUtils; +import org.apache.flink.util.InstantiationUtil; +import org.apache.hadoop.conf.Configuration; +import org.apache.tez.common.TezUtils; +import org.apache.tez.dag.api.UserPayload; +import org.apache.tez.runtime.api.AbstractLogicalIOProcessor; +import org.apache.tez.runtime.api.Event; +import org.apache.tez.runtime.api.LogicalInput; +import org.apache.tez.runtime.api.LogicalOutput; +import org.apache.tez.runtime.api.ProcessorContext; +import org.apache.tez.runtime.library.api.KeyValueReader; +import org.apache.tez.runtime.library.api.KeyValueWriter; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.Future; + + +public class RegularProcessor<S extends Function, OT> extends AbstractLogicalIOProcessor { + + private TezTask<S,OT> task; + protected Map<String, LogicalInput> inputs; + protected Map<String, LogicalOutput> outputs; + private List<KeyValueReader> readers; + private List<KeyValueWriter> writers; + private int numInputs; + private int numOutputs; + + + public RegularProcessor(ProcessorContext context) { + super(context); + } + + @Override + public void initialize() throws Exception { + UserPayload payload = getContext().getUserPayload(); + Configuration conf = TezUtils.createConfFromUserPayload(payload); + + TezTaskConfig taskConfig = (TezTaskConfig) EncodingUtils.decodeObjectFromString(conf.get(TezTaskConfig.TEZ_TASK_CONFIG), getClass().getClassLoader()); + taskConfig.setTaskName(getContext().getTaskVertexName()); + + RuntimeUDFContext runtimeUdfContext = new RuntimeUDFContext( + new TaskInfo( + getContext().getTaskVertexName(), + getContext().getTaskIndex(), + getContext().getVertexParallelism(), + getContext().getTaskAttemptNumber() + ), + getClass().getClassLoader(), + new ExecutionConfig(), + new HashMap<String, Future<Path>>(), + new HashMap<String, Accumulator<?, ?>>()); + + this.task = new TezTask<S, OT>(taskConfig, runtimeUdfContext, this.getContext().getTotalMemoryAvailableToTask()); + } + + @Override + public void handleEvents(List<Event> processorEvents) { + + } + + @Override + public void close() throws Exception { + task.getIOManager().shutdown(); + } + + @Override + public void run(Map<String, LogicalInput> inputs, Map<String, LogicalOutput> outputs) throws Exception { + + this.inputs = inputs; + this.outputs = outputs; + final Class<? extends Driver<S, OT>> driverClass = this.task.getTaskConfig().getDriver(); + Driver<S,OT> driver = InstantiationUtil.instantiate(driverClass, Driver.class); + this.numInputs = driver.getNumberOfInputs(); + this.numOutputs = outputs.size(); + + + this.readers = new ArrayList<KeyValueReader>(numInputs); + //Ensure size of list is = numInputs + for (int i = 0; i < numInputs; i++) + this.readers.add(null); + HashMap<String, ArrayList<Integer>> inputPositions = ((TezTaskConfig) this.task.getTaskConfig()).getInputPositions(); + if (this.inputs != null) { + for (String name : this.inputs.keySet()) { + LogicalInput input = this.inputs.get(name); + //if (input instanceof AbstractLogicalInput) { + // ((AbstractLogicalInput) input).initialize(); + //} + input.start(); + ArrayList<Integer> positions = inputPositions.get(name); + for (Integer pos : positions) { + //int pos = inputPositions.get(name); + readers.set(pos, (KeyValueReader) input.getReader()); + } + } + } + + this.writers = new ArrayList<KeyValueWriter>(numOutputs); + if (this.outputs != null) { + for (LogicalOutput output : this.outputs.values()) { + output.start(); + writers.add((KeyValueWriter) output.getWriter()); + } + } + + // Do the work + task.invoke (readers, writers); + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/e9bf13d8/flink-contrib/flink-tez/src/main/java/org/apache/flink/tez/runtime/TezRuntimeEnvironment.java ---------------------------------------------------------------------- diff --git a/flink-contrib/flink-tez/src/main/java/org/apache/flink/tez/runtime/TezRuntimeEnvironment.java b/flink-contrib/flink-tez/src/main/java/org/apache/flink/tez/runtime/TezRuntimeEnvironment.java new file mode 100644 index 0000000..b61a9b6 --- /dev/null +++ b/flink-contrib/flink-tez/src/main/java/org/apache/flink/tez/runtime/TezRuntimeEnvironment.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.tez.runtime; + +import org.apache.flink.core.memory.MemoryType; +import org.apache.flink.runtime.io.disk.iomanager.IOManager; +import org.apache.flink.runtime.io.disk.iomanager.IOManagerAsync; +import org.apache.flink.runtime.memory.MemoryManager; + +public class TezRuntimeEnvironment { + + private final IOManager ioManager; + + private final MemoryManager memoryManager; + + public TezRuntimeEnvironment(long totalMemory) { + this.memoryManager = new MemoryManager(totalMemory, 1, MemoryManager.DEFAULT_PAGE_SIZE, MemoryType.HEAP, true); + this.ioManager = new IOManagerAsync(); + } + + public IOManager getIOManager() { + return ioManager; + } + + public MemoryManager getMemoryManager() { + return memoryManager; + } +}