This is an automated email from the ASF dual-hosted git repository. echauchot pushed a commit to branch spark-runner_structured-streaming in repository https://gitbox.apache.org/repos/asf/beam.git
commit 2aaf07a41155f35ab36bda4c3c02a7ffa7bd66db Author: Etienne Chauchot <[email protected]> AuthorDate: Thu Aug 29 15:10:40 2019 +0200 Conform to spark ExpressionEncoders: pass classTags, implement scala Product, pass children from within the ExpressionEncoder, fix visibilities --- .../translation/helpers/EncoderHelpers.java | 64 +++++++++++++--------- 1 file changed, 38 insertions(+), 26 deletions(-) diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpers.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpers.java index 8a4f1de..0765c78 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpers.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpers.java @@ -100,13 +100,13 @@ public class EncoderHelpers { List<Expression> serialiserList = new ArrayList<>(); Class<T> claz = (Class<T>) Object.class; - serialiserList.add(new EncodeUsingBeamCoder<>(claz, coder)); + serialiserList.add(new EncodeUsingBeamCoder<>(new BoundReference(0, new ObjectType(claz), true), coder)); ClassTag<T> classTag = ClassTag$.MODULE$.apply(claz); return new ExpressionEncoder<>( SchemaHelpers.binarySchema(), false, JavaConversions.collectionAsScalaIterable(serialiserList).toSeq(), - new DecodeUsingBeamCoder<>(claz, coder), + new DecodeUsingBeamCoder<>(new Cast(new GetColumnByOrdinal(0, BinaryType), BinaryType), classTag, coder), classTag); /* @@ -126,16 +126,14 @@ public class EncoderHelpers { */ } - private static class EncodeUsingBeamCoder<T> extends UnaryExpression implements NonSQLExpression { + public static class EncodeUsingBeamCoder<T> extends UnaryExpression implements NonSQLExpression { - private Class<T> claz; - private Coder<T> beamCoder; private Expression child; + private Coder<T> beamCoder; - private EncodeUsingBeamCoder( Class<T> claz, Coder<T> beamCoder) { - this.claz = claz; + public EncodeUsingBeamCoder(Expression child, Coder<T> beamCoder) { + this.child = child; this.beamCoder = beamCoder; - this.child = new BoundReference(0, new ObjectType(claz), true); } @Override public Expression child() { @@ -175,11 +173,18 @@ public class EncoderHelpers { } @Override public Object productElement(int n) { - return null; + switch (n) { + case 0: + return child; + case 1: + return beamCoder; + default: + throw new ArrayIndexOutOfBoundsException("productElement out of bounds"); + } } @Override public int productArity() { - return 0; + return 2; } @Override public boolean canEqual(Object that) { @@ -194,11 +199,11 @@ public class EncoderHelpers { return false; } EncodeUsingBeamCoder<?> that = (EncodeUsingBeamCoder<?>) o; - return claz.equals(that.claz) && beamCoder.equals(that.beamCoder); + return beamCoder.equals(that.beamCoder); } @Override public int hashCode() { - return Objects.hash(super.hashCode(), claz, beamCoder); + return Objects.hash(super.hashCode(), beamCoder); } } @@ -226,16 +231,16 @@ public class EncoderHelpers { override def dataType: DataType = BinaryType }*/ - private static class DecodeUsingBeamCoder<T> extends UnaryExpression implements NonSQLExpression{ + public static class DecodeUsingBeamCoder<T> extends UnaryExpression implements NonSQLExpression{ - private Class<T> claz; - private Coder<T> beamCoder; private Expression child; + private ClassTag<T> classTag; + private Coder<T> beamCoder; - private DecodeUsingBeamCoder(Class<T> claz, Coder<T> beamCoder) { - this.claz = claz; + public DecodeUsingBeamCoder(Expression child, ClassTag<T> classTag, Coder<T> beamCoder) { + this.child = child; + this.classTag = classTag; this.beamCoder = beamCoder; - this.child = new Cast(new GetColumnByOrdinal(0, BinaryType), BinaryType); } @Override public Expression child() { @@ -267,7 +272,7 @@ public class EncoderHelpers { args.add(new VariableValue("deserialize", String.class)); Block code = blockHelper.code(JavaConversions.collectionAsScalaIterable(args).toSeq()); - return ev.copy(input.code().$plus(code), input.isNull(), new VariableValue("output", claz)); + return ev.copy(input.code().$plus(code), input.isNull(), new VariableValue("output", classTag.runtimeClass())); } @@ -280,17 +285,24 @@ public class EncoderHelpers { } @Override public DataType dataType() { -// return new ObjectType(classTag.runtimeClass()); - //TODO does type erasure impose to use classTag.runtimeClass() ? - return new ObjectType(claz); + return new ObjectType(classTag.runtimeClass()); } @Override public Object productElement(int n) { - return null; + switch (n) { + case 0: + return child; + case 1: + return classTag; + case 2: + return beamCoder; + default: + throw new ArrayIndexOutOfBoundsException("productElement out of bounds"); + } } @Override public int productArity() { - return 0; + return 3; } @Override public boolean canEqual(Object that) { @@ -305,11 +317,11 @@ public class EncoderHelpers { return false; } DecodeUsingBeamCoder<?> that = (DecodeUsingBeamCoder<?>) o; - return claz.equals(that.claz) && beamCoder.equals(that.beamCoder); + return classTag.equals(that.classTag) && beamCoder.equals(that.beamCoder); } @Override public int hashCode() { - return Objects.hash(super.hashCode(), claz, beamCoder); + return Objects.hash(super.hashCode(), classTag, beamCoder); } } /*
