You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by ec...@apache.org on 2019/10/24 10:08:27 UTC
[beam] 03/37: Wrap Beam Coders into Spark Encoders using
ExpressionEncoder: serialization part
This is an automated email from the ASF dual-hosted git repository.
echauchot pushed a commit to branch spark-runner_structured-streaming
in repository https://gitbox.apache.org/repos/asf/beam.git
commit a5c7da32d46d74ab4b79ebb34dcad4842f225c62
Author: Etienne Chauchot <ec...@apache.org>
AuthorDate: Mon Aug 26 14:32:17 2019 +0200
Wrap Beam Coders into Spark Encoders using ExpressionEncoder: serialization part
---
.../translation/helpers/EncoderHelpers.java | 245 +++++++++++++++++++++
1 file changed, 245 insertions(+)
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpers.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpers.java
index d44fe27..b072803 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpers.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpers.java
@@ -17,11 +17,40 @@
*/
package org.apache.beam.runners.spark.structuredstreaming.translation.helpers;
+import static org.apache.spark.sql.types.DataTypes.BinaryType;
+
+import java.lang.reflect.Array;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Objects;
+import org.apache.beam.runners.spark.structuredstreaming.translation.SchemaHelpers;
+import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.KV;
import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.Encoders;
+import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal;
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder;
+import org.apache.spark.sql.catalyst.expressions.BoundReference;
+import org.apache.spark.sql.catalyst.expressions.Cast;
+import org.apache.spark.sql.catalyst.expressions.Expression;
+import org.apache.spark.sql.catalyst.expressions.NonSQLExpression;
+import org.apache.spark.sql.catalyst.expressions.UnaryExpression;
+import org.apache.spark.sql.catalyst.expressions.codegen.Block;
+import org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator;
+import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext;
+import org.apache.spark.sql.catalyst.expressions.codegen.ExprCode;
+import org.apache.spark.sql.catalyst.expressions.codegen.ExprValue;
+import org.apache.spark.sql.catalyst.expressions.codegen.SimpleExprValue;
+import org.apache.spark.sql.catalyst.expressions.codegen.VariableValue;
+import org.apache.spark.sql.types.DataType;
+import org.apache.spark.sql.types.ObjectType;
+import scala.StringContext;
import scala.Tuple2;
+import scala.collection.JavaConversions;
+import scala.collection.Seq;
+import scala.reflect.ClassTag;
+import scala.reflect.ClassTag$;
/** {@link Encoders} utility class. */
public class EncoderHelpers {
@@ -64,4 +93,220 @@ public class EncoderHelpers {
--------- Bridges from Beam Coders to Spark Encoders
*/
+ /** A way to construct encoders using generic serializers. */
+ private <T> Encoder<T> fromBeamCoder(Coder<T> coder, Class<T> claz){
+
+ List<Expression> serialiserList = new ArrayList<>();
+ serialiserList.add(new EncodeUsingBeamCoder<>(claz, coder));
+ ClassTag<T> classTag = ClassTag$.MODULE$.apply(claz);
+ return new ExpressionEncoder<>(
+ SchemaHelpers.binarySchema(),
+ false,
+ JavaConversions.collectionAsScalaIterable(serialiserList).toSeq(),
+ new DecodeUsingBeamCoder<>(classTag, coder), classTag);
+
+/*
+ ExpressionEncoder[T](
+ schema = new StructType().add("value", BinaryType),
+ flat = true,
+ serializer = Seq(
+ EncodeUsingSerializer(
+ BoundReference(0, ObjectType(classOf[AnyRef]), nullable = true), kryo = useKryo)),
+ deserializer =
+ DecodeUsingSerializer[T](
+ Cast(GetColumnByOrdinal(0, BinaryType), BinaryType),
+ classTag[T],
+ kryo = useKryo),
+ clsTag = classTag[T]
+ )
+*/
+ }
+
+ private static class EncodeUsingBeamCoder<T> extends UnaryExpression implements NonSQLExpression {
+
+ private Class<T> claz;
+ private Coder<T> beamCoder;
+ private Expression child;
+
+ private EncodeUsingBeamCoder( Class<T> claz, Coder<T> beamCoder) {
+ this.claz = claz;
+ this.beamCoder = beamCoder;
+ this.child = new BoundReference(0, new ObjectType(claz), true);
+ }
+
+ @Override public Expression child() {
+ return child;
+ }
+
+ @Override public ExprCode doGenCode(CodegenContext ctx, ExprCode ev) {
+ // Code to serialize.
+ ExprCode input = child.genCode(ctx);
+ String javaType = CodeGenerator.javaType(dataType());
+ String outputStream = "ByteArrayOutputStream baos = new ByteArrayOutputStream();";
+
+ String serialize = outputStream + "$beamCoder.encode(${input.value}, baos); baos.toByteArray();";
+
+ String outside = "final $javaType output = ${input.isNull} ? ${CodeGenerator.defaultValue(dataType)} : $serialize;";
+
+ List<String> instructions = new ArrayList<>();
+ instructions.add(outside);
+
+ Seq<String> parts = JavaConversions.collectionAsScalaIterable(instructions).toSeq();
+ StringContext stringContext = new StringContext(parts);
+ Block.BlockHelper blockHelper = new Block.BlockHelper(stringContext);
+ List<Object> args = new ArrayList<>();
+ args.add(new VariableValue("beamCoder", Coder.class));
+ args.add(new SimpleExprValue("input.value", ExprValue.class));
+ args.add(new VariableValue("javaType", String.class));
+ args.add(new SimpleExprValue("input.isNull", Boolean.class));
+ args.add(new SimpleExprValue("CodeGenerator.defaultValue(dataType)", String.class));
+ args.add(new VariableValue("$serialize", String.class));
+ Block code = blockHelper.code(JavaConversions.collectionAsScalaIterable(args).toSeq());
+
+ return ev.copy(input.code().$plus(code), input.isNull(), new VariableValue("output", Array.class));
+ }
+
+ @Override public DataType dataType() {
+ return BinaryType;
+ }
+
+ @Override public Object productElement(int n) {
+ if (n == 0) {
+ return this;
+ } else {
+ throw new IndexOutOfBoundsException(String.valueOf(n));
+ }
+ }
+
+ @Override public int productArity() {
+ //TODO test with spark Encoders if the arity of 1 is ok
+ return 1;
+ }
+
+ @Override public boolean canEqual(Object that) {
+ return (that instanceof EncodeUsingBeamCoder);
+ }
+
+ @Override public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ EncodeUsingBeamCoder<?> that = (EncodeUsingBeamCoder<?>) o;
+ return claz.equals(that.claz) && beamCoder.equals(that.beamCoder);
+ }
+
+ @Override public int hashCode() {
+ return Objects.hash(super.hashCode(), claz, beamCoder);
+ }
+ }
+
+ /*case class EncodeUsingSerializer(child: Expression, kryo: Boolean)
+ extends UnaryExpression with NonSQLExpression with SerializerSupport {
+
+ override def nullSafeEval(input: Any): Any = {
+ serializerInstance.serialize(input).array()
+ }
+
+ override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ val serializer = addImmutableScodererializerIfNeeded(ctx)
+ // Code to serialize.
+ val input = child.genCode(ctx)
+ val javaType = CodeGenerator.javaType(dataType)
+ val serialize = s"$serializer.serialize(${input.value}, null).array()"
+
+ val code = input.code + code"""
+ final $javaType ${ev.value} =
+ ${input.isNull} ? ${CodeGenerator.defaultValue(dataType)} : $serialize;
+ """
+ ev.copy(code = code, isNull = input.isNull)
+ }
+
+ override def dataType: DataType = BinaryType
+ }*/
+
+ private static class DecodeUsingBeamCoder<T> extends UnaryExpression implements NonSQLExpression{
+
+ private ClassTag<T> classTag;
+ private Coder<T> beamCoder;
+
+ private DecodeUsingBeamCoder(ClassTag<T> classTag, Coder<T> beamCoder) {
+ this.classTag = classTag;
+ this.beamCoder = beamCoder;
+ }
+
+ @Override public Expression child() {
+ return new Cast(new GetColumnByOrdinal(0, BinaryType), BinaryType);
+ }
+
+ @Override public ExprCode doGenCode(CodegenContext ctx, ExprCode ev) {
+ return null;
+ }
+
+ @Override public DataType dataType() {
+ return new ObjectType(classTag.runtimeClass());
+ }
+
+ @Override public Object productElement(int n) {
+ if (n == 0) {
+ return this;
+ } else {
+ throw new IndexOutOfBoundsException(String.valueOf(n));
+ }
+ }
+
+ @Override public int productArity() {
+ //TODO test with spark Encoders if the arity of 1 is ok
+ return 1;
+ }
+
+ @Override public boolean canEqual(Object that) {
+ return (that instanceof DecodeUsingBeamCoder);
+ }
+
+ @Override public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ DecodeUsingBeamCoder<?> that = (DecodeUsingBeamCoder<?>) o;
+ return classTag.equals(that.classTag) && beamCoder.equals(that.beamCoder);
+ }
+
+ @Override public int hashCode() {
+ return Objects.hash(super.hashCode(), classTag, beamCoder);
+ }
+ }
+/*
+case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: Boolean)
+ extends UnaryExpression with NonSQLExpression with SerializerSupport {
+
+ override def nullSafeEval(input: Any): Any = {
+ val inputBytes = java.nio.ByteBuffer.wrap(input.asInstanceOf[Array[Byte]])
+ serializerInstance.deserialize(inputBytes)
+ }
+
+ override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ val serializer = addImmutableSerializerIfNeeded(ctx)
+ // Code to deserialize.
+ val input = child.genCode(ctx)
+ val javaType = CodeGenerator.javaType(dataType)
+ val deserialize =
+ s"($javaType) $serializer.deserialize(java.nio.ByteBuffer.wrap(${input.value}), null)"
+
+ val code = input.code + code"""
+ final $javaType ${ev.value} =
+ ${input.isNull} ? ${CodeGenerator.defaultValue(dataType)} : $deserialize;
+ """
+ ev.copy(code = code, isNull = input.isNull)
+ }
+
+ override def dataType: DataType = ObjectType(tag.runtimeClass)
+ }
+*/
+
}