You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by he...@apache.org on 2022/07/26 19:28:58 UTC

[beam] branch master updated: Adds KV support for the Java RunInference transform.

This is an automated email from the ASF dual-hosted git repository.

heejong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 06d6d58102d Adds KV support for the Java RunInference transform.
     new deb72620e02 Merge pull request #22442 from chamikaramj/runinference_kv_support
06d6d58102d is described below

commit 06d6d58102d88da4cd86492de8f2f9a7657b9d0e
Author: Chamikara Jayalath <ch...@gmail.com>
AuthorDate: Mon Jul 25 21:52:32 2022 -0700

    Adds KV support for the Java RunInference transform.
---
 .../python/transforms/DataframeTransform.java      |  2 +-
 .../extensions/python/transforms/PythonMap.java    |  1 +
 .../extensions/python/transforms/RunInference.java | 98 ++++++++++++++++++----
 .../transforms/RunInferenceTransformTest.java      | 54 ++++++++++++
 4 files changed, 136 insertions(+), 19 deletions(-)

diff --git a/sdks/java/extensions/python/src/main/java/org/apache/beam/sdk/extensions/python/transforms/DataframeTransform.java b/sdks/java/extensions/python/src/main/java/org/apache/beam/sdk/extensions/python/transforms/DataframeTransform.java
index 1c956695e15..720adbd29a6 100644
--- a/sdks/java/extensions/python/src/main/java/org/apache/beam/sdk/extensions/python/transforms/DataframeTransform.java
+++ b/sdks/java/extensions/python/src/main/java/org/apache/beam/sdk/extensions/python/transforms/DataframeTransform.java
@@ -23,7 +23,7 @@ import org.apache.beam.sdk.util.PythonCallableSource;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.Row;
 
-/** Wrapper for invoking external Python DataframeTransform. */
+/** Wrapper for invoking external Python {@code DataframeTransform}. @Experimental */
 public class DataframeTransform extends PTransform<PCollection<Row>, PCollection<Row>> {
   private final String func;
   private final boolean includeIndexes;
diff --git a/sdks/java/extensions/python/src/main/java/org/apache/beam/sdk/extensions/python/transforms/PythonMap.java b/sdks/java/extensions/python/src/main/java/org/apache/beam/sdk/extensions/python/transforms/PythonMap.java
index e1eb9cbad19..d2e5cb5642a 100644
--- a/sdks/java/extensions/python/src/main/java/org/apache/beam/sdk/extensions/python/transforms/PythonMap.java
+++ b/sdks/java/extensions/python/src/main/java/org/apache/beam/sdk/extensions/python/transforms/PythonMap.java
@@ -24,6 +24,7 @@ import org.apache.beam.sdk.util.PythonCallableSource;
 import org.apache.beam.sdk.values.PCollection;
 import org.checkerframework.checker.nullness.qual.Nullable;
 
+/** Wrapper for invoking external Python {@code Map} transforms.. @Experimental */
 public class PythonMap<InputT, OutputT>
     extends PTransform<PCollection<? extends InputT>, PCollection<OutputT>> {
 
diff --git a/sdks/java/extensions/python/src/main/java/org/apache/beam/sdk/extensions/python/transforms/RunInference.java b/sdks/java/extensions/python/src/main/java/org/apache/beam/sdk/extensions/python/transforms/RunInference.java
index 209c7061d23..ec4191c16e0 100644
--- a/sdks/java/extensions/python/src/main/java/org/apache/beam/sdk/extensions/python/transforms/RunInference.java
+++ b/sdks/java/extensions/python/src/main/java/org/apache/beam/sdk/extensions/python/transforms/RunInference.java
@@ -18,21 +18,28 @@
 package org.apache.beam.sdk.extensions.python.transforms;
 
 import java.util.Map;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.KvCoder;
 import org.apache.beam.sdk.coders.RowCoder;
 import org.apache.beam.sdk.extensions.python.PythonExternalTransform;
 import org.apache.beam.sdk.schemas.Schema;
+import org.apache.beam.sdk.schemas.Schema.FieldType;
 import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.util.PythonCallableSource;
+import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.Row;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
+import org.checkerframework.checker.nullness.qual.Nullable;
+
+/** Wrapper for invoking external Python {@code RunInference}. @Experimental */
+public class RunInference<OutputT> extends PTransform<PCollection<?>, PCollection<OutputT>> {
 
-/** Wrapper for invoking external Python RunInference. */
-public class RunInference extends PTransform<PCollection<?>, PCollection<Row>> {
   private final String modelLoader;
   private final Schema schema;
   private final Map<String, Object> kwargs;
   private final String expansionService;
+  private final @Nullable Coder<?> keyCoder;
 
   /**
    * Instantiates a multi-language wrapper for a Python RunInference with a given model loader.
@@ -42,12 +49,39 @@ public class RunInference extends PTransform<PCollection<?>, PCollection<Row>> {
    * @param inferenceType A schema field type for the inference column in output rows.
    * @return A {@link RunInference} for the given model loader.
    */
-  public static RunInference of(
+  public static RunInference<Row> of(
       String modelLoader, Schema.FieldType exampleType, Schema.FieldType inferenceType) {
     Schema schema =
         Schema.of(
             Schema.Field.of("example", exampleType), Schema.Field.of("inference", inferenceType));
-    return new RunInference(modelLoader, schema, ImmutableMap.of(), "");
+    return new RunInference<>(modelLoader, schema, ImmutableMap.of(), null, "");
+  }
+
+  /**
+   * Similar to {@link RunInference#of(String, FieldType, FieldType)} but the input is a {@link
+   * PCollection} of {@link KV}s.
+   *
+   * <p>Also outputs a {@link PCollection} of {@link KV}s of the same key type.
+   *
+   * <p>For example, use this if you are using Python {@code KeyedModelHandler} as the model
+   * handler.
+   *
+   * @param modelLoader A Python callable for a model loader class object.
+   * @param exampleType A schema field type for the example column in output rows.
+   * @param inferenceType A schema field type for the inference column in output rows.
+   * @param keyCoder a {@link Coder} for the input and output Key type.
+   * @param <KeyT> input and output Key type. Inferred by the provided coder.
+   * @return A {@link RunInference} for the given model loader.
+   */
+  public static <KeyT> RunInference<KV<KeyT, Row>> ofKVs(
+      String modelLoader,
+      Schema.FieldType exampleType,
+      Schema.FieldType inferenceType,
+      Coder<KeyT> keyCoder) {
+    Schema schema =
+        Schema.of(
+            Schema.Field.of("example", exampleType), Schema.Field.of("inference", inferenceType));
+    return new RunInference<>(modelLoader, schema, ImmutableMap.of(), keyCoder, "");
   }
 
   /**
@@ -57,8 +91,23 @@ public class RunInference extends PTransform<PCollection<?>, PCollection<Row>> {
    * @param schema A schema for output rows.
    * @return A {@link RunInference} for the given model loader.
    */
-  public static RunInference of(String modelLoader, Schema schema) {
-    return new RunInference(modelLoader, schema, ImmutableMap.of(), "");
+  public static RunInference<Row> of(String modelLoader, Schema schema) {
+    return new RunInference<>(modelLoader, schema, ImmutableMap.of(), null, "");
+  }
+
+  /**
+   * Similar to {@link RunInference#of(String, Schema)} but the input is a {@link PCollection} of
+   * {@link KV}s.
+   *
+   * @param modelLoader A Python callable for a model loader class object.
+   * @param schema A schema for output rows.
+   * @param keyCoder a {@link Coder} for the input and output Key type.
+   * @param <KeyT> input and output Key type. Inferred by the provided coder.
+   * @return A {@link RunInference} for the given model loader.
+   */
+  public static <KeyT> RunInference<KV<KeyT, Row>> ofKVs(
+      String modelLoader, Schema schema, Coder<KeyT> keyCoder) {
+    return new RunInference<>(modelLoader, schema, ImmutableMap.of(), keyCoder, "");
   }
 
   /**
@@ -66,10 +115,10 @@ public class RunInference extends PTransform<PCollection<?>, PCollection<Row>> {
    *
    * @return A {@link RunInference} with keyword arguments.
    */
-  public RunInference withKwarg(String key, Object arg) {
+  public RunInference<OutputT> withKwarg(String key, Object arg) {
     ImmutableMap.Builder<String, Object> builder =
         ImmutableMap.<String, Object>builder().putAll(kwargs).put(key, arg);
-    return new RunInference(modelLoader, schema, builder.build(), expansionService);
+    return new RunInference<>(modelLoader, schema, builder.build(), keyCoder, expansionService);
   }
 
   /**
@@ -78,25 +127,38 @@ public class RunInference extends PTransform<PCollection<?>, PCollection<Row>> {
    * @param expansionService A URL for a Python expansion service.
    * @return A {@link RunInference} for the given expansion service endpoint.
    */
-  public RunInference withExpansionService(String expansionService) {
-    return new RunInference(modelLoader, schema, kwargs, expansionService);
+  public RunInference<OutputT> withExpansionService(String expansionService) {
+    return new RunInference<>(modelLoader, schema, kwargs, keyCoder, expansionService);
   }
 
   private RunInference(
-      String modelLoader, Schema schema, Map<String, Object> kwargs, String expansionService) {
+      String modelLoader,
+      Schema schema,
+      Map<String, Object> kwargs,
+      @Nullable Coder<?> keyCoder,
+      String expansionService) {
     this.modelLoader = modelLoader;
     this.schema = schema;
     this.kwargs = kwargs;
+    this.keyCoder = keyCoder;
     this.expansionService = expansionService;
   }
 
   @Override
-  public PCollection<Row> expand(PCollection<?> input) {
-    return input.apply(
-        PythonExternalTransform.<PCollection<?>, PCollection<Row>>from(
-                "apache_beam.ml.inference.base.RunInference.from_callable", expansionService)
-            .withKwarg("model_handler_provider", PythonCallableSource.of(modelLoader))
-            .withKwargs(kwargs)
-            .withOutputCoder(RowCoder.of(schema)));
+  public PCollection<OutputT> expand(PCollection<?> input) {
+    Coder<OutputT> outputCoder;
+    if (this.keyCoder == null) {
+      outputCoder = (Coder<OutputT>) RowCoder.of(schema);
+    } else {
+      outputCoder = (Coder<OutputT>) KvCoder.of(keyCoder, RowCoder.of(schema));
+    }
+
+    return (PCollection<OutputT>)
+        input.apply(
+            PythonExternalTransform.<PCollection<?>, PCollection<Row>>from(
+                    "apache_beam.ml.inference.base.RunInference.from_callable", expansionService)
+                .withKwarg("model_handler_provider", PythonCallableSource.of(modelLoader))
+                .withOutputCoder(outputCoder)
+                .withKwargs(kwargs));
   }
 }
diff --git a/sdks/java/extensions/python/src/test/java/org/apache/beam/sdk/extensions/python/transforms/RunInferenceTransformTest.java b/sdks/java/extensions/python/src/test/java/org/apache/beam/sdk/extensions/python/transforms/RunInferenceTransformTest.java
index 2e875de347e..5f7b80f621e 100644
--- a/sdks/java/extensions/python/src/test/java/org/apache/beam/sdk/extensions/python/transforms/RunInferenceTransformTest.java
+++ b/sdks/java/extensions/python/src/test/java/org/apache/beam/sdk/extensions/python/transforms/RunInferenceTransformTest.java
@@ -18,15 +18,21 @@
 package org.apache.beam.sdk.extensions.python.transforms;
 
 import java.util.Arrays;
+import java.util.List;
 import java.util.Optional;
 import org.apache.beam.runners.core.construction.BaseExternalTest;
 import org.apache.beam.sdk.coders.IterableCoder;
+import org.apache.beam.sdk.coders.KvCoder;
 import org.apache.beam.sdk.coders.VarLongCoder;
 import org.apache.beam.sdk.schemas.Schema;
 import org.apache.beam.sdk.testing.PAssert;
 import org.apache.beam.sdk.testing.UsesPythonExpansionService;
 import org.apache.beam.sdk.testing.ValidatesRunner;
 import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.transforms.MapElements;
+import org.apache.beam.sdk.transforms.SimpleFunction;
+import org.apache.beam.sdk.transforms.Values;
+import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.Row;
 import org.junit.Test;
@@ -65,4 +71,52 @@ public class RunInferenceTransformTest extends BaseExternalTest {
                     .withExpansionService(expansionAddr));
     PAssert.that(col).containsInAnyOrder(row0, row1);
   }
+
+  private String getModelLoaderScriptWithKVs() {
+    String s = "from apache_beam.ml.inference.sklearn_inference import SklearnModelHandlerNumpy\n";
+    s = s + "from apache_beam.ml.inference.base import KeyedModelHandler\n";
+    s = s + "def get_model_handler(model_uri):\n";
+    s = s + "  return KeyedModelHandler(SklearnModelHandlerNumpy(model_uri))\n";
+
+    return s;
+  }
+
+  static class KVFn extends SimpleFunction<Iterable<Long>, KV<Long, Iterable<Long>>> {
+    @Override
+    public KV<Long, Iterable<Long>> apply(Iterable<Long> input) {
+      Long key = (Long) ((List) input).get(0);
+      return KV.of(key, input);
+    }
+  }
+
+  @Test
+  @Category({ValidatesRunner.class, UsesPythonExpansionService.class})
+  public void testRunInferenceWithKVs() {
+    String stagingLocation =
+        Optional.ofNullable(System.getProperty("semiPersistDir")).orElse("/tmp");
+    Schema schema =
+        Schema.of(
+            Schema.Field.of("example", Schema.FieldType.array(Schema.FieldType.INT64)),
+            Schema.Field.of("inference", Schema.FieldType.INT32));
+    Row row0 = Row.withSchema(schema).addArray(0L, 0L).addValue(0).build();
+    Row row1 = Row.withSchema(schema).addArray(1L, 1L).addValue(1).build();
+    PCollection<Row> col =
+        testPipeline
+            .apply(Create.<Iterable<Long>>of(Arrays.asList(0L, 0L), Arrays.asList(1L, 1L)))
+            .apply(MapElements.via(new KVFn()))
+            .setCoder(KvCoder.of(VarLongCoder.of(), IterableCoder.of(VarLongCoder.of())))
+            .apply(
+                RunInference.ofKVs(getModelLoaderScriptWithKVs(), schema, VarLongCoder.of())
+                    .withKwarg(
+                        // The test expansion service creates the test model and saves it to the
+                        // returning external environment as a dependency.
+                        // (sdks/python/apache_beam/runners/portability/expansion_service_test.py)
+                        // The dependencies for Python SDK harness are supposed to be staged to
+                        // $SEMI_PERSIST_DIR/staged directory.
+                        "model_uri", String.format("%s/staged/sklearn_model", stagingLocation))
+                    .withExpansionService(expansionAddr))
+            .apply(Values.<Row>create());
+
+    PAssert.that(col).containsInAnyOrder(row0, row1);
+  }
 }