This is an automated email from the ASF dual-hosted git repository. echauchot pushed a commit to branch spark-runner_structured-streaming in repository https://gitbox.apache.org/repos/asf/beam.git
commit d5645ff60aa99608a9ee3b8a5be6c58f9ac3903b Author: Etienne Chauchot <[email protected]> AuthorDate: Mon Sep 2 15:45:24 2019 +0200 Fix code generation in Beam coder wrapper --- .../translation/helpers/EncoderHelpers.java | 93 ++++++++++++---------- .../structuredstreaming/utils/EncodersTest.java | 4 +- 2 files changed, 55 insertions(+), 42 deletions(-) diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpers.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpers.java index 0765c78..cc862cd 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpers.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpers.java @@ -42,15 +42,13 @@ import org.apache.spark.sql.catalyst.expressions.codegen.Block; import org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator; import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext; import org.apache.spark.sql.catalyst.expressions.codegen.ExprCode; -import org.apache.spark.sql.catalyst.expressions.codegen.ExprValue; -import org.apache.spark.sql.catalyst.expressions.codegen.SimpleExprValue; import org.apache.spark.sql.catalyst.expressions.codegen.VariableValue; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.ObjectType; +import scala.Function1; import scala.StringContext; import scala.Tuple2; import scala.collection.JavaConversions; -import scala.collection.Seq; import scala.reflect.ClassTag; import scala.reflect.ClassTag$; @@ -143,29 +141,33 @@ public class EncoderHelpers { @Override public ExprCode doGenCode(CodegenContext ctx, ExprCode ev) { // Code to serialize. ExprCode input = child.genCode(ctx); - String javaType = CodeGenerator.javaType(dataType()); - String outputStream = "ByteArrayOutputStream baos = new ByteArrayOutputStream();"; - - String serialize = outputStream + "$beamCoder.encode(${input.value}, baos); baos.toByteArray();"; - - String outside = "final $javaType output = ${input.isNull} ? ${CodeGenerator.defaultValue(dataType)} : $serialize;"; - List<String> instructions = new ArrayList<>(); - instructions.add(outside); - Seq<String> parts = JavaConversions.collectionAsScalaIterable(instructions).toSeq(); + /* + CODE GENERATED + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + final bytes[] output; + if ({input.isNull}) + output = null; + else + output = $beamCoder.encode(${input.value}, baos); baos.toByteArray(); + */ + List<String> parts = new ArrayList<>(); + parts.add("ByteArrayOutputStream baos = new ByteArrayOutputStream(); final bytes[] output; if ("); + parts.add(") output = null; else output ="); + parts.add(".encode("); + parts.add(", baos); baos.toByteArray();"); + + StringContext sc = new StringContext(JavaConversions.collectionAsScalaIterable(parts).toSeq()); - StringContext stringContext = new StringContext(parts); - Block.BlockHelper blockHelper = new Block.BlockHelper(stringContext); List<Object> args = new ArrayList<>(); - args.add(new VariableValue("beamCoder", Coder.class)); - args.add(new SimpleExprValue("input.value", ExprValue.class)); - args.add(new VariableValue("javaType", String.class)); - args.add(new SimpleExprValue("input.isNull", Boolean.class)); - args.add(new SimpleExprValue("CodeGenerator.defaultValue(dataType)", String.class)); - args.add(new VariableValue("serialize", String.class)); - Block code = blockHelper.code(JavaConversions.collectionAsScalaIterable(args).toSeq()); - - return ev.copy(input.code().$plus(code), input.isNull(), new VariableValue("output", Array.class)); + args.add(input.isNull()); + args.add(beamCoder); + args.add(input.value()); + Block code = (new Block.BlockHelper(sc)) + .code(JavaConversions.collectionAsScalaIterable(args).toSeq()); + + return ev.copy(input.code().$plus(code), input.isNull(), + new VariableValue("output", Array.class)); } @Override public DataType dataType() { @@ -252,27 +254,38 @@ public class EncoderHelpers { ExprCode input = child.genCode(ctx); String javaType = CodeGenerator.javaType(dataType()); - String inputStream = "ByteArrayInputStream bais = new ByteArrayInputStream(${input.value});"; - String deserialize = inputStream + "($javaType) $beamCoder.decode(bais);"; +/* + CODE GENERATED: + final $javaType output = + ${input.isNull} ? + ${CodeGenerator.defaultValue(dataType)} : + ($javaType) $beamCoder.decode(new ByteArrayInputStream(${input.value})); +*/ - String outside = "final $javaType output = ${input.isNull} ? ${CodeGenerator.defaultValue(dataType)} : $deserialize;"; + List<String> parts = new ArrayList<>(); + parts.add("final "); + parts.add(" output ="); + parts.add("?"); + parts.add(":"); + parts.add("("); + parts.add(") "); + parts.add(".decode(new ByteArrayInputStream("); + parts.add("));"); - List<String> instructions = new ArrayList<>(); - instructions.add(outside); - Seq<String> parts = JavaConversions.collectionAsScalaIterable(instructions).toSeq(); + StringContext sc = new StringContext(JavaConversions.collectionAsScalaIterable(parts).toSeq()); - StringContext stringContext = new StringContext(parts); - Block.BlockHelper blockHelper = new Block.BlockHelper(stringContext); List<Object> args = new ArrayList<>(); - args.add(new SimpleExprValue("input.value", ExprValue.class)); - args.add(new VariableValue("javaType", String.class)); - args.add(new VariableValue("beamCoder", Coder.class)); - args.add(new SimpleExprValue("input.isNull", Boolean.class)); - args.add(new SimpleExprValue("CodeGenerator.defaultValue(dataType)", String.class)); - args.add(new VariableValue("deserialize", String.class)); - Block code = blockHelper.code(JavaConversions.collectionAsScalaIterable(args).toSeq()); - - return ev.copy(input.code().$plus(code), input.isNull(), new VariableValue("output", classTag.runtimeClass())); + args.add(javaType); + args.add(input.isNull()); + args.add(CodeGenerator.defaultValue(dataType(), false)); + args.add(javaType); + args.add(beamCoder); + args.add(input.value()); + Block code = (new Block.BlockHelper(sc)) + .code(JavaConversions.collectionAsScalaIterable(args).toSeq()); + + return ev.copy(input.code().$plus(code), input.isNull(), + new VariableValue("output", classTag.runtimeClass())); } diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/structuredstreaming/utils/EncodersTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/structuredstreaming/utils/EncodersTest.java index 490e3dc..7078b0c 100644 --- a/runners/spark/src/test/java/org/apache/beam/runners/spark/structuredstreaming/utils/EncodersTest.java +++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/structuredstreaming/utils/EncodersTest.java @@ -23,7 +23,7 @@ public class EncodersTest { data.add(1); data.add(2); data.add(3); -// sparkSession.createDataset(data, EncoderHelpers.fromBeamCoder(VarIntCoder.of())); - sparkSession.createDataset(data, EncoderHelpers.genericEncoder()); + sparkSession.createDataset(data, EncoderHelpers.fromBeamCoder(VarIntCoder.of())); +// sparkSession.createDataset(data, EncoderHelpers.genericEncoder()); } }
