http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/SparkDataPartitioner.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/SparkDataPartitioner.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/SparkDataPartitioner.java new file mode 100644 index 0000000..031150b --- /dev/null +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/SparkDataPartitioner.java @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.runtime.controlprogram.paramserv.dp; + +import static org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils.SEED; + +import java.io.Serializable; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import org.apache.sysml.parser.Statement; +import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext; +import org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils; +import org.apache.sysml.runtime.instructions.spark.data.PartitionedBroadcast; +import org.apache.sysml.runtime.matrix.data.MatrixBlock; +import org.apache.sysml.runtime.util.DataConverter; + +public class SparkDataPartitioner implements Serializable { + + private static final long serialVersionUID = 6841548626711057448L; + private DataPartitionSparkScheme _scheme; + + protected SparkDataPartitioner(Statement.PSScheme scheme, SparkExecutionContext sec, int numEntries, int numWorkers) { + switch (scheme) { + case DISJOINT_CONTIGUOUS: + _scheme = new DCSparkScheme(); + // Create the worker id indicator + createDCIndicator(sec, numWorkers, numEntries); + break; + case DISJOINT_ROUND_ROBIN: + _scheme = new DRRSparkScheme(); + // Create the worker id indicator + createDRIndicator(sec, numWorkers, numEntries); + break; + case DISJOINT_RANDOM: + _scheme = new DRSparkScheme(); + // Create the global permutation + createGlobalPermutations(sec, numEntries, 1); + // Create the worker id indicator + createDCIndicator(sec, numWorkers, numEntries); + break; + case OVERLAP_RESHUFFLE: + _scheme = new ORSparkScheme(); + // Create the global permutation seperately for each worker + createGlobalPermutations(sec, numEntries, numWorkers); + break; + } + } + + private void createDRIndicator(SparkExecutionContext sec, int numWorkers, int numEntries) { + double[] vector = IntStream.range(0, numEntries).mapToDouble(n -> n % numWorkers).toArray(); + MatrixBlock vectorMB = DataConverter.convertToMatrixBlock(vector, true); + _scheme.setWorkerIndicator(sec.getBroadcastForMatrixObject(ParamservUtils.newMatrixObject(vectorMB))); + } + + private void createDCIndicator(SparkExecutionContext sec, int numWorkers, int numEntries) { + double[] vector = new double[numEntries]; + int batchSize = (int) Math.ceil((double) numEntries / numWorkers); + for (int i = 1; i < numWorkers; i++) { + int begin = batchSize * i; + int end = Math.min(begin + batchSize, numEntries); + Arrays.fill(vector, begin, end, i); + } + MatrixBlock vectorMB = DataConverter.convertToMatrixBlock(vector, true); + _scheme.setWorkerIndicator(sec.getBroadcastForMatrixObject(ParamservUtils.newMatrixObject(vectorMB))); + } + + private void createGlobalPermutations(SparkExecutionContext sec, int numEntries, int numPerm) { + List<PartitionedBroadcast<MatrixBlock>> perms = IntStream.range(0, numPerm).mapToObj(i -> { + MatrixBlock perm = MatrixBlock.sampleOperations(numEntries, numEntries, false, SEED+i); + // Create the source-target id vector from the permutation ranging from 1 to number of entries + double[] vector = new double[numEntries]; + for (int j = 0; j < perm.getDenseBlockValues().length; j++) { + vector[(int) perm.getDenseBlockValues()[j] - 1] = j; + } + MatrixBlock vectorMB = DataConverter.convertToMatrixBlock(vector, true); + return sec.getBroadcastForMatrixObject(ParamservUtils.newMatrixObject(vectorMB)); + }).collect(Collectors.toList()); + _scheme.setGlobalPermutation(perms); + } + + public DataPartitionSparkScheme.Result doPartitioning(int numWorkers, MatrixBlock features, MatrixBlock labels, + long rowID) { + // Set the rowID in order to get the according permutation + return _scheme.doPartitioning(numWorkers, (int) rowID, features, labels); + } +}
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/rpc/PSRpcCall.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/rpc/PSRpcCall.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/rpc/PSRpcCall.java new file mode 100644 index 0000000..8b0540b --- /dev/null +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/rpc/PSRpcCall.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.runtime.controlprogram.paramserv.rpc; + +import java.io.IOException; +import java.nio.ByteBuffer; + +import org.apache.sysml.runtime.DMLRuntimeException; +import org.apache.sysml.runtime.controlprogram.caching.CacheDataOutput; +import org.apache.sysml.runtime.instructions.cp.ListObject; +import org.apache.sysml.runtime.util.ByteBufferDataInput; + +public class PSRpcCall extends PSRpcObject { + + private int _method; + private int _workerID; + private ListObject _data; + + public PSRpcCall(int method, int workerID, ListObject data) { + _method = method; + _workerID = workerID; + _data = data; + } + + public PSRpcCall(ByteBuffer buffer) throws IOException { + deserialize(buffer); + } + + public int getMethod() { + return _method; + } + + public int getWorkerID() { + return _workerID; + } + + public ListObject getData() { + return _data; + } + + public void deserialize(ByteBuffer buffer) throws IOException { + ByteBufferDataInput dis = new ByteBufferDataInput(buffer); + _method = dis.readInt(); + validateMethod(_method); + _workerID = dis.readInt(); + if (dis.available() > 1) + _data = readAndDeserialize(dis); + } + + public ByteBuffer serialize() throws IOException { + int len = 8 + getExactSerializedSize(_data); + CacheDataOutput dos = new CacheDataOutput(len); + dos.writeInt(_method); + dos.writeInt(_workerID); + if (_data != null) + serializeAndWriteListObject(_data, dos); + return ByteBuffer.wrap(dos.getBytes()); + } + + private void validateMethod(int method) { + switch (method) { + case PUSH: + case PULL: + break; + default: + throw new DMLRuntimeException("PSRpcCall: only support rpc method 'push' or 'pull'"); + } + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/rpc/PSRpcFactory.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/rpc/PSRpcFactory.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/rpc/PSRpcFactory.java new file mode 100644 index 0000000..a7db756 --- /dev/null +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/rpc/PSRpcFactory.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.runtime.controlprogram.paramserv.rpc; + +import java.io.IOException; +import java.util.Collections; + +import org.apache.spark.SparkConf; +import org.apache.spark.network.TransportContext; +import org.apache.spark.network.netty.SparkTransportConf; +import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.util.TransportConf; +import org.apache.spark.util.LongAccumulator; +import org.apache.sysml.runtime.controlprogram.paramserv.LocalParamServer; +import org.apache.sysml.runtime.controlprogram.paramserv.SparkPSProxy; + +public class PSRpcFactory { + + private static final String MODULE_NAME = "ps"; + + private static TransportContext createTransportContext(SparkConf conf, LocalParamServer ps) { + TransportConf tc = SparkTransportConf.fromSparkConf(conf, MODULE_NAME, 0); + PSRpcHandler handler = new PSRpcHandler(ps); + return new TransportContext(tc, handler); + } + + /** + * Create and start the server + * @return server + */ + public static TransportServer createServer(SparkConf conf, LocalParamServer ps, String host) { + TransportContext context = createTransportContext(conf, ps); + return context.createServer(host, 0, Collections.emptyList()); // bind rpc to an ephemeral port + } + + public static SparkPSProxy createSparkPSProxy(SparkConf conf, int port, LongAccumulator aRPC) throws IOException { + long rpcTimeout = conf.contains("spark.rpc.askTimeout") ? + conf.getTimeAsMs("spark.rpc.askTimeout") : + conf.getTimeAsMs("spark.network.timeout", "120s"); + String host = conf.get("spark.driver.host"); + TransportContext context = createTransportContext(conf, new LocalParamServer()); + return new SparkPSProxy(context.createClientFactory().createClient(host, port), rpcTimeout, aRPC); + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/rpc/PSRpcHandler.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/rpc/PSRpcHandler.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/rpc/PSRpcHandler.java new file mode 100644 index 0000000..cf8de6d --- /dev/null +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/rpc/PSRpcHandler.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.runtime.controlprogram.paramserv.rpc; + +import static org.apache.sysml.runtime.controlprogram.paramserv.rpc.PSRpcCall.PULL; +import static org.apache.sysml.runtime.controlprogram.paramserv.rpc.PSRpcCall.PUSH; + +import java.io.IOException; +import java.nio.ByteBuffer; + +import org.apache.commons.lang.exception.ExceptionUtils; +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.server.OneForOneStreamManager; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.StreamManager; +import org.apache.sysml.runtime.DMLRuntimeException; +import org.apache.sysml.runtime.controlprogram.paramserv.LocalParamServer; +import org.apache.sysml.runtime.controlprogram.paramserv.rpc.PSRpcResponse.Type; +import org.apache.sysml.runtime.instructions.cp.ListObject; + +public final class PSRpcHandler extends RpcHandler { + + private LocalParamServer _server; + + protected PSRpcHandler(LocalParamServer server) { + _server = server; + } + + @Override + public void receive(TransportClient client, ByteBuffer buffer, RpcResponseCallback callback) { + PSRpcCall call; + try { + call = new PSRpcCall(buffer); + } catch (IOException e) { + throw new DMLRuntimeException("PSRpcHandler: some error occurred when deserializing the rpc call.", e); + } + PSRpcResponse response = null; + switch (call.getMethod()) { + case PUSH: + try { + _server.push(call.getWorkerID(), call.getData()); + response = new PSRpcResponse(Type.SUCCESS_EMPTY); + } catch (DMLRuntimeException exception) { + response = new PSRpcResponse(Type.ERROR, ExceptionUtils.getFullStackTrace(exception)); + } finally { + try { + callback.onSuccess(response.serialize()); + } catch (IOException e) { + throw new DMLRuntimeException("PSRpcHandler: some error occrred when wrapping the rpc response.", e); + } + } + break; + case PULL: + ListObject data; + try { + data = _server.pull(call.getWorkerID()); + response = new PSRpcResponse(Type.SUCCESS, data); + } catch (DMLRuntimeException exception) { + response = new PSRpcResponse(Type.ERROR, ExceptionUtils.getFullStackTrace(exception)); + } finally { + try { + callback.onSuccess(response.serialize()); + } catch (IOException e) { + throw new DMLRuntimeException("PSRpcHandler: some error occrred when wrapping the rpc response.", e); + } + } + break; + default: + throw new DMLRuntimeException(String.format("Does not support the rpc call for method %s", call.getMethod())); + } + } + + @Override + public StreamManager getStreamManager() { + return new OneForOneStreamManager(); + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/rpc/PSRpcObject.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/rpc/PSRpcObject.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/rpc/PSRpcObject.java new file mode 100644 index 0000000..38d80a2 --- /dev/null +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/rpc/PSRpcObject.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.runtime.controlprogram.paramserv.rpc; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; + +import org.apache.sysml.runtime.DMLRuntimeException; +import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; +import org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils; +import org.apache.sysml.runtime.instructions.cp.Data; +import org.apache.sysml.runtime.instructions.cp.ListObject; +import org.apache.sysml.runtime.io.IOUtilFunctions; +import org.apache.sysml.runtime.matrix.data.MatrixBlock; + +public abstract class PSRpcObject { + + public static final int PUSH = 1; + public static final int PULL = 2; + + public abstract void deserialize(ByteBuffer buffer) throws IOException; + + public abstract ByteBuffer serialize() throws IOException; + + /** + * Deep serialize and write of a list object (currently only support list containing matrices) + * @param lo a list object containing only matrices + * @param output output data to write to + */ + protected void serializeAndWriteListObject(ListObject lo, DataOutput output) throws IOException { + validateListObject(lo); + output.writeInt(lo.getLength()); //write list length + output.writeBoolean(lo.isNamedList()); //write list named + for (int i = 0; i < lo.getLength(); i++) { + if (lo.isNamedList()) + output.writeUTF(lo.getName(i)); //write name + ((MatrixObject) lo.getData().get(i)) + .acquireReadAndRelease().write(output); //write matrix + } + // Cleanup the list object + // because it is transferred to remote worker in binary format + ParamservUtils.cleanupListObject(lo); + } + + protected ListObject readAndDeserialize(DataInput input) throws IOException { + int listLen = input.readInt(); + List<Data> data = new ArrayList<>(); + List<String> names = input.readBoolean() ? + new ArrayList<>() : null; + for(int i=0; i<listLen; i++) { + if( names != null ) + names.add(input.readUTF()); + MatrixBlock mb = new MatrixBlock(); + mb.readFields(input); + data.add(ParamservUtils.newMatrixObject(mb, false)); + } + return new ListObject(data, names); + } + + /** + * Get serialization size of a list object + * (scheme: size|name|size|matrix) + * @param lo list object + * @return serialization size + */ + protected int getExactSerializedSize(ListObject lo) { + if( lo == null ) return 0; + long result = 4 + 1; // list length and of named + if (lo.isNamedList()) //size for names incl length + result += lo.getNames().stream().mapToLong(s -> IOUtilFunctions.getUTFSize(s)).sum(); + result += lo.getData().stream().mapToLong(d -> + ((MatrixObject)d).acquireReadAndRelease().getExactSizeOnDisk()).sum(); + if( result > Integer.MAX_VALUE ) + throw new DMLRuntimeException("Serialized size ("+result+") larger than Integer.MAX_VALUE."); + return (int) result; + } + + private void validateListObject(ListObject lo) { + for (Data d : lo.getData()) { + if (!(d instanceof MatrixObject)) { + throw new DMLRuntimeException(String.format("Paramserv func:" + + " Unsupported deep serialize of %s, which is not matrix.", d.getDebugName())); + } + } + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/rpc/PSRpcResponse.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/rpc/PSRpcResponse.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/rpc/PSRpcResponse.java new file mode 100644 index 0000000..68e1dd1 --- /dev/null +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/rpc/PSRpcResponse.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.runtime.controlprogram.paramserv.rpc; + +import java.io.IOException; +import java.nio.ByteBuffer; + +import org.apache.sysml.runtime.util.ByteBufferDataInput; +import org.apache.sysml.runtime.controlprogram.caching.CacheDataOutput; +import org.apache.sysml.runtime.instructions.cp.ListObject; +import org.apache.sysml.runtime.io.IOUtilFunctions; + +public class PSRpcResponse extends PSRpcObject { + public enum Type { + SUCCESS, + SUCCESS_EMPTY, + ERROR, + } + + private Type _status; + private Object _data; // Could be list object or exception + + public PSRpcResponse(ByteBuffer buffer) throws IOException { + deserialize(buffer); + } + + public PSRpcResponse(Type status) { + this(status, null); + } + + public PSRpcResponse(Type status, Object data) { + _status = status; + _data = data; + if( _status == Type.SUCCESS && data == null ) + _status = Type.SUCCESS_EMPTY; + } + + public boolean isSuccessful() { + return _status != Type.ERROR; + } + + public String getErrorMessage() { + return (String) _data; + } + + public ListObject getResultModel() { + return (ListObject) _data; + } + + @Override + public void deserialize(ByteBuffer buffer) throws IOException { + ByteBufferDataInput dis = new ByteBufferDataInput(buffer); + _status = Type.values()[dis.readInt()]; + switch (_status) { + case SUCCESS: + _data = readAndDeserialize(dis); + break; + case SUCCESS_EMPTY: + break; + case ERROR: + _data = dis.readUTF(); + break; + } + } + + @Override + public ByteBuffer serialize() throws IOException { + int len = 4 + (_status==Type.SUCCESS ? getExactSerializedSize((ListObject)_data) : + _status==Type.SUCCESS_EMPTY ? 0 : IOUtilFunctions.getUTFSize((String)_data)); + CacheDataOutput dos = new CacheDataOutput(len); + dos.writeInt(_status.ordinal()); + switch (_status) { + case SUCCESS: + serializeAndWriteListObject((ListObject) _data, dos); + break; + case SUCCESS_EMPTY: + break; + case ERROR: + dos.writeUTF(_data.toString()); + break; + } + return ByteBuffer.wrap(dos.getBytes()); + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/DCSparkScheme.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/DCSparkScheme.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/DCSparkScheme.java deleted file mode 100644 index 666b891..0000000 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/DCSparkScheme.java +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.sysml.runtime.controlprogram.paramserv.spark; - -import java.util.List; - -import org.apache.sysml.runtime.matrix.data.MatrixBlock; - -import scala.Tuple2; - -/** - * Spark Disjoint_Contiguous data partitioner: - * <p> - * For each row, find out the shifted place according to the workerID indicator - */ -public class DCSparkScheme extends DataPartitionSparkScheme { - - private static final long serialVersionUID = -2786906947020788787L; - - protected DCSparkScheme() { - // No-args constructor used for deserialization - } - - @Override - public Result doPartitioning(int numWorkers, int rblkID, MatrixBlock features, MatrixBlock labels) { - List<Tuple2<Integer, Tuple2<Long, MatrixBlock>>> pfs = nonShuffledPartition(rblkID, features); - List<Tuple2<Integer, Tuple2<Long, MatrixBlock>>> pls = nonShuffledPartition(rblkID, labels); - return new Result(pfs, pls); - } -} http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/DRRSparkScheme.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/DRRSparkScheme.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/DRRSparkScheme.java deleted file mode 100644 index 7683251..0000000 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/DRRSparkScheme.java +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.sysml.runtime.controlprogram.paramserv.spark; - -import java.util.List; - -import org.apache.sysml.runtime.matrix.data.MatrixBlock; - -import scala.Tuple2; - -/** - * Spark Disjoint_Round_Robin data partitioner: - */ -public class DRRSparkScheme extends DataPartitionSparkScheme { - - private static final long serialVersionUID = -3130831851505549672L; - - protected DRRSparkScheme() { - // No-args constructor used for deserialization - } - - @Override - public Result doPartitioning(int numWorkers, int rblkID, MatrixBlock features, MatrixBlock labels) { - List<Tuple2<Integer, Tuple2<Long, MatrixBlock>>> pfs = nonShuffledPartition(rblkID, features); - List<Tuple2<Integer, Tuple2<Long, MatrixBlock>>> pls = nonShuffledPartition(rblkID, labels); - return new Result(pfs, pls); - } -} http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/DRSparkScheme.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/DRSparkScheme.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/DRSparkScheme.java deleted file mode 100644 index 51cc523..0000000 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/DRSparkScheme.java +++ /dev/null @@ -1,69 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.sysml.runtime.controlprogram.paramserv.spark; - -import java.util.List; -import java.util.stream.Collectors; -import java.util.stream.IntStream; - -import org.apache.sysml.hops.OptimizerUtils; -import org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils; -import org.apache.sysml.runtime.matrix.data.MatrixBlock; - -import scala.Tuple2; - -/** - * Spark data partitioner Disjoint_Random: - * - * For the current row block, find all the shifted place for each row (WorkerID => (row block ID, matrix) - */ -public class DRSparkScheme extends DataPartitionSparkScheme { - - private static final long serialVersionUID = -7655310624144544544L; - - protected DRSparkScheme() { - // No-args constructor used for deserialization - } - - @Override - public Result doPartitioning(int numWorkers, int rblkID, MatrixBlock features, MatrixBlock labels) { - List<Tuple2<Integer, Tuple2<Long, MatrixBlock>>> pfs = partition(rblkID, features); - List<Tuple2<Integer, Tuple2<Long, MatrixBlock>>> pls = partition(rblkID, labels); - return new Result(pfs, pls); - } - - private List<Tuple2<Integer, Tuple2<Long, MatrixBlock>>> partition(int rblkID, MatrixBlock mb) { - MatrixBlock partialPerm = _globalPerms.get(0).getBlock(rblkID, 1); - - // For each row, find out the shifted place - return IntStream.range(0, mb.getNumRows()).mapToObj(r -> { - MatrixBlock rowMB = ParamservUtils.sliceMatrixBlock(mb, r + 1, r + 1); - long shiftedPosition = (long) partialPerm.getValue(r, 0); - - // Get the shifted block and position - int shiftedBlkID = (int) (shiftedPosition / OptimizerUtils.DEFAULT_BLOCKSIZE + 1); - - MatrixBlock indicator = _workerIndicator.getBlock(shiftedBlkID, 1); - int workerID = (int) indicator.getValue((int) shiftedPosition / OptimizerUtils.DEFAULT_BLOCKSIZE, 0); - return new Tuple2<>(workerID, new Tuple2<>(shiftedPosition, rowMB)); - }).collect(Collectors.toList()); - } - -} http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/DataPartitionSparkScheme.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/DataPartitionSparkScheme.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/DataPartitionSparkScheme.java deleted file mode 100644 index 9875dd2..0000000 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/DataPartitionSparkScheme.java +++ /dev/null @@ -1,76 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.sysml.runtime.controlprogram.paramserv.spark; - -import java.io.Serializable; -import java.util.List; -import java.util.stream.Collectors; -import java.util.stream.LongStream; - -import org.apache.sysml.hops.OptimizerUtils; -import org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils; -import org.apache.sysml.runtime.instructions.spark.data.PartitionedBroadcast; -import org.apache.sysml.runtime.matrix.data.MatrixBlock; - -import scala.Tuple2; - -public abstract class DataPartitionSparkScheme implements Serializable { - - protected final class Result { - protected final List<Tuple2<Integer, Tuple2<Long, MatrixBlock>>> pFeatures; // WorkerID => (rowID, matrix) - protected final List<Tuple2<Integer, Tuple2<Long, MatrixBlock>>> pLabels; - - protected Result(List<Tuple2<Integer, Tuple2<Long, MatrixBlock>>> pFeatures, List<Tuple2<Integer, Tuple2<Long, MatrixBlock>>> pLabels) { - this.pFeatures = pFeatures; - this.pLabels = pLabels; - } - } - - private static final long serialVersionUID = -3462829818083371171L; - - protected List<PartitionedBroadcast<MatrixBlock>> _globalPerms; // a list of global permutations - protected PartitionedBroadcast<MatrixBlock> _workerIndicator; // a matrix indicating to which worker the given row belongs - - protected void setGlobalPermutation(List<PartitionedBroadcast<MatrixBlock>> gps) { - _globalPerms = gps; - } - - protected void setWorkerIndicator(PartitionedBroadcast<MatrixBlock> wi) { - _workerIndicator = wi; - } - - /** - * Do non-reshuffled data partitioning according to worker indicator - * @param rblkID row block ID - * @param mb Matrix - * @return list of tuple (workerID, (row block ID, matrix row)) - */ - protected List<Tuple2<Integer, Tuple2<Long, MatrixBlock>>> nonShuffledPartition(int rblkID, MatrixBlock mb) { - MatrixBlock indicator = _workerIndicator.getBlock(rblkID, 1); - return LongStream.range(0, mb.getNumRows()).mapToObj(r -> { - int workerID = (int) indicator.getValue((int) r, 0); - MatrixBlock rowMB = ParamservUtils.sliceMatrixBlock(mb, r + 1, r + 1); - long shiftedPosition = r + (rblkID - 1) * OptimizerUtils.DEFAULT_BLOCKSIZE; - return new Tuple2<>(workerID, new Tuple2<>(shiftedPosition, rowMB)); - }).collect(Collectors.toList()); - } - - public abstract Result doPartitioning(int numWorkers, int rblkID, MatrixBlock features, MatrixBlock labels); -} http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/DataPartitionerSparkAggregator.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/DataPartitionerSparkAggregator.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/DataPartitionerSparkAggregator.java deleted file mode 100644 index 39b8adf..0000000 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/DataPartitionerSparkAggregator.java +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.sysml.runtime.controlprogram.paramserv.spark; - -import java.io.Serializable; -import java.util.LinkedList; - -import org.apache.spark.api.java.function.PairFunction; -import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; -import org.apache.sysml.runtime.matrix.data.MatrixBlock; - -import scala.Tuple2; - -public class DataPartitionerSparkAggregator implements PairFunction<Tuple2<Integer,LinkedList<Tuple2<Long,Tuple2<MatrixBlock,MatrixBlock>>>>, Integer, Tuple2<MatrixBlock, MatrixBlock>>, Serializable { - - private static final long serialVersionUID = -1245300852709085117L; - private long _fcol; - private long _lcol; - - public DataPartitionerSparkAggregator() { - - } - - public DataPartitionerSparkAggregator(long fcol, long lcol) { - _fcol = fcol; - _lcol = lcol; - } - - /** - * Row-wise combine the matrix - * @param input workerID => ordered list [(rowBlockID, (features, labels))] - * @return workerID => [(features, labels)] - * @throws Exception Some exception - */ - @Override - public Tuple2<Integer, Tuple2<MatrixBlock, MatrixBlock>> call(Tuple2<Integer, LinkedList<Tuple2<Long, Tuple2<MatrixBlock, MatrixBlock>>>> input) throws Exception { - MatrixBlock fmb = new MatrixBlock(input._2.size(), (int) _fcol, false); - MatrixBlock lmb = new MatrixBlock(input._2.size(), (int) _lcol, false); - - for (int i = 0; i < input._2.size(); i++) { - MatrixBlock tmpFMB = input._2.get(i)._2._1; - MatrixBlock tmpLMB = input._2.get(i)._2._2; - // Row-wise aggregation - fmb = fmb.leftIndexingOperations(tmpFMB, i, i, 0, (int) _fcol - 1, fmb, MatrixObject.UpdateType.INPLACE_PINNED); - lmb = lmb.leftIndexingOperations(tmpLMB, i, i, 0, (int) _lcol - 1, lmb, MatrixObject.UpdateType.INPLACE_PINNED); - } - return new Tuple2<>(input._1, new Tuple2<>(fmb, lmb)); - } -} http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/DataPartitionerSparkMapper.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/DataPartitionerSparkMapper.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/DataPartitionerSparkMapper.java deleted file mode 100644 index 2a69986..0000000 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/DataPartitionerSparkMapper.java +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.sysml.runtime.controlprogram.paramserv.spark; - -import java.io.Serializable; -import java.util.Iterator; -import java.util.LinkedList; -import java.util.List; - -import org.apache.spark.api.java.function.PairFlatMapFunction; -import org.apache.sysml.parser.Statement; -import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext; -import org.apache.sysml.runtime.matrix.data.MatrixBlock; - -import scala.Tuple2; - -public class DataPartitionerSparkMapper implements PairFlatMapFunction<Tuple2<Long, Tuple2<MatrixBlock, MatrixBlock>>, Integer, Tuple2<Long, Tuple2<MatrixBlock, MatrixBlock>>>, Serializable { - - private static final long serialVersionUID = 1710721606050403296L; - private int _workersNum; - - private SparkDataPartitioner _dp; - - protected DataPartitionerSparkMapper() { - // No-args constructor used for deserialization - } - - public DataPartitionerSparkMapper(Statement.PSScheme scheme, int workersNum, SparkExecutionContext sec, int numEntries) { - _workersNum = workersNum; - _dp = new SparkDataPartitioner(scheme, sec, numEntries, workersNum); - } - - /** - * Do data partitioning - * @param input RowBlockID => (features, labels) - * @return WorkerID => (rowBlockID, (single row features, single row labels)) - * @throws Exception Some exception - */ - @Override - public Iterator<Tuple2<Integer, Tuple2<Long, Tuple2<MatrixBlock, MatrixBlock>>>> call(Tuple2<Long,Tuple2<MatrixBlock,MatrixBlock>> input) - throws Exception { - List<Tuple2<Integer, Tuple2<Long,Tuple2<MatrixBlock,MatrixBlock>>>> partitions = new LinkedList<>(); - MatrixBlock features = input._2._1; - MatrixBlock labels = input._2._2; - DataPartitionSparkScheme.Result result = _dp.doPartitioning(_workersNum, features, labels, input._1); - for (int i = 0; i < result.pFeatures.size(); i++) { - Tuple2<Integer, Tuple2<Long, MatrixBlock>> ft = result.pFeatures.get(i); - Tuple2<Integer, Tuple2<Long, MatrixBlock>> lt = result.pLabels.get(i); - partitions.add(new Tuple2<>(ft._1, new Tuple2<>(ft._2._1, new Tuple2<>(ft._2._2, lt._2._2)))); - } - return partitions.iterator(); - } -} http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/ORSparkScheme.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/ORSparkScheme.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/ORSparkScheme.java deleted file mode 100644 index 16ce516..0000000 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/ORSparkScheme.java +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.sysml.runtime.controlprogram.paramserv.spark; - -import java.util.List; -import java.util.stream.Collectors; -import java.util.stream.IntStream; - -import org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils; -import org.apache.sysml.runtime.matrix.data.MatrixBlock; - -import scala.Tuple2; - -/** - * Spark data partitioner Overlap_Reshuffle: - * - */ -public class ORSparkScheme extends DataPartitionSparkScheme { - - private static final long serialVersionUID = 6867567406403580311L; - - protected ORSparkScheme() { - // No-args constructor used for deserialization - } - - @Override - public Result doPartitioning(int numWorkers, int rblkID, MatrixBlock features, MatrixBlock labels) { - List<Tuple2<Integer, Tuple2<Long, MatrixBlock>>> pfs = partition(numWorkers, rblkID, features); - List<Tuple2<Integer, Tuple2<Long, MatrixBlock>>> pls = partition(numWorkers, rblkID, labels); - return new Result(pfs, pls); - } - - private List<Tuple2<Integer, Tuple2<Long, MatrixBlock>>> partition(int numWorkers, int rblkID, MatrixBlock mb) { - return IntStream.range(0, numWorkers).mapToObj(i -> i).flatMap(workerID -> { - MatrixBlock partialPerm = _globalPerms.get(workerID).getBlock(rblkID, 1); - return IntStream.range(0, mb.getNumRows()).mapToObj(r -> { - MatrixBlock rowMB = ParamservUtils.sliceMatrixBlock(mb, r + 1, r + 1); - long shiftedPosition = (long) partialPerm.getValue(r, 0); - return new Tuple2<>(workerID, new Tuple2<>(shiftedPosition, rowMB)); - }); - }).collect(Collectors.toList()); - } -} http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkDataPartitioner.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkDataPartitioner.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkDataPartitioner.java deleted file mode 100644 index 6883d0f..0000000 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkDataPartitioner.java +++ /dev/null @@ -1,106 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.sysml.runtime.controlprogram.paramserv.spark; - -import static org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils.SEED; - -import java.io.Serializable; -import java.util.Arrays; -import java.util.List; -import java.util.stream.Collectors; -import java.util.stream.IntStream; - -import org.apache.sysml.parser.Statement; -import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext; -import org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils; -import org.apache.sysml.runtime.instructions.spark.data.PartitionedBroadcast; -import org.apache.sysml.runtime.matrix.data.MatrixBlock; -import org.apache.sysml.runtime.util.DataConverter; - -public class SparkDataPartitioner implements Serializable { - - private static final long serialVersionUID = 6841548626711057448L; - private DataPartitionSparkScheme _scheme; - - protected SparkDataPartitioner(Statement.PSScheme scheme, SparkExecutionContext sec, int numEntries, int numWorkers) { - switch (scheme) { - case DISJOINT_CONTIGUOUS: - _scheme = new DCSparkScheme(); - // Create the worker id indicator - createDCIndicator(sec, numWorkers, numEntries); - break; - case DISJOINT_ROUND_ROBIN: - _scheme = new DRRSparkScheme(); - // Create the worker id indicator - createDRIndicator(sec, numWorkers, numEntries); - break; - case DISJOINT_RANDOM: - _scheme = new DRSparkScheme(); - // Create the global permutation - createGlobalPermutations(sec, numEntries, 1); - // Create the worker id indicator - createDCIndicator(sec, numWorkers, numEntries); - break; - case OVERLAP_RESHUFFLE: - _scheme = new ORSparkScheme(); - // Create the global permutation seperately for each worker - createGlobalPermutations(sec, numEntries, numWorkers); - break; - } - } - - private void createDRIndicator(SparkExecutionContext sec, int numWorkers, int numEntries) { - double[] vector = IntStream.range(0, numEntries).mapToDouble(n -> n % numWorkers).toArray(); - MatrixBlock vectorMB = DataConverter.convertToMatrixBlock(vector, true); - _scheme.setWorkerIndicator(sec.getBroadcastForMatrixObject(ParamservUtils.newMatrixObject(vectorMB))); - } - - private void createDCIndicator(SparkExecutionContext sec, int numWorkers, int numEntries) { - double[] vector = new double[numEntries]; - int batchSize = (int) Math.ceil((double) numEntries / numWorkers); - for (int i = 1; i < numWorkers; i++) { - int begin = batchSize * i; - int end = Math.min(begin + batchSize, numEntries); - Arrays.fill(vector, begin, end, i); - } - MatrixBlock vectorMB = DataConverter.convertToMatrixBlock(vector, true); - _scheme.setWorkerIndicator(sec.getBroadcastForMatrixObject(ParamservUtils.newMatrixObject(vectorMB))); - } - - private void createGlobalPermutations(SparkExecutionContext sec, int numEntries, int numPerm) { - List<PartitionedBroadcast<MatrixBlock>> perms = IntStream.range(0, numPerm).mapToObj(i -> { - MatrixBlock perm = MatrixBlock.sampleOperations(numEntries, numEntries, false, SEED+i); - // Create the source-target id vector from the permutation ranging from 1 to number of entries - double[] vector = new double[numEntries]; - for (int j = 0; j < perm.getDenseBlockValues().length; j++) { - vector[(int) perm.getDenseBlockValues()[j] - 1] = j; - } - MatrixBlock vectorMB = DataConverter.convertToMatrixBlock(vector, true); - return sec.getBroadcastForMatrixObject(ParamservUtils.newMatrixObject(vectorMB)); - }).collect(Collectors.toList()); - _scheme.setGlobalPermutation(perms); - } - - public DataPartitionSparkScheme.Result doPartitioning(int numWorkers, MatrixBlock features, MatrixBlock labels, - long rowID) { - // Set the rowID in order to get the according permutation - return _scheme.doPartitioning(numWorkers, (int) rowID, features, labels); - } -} http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSBody.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSBody.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSBody.java deleted file mode 100644 index 9354025..0000000 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSBody.java +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.sysml.runtime.controlprogram.paramserv.spark; - -import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; - -/** - * Wrapper class containing all needed for launching spark remote worker - */ -public class SparkPSBody { - - private ExecutionContext _ec; - - public SparkPSBody() {} - - public SparkPSBody(ExecutionContext ec) { - _ec = ec; - } - - public ExecutionContext getEc() { - return _ec; - } - - public void setEc(ExecutionContext ec) { - this._ec = ec; - } -} http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSProxy.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSProxy.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSProxy.java deleted file mode 100644 index 48a4883..0000000 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSProxy.java +++ /dev/null @@ -1,85 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.sysml.runtime.controlprogram.paramserv.spark; - -import static org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcObject.PULL; -import static org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcObject.PUSH; - -import java.io.IOException; - -import org.apache.spark.network.client.TransportClient; -import org.apache.spark.util.LongAccumulator; -import org.apache.sysml.api.DMLScript; -import org.apache.sysml.runtime.DMLRuntimeException; -import org.apache.sysml.runtime.controlprogram.paramserv.ParamServer; -import org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcCall; -import org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcResponse; -import org.apache.sysml.runtime.controlprogram.parfor.stat.Timing; -import org.apache.sysml.runtime.instructions.cp.ListObject; - -public class SparkPSProxy extends ParamServer { - - private final TransportClient _client; - private final long _rpcTimeout; - private final LongAccumulator _aRPC; - - public SparkPSProxy(TransportClient client, long rpcTimeout, LongAccumulator aRPC) { - super(); - _client = client; - _rpcTimeout = rpcTimeout; - _aRPC = aRPC; - } - - private void accRpcRequestTime(Timing tRpc) { - if (DMLScript.STATISTICS) - _aRPC.add((long) tRpc.stop()); - } - - @Override - public void push(int workerID, ListObject value) { - Timing tRpc = DMLScript.STATISTICS ? new Timing(true) : null; - PSRpcResponse response; - try { - response = new PSRpcResponse(_client.sendRpcSync(new PSRpcCall(PUSH, workerID, value).serialize(), _rpcTimeout)); - } catch (IOException e) { - throw new DMLRuntimeException(String.format("SparkPSProxy: spark worker_%d failed to push gradients.", workerID), e); - } - accRpcRequestTime(tRpc); - if (!response.isSuccessful()) { - throw new DMLRuntimeException(String.format("SparkPSProxy: spark worker_%d failed to push gradients. \n%s", workerID, response.getErrorMessage())); - } - } - - @Override - public ListObject pull(int workerID) { - Timing tRpc = DMLScript.STATISTICS ? new Timing(true) : null; - PSRpcResponse response; - try { - response = new PSRpcResponse(_client.sendRpcSync(new PSRpcCall(PULL, workerID, null).serialize(), _rpcTimeout)); - } catch (IOException e) { - throw new DMLRuntimeException(String.format("SparkPSProxy: spark worker_%d failed to pull models.", workerID), e); - } - accRpcRequestTime(tRpc); - if (!response.isSuccessful()) { - throw new DMLRuntimeException(String.format("SparkPSProxy: spark worker_%d failed to pull models. \n%s", workerID, response.getErrorMessage())); - } - return response.getResultModel(); - } -} http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSWorker.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSWorker.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSWorker.java deleted file mode 100644 index cb3e729..0000000 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSWorker.java +++ /dev/null @@ -1,168 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.sysml.runtime.controlprogram.paramserv.spark; - -import java.io.IOException; -import java.util.HashMap; -import java.util.Map; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.function.VoidFunction; -import org.apache.spark.util.LongAccumulator; -import org.apache.sysml.parser.Statement; -import org.apache.sysml.runtime.codegen.CodegenUtils; -import org.apache.sysml.runtime.controlprogram.paramserv.LocalPSWorker; -import org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils; -import org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcFactory; -import org.apache.sysml.runtime.controlprogram.parfor.RemoteParForUtils; -import org.apache.sysml.runtime.controlprogram.parfor.stat.Timing; -import org.apache.sysml.runtime.matrix.data.MatrixBlock; -import org.apache.sysml.runtime.util.ProgramConverter; - -import scala.Tuple2; - -public class SparkPSWorker extends LocalPSWorker implements VoidFunction<Tuple2<Integer, Tuple2<MatrixBlock, MatrixBlock>>> { - - private static final long serialVersionUID = -8674739573419648732L; - - private final String _program; - private final HashMap<String, byte[]> _clsMap; - private final SparkConf _conf; - private final int _port; // rpc port - private final String _aggFunc; - private final LongAccumulator _aSetup; // accumulator for setup time - private final LongAccumulator _aWorker; // accumulator for worker number - private final LongAccumulator _aUpdate; // accumulator for model update - private final LongAccumulator _aIndex; // accumulator for batch indexing - private final LongAccumulator _aGrad; // accumulator for gradients computing - private final LongAccumulator _aRPC; // accumulator for rpc request - private final LongAccumulator _nBatches; //number of executed batches - private final LongAccumulator _nEpochs; //number of executed epoches - - public SparkPSWorker(String updFunc, String aggFunc, Statement.PSFrequency freq, int epochs, long batchSize, String program, HashMap<String, byte[]> clsMap, SparkConf conf, int port, LongAccumulator aSetup, LongAccumulator aWorker, LongAccumulator aUpdate, LongAccumulator aIndex, LongAccumulator aGrad, LongAccumulator aRPC, LongAccumulator aBatches, LongAccumulator aEpochs) { - _updFunc = updFunc; - _aggFunc = aggFunc; - _freq = freq; - _epochs = epochs; - _batchSize = batchSize; - _program = program; - _clsMap = clsMap; - _conf = conf; - _port = port; - _aSetup = aSetup; - _aWorker = aWorker; - _aUpdate = aUpdate; - _aIndex = aIndex; - _aGrad = aGrad; - _aRPC = aRPC; - _nBatches = aBatches; - _nEpochs = aEpochs; - } - - @Override - public String getWorkerName() { - return String.format("Spark worker_%d", _workerID); - } - - @Override - public void call(Tuple2<Integer, Tuple2<MatrixBlock, MatrixBlock>> input) throws Exception { - Timing tSetup = new Timing(true); - configureWorker(input); - accSetupTime(tSetup); - - call(); // Launch the worker - } - - private void configureWorker(Tuple2<Integer, Tuple2<MatrixBlock, MatrixBlock>> input) throws IOException { - _workerID = input._1; - - // Initialize codegen class cache (before program parsing) - for (Map.Entry<String, byte[]> e : _clsMap.entrySet()) { - CodegenUtils.getClassSync(e.getKey(), e.getValue()); - } - - // Deserialize the body to initialize the execution context - SparkPSBody body = ProgramConverter.parseSparkPSBody(_program, _workerID); - _ec = body.getEc(); - - // Initialize the buffer pool and register it in the jvm shutdown hook in order to be cleanuped at the end - RemoteParForUtils.setupBufferPool(_workerID); - - // Get some configurations - long rpcTimeout = _conf.contains("spark.rpc.askTimeout") ? - _conf.getTimeAsMs("spark.rpc.askTimeout") : - _conf.getTimeAsMs("spark.network.timeout", "120s"); - String host = _conf.get("spark.driver.host"); - - // Create the ps proxy - _ps = PSRpcFactory.createSparkPSProxy(_conf, host, _port, rpcTimeout, _aRPC); - - // Initialize the update function - setupUpdateFunction(_updFunc, _ec); - - // Initialize the agg function - _ps.setupAggFunc(_ec, _aggFunc); - - // Lazy initialize the matrix of features and labels - setFeatures(ParamservUtils.newMatrixObject(input._2._1)); - setLabels(ParamservUtils.newMatrixObject(input._2._2)); - _features.enableCleanup(false); - _labels.enableCleanup(false); - } - - - @Override - protected void incWorkerNumber() { - _aWorker.add(1); - } - - @Override - protected void accLocalModelUpdateTime(Timing time) { - if( time != null ) - _aUpdate.add((long) time.stop()); - } - - @Override - protected void accBatchIndexingTime(Timing time) { - if( time != null ) - _aIndex.add((long) time.stop()); - } - - @Override - protected void accGradientComputeTime(Timing time) { - if( time != null ) - _aGrad.add((long) time.stop()); - } - - @Override - protected void accNumEpochs(int n) { - _nEpochs.add(n); - } - - @Override - protected void accNumBatches(int n) { - _nBatches.add(n); - } - - private void accSetupTime(Timing time) { - if( time != null ) - _aSetup.add((long) time.stop()); - } -} http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcCall.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcCall.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcCall.java deleted file mode 100644 index a33fda2..0000000 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcCall.java +++ /dev/null @@ -1,86 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc; - -import java.io.IOException; -import java.nio.ByteBuffer; - -import org.apache.sysml.runtime.DMLRuntimeException; -import org.apache.sysml.runtime.controlprogram.caching.CacheDataOutput; -import org.apache.sysml.runtime.instructions.cp.ListObject; -import org.apache.sysml.runtime.util.ByteBufferDataInput; - -public class PSRpcCall extends PSRpcObject { - - private int _method; - private int _workerID; - private ListObject _data; - - public PSRpcCall(int method, int workerID, ListObject data) { - _method = method; - _workerID = workerID; - _data = data; - } - - public PSRpcCall(ByteBuffer buffer) throws IOException { - deserialize(buffer); - } - - public int getMethod() { - return _method; - } - - public int getWorkerID() { - return _workerID; - } - - public ListObject getData() { - return _data; - } - - public void deserialize(ByteBuffer buffer) throws IOException { - ByteBufferDataInput dis = new ByteBufferDataInput(buffer); - _method = dis.readInt(); - validateMethod(_method); - _workerID = dis.readInt(); - if (dis.available() > 1) - _data = readAndDeserialize(dis); - } - - public ByteBuffer serialize() throws IOException { - int len = 8 + getExactSerializedSize(_data); - CacheDataOutput dos = new CacheDataOutput(len); - dos.writeInt(_method); - dos.writeInt(_workerID); - if (_data != null) - serializeAndWriteListObject(_data, dos); - return ByteBuffer.wrap(dos.getBytes()); - } - - private void validateMethod(int method) { - switch (method) { - case PUSH: - case PULL: - break; - default: - throw new DMLRuntimeException("PSRpcCall: only support rpc method 'push' or 'pull'"); - } - } -} http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcFactory.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcFactory.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcFactory.java deleted file mode 100644 index 5e76d23..0000000 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcFactory.java +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc; - -import java.io.IOException; -import java.util.Collections; - -import org.apache.spark.SparkConf; -import org.apache.spark.network.TransportContext; -import org.apache.spark.network.netty.SparkTransportConf; -import org.apache.spark.network.server.TransportServer; -import org.apache.spark.network.util.TransportConf; -import org.apache.spark.util.LongAccumulator; -import org.apache.sysml.runtime.controlprogram.paramserv.LocalParamServer; -import org.apache.sysml.runtime.controlprogram.paramserv.spark.SparkPSProxy; - -public class PSRpcFactory { - - private static final String MODULE_NAME = "ps"; - - private static TransportContext createTransportContext(SparkConf conf, LocalParamServer ps) { - TransportConf tc = SparkTransportConf.fromSparkConf(conf, MODULE_NAME, 0); - PSRpcHandler handler = new PSRpcHandler(ps); - return new TransportContext(tc, handler); - } - - /** - * Create and start the server - * @return server - */ - public static TransportServer createServer(SparkConf conf, LocalParamServer ps, String host) { - TransportContext context = createTransportContext(conf, ps); - return context.createServer(host, 0, Collections.emptyList()); // bind rpc to an ephemeral port - } - - public static SparkPSProxy createSparkPSProxy(SparkConf conf, String host, int port, long rpcTimeout, LongAccumulator aRPC) throws IOException { - TransportContext context = createTransportContext(conf, new LocalParamServer()); - return new SparkPSProxy(context.createClientFactory().createClient(host, port), rpcTimeout, aRPC); - } -} http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcHandler.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcHandler.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcHandler.java deleted file mode 100644 index a2c311e..0000000 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcHandler.java +++ /dev/null @@ -1,95 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc; - -import static org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcCall.PULL; -import static org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcCall.PUSH; - -import java.io.IOException; -import java.nio.ByteBuffer; - -import org.apache.commons.lang.exception.ExceptionUtils; -import org.apache.spark.network.client.RpcResponseCallback; -import org.apache.spark.network.client.TransportClient; -import org.apache.spark.network.server.OneForOneStreamManager; -import org.apache.spark.network.server.RpcHandler; -import org.apache.spark.network.server.StreamManager; -import org.apache.sysml.runtime.DMLRuntimeException; -import org.apache.sysml.runtime.controlprogram.paramserv.LocalParamServer; -import org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcResponse.Type; -import org.apache.sysml.runtime.instructions.cp.ListObject; - -public final class PSRpcHandler extends RpcHandler { - - private LocalParamServer _server; - - protected PSRpcHandler(LocalParamServer server) { - _server = server; - } - - @Override - public void receive(TransportClient client, ByteBuffer buffer, RpcResponseCallback callback) { - PSRpcCall call; - try { - call = new PSRpcCall(buffer); - } catch (IOException e) { - throw new DMLRuntimeException("PSRpcHandler: some error occurred when deserializing the rpc call.", e); - } - PSRpcResponse response = null; - switch (call.getMethod()) { - case PUSH: - try { - _server.push(call.getWorkerID(), call.getData()); - response = new PSRpcResponse(Type.SUCCESS_EMPTY); - } catch (DMLRuntimeException exception) { - response = new PSRpcResponse(Type.ERROR, ExceptionUtils.getFullStackTrace(exception)); - } finally { - try { - callback.onSuccess(response.serialize()); - } catch (IOException e) { - throw new DMLRuntimeException("PSRpcHandler: some error occrred when wrapping the rpc response.", e); - } - } - break; - case PULL: - ListObject data; - try { - data = _server.pull(call.getWorkerID()); - response = new PSRpcResponse(Type.SUCCESS, data); - } catch (DMLRuntimeException exception) { - response = new PSRpcResponse(Type.ERROR, ExceptionUtils.getFullStackTrace(exception)); - } finally { - try { - callback.onSuccess(response.serialize()); - } catch (IOException e) { - throw new DMLRuntimeException("PSRpcHandler: some error occrred when wrapping the rpc response.", e); - } - } - break; - default: - throw new DMLRuntimeException(String.format("Does not support the rpc call for method %s", call.getMethod())); - } - } - - @Override - public StreamManager getStreamManager() { - return new OneForOneStreamManager(); - } -} http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcObject.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcObject.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcObject.java deleted file mode 100644 index 816cefd..0000000 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcObject.java +++ /dev/null @@ -1,107 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc; - -import java.io.DataInput; -import java.io.DataOutput; -import java.io.IOException; -import java.nio.ByteBuffer; -import java.util.ArrayList; -import java.util.List; - -import org.apache.sysml.runtime.DMLRuntimeException; -import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; -import org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils; -import org.apache.sysml.runtime.instructions.cp.Data; -import org.apache.sysml.runtime.instructions.cp.ListObject; -import org.apache.sysml.runtime.io.IOUtilFunctions; -import org.apache.sysml.runtime.matrix.data.MatrixBlock; - -public abstract class PSRpcObject { - - public static final int PUSH = 1; - public static final int PULL = 2; - - public abstract void deserialize(ByteBuffer buffer) throws IOException; - - public abstract ByteBuffer serialize() throws IOException; - - /** - * Deep serialize and write of a list object (currently only support list containing matrices) - * @param lo a list object containing only matrices - * @param output output data to write to - */ - protected void serializeAndWriteListObject(ListObject lo, DataOutput output) throws IOException { - validateListObject(lo); - output.writeInt(lo.getLength()); //write list length - output.writeBoolean(lo.isNamedList()); //write list named - for (int i = 0; i < lo.getLength(); i++) { - if (lo.isNamedList()) - output.writeUTF(lo.getName(i)); //write name - ((MatrixObject) lo.getData().get(i)) - .acquireReadAndRelease().write(output); //write matrix - } - // Cleanup the list object - // because it is transferred to remote worker in binary format - ParamservUtils.cleanupListObject(lo); - } - - protected ListObject readAndDeserialize(DataInput input) throws IOException { - int listLen = input.readInt(); - List<Data> data = new ArrayList<>(); - List<String> names = input.readBoolean() ? - new ArrayList<>() : null; - for(int i=0; i<listLen; i++) { - if( names != null ) - names.add(input.readUTF()); - MatrixBlock mb = new MatrixBlock(); - mb.readFields(input); - data.add(ParamservUtils.newMatrixObject(mb, false)); - } - return new ListObject(data, names); - } - - /** - * Get serialization size of a list object - * (scheme: size|name|size|matrix) - * @param lo list object - * @return serialization size - */ - protected int getExactSerializedSize(ListObject lo) { - if( lo == null ) return 0; - long result = 4 + 1; // list length and of named - if (lo.isNamedList()) //size for names incl length - result += lo.getNames().stream().mapToLong(s -> IOUtilFunctions.getUTFSize(s)).sum(); - result += lo.getData().stream().mapToLong(d -> - ((MatrixObject)d).acquireReadAndRelease().getExactSizeOnDisk()).sum(); - if( result > Integer.MAX_VALUE ) - throw new DMLRuntimeException("Serialized size ("+result+") larger than Integer.MAX_VALUE."); - return (int) result; - } - - private void validateListObject(ListObject lo) { - for (Data d : lo.getData()) { - if (!(d instanceof MatrixObject)) { - throw new DMLRuntimeException(String.format("Paramserv func:" - + " Unsupported deep serialize of %s, which is not matrix.", d.getDebugName())); - } - } - } -} http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcResponse.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcResponse.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcResponse.java deleted file mode 100644 index 010481e..0000000 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcResponse.java +++ /dev/null @@ -1,101 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc; - -import java.io.IOException; -import java.nio.ByteBuffer; - -import org.apache.sysml.runtime.util.ByteBufferDataInput; -import org.apache.sysml.runtime.controlprogram.caching.CacheDataOutput; -import org.apache.sysml.runtime.instructions.cp.ListObject; -import org.apache.sysml.runtime.io.IOUtilFunctions; - -public class PSRpcResponse extends PSRpcObject { - public enum Type { - SUCCESS, - SUCCESS_EMPTY, - ERROR, - } - - private Type _status; - private Object _data; // Could be list object or exception - - public PSRpcResponse(ByteBuffer buffer) throws IOException { - deserialize(buffer); - } - - public PSRpcResponse(Type status) { - this(status, null); - } - - public PSRpcResponse(Type status, Object data) { - _status = status; - _data = data; - if( _status == Type.SUCCESS && data == null ) - _status = Type.SUCCESS_EMPTY; - } - - public boolean isSuccessful() { - return _status != Type.ERROR; - } - - public String getErrorMessage() { - return (String) _data; - } - - public ListObject getResultModel() { - return (ListObject) _data; - } - - @Override - public void deserialize(ByteBuffer buffer) throws IOException { - ByteBufferDataInput dis = new ByteBufferDataInput(buffer); - _status = Type.values()[dis.readInt()]; - switch (_status) { - case SUCCESS: - _data = readAndDeserialize(dis); - break; - case SUCCESS_EMPTY: - break; - case ERROR: - _data = dis.readUTF(); - break; - } - } - - @Override - public ByteBuffer serialize() throws IOException { - int len = 4 + (_status==Type.SUCCESS ? getExactSerializedSize((ListObject)_data) : - _status==Type.SUCCESS_EMPTY ? 0 : IOUtilFunctions.getUTFSize((String)_data)); - CacheDataOutput dos = new CacheDataOutput(len); - dos.writeInt(_status.ordinal()); - switch (_status) { - case SUCCESS: - serializeAndWriteListObject((ListObject) _data, dos); - break; - case SUCCESS_EMPTY: - break; - case ERROR: - dos.writeUTF(_data.toString()); - break; - } - return ByteBuffer.wrap(dos.getBytes()); - } -} http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java index 6220bb6..83ec3f7 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java @@ -65,15 +65,15 @@ import org.apache.sysml.runtime.controlprogram.LocalVariableMap; import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext; -import org.apache.sysml.runtime.controlprogram.paramserv.DataPartitionScheme; -import org.apache.sysml.runtime.controlprogram.paramserv.DataPartitioner; +import org.apache.sysml.runtime.controlprogram.paramserv.dp.DataPartitionLocalScheme; +import org.apache.sysml.runtime.controlprogram.paramserv.dp.LocalDataPartitioner; import org.apache.sysml.runtime.controlprogram.paramserv.LocalPSWorker; import org.apache.sysml.runtime.controlprogram.paramserv.LocalParamServer; import org.apache.sysml.runtime.controlprogram.paramserv.ParamServer; import org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils; -import org.apache.sysml.runtime.controlprogram.paramserv.spark.SparkPSBody; -import org.apache.sysml.runtime.controlprogram.paramserv.spark.SparkPSWorker; -import org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcFactory; +import org.apache.sysml.runtime.controlprogram.paramserv.SparkPSBody; +import org.apache.sysml.runtime.controlprogram.paramserv.SparkPSWorker; +import org.apache.sysml.runtime.controlprogram.paramserv.rpc.PSRpcFactory; import org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer; import org.apache.sysml.runtime.controlprogram.parfor.stat.Timing; import org.apache.sysml.runtime.matrix.operators.Operator; @@ -350,7 +350,7 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc switch (mode) { case LOCAL: case REMOTE_SPARK: - return new LocalParamServer(model, aggFunc, updateType, ec, workerNum); + return LocalParamServer.create(model, aggFunc, updateType, ec, workerNum); default: throw new DMLRuntimeException("Unsupported parameter server: "+mode.name()); } @@ -379,9 +379,7 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc private void partitionLocally(PSScheme scheme, ExecutionContext ec, List<LocalPSWorker> workers) { MatrixObject features = ec.getMatrixObject(getParam(PS_FEATURES)); MatrixObject labels = ec.getMatrixObject(getParam(PS_LABELS)); - DataPartitionScheme.Result result = new DataPartitioner(scheme).doPartitioning(workers.size(), features.acquireRead(), labels.acquireRead()); - features.release(); - labels.release(); + DataPartitionLocalScheme.Result result = new LocalDataPartitioner(scheme).doPartitioning(workers.size(), features.acquireReadAndRelease(), labels.acquireReadAndRelease()); List<MatrixObject> pfs = result.pFeatures; List<MatrixObject> pls = result.pLabels; if (pfs.size() < workers.size()) { http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/util/ProgramConverter.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/util/ProgramConverter.java b/src/main/java/org/apache/sysml/runtime/util/ProgramConverter.java index fc9d9b4..21e6bd3 100644 --- a/src/main/java/org/apache/sysml/runtime/util/ProgramConverter.java +++ b/src/main/java/org/apache/sysml/runtime/util/ProgramConverter.java @@ -69,7 +69,7 @@ import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; import org.apache.sysml.runtime.controlprogram.caching.MatrixObject.UpdateType; import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory; -import org.apache.sysml.runtime.controlprogram.paramserv.spark.SparkPSBody; +import org.apache.sysml.runtime.controlprogram.paramserv.SparkPSBody; import org.apache.sysml.runtime.controlprogram.parfor.ParForBody; import org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer; import org.apache.sysml.runtime.instructions.CPInstructionParser; http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/test/java/org/apache/sysml/test/integration/functions/paramserv/BaseDataPartitionerTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/BaseDataPartitionerTest.java b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/BaseDataPartitionerTest.java index 0092aed..2f39c91 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/BaseDataPartitionerTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/BaseDataPartitionerTest.java @@ -22,8 +22,8 @@ package org.apache.sysml.test.integration.functions.paramserv; import java.util.stream.IntStream; import org.apache.sysml.parser.Statement; -import org.apache.sysml.runtime.controlprogram.paramserv.DataPartitionScheme; -import org.apache.sysml.runtime.controlprogram.paramserv.DataPartitioner; +import org.apache.sysml.runtime.controlprogram.paramserv.dp.DataPartitionLocalScheme; +import org.apache.sysml.runtime.controlprogram.paramserv.dp.LocalDataPartitioner; import org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils; import org.apache.sysml.runtime.matrix.data.MatrixBlock; import org.apache.sysml.runtime.util.DataConverter; @@ -54,26 +54,26 @@ public abstract class BaseDataPartitionerTest { return IntStream.range(from, to).mapToDouble(i -> (double) i).toArray(); } - protected DataPartitionScheme.Result launchLocalDataPartitionerDC() { - DataPartitioner dp = new DataPartitioner(Statement.PSScheme.DISJOINT_CONTIGUOUS); + protected DataPartitionLocalScheme.Result launchLocalDataPartitionerDC() { + LocalDataPartitioner dp = new LocalDataPartitioner(Statement.PSScheme.DISJOINT_CONTIGUOUS); MatrixBlock[] mbs = generateData(); return dp.doPartitioning(WORKER_NUM, mbs[0], mbs[1]); } - protected DataPartitionScheme.Result launchLocalDataPartitionerDR(MatrixBlock[] mbs) { + protected DataPartitionLocalScheme.Result launchLocalDataPartitionerDR(MatrixBlock[] mbs) { ParamservUtils.SEED = System.nanoTime(); - DataPartitioner dp = new DataPartitioner(Statement.PSScheme.DISJOINT_RANDOM); + LocalDataPartitioner dp = new LocalDataPartitioner(Statement.PSScheme.DISJOINT_RANDOM); return dp.doPartitioning(WORKER_NUM, mbs[0], mbs[1]); } - protected DataPartitionScheme.Result launchLocalDataPartitionerDRR() { - DataPartitioner dp = new DataPartitioner(Statement.PSScheme.DISJOINT_ROUND_ROBIN); + protected DataPartitionLocalScheme.Result launchLocalDataPartitionerDRR() { + LocalDataPartitioner dp = new LocalDataPartitioner(Statement.PSScheme.DISJOINT_ROUND_ROBIN); MatrixBlock[] mbs = generateData(); return dp.doPartitioning(WORKER_NUM, mbs[0], mbs[1]); } - protected DataPartitionScheme.Result launchLocalDataPartitionerOR() { - DataPartitioner dp = new DataPartitioner(Statement.PSScheme.OVERLAP_RESHUFFLE); + protected DataPartitionLocalScheme.Result launchLocalDataPartitionerOR() { + LocalDataPartitioner dp = new LocalDataPartitioner(Statement.PSScheme.OVERLAP_RESHUFFLE); MatrixBlock[] mbs = generateData(); return dp.doPartitioning(WORKER_NUM, mbs[0], mbs[1]); }
