refine transpose_and_dot
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/abbf5492 Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/abbf5492 Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/abbf5492 Branch: refs/heads/JIRA-22/pr-385 Commit: abbf5492b95dd69e347580c59ac044a78627c547 Parents: a16a3fd Author: amaya <g...@sapphire.in.net> Authored: Wed Sep 21 13:11:00 2016 +0900 Committer: amaya <g...@sapphire.in.net> Committed: Wed Sep 21 13:40:54 2016 +0900 ---------------------------------------------------------------------- .../tools/matrix/TransposeAndDotUDAF.java | 32 +++++++++++--------- 1 file changed, 18 insertions(+), 14 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/abbf5492/core/src/main/java/hivemall/tools/matrix/TransposeAndDotUDAF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/tools/matrix/TransposeAndDotUDAF.java b/core/src/main/java/hivemall/tools/matrix/TransposeAndDotUDAF.java index 1e54004..9d68f93 100644 --- a/core/src/main/java/hivemall/tools/matrix/TransposeAndDotUDAF.java +++ b/core/src/main/java/hivemall/tools/matrix/TransposeAndDotUDAF.java @@ -127,33 +127,37 @@ public final class TransposeAndDotUDAF extends AbstractGenericUDAFResolver { @Override public AbstractAggregationBuffer getNewAggregationBuffer() throws HiveException { - TransposeAndDotAggregationBuffer myAgg = new TransposeAndDotAggregationBuffer(); + final TransposeAndDotAggregationBuffer myAgg = new TransposeAndDotAggregationBuffer(); reset(myAgg); return myAgg; } @Override public void reset(AggregationBuffer agg) throws HiveException { - TransposeAndDotAggregationBuffer myAgg = (TransposeAndDotAggregationBuffer) agg; + final TransposeAndDotAggregationBuffer myAgg = (TransposeAndDotAggregationBuffer) agg; myAgg.reset(); } @Override public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException { - TransposeAndDotAggregationBuffer myAgg = (TransposeAndDotAggregationBuffer) agg; + final Object matrix0RowObj = parameters[0]; + final Object matrix1RowObj = parameters[1]; + Preconditions.checkNotNull(matrix0RowObj); + Preconditions.checkNotNull(matrix1RowObj); + + final TransposeAndDotAggregationBuffer myAgg = (TransposeAndDotAggregationBuffer) agg; + + // init if (matrix0Row == null) { - matrix0Row = new double[matrix0RowOI.getListLength(parameters[0])]; + matrix0Row = new double[matrix0RowOI.getListLength(matrix0RowObj)]; } if (matrix1Row == null) { - matrix1Row = new double[matrix1RowOI.getListLength(parameters[1])]; + matrix1Row = new double[matrix1RowOI.getListLength(matrix1RowObj)]; } - HiveUtils.toDoubleArray(parameters[0], matrix0RowOI, matrix0ElOI, matrix0Row, false); - HiveUtils.toDoubleArray(parameters[1], matrix1RowOI, matrix1ElOI, matrix1Row, false); - - Preconditions.checkNotNull(matrix0Row); - Preconditions.checkNotNull(matrix1Row); + HiveUtils.toDoubleArray(matrix0RowObj, matrix0RowOI, matrix0ElOI, matrix0Row, false); + HiveUtils.toDoubleArray(matrix1RowObj, matrix1RowOI, matrix1ElOI, matrix1Row, false); if (myAgg.aggMatrix == null) { myAgg.init(matrix0Row.length, matrix1Row.length); @@ -172,9 +176,9 @@ public final class TransposeAndDotUDAF extends AbstractGenericUDAFResolver { return; } - TransposeAndDotAggregationBuffer myAgg = (TransposeAndDotAggregationBuffer) agg; + final TransposeAndDotAggregationBuffer myAgg = (TransposeAndDotAggregationBuffer) agg; - List matrix = aggMatrixOI.getList(other); + final List matrix = aggMatrixOI.getList(other); final int n = matrix.size(); final double[] row = new double[aggMatrixRowOI.getListLength(matrix.get(0))]; for (int i = 0; i < n; i++) { @@ -197,9 +201,9 @@ public final class TransposeAndDotUDAF extends AbstractGenericUDAFResolver { @Override public Object terminate(AggregationBuffer agg) throws HiveException { - TransposeAndDotAggregationBuffer myAgg = (TransposeAndDotAggregationBuffer) agg; + final TransposeAndDotAggregationBuffer myAgg = (TransposeAndDotAggregationBuffer) agg; - List<List<DoubleWritable>> result = new ArrayList<List<DoubleWritable>>(); + final List<List<DoubleWritable>> result = new ArrayList<List<DoubleWritable>>(); for (double[] row : myAgg.aggMatrix) { result.add(WritableUtils.toWritableList(row)); }