You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by he...@apache.org on 2022/07/26 19:28:58 UTC
[beam] branch master updated: Adds KV support for the Java RunInference transform.
This is an automated email from the ASF dual-hosted git repository.
heejong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 06d6d58102d Adds KV support for the Java RunInference transform.
new deb72620e02 Merge pull request #22442 from chamikaramj/runinference_kv_support
06d6d58102d is described below
commit 06d6d58102d88da4cd86492de8f2f9a7657b9d0e
Author: Chamikara Jayalath <ch...@gmail.com>
AuthorDate: Mon Jul 25 21:52:32 2022 -0700
Adds KV support for the Java RunInference transform.
---
.../python/transforms/DataframeTransform.java | 2 +-
.../extensions/python/transforms/PythonMap.java | 1 +
.../extensions/python/transforms/RunInference.java | 98 ++++++++++++++++++----
.../transforms/RunInferenceTransformTest.java | 54 ++++++++++++
4 files changed, 136 insertions(+), 19 deletions(-)
diff --git a/sdks/java/extensions/python/src/main/java/org/apache/beam/sdk/extensions/python/transforms/DataframeTransform.java b/sdks/java/extensions/python/src/main/java/org/apache/beam/sdk/extensions/python/transforms/DataframeTransform.java
index 1c956695e15..720adbd29a6 100644
--- a/sdks/java/extensions/python/src/main/java/org/apache/beam/sdk/extensions/python/transforms/DataframeTransform.java
+++ b/sdks/java/extensions/python/src/main/java/org/apache/beam/sdk/extensions/python/transforms/DataframeTransform.java
@@ -23,7 +23,7 @@ import org.apache.beam.sdk.util.PythonCallableSource;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.Row;
-/** Wrapper for invoking external Python DataframeTransform. */
+/** Wrapper for invoking external Python {@code DataframeTransform}. @Experimental */
public class DataframeTransform extends PTransform<PCollection<Row>, PCollection<Row>> {
private final String func;
private final boolean includeIndexes;
diff --git a/sdks/java/extensions/python/src/main/java/org/apache/beam/sdk/extensions/python/transforms/PythonMap.java b/sdks/java/extensions/python/src/main/java/org/apache/beam/sdk/extensions/python/transforms/PythonMap.java
index e1eb9cbad19..d2e5cb5642a 100644
--- a/sdks/java/extensions/python/src/main/java/org/apache/beam/sdk/extensions/python/transforms/PythonMap.java
+++ b/sdks/java/extensions/python/src/main/java/org/apache/beam/sdk/extensions/python/transforms/PythonMap.java
@@ -24,6 +24,7 @@ import org.apache.beam.sdk.util.PythonCallableSource;
import org.apache.beam.sdk.values.PCollection;
import org.checkerframework.checker.nullness.qual.Nullable;
+/** Wrapper for invoking external Python {@code Map} transforms.. @Experimental */
public class PythonMap<InputT, OutputT>
extends PTransform<PCollection<? extends InputT>, PCollection<OutputT>> {
diff --git a/sdks/java/extensions/python/src/main/java/org/apache/beam/sdk/extensions/python/transforms/RunInference.java b/sdks/java/extensions/python/src/main/java/org/apache/beam/sdk/extensions/python/transforms/RunInference.java
index 209c7061d23..ec4191c16e0 100644
--- a/sdks/java/extensions/python/src/main/java/org/apache/beam/sdk/extensions/python/transforms/RunInference.java
+++ b/sdks/java/extensions/python/src/main/java/org/apache/beam/sdk/extensions/python/transforms/RunInference.java
@@ -18,21 +18,28 @@
package org.apache.beam.sdk.extensions.python.transforms;
import java.util.Map;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.coders.RowCoder;
import org.apache.beam.sdk.extensions.python.PythonExternalTransform;
import org.apache.beam.sdk.schemas.Schema;
+import org.apache.beam.sdk.schemas.Schema.FieldType;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.util.PythonCallableSource;
+import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
+import org.checkerframework.checker.nullness.qual.Nullable;
+
+/** Wrapper for invoking external Python {@code RunInference}. @Experimental */
+public class RunInference<OutputT> extends PTransform<PCollection<?>, PCollection<OutputT>> {
-/** Wrapper for invoking external Python RunInference. */
-public class RunInference extends PTransform<PCollection<?>, PCollection<Row>> {
private final String modelLoader;
private final Schema schema;
private final Map<String, Object> kwargs;
private final String expansionService;
+ private final @Nullable Coder<?> keyCoder;
/**
* Instantiates a multi-language wrapper for a Python RunInference with a given model loader.
@@ -42,12 +49,39 @@ public class RunInference extends PTransform<PCollection<?>, PCollection<Row>> {
* @param inferenceType A schema field type for the inference column in output rows.
* @return A {@link RunInference} for the given model loader.
*/
- public static RunInference of(
+ public static RunInference<Row> of(
String modelLoader, Schema.FieldType exampleType, Schema.FieldType inferenceType) {
Schema schema =
Schema.of(
Schema.Field.of("example", exampleType), Schema.Field.of("inference", inferenceType));
- return new RunInference(modelLoader, schema, ImmutableMap.of(), "");
+ return new RunInference<>(modelLoader, schema, ImmutableMap.of(), null, "");
+ }
+
+ /**
+ * Similar to {@link RunInference#of(String, FieldType, FieldType)} but the input is a {@link
+ * PCollection} of {@link KV}s.
+ *
+ * <p>Also outputs a {@link PCollection} of {@link KV}s of the same key type.
+ *
+ * <p>For example, use this if you are using Python {@code KeyedModelHandler} as the model
+ * handler.
+ *
+ * @param modelLoader A Python callable for a model loader class object.
+ * @param exampleType A schema field type for the example column in output rows.
+ * @param inferenceType A schema field type for the inference column in output rows.
+ * @param keyCoder a {@link Coder} for the input and output Key type.
+ * @param <KeyT> input and output Key type. Inferred by the provided coder.
+ * @return A {@link RunInference} for the given model loader.
+ */
+ public static <KeyT> RunInference<KV<KeyT, Row>> ofKVs(
+ String modelLoader,
+ Schema.FieldType exampleType,
+ Schema.FieldType inferenceType,
+ Coder<KeyT> keyCoder) {
+ Schema schema =
+ Schema.of(
+ Schema.Field.of("example", exampleType), Schema.Field.of("inference", inferenceType));
+ return new RunInference<>(modelLoader, schema, ImmutableMap.of(), keyCoder, "");
}
/**
@@ -57,8 +91,23 @@ public class RunInference extends PTransform<PCollection<?>, PCollection<Row>> {
* @param schema A schema for output rows.
* @return A {@link RunInference} for the given model loader.
*/
- public static RunInference of(String modelLoader, Schema schema) {
- return new RunInference(modelLoader, schema, ImmutableMap.of(), "");
+ public static RunInference<Row> of(String modelLoader, Schema schema) {
+ return new RunInference<>(modelLoader, schema, ImmutableMap.of(), null, "");
+ }
+
+ /**
+ * Similar to {@link RunInference#of(String, Schema)} but the input is a {@link PCollection} of
+ * {@link KV}s.
+ *
+ * @param modelLoader A Python callable for a model loader class object.
+ * @param schema A schema for output rows.
+ * @param keyCoder a {@link Coder} for the input and output Key type.
+ * @param <KeyT> input and output Key type. Inferred by the provided coder.
+ * @return A {@link RunInference} for the given model loader.
+ */
+ public static <KeyT> RunInference<KV<KeyT, Row>> ofKVs(
+ String modelLoader, Schema schema, Coder<KeyT> keyCoder) {
+ return new RunInference<>(modelLoader, schema, ImmutableMap.of(), keyCoder, "");
}
/**
@@ -66,10 +115,10 @@ public class RunInference extends PTransform<PCollection<?>, PCollection<Row>> {
*
* @return A {@link RunInference} with keyword arguments.
*/
- public RunInference withKwarg(String key, Object arg) {
+ public RunInference<OutputT> withKwarg(String key, Object arg) {
ImmutableMap.Builder<String, Object> builder =
ImmutableMap.<String, Object>builder().putAll(kwargs).put(key, arg);
- return new RunInference(modelLoader, schema, builder.build(), expansionService);
+ return new RunInference<>(modelLoader, schema, builder.build(), keyCoder, expansionService);
}
/**
@@ -78,25 +127,38 @@ public class RunInference extends PTransform<PCollection<?>, PCollection<Row>> {
* @param expansionService A URL for a Python expansion service.
* @return A {@link RunInference} for the given expansion service endpoint.
*/
- public RunInference withExpansionService(String expansionService) {
- return new RunInference(modelLoader, schema, kwargs, expansionService);
+ public RunInference<OutputT> withExpansionService(String expansionService) {
+ return new RunInference<>(modelLoader, schema, kwargs, keyCoder, expansionService);
}
private RunInference(
- String modelLoader, Schema schema, Map<String, Object> kwargs, String expansionService) {
+ String modelLoader,
+ Schema schema,
+ Map<String, Object> kwargs,
+ @Nullable Coder<?> keyCoder,
+ String expansionService) {
this.modelLoader = modelLoader;
this.schema = schema;
this.kwargs = kwargs;
+ this.keyCoder = keyCoder;
this.expansionService = expansionService;
}
@Override
- public PCollection<Row> expand(PCollection<?> input) {
- return input.apply(
- PythonExternalTransform.<PCollection<?>, PCollection<Row>>from(
- "apache_beam.ml.inference.base.RunInference.from_callable", expansionService)
- .withKwarg("model_handler_provider", PythonCallableSource.of(modelLoader))
- .withKwargs(kwargs)
- .withOutputCoder(RowCoder.of(schema)));
+ public PCollection<OutputT> expand(PCollection<?> input) {
+ Coder<OutputT> outputCoder;
+ if (this.keyCoder == null) {
+ outputCoder = (Coder<OutputT>) RowCoder.of(schema);
+ } else {
+ outputCoder = (Coder<OutputT>) KvCoder.of(keyCoder, RowCoder.of(schema));
+ }
+
+ return (PCollection<OutputT>)
+ input.apply(
+ PythonExternalTransform.<PCollection<?>, PCollection<Row>>from(
+ "apache_beam.ml.inference.base.RunInference.from_callable", expansionService)
+ .withKwarg("model_handler_provider", PythonCallableSource.of(modelLoader))
+ .withOutputCoder(outputCoder)
+ .withKwargs(kwargs));
}
}
diff --git a/sdks/java/extensions/python/src/test/java/org/apache/beam/sdk/extensions/python/transforms/RunInferenceTransformTest.java b/sdks/java/extensions/python/src/test/java/org/apache/beam/sdk/extensions/python/transforms/RunInferenceTransformTest.java
index 2e875de347e..5f7b80f621e 100644
--- a/sdks/java/extensions/python/src/test/java/org/apache/beam/sdk/extensions/python/transforms/RunInferenceTransformTest.java
+++ b/sdks/java/extensions/python/src/test/java/org/apache/beam/sdk/extensions/python/transforms/RunInferenceTransformTest.java
@@ -18,15 +18,21 @@
package org.apache.beam.sdk.extensions.python.transforms;
import java.util.Arrays;
+import java.util.List;
import java.util.Optional;
import org.apache.beam.runners.core.construction.BaseExternalTest;
import org.apache.beam.sdk.coders.IterableCoder;
+import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.coders.VarLongCoder;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.testing.UsesPythonExpansionService;
import org.apache.beam.sdk.testing.ValidatesRunner;
import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.transforms.MapElements;
+import org.apache.beam.sdk.transforms.SimpleFunction;
+import org.apache.beam.sdk.transforms.Values;
+import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.Row;
import org.junit.Test;
@@ -65,4 +71,52 @@ public class RunInferenceTransformTest extends BaseExternalTest {
.withExpansionService(expansionAddr));
PAssert.that(col).containsInAnyOrder(row0, row1);
}
+
+ private String getModelLoaderScriptWithKVs() {
+ String s = "from apache_beam.ml.inference.sklearn_inference import SklearnModelHandlerNumpy\n";
+ s = s + "from apache_beam.ml.inference.base import KeyedModelHandler\n";
+ s = s + "def get_model_handler(model_uri):\n";
+ s = s + " return KeyedModelHandler(SklearnModelHandlerNumpy(model_uri))\n";
+
+ return s;
+ }
+
+ static class KVFn extends SimpleFunction<Iterable<Long>, KV<Long, Iterable<Long>>> {
+ @Override
+ public KV<Long, Iterable<Long>> apply(Iterable<Long> input) {
+ Long key = (Long) ((List) input).get(0);
+ return KV.of(key, input);
+ }
+ }
+
+ @Test
+ @Category({ValidatesRunner.class, UsesPythonExpansionService.class})
+ public void testRunInferenceWithKVs() {
+ String stagingLocation =
+ Optional.ofNullable(System.getProperty("semiPersistDir")).orElse("/tmp");
+ Schema schema =
+ Schema.of(
+ Schema.Field.of("example", Schema.FieldType.array(Schema.FieldType.INT64)),
+ Schema.Field.of("inference", Schema.FieldType.INT32));
+ Row row0 = Row.withSchema(schema).addArray(0L, 0L).addValue(0).build();
+ Row row1 = Row.withSchema(schema).addArray(1L, 1L).addValue(1).build();
+ PCollection<Row> col =
+ testPipeline
+ .apply(Create.<Iterable<Long>>of(Arrays.asList(0L, 0L), Arrays.asList(1L, 1L)))
+ .apply(MapElements.via(new KVFn()))
+ .setCoder(KvCoder.of(VarLongCoder.of(), IterableCoder.of(VarLongCoder.of())))
+ .apply(
+ RunInference.ofKVs(getModelLoaderScriptWithKVs(), schema, VarLongCoder.of())
+ .withKwarg(
+ // The test expansion service creates the test model and saves it to the
+ // returning external environment as a dependency.
+ // (sdks/python/apache_beam/runners/portability/expansion_service_test.py)
+ // The dependencies for Python SDK harness are supposed to be staged to
+ // $SEMI_PERSIST_DIR/staged directory.
+ "model_uri", String.format("%s/staged/sklearn_model", stagingLocation))
+ .withExpansionService(expansionAddr))
+ .apply(Values.<Row>create());
+
+ PAssert.that(col).containsInAnyOrder(row0, row1);
+ }
}