This is an automated email from the ASF dual-hosted git repository. echauchot pushed a commit to branch spark-runner_structured-streaming in repository https://gitbox.apache.org/repos/asf/beam.git
commit e6b68a8f21aba2adcb7543eae806d71e08c0bff3 Author: Etienne Chauchot <[email protected]> AuthorDate: Mon Sep 2 17:55:24 2019 +0200 Lazy init coder because coder instance cannot be interpolated by catalyst --- runners/spark/build.gradle | 1 + .../translation/helpers/EncoderHelpers.java | 63 +++++++++++++++------- .../structuredstreaming/utils/EncodersTest.java | 3 +- 3 files changed, 47 insertions(+), 20 deletions(-) diff --git a/runners/spark/build.gradle b/runners/spark/build.gradle index 73a710b..a948ef1 100644 --- a/runners/spark/build.gradle +++ b/runners/spark/build.gradle @@ -77,6 +77,7 @@ dependencies { provided "com.esotericsoftware.kryo:kryo:2.21" runtimeOnly library.java.jackson_module_scala runtimeOnly "org.scala-lang:scala-library:2.11.8" + compile "org.scala-lang.modules:scala-java8-compat_2.11:0.9.0" testCompile project(":sdks:java:io:kafka") testCompile project(path: ":sdks:java:core", configuration: "shadowTest") // SparkStateInternalsTest extends abstract StateInternalsTest diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpers.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpers.java index cc862cd..694bc24 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpers.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpers.java @@ -18,9 +18,9 @@ package org.apache.beam.runners.spark.structuredstreaming.translation.helpers; import static org.apache.spark.sql.types.DataTypes.BinaryType; +import static scala.compat.java8.JFunction.func; import java.io.ByteArrayInputStream; -import java.io.IOException; import java.lang.reflect.Array; import java.util.ArrayList; import java.util.List; @@ -45,7 +45,6 @@ import org.apache.spark.sql.catalyst.expressions.codegen.ExprCode; import org.apache.spark.sql.catalyst.expressions.codegen.VariableValue; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.ObjectType; -import scala.Function1; import scala.StringContext; import scala.Tuple2; import scala.collection.JavaConversions; @@ -94,17 +93,17 @@ public class EncoderHelpers { */ /** A way to construct encoders using generic serializers. */ - public static <T> Encoder<T> fromBeamCoder(Coder<T> coder/*, Class<T> claz*/){ + public static <T> Encoder<T> fromBeamCoder(Class<? extends Coder<T>> coderClass/*, Class<T> claz*/){ List<Expression> serialiserList = new ArrayList<>(); Class<T> claz = (Class<T>) Object.class; - serialiserList.add(new EncodeUsingBeamCoder<>(new BoundReference(0, new ObjectType(claz), true), coder)); + serialiserList.add(new EncodeUsingBeamCoder<>(new BoundReference(0, new ObjectType(claz), true), (Class<Coder<T>>)coderClass)); ClassTag<T> classTag = ClassTag$.MODULE$.apply(claz); return new ExpressionEncoder<>( SchemaHelpers.binarySchema(), false, JavaConversions.collectionAsScalaIterable(serialiserList).toSeq(), - new DecodeUsingBeamCoder<>(new Cast(new GetColumnByOrdinal(0, BinaryType), BinaryType), classTag, coder), + new DecodeUsingBeamCoder<>(new Cast(new GetColumnByOrdinal(0, BinaryType), BinaryType), classTag, (Class<Coder<T>>)coderClass), classTag); /* @@ -127,11 +126,11 @@ public class EncoderHelpers { public static class EncodeUsingBeamCoder<T> extends UnaryExpression implements NonSQLExpression { private Expression child; - private Coder<T> beamCoder; + private Class<Coder<T>> coderClass; - public EncodeUsingBeamCoder(Expression child, Coder<T> beamCoder) { + public EncodeUsingBeamCoder(Expression child, Class<Coder<T>> coderClass) { this.child = child; - this.beamCoder = beamCoder; + this.coderClass = coderClass; } @Override public Expression child() { @@ -140,6 +139,7 @@ public class EncoderHelpers { @Override public ExprCode doGenCode(CodegenContext ctx, ExprCode ev) { // Code to serialize. + String beamCoder = lazyInitBeamCoder(ctx, coderClass); ExprCode input = child.genCode(ctx); /* @@ -170,6 +170,7 @@ public class EncoderHelpers { new VariableValue("output", Array.class)); } + @Override public DataType dataType() { return BinaryType; } @@ -179,7 +180,7 @@ public class EncoderHelpers { case 0: return child; case 1: - return beamCoder; + return coderClass; default: throw new ArrayIndexOutOfBoundsException("productElement out of bounds"); } @@ -201,11 +202,11 @@ public class EncoderHelpers { return false; } EncodeUsingBeamCoder<?> that = (EncodeUsingBeamCoder<?>) o; - return beamCoder.equals(that.beamCoder); + return coderClass.equals(that.coderClass); } @Override public int hashCode() { - return Objects.hash(super.hashCode(), beamCoder); + return Objects.hash(super.hashCode(), coderClass); } } @@ -237,12 +238,12 @@ public class EncoderHelpers { private Expression child; private ClassTag<T> classTag; - private Coder<T> beamCoder; + private Class<Coder<T>> coderClass; - public DecodeUsingBeamCoder(Expression child, ClassTag<T> classTag, Coder<T> beamCoder) { + public DecodeUsingBeamCoder(Expression child, ClassTag<T> classTag, Class<Coder<T>> coderClass) { this.child = child; this.classTag = classTag; - this.beamCoder = beamCoder; + this.coderClass = coderClass; } @Override public Expression child() { @@ -251,6 +252,7 @@ public class EncoderHelpers { @Override public ExprCode doGenCode(CodegenContext ctx, ExprCode ev) { // Code to deserialize. + String beamCoder = lazyInitBeamCoder(ctx, coderClass); ExprCode input = child.genCode(ctx); String javaType = CodeGenerator.javaType(dataType()); @@ -291,9 +293,10 @@ public class EncoderHelpers { @Override public Object nullSafeEval(Object input) { try { + Coder<T> beamCoder = coderClass.newInstance(); return beamCoder.decode(new ByteArrayInputStream((byte[]) input)); - } catch (IOException e) { - throw new IllegalStateException("Error decoding bytes for coder: " + beamCoder, e); + } catch (Exception e) { + throw new IllegalStateException("Error decoding bytes for coder: " + coderClass, e); } } @@ -308,7 +311,7 @@ public class EncoderHelpers { case 1: return classTag; case 2: - return beamCoder; + return coderClass; default: throw new ArrayIndexOutOfBoundsException("productElement out of bounds"); } @@ -330,11 +333,11 @@ public class EncoderHelpers { return false; } DecodeUsingBeamCoder<?> that = (DecodeUsingBeamCoder<?>) o; - return classTag.equals(that.classTag) && beamCoder.equals(that.beamCoder); + return classTag.equals(that.classTag) && coderClass.equals(that.coderClass); } @Override public int hashCode() { - return Objects.hash(super.hashCode(), classTag, beamCoder); + return Objects.hash(super.hashCode(), classTag, coderClass); } } /* @@ -365,4 +368,26 @@ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: B } */ + private static <T> String lazyInitBeamCoder(CodegenContext ctx, Class<Coder<T>> coderClass) { + String beamCoderInstance = "beamCoder"; + ctx.addImmutableStateIfNotExists(coderClass.getName(), beamCoderInstance, func(v1 -> { + /* + CODE GENERATED + v = (coderClass) coderClass.newInstance(); + */ + List<String> parts = new ArrayList<>(); + parts.add(""); + parts.add(" = ("); + parts.add(") "); + parts.add(".newInstance();"); + StringContext sc = new StringContext(JavaConversions.collectionAsScalaIterable(parts).toSeq()); + List<Object> args = new ArrayList<>(); + args.add(v1); + args.add(coderClass.getName()); + args.add(coderClass.getName()); + return sc.s(JavaConversions.collectionAsScalaIterable(args).toSeq()); + })); + return beamCoderInstance; + } + } diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/structuredstreaming/utils/EncodersTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/structuredstreaming/utils/EncodersTest.java index 7078b0c..0e38fe1 100644 --- a/runners/spark/src/test/java/org/apache/beam/runners/spark/structuredstreaming/utils/EncodersTest.java +++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/structuredstreaming/utils/EncodersTest.java @@ -3,6 +3,7 @@ package org.apache.beam.runners.spark.structuredstreaming.utils; import java.util.ArrayList; import java.util.List; import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers; +import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.VarIntCoder; import org.apache.spark.sql.SparkSession; import org.junit.Test; @@ -23,7 +24,7 @@ public class EncodersTest { data.add(1); data.add(2); data.add(3); - sparkSession.createDataset(data, EncoderHelpers.fromBeamCoder(VarIntCoder.of())); + sparkSession.createDataset(data, EncoderHelpers.fromBeamCoder(VarIntCoder.class)); // sparkSession.createDataset(data, EncoderHelpers.genericEncoder()); } }
