You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by ch...@apache.org on 2022/04/03 04:19:53 UTC

[beam] branch master updated: [BEAM-14143] Simplifies the ExternalPythonTransform API (#17101)

This is an automated email from the ASF dual-hosted git repository.

chamikara pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 8d5ca41  [BEAM-14143] Simplifies the ExternalPythonTransform API (#17101)
8d5ca41 is described below

commit 8d5ca41992b3f4fda75e678fa5d517e5333bbb8a
Author: Chamikara Jayalath <ch...@gmail.com>
AuthorDate: Sat Apr 2 21:18:28 2022 -0700

    [BEAM-14143] Simplifies the ExternalPythonTransform API (#17101)
    
    * Simplifies the ExternalPythonTransform API
    
    * Fix checkerframework errors
---
 .../extensions/python/ExternalPythonTransform.java | 213 +++++++++++++++++++-
 .../python/ExternalPythonTransformTest.java        | 224 ++++++++++++++++++++-
 2 files changed, 425 insertions(+), 12 deletions(-)

diff --git a/sdks/java/extensions/python/src/main/java/org/apache/beam/sdk/extensions/python/ExternalPythonTransform.java b/sdks/java/extensions/python/src/main/java/org/apache/beam/sdk/extensions/python/ExternalPythonTransform.java
index 163f873..b381dcb 100644
--- a/sdks/java/extensions/python/src/main/java/org/apache/beam/sdk/extensions/python/ExternalPythonTransform.java
+++ b/sdks/java/extensions/python/src/main/java/org/apache/beam/sdk/extensions/python/ExternalPythonTransform.java
@@ -17,14 +17,24 @@
  */
 package org.apache.beam.sdk.extensions.python;
 
+import java.util.Arrays;
+import java.util.Map;
 import java.util.Set;
+import java.util.SortedMap;
+import java.util.TreeMap;
 import java.util.UUID;
 import org.apache.beam.model.pipeline.v1.ExternalTransforms;
+import org.apache.beam.repackaged.core.org.apache.commons.lang3.ClassUtils;
 import org.apache.beam.runners.core.construction.External;
 import org.apache.beam.sdk.coders.RowCoder;
+import org.apache.beam.sdk.schemas.JavaFieldSchema;
+import org.apache.beam.sdk.schemas.NoSuchSchemaException;
 import org.apache.beam.sdk.schemas.Schema;
+import org.apache.beam.sdk.schemas.SchemaRegistry;
 import org.apache.beam.sdk.schemas.SchemaTranslation;
+import org.apache.beam.sdk.schemas.utils.StaticSchemaInference;
 import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.SerializableFunction;
 import org.apache.beam.sdk.util.CoderUtils;
 import org.apache.beam.sdk.values.PBegin;
 import org.apache.beam.sdk.values.PCollection;
@@ -33,25 +43,210 @@ import org.apache.beam.sdk.values.PInput;
 import org.apache.beam.sdk.values.POutput;
 import org.apache.beam.sdk.values.Row;
 import org.apache.beam.sdk.values.TupleTag;
+import org.apache.beam.sdk.values.TypeDescriptor;
 import org.apache.beam.vendor.grpc.v1p43p2.com.google.protobuf.ByteString;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
+import org.checkerframework.checker.nullness.qual.NonNull;
+import org.checkerframework.checker.nullness.qual.Nullable;
 
 /** Wrapper for invoking external Python transforms. */
 public class ExternalPythonTransform<InputT extends PInput, OutputT extends POutput>
     extends PTransform<InputT, OutputT> {
-  private final String fullyQualifiedName;
-  private final Row args;
-  private final Row kwargs;
 
-  public ExternalPythonTransform(String fullyQualifiedName, Row args, Row kwargs) {
+  private static final SchemaRegistry SCHEMA_REGISTRY = SchemaRegistry.createDefault();
+  private String fullyQualifiedName;
+
+  // We preseve the order here since Schema's care about order of fields but the order will not
+  // matter when applying kwargs at the Python side.
+  private SortedMap<String, Object> kwargsMap;
+
+  private @Nullable Object @NonNull [] argsArray;
+  private @Nullable Row providedKwargsRow;
+
+  private ExternalPythonTransform(String fullyQualifiedName) {
     this.fullyQualifiedName = fullyQualifiedName;
-    this.args = args;
-    this.kwargs = kwargs;
+    this.kwargsMap = new TreeMap<>();
+    argsArray = new Object[] {};
+  }
+
+  /**
+   * Instantiates a cross-language wrapper for a Python transform with a given transform name.
+   *
+   * @param tranformName fully qualified transform name.
+   * @param <InputT> Input {@link PCollection} type
+   * @param <OutputT> Output {@link PCollection} type
+   * @return A {@link ExternalPythonTransform} for the given transform name.
+   */
+  public static <InputT extends PInput, OutputT extends POutput>
+      ExternalPythonTransform<InputT, OutputT> from(String tranformName) {
+    return new ExternalPythonTransform<InputT, OutputT>(tranformName);
+  }
+
+  /**
+   * Positional arguments for the Python cross-language transform. If invoked more than once, new
+   * arguments will be appended to the previously specified arguments.
+   *
+   * @param args list of arguments.
+   * @return updated wrapper for the cross-language transform.
+   */
+  public ExternalPythonTransform<InputT, OutputT> withArgs(@NonNull Object... args) {
+    @Nullable
+    Object @NonNull [] result = Arrays.copyOf(this.argsArray, this.argsArray.length + args.length);
+    System.arraycopy(args, 0, result, this.argsArray.length, args.length);
+    this.argsArray = result;
+    return this;
+  }
+
+  /**
+   * Specifies a single keyword argument for the Python cross-language transform. This may be
+   * invoked multiple times to add more than one keyword argument.
+   *
+   * @param name argument name.
+   * @param value argument value
+   * @return updated wrapper for the cross-language transform.
+   */
+  public ExternalPythonTransform<InputT, OutputT> withKwarg(String name, Object value) {
+    if (providedKwargsRow != null) {
+      throw new IllegalArgumentException("Kwargs were specified both directly and as a Row object");
+    }
+    kwargsMap.put(name, value);
+    return this;
+  }
+
+  /**
+   * Specifies keyword arguments for the Python cross-language transform. If invoked more than once,
+   * new keyword arguments map will be added to the previously prided keyword arguments.
+   *
+   * @return updated wrapper for the cross-language transform.
+   */
+  public ExternalPythonTransform<InputT, OutputT> withKwargs(Map<String, Object> kwargs) {
+    if (providedKwargsRow != null) {
+      throw new IllegalArgumentException("Kwargs were specified both directly and as a Row object");
+    }
+    kwargsMap.putAll(kwargs);
+    return this;
+  }
+
+  /**
+   * Specifies keyword arguments as a Row objects.
+   *
+   * @param kwargs keyword arguments as a {@link Row} objects. An empty Row represents zero keyword
+   *     arguments.
+   * @return updated wrapper for the cross-language transform.
+   */
+  public ExternalPythonTransform<InputT, OutputT> withKwargs(Row kwargs) {
+    if (this.kwargsMap.size() > 0) {
+      throw new IllegalArgumentException("Kwargs were specified both directly and as a Row object");
+    }
+    this.providedKwargsRow = kwargs;
+    return this;
+  }
+
+  @VisibleForTesting
+  Row buildOrGetKwargsRow() {
+    if (providedKwargsRow != null) {
+      return providedKwargsRow;
+    } else {
+      Schema schema =
+          generateSchemaFromFieldValues(
+              kwargsMap.values().toArray(), kwargsMap.keySet().toArray(new String[] {}));
+      return Row.withSchema(schema)
+          .addValues(convertComplexTypesToRows(kwargsMap.values().toArray()))
+          .build();
+    }
+  }
+
+  // Types that are not one of following are considered custom types.
+  // * Java primitives
+  // * Type String
+  // * Type Row
+  private static boolean isCustomType(java.lang.Class<?> type) {
+    boolean val =
+        !(ClassUtils.isPrimitiveOrWrapper(type)
+            || type == String.class
+            || Row.class.isAssignableFrom(type));
+    return val;
+  }
+
+  // If the custom type has a registered schema, we use that. OTherwise we try to register it using
+  // 'JavaFieldSchema'.
+  private Row convertCustomValue(Object value) {
+    SerializableFunction<Object, Row> toRowFunc;
+    try {
+      toRowFunc =
+          (SerializableFunction<Object, Row>) SCHEMA_REGISTRY.getToRowFunction(value.getClass());
+    } catch (NoSuchSchemaException e) {
+      SCHEMA_REGISTRY.registerSchemaProvider(value.getClass(), new JavaFieldSchema());
+      try {
+        toRowFunc =
+            (SerializableFunction<Object, Row>) SCHEMA_REGISTRY.getToRowFunction(value.getClass());
+      } catch (NoSuchSchemaException e1) {
+        throw new RuntimeException(e1);
+      }
+    }
+    return toRowFunc.apply(value);
+  }
+
+  private Object[] convertComplexTypesToRows(@Nullable Object @NonNull [] values) {
+    Object[] converted = new Object[values.length];
+    for (int i = 0; i < values.length; i++) {
+      Object value = values[i];
+      if (value != null) {
+        converted[i] = isCustomType(value.getClass()) ? convertCustomValue(value) : value;
+      } else {
+        throw new RuntimeException("Null values are not supported");
+      }
+    }
+    return converted;
+  }
+
+  @VisibleForTesting
+  Row buildOrGetArgsRow() {
+    Schema schema = generateSchemaFromFieldValues(argsArray, null);
+    Object[] convertedValues = convertComplexTypesToRows(argsArray);
+    return Row.withSchema(schema).addValues(convertedValues).build();
+  }
+
+  private Schema generateSchemaDirectly(
+      @Nullable Object @NonNull [] fieldValues, @NonNull String @Nullable [] fieldNames) {
+    Schema.Builder builder = Schema.builder();
+    int counter = 0;
+    for (Object field : fieldValues) {
+      if (field == null) {
+        throw new RuntimeException("Null field values are not supported");
+      }
+      String fieldName = (fieldNames != null) ? fieldNames[counter] : "field" + counter;
+      if (field instanceof Row) {
+        // Rows are used as is but other types are converted to proper field types.
+        builder.addRowField(fieldName, ((Row) field).getSchema());
+      } else {
+        builder.addField(
+            fieldName,
+            StaticSchemaInference.fieldFromType(
+                TypeDescriptor.of(field.getClass()),
+                JavaFieldSchema.JavaFieldTypeSupplier.INSTANCE));
+      }
+
+      counter++;
+    }
+
+    Schema schema = builder.build();
+    return schema;
+  }
+
+  // We generate the Schema from the provided field names and values. If field names are
+  // not provided, we generate them.
+  private Schema generateSchemaFromFieldValues(
+      @Nullable Object @NonNull [] fieldValues, @NonNull String @Nullable [] fieldNames) {
+    return generateSchemaDirectly(fieldValues, fieldNames);
   }
 
   @Override
   public OutputT expand(InputT input) {
     int port;
+    Row argsRow = buildOrGetArgsRow();
+    Row kwargsRow = buildOrGetKwargsRow();
     try {
       port = PythonService.findAvailablePort();
       PythonService service =
@@ -64,11 +259,11 @@ public class ExternalPythonTransform<InputT extends PInput, OutputT extends POut
       Schema payloadSchema =
           Schema.of(
               Schema.Field.of("constructor", Schema.FieldType.STRING),
-              Schema.Field.of("args", Schema.FieldType.row(args.getSchema())),
-              Schema.Field.of("kwargs", Schema.FieldType.row(kwargs.getSchema())));
+              Schema.Field.of("args", Schema.FieldType.row(argsRow.getSchema())),
+              Schema.Field.of("kwargs", Schema.FieldType.row(kwargsRow.getSchema())));
       payloadSchema.setUUID(UUID.randomUUID());
       Row payloadRow =
-          Row.withSchema(payloadSchema).addValues(fullyQualifiedName, args, kwargs).build();
+          Row.withSchema(payloadSchema).addValues(fullyQualifiedName, argsRow, kwargsRow).build();
       ExternalTransforms.ExternalConfigurationPayload payload =
           ExternalTransforms.ExternalConfigurationPayload.newBuilder()
               .setSchema(SchemaTranslation.schemaToProto(payloadSchema, true))
diff --git a/sdks/java/extensions/python/src/test/java/org/apache/beam/sdk/extensions/python/ExternalPythonTransformTest.java b/sdks/java/extensions/python/src/test/java/org/apache/beam/sdk/extensions/python/ExternalPythonTransformTest.java
index f2d8bee..b9502c8 100644
--- a/sdks/java/extensions/python/src/test/java/org/apache/beam/sdk/extensions/python/ExternalPythonTransformTest.java
+++ b/sdks/java/extensions/python/src/test/java/org/apache/beam/sdk/extensions/python/ExternalPythonTransformTest.java
@@ -17,7 +17,10 @@
  */
 package org.apache.beam.sdk.extensions.python;
 
+import static org.junit.Assert.assertEquals;
+
 import java.io.Serializable;
+import java.util.Map;
 import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.schemas.Schema;
 import org.apache.beam.sdk.testing.PAssert;
@@ -27,6 +30,7 @@ import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.Row;
 import org.apache.beam.sdk.values.TypeDescriptors;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
 import org.junit.Ignore;
 import org.junit.Test;
 import org.junit.runner.RunWith;
@@ -41,11 +45,225 @@ public class ExternalPythonTransformTest implements Serializable {
     PCollection<String> output =
         p.apply(Create.of(KV.of("A", "x"), KV.of("A", "y"), KV.of("B", "z")))
             .apply(
-                new ExternalPythonTransform<
-                    PCollection<KV<String, String>>, PCollection<KV<String, Iterable<String>>>>(
-                    "apache_beam.GroupByKey", Row.nullRow(Schema.of()), Row.nullRow(Schema.of())))
+                ExternalPythonTransform
+                    .<PCollection<KV<String, String>>, PCollection<KV<String, Iterable<String>>>>
+                        from("apache_beam.GroupByKey"))
             .apply(MapElements.into(TypeDescriptors.strings()).via(kv -> kv.getKey()));
     PAssert.that(output).containsInAnyOrder("A", "B");
     // TODO: Run this on a multi-language supporting runner.
   }
+
+  @Test
+  public void generateArgsEmpty() {
+    ExternalPythonTransform<?, ?> transform =
+        ExternalPythonTransform
+            .<PCollection<KV<String, String>>, PCollection<KV<String, Iterable<String>>>>from(
+                "DummyTransform");
+
+    Row receivedRow = transform.buildOrGetArgsRow();
+    assertEquals(0, receivedRow.getFieldCount());
+  }
+
+  @Test
+  public void generateArgsWithPrimitives() {
+    ExternalPythonTransform<?, ?> transform =
+        ExternalPythonTransform
+            .<PCollection<KV<String, String>>, PCollection<KV<String, Iterable<String>>>>from(
+                "DummyTransform")
+            .withArgs("aaa", "bbb", 11, 12L, 15.6, true);
+
+    Schema expectedSchema =
+        Schema.builder()
+            .addStringField("field0")
+            .addStringField("field1")
+            .addInt32Field("field2")
+            .addInt64Field("field3")
+            .addDoubleField("field4")
+            .addBooleanField("field5")
+            .build();
+    Row expectedRow =
+        Row.withSchema(expectedSchema).addValues("aaa", "bbb", 11, 12L, 15.6, true).build();
+
+    Row receivedRow = transform.buildOrGetArgsRow();
+    assertEquals(expectedRow, receivedRow);
+  }
+
+  @Test
+  public void generateArgsWithRow() {
+    Schema subRowSchema1 =
+        Schema.builder().addStringField("field0").addInt32Field("field1").build();
+    Row rowField1 = Row.withSchema(subRowSchema1).addValues("xxx", 123).build();
+    Schema subRowSchema2 =
+        Schema.builder()
+            .addDoubleField("field0")
+            .addBooleanField("field1")
+            .addStringField("field2")
+            .build();
+    Row rowField2 = Row.withSchema(subRowSchema2).addValues(12.5, true, "yyy").build();
+
+    ExternalPythonTransform<?, ?> transform =
+        ExternalPythonTransform
+            .<PCollection<KV<String, String>>, PCollection<KV<String, Iterable<String>>>>from(
+                "DummyTransform")
+            .withArgs(rowField1, rowField2);
+
+    Schema expectedSchema =
+        Schema.builder()
+            .addRowField("field0", subRowSchema1)
+            .addRowField("field1", subRowSchema2)
+            .build();
+    Row expectedRow = Row.withSchema(expectedSchema).addValues(rowField1, rowField2).build();
+
+    Row receivedRow = transform.buildOrGetArgsRow();
+    assertEquals(expectedRow, receivedRow);
+  }
+
+  static class CustomType {
+    int intField;
+    String strField;
+  }
+
+  @Test
+  public void generateArgsWithCustomType() {
+    CustomType customType1 = new CustomType();
+    customType1.strField = "xxx";
+    customType1.intField = 123;
+
+    CustomType customType2 = new CustomType();
+    customType2.strField = "yyy";
+    customType2.intField = 456;
+
+    ExternalPythonTransform<?, ?> transform =
+        ExternalPythonTransform
+            .<PCollection<KV<String, String>>, PCollection<KV<String, Iterable<String>>>>from(
+                "DummyTransform")
+            .withArgs(customType1, customType2);
+
+    Row receivedRow = transform.buildOrGetArgsRow();
+
+    assertEquals("xxx", receivedRow.getRow("field0").getString("strField"));
+    assertEquals(123, (int) receivedRow.getRow("field0").getInt32("intField"));
+
+    assertEquals("yyy", receivedRow.getRow("field1").getString("strField"));
+    assertEquals(456, (int) receivedRow.getRow("field1").getInt32("intField"));
+  }
+
+  @Test
+  public void generateKwargsEmpty() {
+    ExternalPythonTransform<?, ?> transform =
+        ExternalPythonTransform
+            .<PCollection<KV<String, String>>, PCollection<KV<String, Iterable<String>>>>from(
+                "DummyTransform");
+
+    Row receivedRow = transform.buildOrGetKwargsRow();
+    assertEquals(0, receivedRow.getFieldCount());
+  }
+
+  @Test
+  public void generateKwargsWithPrimitives() {
+    ExternalPythonTransform<?, ?> transform =
+        ExternalPythonTransform
+            .<PCollection<KV<String, String>>, PCollection<KV<String, Iterable<String>>>>from(
+                "DummyTransform")
+            .withKwarg("stringField1", "aaa")
+            .withKwarg("stringField2", "bbb")
+            .withKwarg("intField", 11)
+            .withKwarg("longField", 12L)
+            .withKwarg("doubleField", 15.6)
+            .withKwarg("boolField", true);
+
+    Row receivedRow = transform.buildOrGetKwargsRow();
+    assertEquals("aaa", receivedRow.getString("stringField1"));
+    assertEquals("bbb", receivedRow.getString("stringField2"));
+    assertEquals(11, (int) receivedRow.getInt32("intField"));
+    assertEquals(12L, (long) receivedRow.getInt64("longField"));
+    assertEquals(15.6, (double) receivedRow.getDouble("doubleField"), 0);
+    assertEquals(true, receivedRow.getBoolean("boolField"));
+  }
+
+  @Test
+  public void generateKwargsRow() {
+    Schema subRowSchema1 =
+        Schema.builder().addStringField("field0").addInt32Field("field1").build();
+    Row rowField1 = Row.withSchema(subRowSchema1).addValues("xxx", 123).build();
+    Schema subRowSchema2 =
+        Schema.builder()
+            .addDoubleField("field0")
+            .addBooleanField("field1")
+            .addStringField("field2")
+            .build();
+    Row rowField2 = Row.withSchema(subRowSchema2).addValues(12.5, true, "yyy").build();
+
+    ExternalPythonTransform<?, ?> transform =
+        ExternalPythonTransform
+            .<PCollection<KV<String, String>>, PCollection<KV<String, Iterable<String>>>>from(
+                "DummyTransform")
+            .withKwarg("customField0", rowField1)
+            .withKwarg("customField1", rowField2);
+
+    Schema expectedSchema =
+        Schema.builder()
+            .addRowField("customField0", subRowSchema1)
+            .addRowField("customField1", subRowSchema2)
+            .build();
+    Row expectedRow = Row.withSchema(expectedSchema).addValues(rowField1, rowField2).build();
+
+    Row receivedRow = transform.buildOrGetKwargsRow();
+    assertEquals(expectedRow, receivedRow);
+  }
+
+  @Test
+  public void generateKwargsWithCustomType() {
+    CustomType customType1 = new CustomType();
+    customType1.strField = "xxx";
+    customType1.intField = 123;
+
+    CustomType customType2 = new CustomType();
+    customType2.strField = "yyy";
+    customType2.intField = 456;
+
+    ExternalPythonTransform<?, ?> transform =
+        ExternalPythonTransform
+            .<PCollection<KV<String, String>>, PCollection<KV<String, Iterable<String>>>>from(
+                "DummyTransform")
+            .withKwarg("customField0", customType1)
+            .withKwarg("customField1", customType2);
+
+    Row receivedRow = transform.buildOrGetKwargsRow();
+
+    assertEquals("xxx", receivedRow.getRow("customField0").getString("strField"));
+    assertEquals(123, (int) receivedRow.getRow("customField0").getInt32("intField"));
+
+    assertEquals("yyy", receivedRow.getRow("customField1").getString("strField"));
+    assertEquals(456, (int) receivedRow.getRow("customField1").getInt32("intField"));
+  }
+
+  @Test
+  public void generateKwargsFromMap() {
+    Map<String, Object> kwargsMap =
+        ImmutableMap.of(
+            "stringField1",
+            "aaa",
+            "stringField2",
+            "bbb",
+            "intField",
+            Integer.valueOf(11),
+            "longField",
+            Long.valueOf(12L),
+            "doubleField",
+            Double.valueOf(15.6));
+
+    ExternalPythonTransform<?, ?> transform =
+        ExternalPythonTransform
+            .<PCollection<KV<String, String>>, PCollection<KV<String, Iterable<String>>>>from(
+                "DummyTransform")
+            .withKwargs(kwargsMap);
+
+    Row receivedRow = transform.buildOrGetKwargsRow();
+    assertEquals("aaa", receivedRow.getString("stringField1"));
+    assertEquals("bbb", receivedRow.getString("stringField2"));
+    assertEquals(11, (int) receivedRow.getInt32("intField"));
+    assertEquals(12L, (long) receivedRow.getInt64("longField"));
+    assertEquals(15.6, (double) receivedRow.getDouble("doubleField"), 0);
+  }
 }