You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by ec...@apache.org on 2019/09/05 14:48:03 UTC

[beam] 10/24: Lazy init coder because coder instance cannot be interpolated by catalyst

This is an automated email from the ASF dual-hosted git repository.

echauchot pushed a commit to branch spark-runner_structured-streaming
in repository https://gitbox.apache.org/repos/asf/beam.git

commit a758985ad24183bbdba3c11ea51b7020bc0865cf
Author: Etienne Chauchot <ec...@apache.org>
AuthorDate: Mon Sep 2 17:55:24 2019 +0200

    Lazy init coder because coder instance cannot be interpolated by catalyst
---
 .../translation/helpers/EncoderHelpers.java        | 61 +++++++++++++++-------
 .../structuredstreaming/utils/EncodersTest.java    |  3 +-
 2 files changed, 45 insertions(+), 19 deletions(-)

diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpers.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpers.java
index cc862cd..f7706cc 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpers.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpers.java
@@ -45,7 +45,6 @@ import org.apache.spark.sql.catalyst.expressions.codegen.ExprCode;
 import org.apache.spark.sql.catalyst.expressions.codegen.VariableValue;
 import org.apache.spark.sql.types.DataType;
 import org.apache.spark.sql.types.ObjectType;
-import scala.Function1;
 import scala.StringContext;
 import scala.Tuple2;
 import scala.collection.JavaConversions;
@@ -94,17 +93,17 @@ public class EncoderHelpers {
   */
 
   /** A way to construct encoders using generic serializers. */
-  public static <T> Encoder<T> fromBeamCoder(Coder<T> coder/*, Class<T> claz*/){
+  public static <T> Encoder<T> fromBeamCoder(Class<? extends Coder<T>> coderClass/*, Class<T> claz*/){
 
     List<Expression> serialiserList = new ArrayList<>();
     Class<T> claz = (Class<T>) Object.class;
-    serialiserList.add(new EncodeUsingBeamCoder<>(new BoundReference(0, new ObjectType(claz), true), coder));
+    serialiserList.add(new EncodeUsingBeamCoder<>(new BoundReference(0, new ObjectType(claz), true), (Class<Coder<T>>)coderClass));
     ClassTag<T> classTag = ClassTag$.MODULE$.apply(claz);
     return new ExpressionEncoder<>(
         SchemaHelpers.binarySchema(),
         false,
         JavaConversions.collectionAsScalaIterable(serialiserList).toSeq(),
-        new DecodeUsingBeamCoder<>(new Cast(new GetColumnByOrdinal(0, BinaryType), BinaryType), classTag, coder),
+        new DecodeUsingBeamCoder<>(new Cast(new GetColumnByOrdinal(0, BinaryType), BinaryType), classTag, (Class<Coder<T>>)coderClass),
         classTag);
 
 /*
@@ -127,11 +126,11 @@ public class EncoderHelpers {
   public static class EncodeUsingBeamCoder<T> extends UnaryExpression implements NonSQLExpression {
 
     private Expression child;
-    private Coder<T> beamCoder;
+    private Class<Coder<T>> coderClass;
 
-    public EncodeUsingBeamCoder(Expression child, Coder<T> beamCoder) {
+    public EncodeUsingBeamCoder(Expression child, Class<Coder<T>> coderClass) {
       this.child = child;
-      this.beamCoder = beamCoder;
+      this.coderClass = coderClass;
     }
 
     @Override public Expression child() {
@@ -140,6 +139,7 @@ public class EncoderHelpers {
 
     @Override public ExprCode doGenCode(CodegenContext ctx, ExprCode ev) {
       // Code to serialize.
+      String beamCoder = lazyInitBeamCoder(ctx, coderClass);
       ExprCode input = child.genCode(ctx);
 
       /*
@@ -170,6 +170,7 @@ public class EncoderHelpers {
           new VariableValue("output", Array.class));
     }
 
+
     @Override public DataType dataType() {
       return BinaryType;
     }
@@ -179,7 +180,7 @@ public class EncoderHelpers {
         case 0:
           return child;
         case 1:
-          return beamCoder;
+          return coderClass;
         default:
           throw new ArrayIndexOutOfBoundsException("productElement out of bounds");
       }
@@ -201,11 +202,11 @@ public class EncoderHelpers {
         return false;
       }
       EncodeUsingBeamCoder<?> that = (EncodeUsingBeamCoder<?>) o;
-      return beamCoder.equals(that.beamCoder);
+      return coderClass.equals(that.coderClass);
     }
 
     @Override public int hashCode() {
-      return Objects.hash(super.hashCode(), beamCoder);
+      return Objects.hash(super.hashCode(), coderClass);
     }
   }
 
@@ -237,12 +238,12 @@ public class EncoderHelpers {
 
     private Expression child;
     private ClassTag<T> classTag;
-    private Coder<T> beamCoder;
+    private Class<Coder<T>> coderClass;
 
-    public DecodeUsingBeamCoder(Expression child, ClassTag<T> classTag, Coder<T> beamCoder) {
+    public DecodeUsingBeamCoder(Expression child, ClassTag<T> classTag, Class<Coder<T>> coderClass) {
       this.child = child;
       this.classTag = classTag;
-      this.beamCoder = beamCoder;
+      this.coderClass = coderClass;
     }
 
     @Override public Expression child() {
@@ -251,6 +252,7 @@ public class EncoderHelpers {
 
     @Override public ExprCode doGenCode(CodegenContext ctx, ExprCode ev) {
       // Code to deserialize.
+      String beamCoder = lazyInitBeamCoder(ctx, coderClass);
       ExprCode input = child.genCode(ctx);
       String javaType = CodeGenerator.javaType(dataType());
 
@@ -291,9 +293,10 @@ public class EncoderHelpers {
 
     @Override public Object nullSafeEval(Object input) {
       try {
+        Coder<T> beamCoder = coderClass.newInstance();
         return beamCoder.decode(new ByteArrayInputStream((byte[]) input));
-      } catch (IOException e) {
-        throw new IllegalStateException("Error decoding bytes for coder: " + beamCoder, e);
+      } catch (Exception e) {
+        throw new IllegalStateException("Error decoding bytes for coder: " + coderClass, e);
       }
     }
 
@@ -308,7 +311,7 @@ public class EncoderHelpers {
         case 1:
           return classTag;
         case 2:
-          return beamCoder;
+          return coderClass;
         default:
           throw new ArrayIndexOutOfBoundsException("productElement out of bounds");
       }
@@ -330,11 +333,11 @@ public class EncoderHelpers {
         return false;
       }
       DecodeUsingBeamCoder<?> that = (DecodeUsingBeamCoder<?>) o;
-      return classTag.equals(that.classTag) && beamCoder.equals(that.beamCoder);
+      return classTag.equals(that.classTag) && coderClass.equals(that.coderClass);
     }
 
     @Override public int hashCode() {
-      return Objects.hash(super.hashCode(), classTag, beamCoder);
+      return Objects.hash(super.hashCode(), classTag, coderClass);
     }
   }
 /*
@@ -365,4 +368,26 @@ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: B
   }
 */
 
+  private static <T> String lazyInitBeamCoder(CodegenContext ctx, Class<Coder<T>> coderClass) {
+    String beamCoderInstance = "beamCoder";
+    ctx.addImmutableStateIfNotExists(coderClass.getName(), beamCoderInstance, v -> {
+      /*
+    CODE GENERATED
+    v = (coderClass) coderClass.newInstance();
+     */
+      List<String> parts = new ArrayList<>();
+      parts.add("");
+      parts.add(" = (");
+      parts.add(") ");
+      parts.add(".newInstance();");
+      StringContext sc = new StringContext(JavaConversions.collectionAsScalaIterable(parts).toSeq());
+      List<Object> args = new ArrayList<>();
+      args.add(v);
+      args.add(coderClass.getName());
+      args.add(coderClass.getName());
+      return sc.s(JavaConversions.collectionAsScalaIterable(args).toSeq());
+    });
+    return beamCoderInstance;
+  }
+
 }
diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/structuredstreaming/utils/EncodersTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/structuredstreaming/utils/EncodersTest.java
index 7078b0c..0e38fe1 100644
--- a/runners/spark/src/test/java/org/apache/beam/runners/spark/structuredstreaming/utils/EncodersTest.java
+++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/structuredstreaming/utils/EncodersTest.java
@@ -3,6 +3,7 @@ package org.apache.beam.runners.spark.structuredstreaming.utils;
 import java.util.ArrayList;
 import java.util.List;
 import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers;
+import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.coders.VarIntCoder;
 import org.apache.spark.sql.SparkSession;
 import org.junit.Test;
@@ -23,7 +24,7 @@ public class EncodersTest {
     data.add(1);
     data.add(2);
     data.add(3);
-    sparkSession.createDataset(data, EncoderHelpers.fromBeamCoder(VarIntCoder.of()));
+    sparkSession.createDataset(data, EncoderHelpers.fromBeamCoder(VarIntCoder.class));
 //    sparkSession.createDataset(data, EncoderHelpers.genericEncoder());
   }
 }