You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by he...@apache.org on 2022/04/11 22:23:51 UTC

[beam] branch master updated: [BEAM-14233] Merge requirements from expanded response for Java External transform

This is an automated email from the ASF dual-hosted git repository.

heejong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 61af8f42a9d [BEAM-14233] Merge requirements from expanded response for Java External transform
     new 2452e2eccc3 Merge pull request #17248 from ihji/BEAM-14233
61af8f42a9d is described below

commit 61af8f42a9d4d38585d40873939373e40d1b35c0
Author: Heejong Lee <he...@gmail.com>
AuthorDate: Fri Apr 1 10:46:51 2022 -0700

    [BEAM-14233] Merge requirements from expanded response for Java External transform
---
 .../beam/runners/core/construction/External.java   |  38 ++++++--
 .../core/construction/ExternalTranslation.java     |   6 ++
 .../core/construction/ExternalTranslationTest.java | 105 +++++++++++++++++++++
 3 files changed, 143 insertions(+), 6 deletions(-)

diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/External.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/External.java
index 8d264b129f5..ec9f5fcb8c9 100644
--- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/External.java
+++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/External.java
@@ -54,6 +54,7 @@ import org.apache.beam.sdk.values.TupleTag;
 import org.apache.beam.vendor.grpc.v1p43p2.com.google.protobuf.ByteString;
 import org.apache.beam.vendor.grpc.v1p43p2.io.grpc.ManagedChannel;
 import org.apache.beam.vendor.grpc.v1p43p2.io.grpc.ManagedChannelBuilder;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Strings;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
@@ -91,7 +92,17 @@ public class External {
           String urn, byte[] payload, String endpoint) {
     Endpoints.ApiServiceDescriptor apiDesc =
         Endpoints.ApiServiceDescriptor.newBuilder().setUrl(endpoint).build();
-    return new SingleOutputExpandableTransform<>(urn, payload, apiDesc, getFreshNamespaceIndex());
+    return new SingleOutputExpandableTransform<>(
+        urn, payload, apiDesc, DEFAULT, getFreshNamespaceIndex());
+  }
+
+  @VisibleForTesting
+  static <InputT extends PInput, OutputT> SingleOutputExpandableTransform<InputT, OutputT> of(
+      String urn, byte[] payload, String endpoint, ExpansionServiceClientFactory clientFactory) {
+    Endpoints.ApiServiceDescriptor apiDesc =
+        Endpoints.ApiServiceDescriptor.newBuilder().setUrl(endpoint).build();
+    return new SingleOutputExpandableTransform<>(
+        urn, payload, apiDesc, clientFactory, getFreshNamespaceIndex());
   }
 
   /** Expandable transform for output type of PCollection. */
@@ -101,8 +112,9 @@ public class External {
         String urn,
         byte[] payload,
         Endpoints.ApiServiceDescriptor endpoint,
+        ExpansionServiceClientFactory clientFactory,
         Integer namespaceIndex) {
-      super(urn, payload, endpoint, namespaceIndex);
+      super(urn, payload, endpoint, clientFactory, namespaceIndex);
     }
 
     @Override
@@ -113,12 +125,12 @@ public class External {
 
     public MultiOutputExpandableTransform<InputT> withMultiOutputs() {
       return new MultiOutputExpandableTransform<>(
-          getUrn(), getPayload(), getEndpoint(), getNamespaceIndex());
+          getUrn(), getPayload(), getEndpoint(), getClientFactory(), getNamespaceIndex());
     }
 
     public <T> SingleOutputExpandableTransform<InputT, T> withOutputType() {
       return new SingleOutputExpandableTransform<>(
-          getUrn(), getPayload(), getEndpoint(), getNamespaceIndex());
+          getUrn(), getPayload(), getEndpoint(), getClientFactory(), getNamespaceIndex());
     }
   }
 
@@ -129,8 +141,9 @@ public class External {
         String urn,
         byte[] payload,
         Endpoints.ApiServiceDescriptor endpoint,
+        ExpansionServiceClientFactory clientFactory,
         Integer namespaceIndex) {
-      super(urn, payload, endpoint, namespaceIndex);
+      super(urn, payload, endpoint, clientFactory, namespaceIndex);
     }
 
     @Override
@@ -151,10 +164,12 @@ public class External {
     private final String urn;
     private final byte[] payload;
     private final Endpoints.ApiServiceDescriptor endpoint;
+    private final ExpansionServiceClientFactory clientFactory;
     private final Integer namespaceIndex;
 
     private transient RunnerApi.@Nullable Components expandedComponents;
     private transient RunnerApi.@Nullable PTransform expandedTransform;
+    private transient @Nullable List<String> expandedRequirements;
     private transient @Nullable Map<PCollection, String> externalPCollectionIdMap;
     private transient @Nullable Map<Coder, String> externalCoderIdMap;
 
@@ -162,10 +177,12 @@ public class External {
         String urn,
         byte[] payload,
         Endpoints.ApiServiceDescriptor endpoint,
+        ExpansionServiceClientFactory clientFactory,
         Integer namespaceIndex) {
       this.urn = urn;
       this.payload = payload;
       this.endpoint = endpoint;
+      this.clientFactory = clientFactory;
       this.namespaceIndex = namespaceIndex;
     }
 
@@ -215,7 +232,7 @@ public class External {
               .build();
 
       ExpansionApi.ExpansionResponse response =
-          DEFAULT.getExpansionServiceClient(endpoint).expand(request);
+          clientFactory.getExpansionServiceClient(endpoint).expand(request);
 
       if (!Strings.isNullOrEmpty(response.getError())) {
         throw new RuntimeException(
@@ -224,6 +241,7 @@ public class External {
 
       expandedComponents = resolveArtifacts(response.getComponents());
       expandedTransform = response.getTransform();
+      expandedRequirements = response.getRequirementsList();
 
       RehydratedComponents rehydratedComponents =
           RehydratedComponents.forComponents(expandedComponents).withPipeline(p);
@@ -369,6 +387,10 @@ public class External {
       return expandedComponents;
     }
 
+    List<String> getExpandedRequirements() {
+      return expandedRequirements;
+    }
+
     Map<PCollection, String> getExternalPCollectionIdMap() {
       return externalPCollectionIdMap;
     }
@@ -389,6 +411,10 @@ public class External {
       return endpoint;
     }
 
+    ExpansionServiceClientFactory getClientFactory() {
+      return clientFactory;
+    }
+
     Integer getNamespaceIndex() {
       return namespaceIndex;
     }
diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ExternalTranslation.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ExternalTranslation.java
index 32258c1bde0..3c51515392c 100644
--- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ExternalTranslation.java
+++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ExternalTranslation.java
@@ -70,6 +70,12 @@ public class ExternalTranslation {
       String impulsePrefix = expandableTransform.getImpulsePrefix();
       RunnerApi.PTransform expandedTransform = expandableTransform.getExpandedTransform();
       RunnerApi.Components expandedComponents = expandableTransform.getExpandedComponents();
+      List<String> expandedRequirements = expandableTransform.getExpandedRequirements();
+
+      for (String requirement : expandedRequirements) {
+        components.addRequirement(requirement);
+      }
+
       Map<PCollection, String> externalPCollectionIdMap =
           expandableTransform.getExternalPCollectionIdMap();
       Map<Coder, String> externalCoderIdMap = expandableTransform.getExternalCoderIdMap();
diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/ExternalTranslationTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/ExternalTranslationTest.java
new file mode 100644
index 00000000000..24f7e1cdbe3
--- /dev/null
+++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/ExternalTranslationTest.java
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.core.construction;
+
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.hasItems;
+
+import org.apache.beam.model.expansion.v1.ExpansionApi;
+import org.apache.beam.model.pipeline.v1.Endpoints;
+import org.apache.beam.model.pipeline.v1.RunnerApi;
+import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/** Tests for {@link org.apache.beam.runners.core.construction.ExternalTranslation}. */
+@RunWith(JUnit4.class)
+public class ExternalTranslationTest {
+  @Test
+  public void testTranslation() {
+    Pipeline p = TestPipeline.create();
+    TestExpansionServiceClientFactory clientFactory = new TestExpansionServiceClientFactory();
+    p.apply(External.of("", new byte[] {}, "", clientFactory));
+    RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p);
+    assertThat(
+        pipelineProto.getRequirementsList(), equalTo(clientFactory.response.getRequirementsList()));
+    assertThat(
+        pipelineProto.getComponents().getPcollectionsMap().keySet(),
+        equalTo(clientFactory.response.getComponents().getPcollectionsMap().keySet()));
+    assertThat(
+        pipelineProto.getComponents().getTransformsMap().keySet(),
+        hasItems(
+            clientFactory
+                .response
+                .getComponents()
+                .getTransformsMap()
+                .keySet()
+                .toArray(new String[0])));
+  }
+
+  static class TestExpansionServiceClientFactory implements ExpansionServiceClientFactory {
+    ExpansionApi.ExpansionResponse response;
+
+    @Override
+    public ExpansionServiceClient getExpansionServiceClient(
+        Endpoints.ApiServiceDescriptor endpoint) {
+      return new ExpansionServiceClient() {
+        @Override
+        public ExpansionApi.ExpansionResponse expand(ExpansionApi.ExpansionRequest request) {
+          Pipeline p = TestPipeline.create();
+          p.apply(Create.of(1, 2, 3));
+          SdkComponents sdkComponents =
+              SdkComponents.create(p.getOptions()).withNewIdPrefix(request.getNamespace());
+          RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p, sdkComponents);
+          String transformId = Iterables.getOnlyElement(pipelineProto.getRootTransformIdsList());
+          RunnerApi.Components components = pipelineProto.getComponents();
+          ImmutableList.Builder<String> requirementsBuilder = ImmutableList.builder();
+          requirementsBuilder.addAll(pipelineProto.getRequirementsList());
+          requirementsBuilder.add("ExternalTranslationTest_Requirement_URN");
+          response =
+              ExpansionApi.ExpansionResponse.newBuilder()
+                  .setComponents(components)
+                  .setTransform(
+                      components
+                          .getTransformsOrThrow(transformId)
+                          .toBuilder()
+                          .setUniqueName(transformId))
+                  .addAllRequirements(requirementsBuilder.build())
+                  .build();
+          return response;
+        }
+
+        @Override
+        public void close() throws Exception {
+          // do nothing
+        }
+      };
+    }
+
+    @Override
+    public void close() throws Exception {
+      // do nothing
+    }
+  }
+}