You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by ec...@apache.org on 2019/09/05 14:47:55 UTC

[beam] 02/24: Wrap Beam Coders into Spark Encoders using ExpressionEncoder: deserialization part

This is an automated email from the ASF dual-hosted git repository.

echauchot pushed a commit to branch spark-runner_structured-streaming
in repository https://gitbox.apache.org/repos/asf/beam.git

commit d613d6bc28d63df6631367270f8474caf6f24059
Author: Etienne Chauchot <ec...@apache.org>
AuthorDate: Mon Aug 26 15:22:12 2019 +0200

    Wrap Beam Coders into Spark Encoders using ExpressionEncoder: deserialization part
---
 .../translation/helpers/EncoderHelpers.java        | 61 ++++++++++++++++++----
 1 file changed, 50 insertions(+), 11 deletions(-)

diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpers.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpers.java
index b072803..3159de9b 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpers.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpers.java
@@ -19,6 +19,8 @@ package org.apache.beam.runners.spark.structuredstreaming.translation.helpers;
 
 import static org.apache.spark.sql.types.DataTypes.BinaryType;
 
+import java.io.ByteArrayInputStream;
+import java.io.IOException;
 import java.lang.reflect.Array;
 import java.util.ArrayList;
 import java.util.List;
@@ -103,7 +105,7 @@ public class EncoderHelpers {
         SchemaHelpers.binarySchema(),
         false,
         JavaConversions.collectionAsScalaIterable(serialiserList).toSeq(),
-        new DecodeUsingBeamCoder<>(classTag, coder), classTag);
+        new DecodeUsingBeamCoder<>(claz, coder), classTag);
 
 /*
     ExpressionEncoder[T](
@@ -150,8 +152,8 @@ public class EncoderHelpers {
 
       List<String> instructions = new ArrayList<>();
       instructions.add(outside);
-
       Seq<String> parts = JavaConversions.collectionAsScalaIterable(instructions).toSeq();
+
       StringContext stringContext = new StringContext(parts);
       Block.BlockHelper blockHelper = new Block.BlockHelper(stringContext);
       List<Object> args = new ArrayList<>();
@@ -160,7 +162,7 @@ public class EncoderHelpers {
       args.add(new VariableValue("javaType", String.class));
       args.add(new SimpleExprValue("input.isNull", Boolean.class));
       args.add(new SimpleExprValue("CodeGenerator.defaultValue(dataType)", String.class));
-      args.add(new VariableValue("$serialize", String.class));
+      args.add(new VariableValue("serialize", String.class));
       Block code = blockHelper.code(JavaConversions.collectionAsScalaIterable(args).toSeq());
 
       return ev.copy(input.code().$plus(code), input.isNull(), new VariableValue("output", Array.class));
@@ -229,24 +231,61 @@ public class EncoderHelpers {
 
   private static class DecodeUsingBeamCoder<T> extends UnaryExpression implements  NonSQLExpression{
 
-    private ClassTag<T> classTag;
+    private Class<T> claz;
     private Coder<T> beamCoder;
+    private Expression child;
 
-    private DecodeUsingBeamCoder(ClassTag<T> classTag, Coder<T> beamCoder) {
-      this.classTag = classTag;
+    private DecodeUsingBeamCoder(Class<T> claz, Coder<T> beamCoder) {
+      this.claz = claz;
       this.beamCoder = beamCoder;
+      this.child = new Cast(new GetColumnByOrdinal(0, BinaryType), BinaryType);
     }
 
     @Override public Expression child() {
-      return new Cast(new GetColumnByOrdinal(0, BinaryType), BinaryType);
+      return child;
     }
 
     @Override public ExprCode doGenCode(CodegenContext ctx, ExprCode ev) {
-      return null;
+      // Code to deserialize.
+      ExprCode input = child.genCode(ctx);
+      String javaType = CodeGenerator.javaType(dataType());
+
+      String inputStream = "ByteArrayInputStream bais = new ByteArrayInputStream(${input.value});";
+      String deserialize = inputStream + "($javaType) $beamCoder.decode(bais);";
+
+      String outside = "final $javaType output = ${input.isNull} ? ${CodeGenerator.defaultValue(dataType)} : $deserialize;";
+
+      List<String> instructions = new ArrayList<>();
+      instructions.add(outside);
+      Seq<String> parts = JavaConversions.collectionAsScalaIterable(instructions).toSeq();
+
+      StringContext stringContext = new StringContext(parts);
+      Block.BlockHelper blockHelper = new Block.BlockHelper(stringContext);
+      List<Object> args = new ArrayList<>();
+      args.add(new SimpleExprValue("input.value", ExprValue.class));
+      args.add(new VariableValue("javaType", String.class));
+      args.add(new VariableValue("beamCoder", Coder.class));
+      args.add(new SimpleExprValue("input.isNull", Boolean.class));
+      args.add(new SimpleExprValue("CodeGenerator.defaultValue(dataType)", String.class));
+      args.add(new VariableValue("deserialize", String.class));
+      Block code = blockHelper.code(JavaConversions.collectionAsScalaIterable(args).toSeq());
+
+      return ev.copy(input.code().$plus(code), input.isNull(), new VariableValue("output", claz));
+
+    }
+
+    @Override public Object nullSafeEval(Object input) {
+      try {
+        return beamCoder.decode(new ByteArrayInputStream((byte[]) input));
+      } catch (IOException e) {
+        throw new IllegalStateException("Error decoding bytes for coder: " + beamCoder, e);
+      }
     }
 
     @Override public DataType dataType() {
-      return new ObjectType(classTag.runtimeClass());
+//      return new ObjectType(classTag.runtimeClass());
+      //TODO does type erasure impose to use classTag.runtimeClass() ?
+      return new ObjectType(claz);
     }
 
     @Override public Object productElement(int n) {
@@ -274,11 +313,11 @@ public class EncoderHelpers {
         return false;
       }
       DecodeUsingBeamCoder<?> that = (DecodeUsingBeamCoder<?>) o;
-      return classTag.equals(that.classTag) && beamCoder.equals(that.beamCoder);
+      return claz.equals(that.claz) && beamCoder.equals(that.beamCoder);
     }
 
     @Override public int hashCode() {
-      return Objects.hash(super.hashCode(), classTag, beamCoder);
+      return Objects.hash(super.hashCode(), claz, beamCoder);
     }
   }
 /*