You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by ch...@apache.org on 2022/04/03 04:19:53 UTC
[beam] branch master updated: [BEAM-14143] Simplifies the ExternalPythonTransform API (#17101)
This is an automated email from the ASF dual-hosted git repository.
chamikara pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 8d5ca41 [BEAM-14143] Simplifies the ExternalPythonTransform API (#17101)
8d5ca41 is described below
commit 8d5ca41992b3f4fda75e678fa5d517e5333bbb8a
Author: Chamikara Jayalath <ch...@gmail.com>
AuthorDate: Sat Apr 2 21:18:28 2022 -0700
[BEAM-14143] Simplifies the ExternalPythonTransform API (#17101)
* Simplifies the ExternalPythonTransform API
* Fix checkerframework errors
---
.../extensions/python/ExternalPythonTransform.java | 213 +++++++++++++++++++-
.../python/ExternalPythonTransformTest.java | 224 ++++++++++++++++++++-
2 files changed, 425 insertions(+), 12 deletions(-)
diff --git a/sdks/java/extensions/python/src/main/java/org/apache/beam/sdk/extensions/python/ExternalPythonTransform.java b/sdks/java/extensions/python/src/main/java/org/apache/beam/sdk/extensions/python/ExternalPythonTransform.java
index 163f873..b381dcb 100644
--- a/sdks/java/extensions/python/src/main/java/org/apache/beam/sdk/extensions/python/ExternalPythonTransform.java
+++ b/sdks/java/extensions/python/src/main/java/org/apache/beam/sdk/extensions/python/ExternalPythonTransform.java
@@ -17,14 +17,24 @@
*/
package org.apache.beam.sdk.extensions.python;
+import java.util.Arrays;
+import java.util.Map;
import java.util.Set;
+import java.util.SortedMap;
+import java.util.TreeMap;
import java.util.UUID;
import org.apache.beam.model.pipeline.v1.ExternalTransforms;
+import org.apache.beam.repackaged.core.org.apache.commons.lang3.ClassUtils;
import org.apache.beam.runners.core.construction.External;
import org.apache.beam.sdk.coders.RowCoder;
+import org.apache.beam.sdk.schemas.JavaFieldSchema;
+import org.apache.beam.sdk.schemas.NoSuchSchemaException;
import org.apache.beam.sdk.schemas.Schema;
+import org.apache.beam.sdk.schemas.SchemaRegistry;
import org.apache.beam.sdk.schemas.SchemaTranslation;
+import org.apache.beam.sdk.schemas.utils.StaticSchemaInference;
import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.util.CoderUtils;
import org.apache.beam.sdk.values.PBegin;
import org.apache.beam.sdk.values.PCollection;
@@ -33,25 +43,210 @@ import org.apache.beam.sdk.values.PInput;
import org.apache.beam.sdk.values.POutput;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.sdk.values.TupleTag;
+import org.apache.beam.sdk.values.TypeDescriptor;
import org.apache.beam.vendor.grpc.v1p43p2.com.google.protobuf.ByteString;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
+import org.checkerframework.checker.nullness.qual.NonNull;
+import org.checkerframework.checker.nullness.qual.Nullable;
/** Wrapper for invoking external Python transforms. */
public class ExternalPythonTransform<InputT extends PInput, OutputT extends POutput>
extends PTransform<InputT, OutputT> {
- private final String fullyQualifiedName;
- private final Row args;
- private final Row kwargs;
- public ExternalPythonTransform(String fullyQualifiedName, Row args, Row kwargs) {
+ private static final SchemaRegistry SCHEMA_REGISTRY = SchemaRegistry.createDefault();
+ private String fullyQualifiedName;
+
+ // We preseve the order here since Schema's care about order of fields but the order will not
+ // matter when applying kwargs at the Python side.
+ private SortedMap<String, Object> kwargsMap;
+
+ private @Nullable Object @NonNull [] argsArray;
+ private @Nullable Row providedKwargsRow;
+
+ private ExternalPythonTransform(String fullyQualifiedName) {
this.fullyQualifiedName = fullyQualifiedName;
- this.args = args;
- this.kwargs = kwargs;
+ this.kwargsMap = new TreeMap<>();
+ argsArray = new Object[] {};
+ }
+
+ /**
+ * Instantiates a cross-language wrapper for a Python transform with a given transform name.
+ *
+ * @param tranformName fully qualified transform name.
+ * @param <InputT> Input {@link PCollection} type
+ * @param <OutputT> Output {@link PCollection} type
+ * @return A {@link ExternalPythonTransform} for the given transform name.
+ */
+ public static <InputT extends PInput, OutputT extends POutput>
+ ExternalPythonTransform<InputT, OutputT> from(String tranformName) {
+ return new ExternalPythonTransform<InputT, OutputT>(tranformName);
+ }
+
+ /**
+ * Positional arguments for the Python cross-language transform. If invoked more than once, new
+ * arguments will be appended to the previously specified arguments.
+ *
+ * @param args list of arguments.
+ * @return updated wrapper for the cross-language transform.
+ */
+ public ExternalPythonTransform<InputT, OutputT> withArgs(@NonNull Object... args) {
+ @Nullable
+ Object @NonNull [] result = Arrays.copyOf(this.argsArray, this.argsArray.length + args.length);
+ System.arraycopy(args, 0, result, this.argsArray.length, args.length);
+ this.argsArray = result;
+ return this;
+ }
+
+ /**
+ * Specifies a single keyword argument for the Python cross-language transform. This may be
+ * invoked multiple times to add more than one keyword argument.
+ *
+ * @param name argument name.
+ * @param value argument value
+ * @return updated wrapper for the cross-language transform.
+ */
+ public ExternalPythonTransform<InputT, OutputT> withKwarg(String name, Object value) {
+ if (providedKwargsRow != null) {
+ throw new IllegalArgumentException("Kwargs were specified both directly and as a Row object");
+ }
+ kwargsMap.put(name, value);
+ return this;
+ }
+
+ /**
+ * Specifies keyword arguments for the Python cross-language transform. If invoked more than once,
+ * new keyword arguments map will be added to the previously prided keyword arguments.
+ *
+ * @return updated wrapper for the cross-language transform.
+ */
+ public ExternalPythonTransform<InputT, OutputT> withKwargs(Map<String, Object> kwargs) {
+ if (providedKwargsRow != null) {
+ throw new IllegalArgumentException("Kwargs were specified both directly and as a Row object");
+ }
+ kwargsMap.putAll(kwargs);
+ return this;
+ }
+
+ /**
+ * Specifies keyword arguments as a Row objects.
+ *
+ * @param kwargs keyword arguments as a {@link Row} objects. An empty Row represents zero keyword
+ * arguments.
+ * @return updated wrapper for the cross-language transform.
+ */
+ public ExternalPythonTransform<InputT, OutputT> withKwargs(Row kwargs) {
+ if (this.kwargsMap.size() > 0) {
+ throw new IllegalArgumentException("Kwargs were specified both directly and as a Row object");
+ }
+ this.providedKwargsRow = kwargs;
+ return this;
+ }
+
+ @VisibleForTesting
+ Row buildOrGetKwargsRow() {
+ if (providedKwargsRow != null) {
+ return providedKwargsRow;
+ } else {
+ Schema schema =
+ generateSchemaFromFieldValues(
+ kwargsMap.values().toArray(), kwargsMap.keySet().toArray(new String[] {}));
+ return Row.withSchema(schema)
+ .addValues(convertComplexTypesToRows(kwargsMap.values().toArray()))
+ .build();
+ }
+ }
+
+ // Types that are not one of following are considered custom types.
+ // * Java primitives
+ // * Type String
+ // * Type Row
+ private static boolean isCustomType(java.lang.Class<?> type) {
+ boolean val =
+ !(ClassUtils.isPrimitiveOrWrapper(type)
+ || type == String.class
+ || Row.class.isAssignableFrom(type));
+ return val;
+ }
+
+ // If the custom type has a registered schema, we use that. OTherwise we try to register it using
+ // 'JavaFieldSchema'.
+ private Row convertCustomValue(Object value) {
+ SerializableFunction<Object, Row> toRowFunc;
+ try {
+ toRowFunc =
+ (SerializableFunction<Object, Row>) SCHEMA_REGISTRY.getToRowFunction(value.getClass());
+ } catch (NoSuchSchemaException e) {
+ SCHEMA_REGISTRY.registerSchemaProvider(value.getClass(), new JavaFieldSchema());
+ try {
+ toRowFunc =
+ (SerializableFunction<Object, Row>) SCHEMA_REGISTRY.getToRowFunction(value.getClass());
+ } catch (NoSuchSchemaException e1) {
+ throw new RuntimeException(e1);
+ }
+ }
+ return toRowFunc.apply(value);
+ }
+
+ private Object[] convertComplexTypesToRows(@Nullable Object @NonNull [] values) {
+ Object[] converted = new Object[values.length];
+ for (int i = 0; i < values.length; i++) {
+ Object value = values[i];
+ if (value != null) {
+ converted[i] = isCustomType(value.getClass()) ? convertCustomValue(value) : value;
+ } else {
+ throw new RuntimeException("Null values are not supported");
+ }
+ }
+ return converted;
+ }
+
+ @VisibleForTesting
+ Row buildOrGetArgsRow() {
+ Schema schema = generateSchemaFromFieldValues(argsArray, null);
+ Object[] convertedValues = convertComplexTypesToRows(argsArray);
+ return Row.withSchema(schema).addValues(convertedValues).build();
+ }
+
+ private Schema generateSchemaDirectly(
+ @Nullable Object @NonNull [] fieldValues, @NonNull String @Nullable [] fieldNames) {
+ Schema.Builder builder = Schema.builder();
+ int counter = 0;
+ for (Object field : fieldValues) {
+ if (field == null) {
+ throw new RuntimeException("Null field values are not supported");
+ }
+ String fieldName = (fieldNames != null) ? fieldNames[counter] : "field" + counter;
+ if (field instanceof Row) {
+ // Rows are used as is but other types are converted to proper field types.
+ builder.addRowField(fieldName, ((Row) field).getSchema());
+ } else {
+ builder.addField(
+ fieldName,
+ StaticSchemaInference.fieldFromType(
+ TypeDescriptor.of(field.getClass()),
+ JavaFieldSchema.JavaFieldTypeSupplier.INSTANCE));
+ }
+
+ counter++;
+ }
+
+ Schema schema = builder.build();
+ return schema;
+ }
+
+ // We generate the Schema from the provided field names and values. If field names are
+ // not provided, we generate them.
+ private Schema generateSchemaFromFieldValues(
+ @Nullable Object @NonNull [] fieldValues, @NonNull String @Nullable [] fieldNames) {
+ return generateSchemaDirectly(fieldValues, fieldNames);
}
@Override
public OutputT expand(InputT input) {
int port;
+ Row argsRow = buildOrGetArgsRow();
+ Row kwargsRow = buildOrGetKwargsRow();
try {
port = PythonService.findAvailablePort();
PythonService service =
@@ -64,11 +259,11 @@ public class ExternalPythonTransform<InputT extends PInput, OutputT extends POut
Schema payloadSchema =
Schema.of(
Schema.Field.of("constructor", Schema.FieldType.STRING),
- Schema.Field.of("args", Schema.FieldType.row(args.getSchema())),
- Schema.Field.of("kwargs", Schema.FieldType.row(kwargs.getSchema())));
+ Schema.Field.of("args", Schema.FieldType.row(argsRow.getSchema())),
+ Schema.Field.of("kwargs", Schema.FieldType.row(kwargsRow.getSchema())));
payloadSchema.setUUID(UUID.randomUUID());
Row payloadRow =
- Row.withSchema(payloadSchema).addValues(fullyQualifiedName, args, kwargs).build();
+ Row.withSchema(payloadSchema).addValues(fullyQualifiedName, argsRow, kwargsRow).build();
ExternalTransforms.ExternalConfigurationPayload payload =
ExternalTransforms.ExternalConfigurationPayload.newBuilder()
.setSchema(SchemaTranslation.schemaToProto(payloadSchema, true))
diff --git a/sdks/java/extensions/python/src/test/java/org/apache/beam/sdk/extensions/python/ExternalPythonTransformTest.java b/sdks/java/extensions/python/src/test/java/org/apache/beam/sdk/extensions/python/ExternalPythonTransformTest.java
index f2d8bee..b9502c8 100644
--- a/sdks/java/extensions/python/src/test/java/org/apache/beam/sdk/extensions/python/ExternalPythonTransformTest.java
+++ b/sdks/java/extensions/python/src/test/java/org/apache/beam/sdk/extensions/python/ExternalPythonTransformTest.java
@@ -17,7 +17,10 @@
*/
package org.apache.beam.sdk.extensions.python;
+import static org.junit.Assert.assertEquals;
+
import java.io.Serializable;
+import java.util.Map;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.testing.PAssert;
@@ -27,6 +30,7 @@ import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.sdk.values.TypeDescriptors;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
import org.junit.Ignore;
import org.junit.Test;
import org.junit.runner.RunWith;
@@ -41,11 +45,225 @@ public class ExternalPythonTransformTest implements Serializable {
PCollection<String> output =
p.apply(Create.of(KV.of("A", "x"), KV.of("A", "y"), KV.of("B", "z")))
.apply(
- new ExternalPythonTransform<
- PCollection<KV<String, String>>, PCollection<KV<String, Iterable<String>>>>(
- "apache_beam.GroupByKey", Row.nullRow(Schema.of()), Row.nullRow(Schema.of())))
+ ExternalPythonTransform
+ .<PCollection<KV<String, String>>, PCollection<KV<String, Iterable<String>>>>
+ from("apache_beam.GroupByKey"))
.apply(MapElements.into(TypeDescriptors.strings()).via(kv -> kv.getKey()));
PAssert.that(output).containsInAnyOrder("A", "B");
// TODO: Run this on a multi-language supporting runner.
}
+
+ @Test
+ public void generateArgsEmpty() {
+ ExternalPythonTransform<?, ?> transform =
+ ExternalPythonTransform
+ .<PCollection<KV<String, String>>, PCollection<KV<String, Iterable<String>>>>from(
+ "DummyTransform");
+
+ Row receivedRow = transform.buildOrGetArgsRow();
+ assertEquals(0, receivedRow.getFieldCount());
+ }
+
+ @Test
+ public void generateArgsWithPrimitives() {
+ ExternalPythonTransform<?, ?> transform =
+ ExternalPythonTransform
+ .<PCollection<KV<String, String>>, PCollection<KV<String, Iterable<String>>>>from(
+ "DummyTransform")
+ .withArgs("aaa", "bbb", 11, 12L, 15.6, true);
+
+ Schema expectedSchema =
+ Schema.builder()
+ .addStringField("field0")
+ .addStringField("field1")
+ .addInt32Field("field2")
+ .addInt64Field("field3")
+ .addDoubleField("field4")
+ .addBooleanField("field5")
+ .build();
+ Row expectedRow =
+ Row.withSchema(expectedSchema).addValues("aaa", "bbb", 11, 12L, 15.6, true).build();
+
+ Row receivedRow = transform.buildOrGetArgsRow();
+ assertEquals(expectedRow, receivedRow);
+ }
+
+ @Test
+ public void generateArgsWithRow() {
+ Schema subRowSchema1 =
+ Schema.builder().addStringField("field0").addInt32Field("field1").build();
+ Row rowField1 = Row.withSchema(subRowSchema1).addValues("xxx", 123).build();
+ Schema subRowSchema2 =
+ Schema.builder()
+ .addDoubleField("field0")
+ .addBooleanField("field1")
+ .addStringField("field2")
+ .build();
+ Row rowField2 = Row.withSchema(subRowSchema2).addValues(12.5, true, "yyy").build();
+
+ ExternalPythonTransform<?, ?> transform =
+ ExternalPythonTransform
+ .<PCollection<KV<String, String>>, PCollection<KV<String, Iterable<String>>>>from(
+ "DummyTransform")
+ .withArgs(rowField1, rowField2);
+
+ Schema expectedSchema =
+ Schema.builder()
+ .addRowField("field0", subRowSchema1)
+ .addRowField("field1", subRowSchema2)
+ .build();
+ Row expectedRow = Row.withSchema(expectedSchema).addValues(rowField1, rowField2).build();
+
+ Row receivedRow = transform.buildOrGetArgsRow();
+ assertEquals(expectedRow, receivedRow);
+ }
+
+ static class CustomType {
+ int intField;
+ String strField;
+ }
+
+ @Test
+ public void generateArgsWithCustomType() {
+ CustomType customType1 = new CustomType();
+ customType1.strField = "xxx";
+ customType1.intField = 123;
+
+ CustomType customType2 = new CustomType();
+ customType2.strField = "yyy";
+ customType2.intField = 456;
+
+ ExternalPythonTransform<?, ?> transform =
+ ExternalPythonTransform
+ .<PCollection<KV<String, String>>, PCollection<KV<String, Iterable<String>>>>from(
+ "DummyTransform")
+ .withArgs(customType1, customType2);
+
+ Row receivedRow = transform.buildOrGetArgsRow();
+
+ assertEquals("xxx", receivedRow.getRow("field0").getString("strField"));
+ assertEquals(123, (int) receivedRow.getRow("field0").getInt32("intField"));
+
+ assertEquals("yyy", receivedRow.getRow("field1").getString("strField"));
+ assertEquals(456, (int) receivedRow.getRow("field1").getInt32("intField"));
+ }
+
+ @Test
+ public void generateKwargsEmpty() {
+ ExternalPythonTransform<?, ?> transform =
+ ExternalPythonTransform
+ .<PCollection<KV<String, String>>, PCollection<KV<String, Iterable<String>>>>from(
+ "DummyTransform");
+
+ Row receivedRow = transform.buildOrGetKwargsRow();
+ assertEquals(0, receivedRow.getFieldCount());
+ }
+
+ @Test
+ public void generateKwargsWithPrimitives() {
+ ExternalPythonTransform<?, ?> transform =
+ ExternalPythonTransform
+ .<PCollection<KV<String, String>>, PCollection<KV<String, Iterable<String>>>>from(
+ "DummyTransform")
+ .withKwarg("stringField1", "aaa")
+ .withKwarg("stringField2", "bbb")
+ .withKwarg("intField", 11)
+ .withKwarg("longField", 12L)
+ .withKwarg("doubleField", 15.6)
+ .withKwarg("boolField", true);
+
+ Row receivedRow = transform.buildOrGetKwargsRow();
+ assertEquals("aaa", receivedRow.getString("stringField1"));
+ assertEquals("bbb", receivedRow.getString("stringField2"));
+ assertEquals(11, (int) receivedRow.getInt32("intField"));
+ assertEquals(12L, (long) receivedRow.getInt64("longField"));
+ assertEquals(15.6, (double) receivedRow.getDouble("doubleField"), 0);
+ assertEquals(true, receivedRow.getBoolean("boolField"));
+ }
+
+ @Test
+ public void generateKwargsRow() {
+ Schema subRowSchema1 =
+ Schema.builder().addStringField("field0").addInt32Field("field1").build();
+ Row rowField1 = Row.withSchema(subRowSchema1).addValues("xxx", 123).build();
+ Schema subRowSchema2 =
+ Schema.builder()
+ .addDoubleField("field0")
+ .addBooleanField("field1")
+ .addStringField("field2")
+ .build();
+ Row rowField2 = Row.withSchema(subRowSchema2).addValues(12.5, true, "yyy").build();
+
+ ExternalPythonTransform<?, ?> transform =
+ ExternalPythonTransform
+ .<PCollection<KV<String, String>>, PCollection<KV<String, Iterable<String>>>>from(
+ "DummyTransform")
+ .withKwarg("customField0", rowField1)
+ .withKwarg("customField1", rowField2);
+
+ Schema expectedSchema =
+ Schema.builder()
+ .addRowField("customField0", subRowSchema1)
+ .addRowField("customField1", subRowSchema2)
+ .build();
+ Row expectedRow = Row.withSchema(expectedSchema).addValues(rowField1, rowField2).build();
+
+ Row receivedRow = transform.buildOrGetKwargsRow();
+ assertEquals(expectedRow, receivedRow);
+ }
+
+ @Test
+ public void generateKwargsWithCustomType() {
+ CustomType customType1 = new CustomType();
+ customType1.strField = "xxx";
+ customType1.intField = 123;
+
+ CustomType customType2 = new CustomType();
+ customType2.strField = "yyy";
+ customType2.intField = 456;
+
+ ExternalPythonTransform<?, ?> transform =
+ ExternalPythonTransform
+ .<PCollection<KV<String, String>>, PCollection<KV<String, Iterable<String>>>>from(
+ "DummyTransform")
+ .withKwarg("customField0", customType1)
+ .withKwarg("customField1", customType2);
+
+ Row receivedRow = transform.buildOrGetKwargsRow();
+
+ assertEquals("xxx", receivedRow.getRow("customField0").getString("strField"));
+ assertEquals(123, (int) receivedRow.getRow("customField0").getInt32("intField"));
+
+ assertEquals("yyy", receivedRow.getRow("customField1").getString("strField"));
+ assertEquals(456, (int) receivedRow.getRow("customField1").getInt32("intField"));
+ }
+
+ @Test
+ public void generateKwargsFromMap() {
+ Map<String, Object> kwargsMap =
+ ImmutableMap.of(
+ "stringField1",
+ "aaa",
+ "stringField2",
+ "bbb",
+ "intField",
+ Integer.valueOf(11),
+ "longField",
+ Long.valueOf(12L),
+ "doubleField",
+ Double.valueOf(15.6));
+
+ ExternalPythonTransform<?, ?> transform =
+ ExternalPythonTransform
+ .<PCollection<KV<String, String>>, PCollection<KV<String, Iterable<String>>>>from(
+ "DummyTransform")
+ .withKwargs(kwargsMap);
+
+ Row receivedRow = transform.buildOrGetKwargsRow();
+ assertEquals("aaa", receivedRow.getString("stringField1"));
+ assertEquals("bbb", receivedRow.getString("stringField2"));
+ assertEquals(11, (int) receivedRow.getInt32("intField"));
+ assertEquals(12L, (long) receivedRow.getInt64("longField"));
+ assertEquals(15.6, (double) receivedRow.getDouble("doubleField"), 0);
+ }
}