You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by ce...@apache.org on 2023/02/28 16:23:28 UTC

[kafka] branch trunk updated: KAFKA-14671: Refactor PredicatedTransformation to not implement Transformation (#13184)

This is an automated email from the ASF dual-hosted git repository.

cegerton pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new f586fa59d3f KAFKA-14671: Refactor PredicatedTransformation to not implement Transformation (#13184)
f586fa59d3f is described below

commit f586fa59d3f938e04bda4e8143ddb1c4310eaf78
Author: Greg Harris <gr...@aiven.io>
AuthorDate: Tue Feb 28 08:23:19 2023 -0800

    KAFKA-14671: Refactor PredicatedTransformation to not implement Transformation (#13184)
    
    Reviewers: Christo Lolov <ch...@gmail.com>, Yash Mayya <ya...@gmail.com>, Chris Egerton <ch...@aiven.io>
---
 .../kafka/connect/runtime/ConnectorConfig.java     | 28 +++++-----
 .../kafka/connect/runtime/TransformationChain.java | 27 +++++-----
 ...ransformation.java => TransformationStage.java} | 49 ++++++++----------
 .../org/apache/kafka/connect/runtime/Worker.java   |  6 +--
 .../rest/resources/ConnectorPluginsResource.java   |  9 +---
 .../kafka/connect/runtime/ConnectorConfigTest.java | 60 +++++++++++-----------
 .../connect/runtime/ErrorHandlingTaskTest.java     |  4 +-
 ...ationTest.java => TransformationStageTest.java} | 29 +++++++----
 .../resources/ConnectorPluginsResourceTest.java    |  3 +-
 .../kafka/connect/util/TopicCreationTest.java      | 16 +++---
 10 files changed, 113 insertions(+), 118 deletions(-)

diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/ConnectorConfig.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/ConnectorConfig.java
index 485dda9b98b..40b7c0a1462 100644
--- a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/ConnectorConfig.java
+++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/ConnectorConfig.java
@@ -268,12 +268,14 @@ public class ConnectorConfig extends AbstractConfig {
     }
 
     /**
-     * Returns the initialized list of {@link Transformation} which are specified in {@link #TRANSFORMS_CONFIG}.
+     * Returns the initialized list of {@link TransformationStage} which apply the
+     * {@link Transformation transformations} and {@link Predicate predicates}
+     * as they are specified in the {@link #TRANSFORMS_CONFIG} and {@link #PREDICATES_CONFIG}
      */
-    public <R extends ConnectRecord<R>> List<Transformation<R>> transformations() {
+    public <R extends ConnectRecord<R>> List<TransformationStage<R>> transformationStages() {
         final List<String> transformAliases = getList(TRANSFORMS_CONFIG);
 
-        final List<Transformation<R>> transformations = new ArrayList<>(transformAliases.size());
+        final List<TransformationStage<R>> transformations = new ArrayList<>(transformAliases.size());
         for (String alias : transformAliases) {
             final String prefix = TRANSFORMS_CONFIG + "." + alias + ".";
 
@@ -281,17 +283,17 @@ public class ConnectorConfig extends AbstractConfig {
                 @SuppressWarnings("unchecked")
                 final Transformation<R> transformation = Utils.newInstance(getClass(prefix + "type"), Transformation.class);
                 Map<String, Object> configs = originalsWithPrefix(prefix);
-                Object predicateAlias = configs.remove(PredicatedTransformation.PREDICATE_CONFIG);
-                Object negate = configs.remove(PredicatedTransformation.NEGATE_CONFIG);
+                Object predicateAlias = configs.remove(TransformationStage.PREDICATE_CONFIG);
+                Object negate = configs.remove(TransformationStage.NEGATE_CONFIG);
                 transformation.configure(configs);
                 if (predicateAlias != null) {
                     String predicatePrefix = PREDICATES_PREFIX + predicateAlias + ".";
                     @SuppressWarnings("unchecked")
                     Predicate<R> predicate = Utils.newInstance(getClass(predicatePrefix + "type"), Predicate.class);
                     predicate.configure(originalsWithPrefix(predicatePrefix));
-                    transformations.add(new PredicatedTransformation<>(predicate, negate == null ? false : Boolean.parseBoolean(negate.toString()), transformation));
+                    transformations.add(new TransformationStage<>(predicate, negate == null ? false : Boolean.parseBoolean(negate.toString()), transformation));
                 } else {
-                    transformations.add(transformation);
+                    transformations.add(new TransformationStage<>(transformation));
                 }
             } catch (Exception e) {
                 throw new ConnectException(e);
@@ -321,9 +323,9 @@ public class ConnectorConfig extends AbstractConfig {
             protected ConfigDef initialConfigDef() {
                 // All Transformations get these config parameters implicitly
                 return super.initialConfigDef()
-                        .define(PredicatedTransformation.PREDICATE_CONFIG, Type.STRING, "", Importance.MEDIUM,
+                        .define(TransformationStage.PREDICATE_CONFIG, Type.STRING, null, Importance.MEDIUM,
                                 "The alias of a predicate used to determine whether to apply this transformation.")
-                        .define(PredicatedTransformation.NEGATE_CONFIG, Type.BOOLEAN, false, Importance.MEDIUM,
+                        .define(TransformationStage.NEGATE_CONFIG, Type.BOOLEAN, false, Importance.MEDIUM,
                                 "Whether the configured predicate should be negated.");
             }
 
@@ -332,8 +334,8 @@ public class ConnectorConfig extends AbstractConfig {
                 return super.configDefsForClass(typeConfig)
                     .filter(entry -> {
                         // The implicit parameters mask any from the transformer with the same name
-                        if (PredicatedTransformation.PREDICATE_CONFIG.equals(entry.getKey())
-                                || PredicatedTransformation.NEGATE_CONFIG.equals(entry.getKey())) {
+                        if (TransformationStage.PREDICATE_CONFIG.equals(entry.getKey())
+                                || TransformationStage.NEGATE_CONFIG.equals(entry.getKey())) {
                             log.warn("Transformer config {} is masked by implicit config of that name",
                                     entry.getKey());
                             return false;
@@ -350,8 +352,8 @@ public class ConnectorConfig extends AbstractConfig {
 
             @Override
             protected void validateProps(String prefix) {
-                String prefixedNegate = prefix + PredicatedTransformation.NEGATE_CONFIG;
-                String prefixedPredicate = prefix + PredicatedTransformation.PREDICATE_CONFIG;
+                String prefixedNegate = prefix + TransformationStage.NEGATE_CONFIG;
+                String prefixedPredicate = prefix + TransformationStage.PREDICATE_CONFIG;
                 if (props.containsKey(prefixedNegate) &&
                         !props.containsKey(prefixedPredicate)) {
                     throw new ConfigException("Config '" + prefixedNegate + "' was provided " +
diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/TransformationChain.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/TransformationChain.java
index 984c1422572..b130a226f5f 100644
--- a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/TransformationChain.java
+++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/TransformationChain.java
@@ -19,7 +19,6 @@ package org.apache.kafka.connect.runtime;
 import org.apache.kafka.connect.connector.ConnectRecord;
 import org.apache.kafka.connect.runtime.errors.RetryWithToleranceOperator;
 import org.apache.kafka.connect.runtime.errors.Stage;
-import org.apache.kafka.connect.transforms.Transformation;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -34,24 +33,24 @@ import java.util.StringJoiner;
 public class TransformationChain<R extends ConnectRecord<R>> implements AutoCloseable {
     private static final Logger log = LoggerFactory.getLogger(TransformationChain.class);
 
-    private final List<Transformation<R>> transformations;
+    private final List<TransformationStage<R>> transformationStages;
     private final RetryWithToleranceOperator retryWithToleranceOperator;
 
-    public TransformationChain(List<Transformation<R>> transformations, RetryWithToleranceOperator retryWithToleranceOperator) {
-        this.transformations = transformations;
+    public TransformationChain(List<TransformationStage<R>> transformationStages, RetryWithToleranceOperator retryWithToleranceOperator) {
+        this.transformationStages = transformationStages;
         this.retryWithToleranceOperator = retryWithToleranceOperator;
     }
 
     public R apply(R record) {
-        if (transformations.isEmpty()) return record;
+        if (transformationStages.isEmpty()) return record;
 
-        for (final Transformation<R> transformation : transformations) {
+        for (final TransformationStage<R> transformationStage : transformationStages) {
             final R current = record;
 
             log.trace("Applying transformation {} to {}",
-                transformation.getClass().getName(), record);
+                transformationStage.transformClass().getName(), record);
             // execute the operation
-            record = retryWithToleranceOperator.execute(() -> transformation.apply(current), Stage.TRANSFORMATION, transformation.getClass());
+            record = retryWithToleranceOperator.execute(() -> transformationStage.apply(current), Stage.TRANSFORMATION, transformationStage.transformClass());
 
             if (record == null) break;
         }
@@ -61,8 +60,8 @@ public class TransformationChain<R extends ConnectRecord<R>> implements AutoClos
 
     @Override
     public void close() {
-        for (Transformation<R> transformation : transformations) {
-            transformation.close();
+        for (TransformationStage<R> transformationStage : transformationStages) {
+            transformationStage.close();
         }
     }
 
@@ -71,18 +70,18 @@ public class TransformationChain<R extends ConnectRecord<R>> implements AutoClos
         if (this == o) return true;
         if (o == null || getClass() != o.getClass()) return false;
         TransformationChain<?> that = (TransformationChain<?>) o;
-        return Objects.equals(transformations, that.transformations);
+        return Objects.equals(transformationStages, that.transformationStages);
     }
 
     @Override
     public int hashCode() {
-        return Objects.hash(transformations);
+        return Objects.hash(transformationStages);
     }
 
     public String toString() {
         StringJoiner chain = new StringJoiner(", ", getClass().getName() + "{", "}");
-        for (Transformation<R> transformation : transformations) {
-            chain.add(transformation.getClass().getName());
+        for (TransformationStage<R> transformationStage : transformationStages) {
+            chain.add(transformationStage.transformClass().getName());
         }
         return chain.toString();
     }
diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/PredicatedTransformation.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/TransformationStage.java
similarity index 56%
rename from connect/runtime/src/main/java/org/apache/kafka/connect/runtime/PredicatedTransformation.java
rename to connect/runtime/src/main/java/org/apache/kafka/connect/runtime/TransformationStage.java
index 446db5b2f32..3831730ad8f 100644
--- a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/PredicatedTransformation.java
+++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/TransformationStage.java
@@ -16,65 +16,60 @@
  */
 package org.apache.kafka.connect.runtime;
 
-import java.util.Map;
 
-import org.apache.kafka.common.config.ConfigDef;
 import org.apache.kafka.common.utils.Utils;
 import org.apache.kafka.connect.connector.ConnectRecord;
-import org.apache.kafka.connect.errors.ConnectException;
 import org.apache.kafka.connect.transforms.Transformation;
 import org.apache.kafka.connect.transforms.predicates.Predicate;
 
 /**
- * Decorator for a {@link Transformation} which applies the delegate only when a
- * {@link Predicate} is true (or false, according to {@code negate}).
+ * Wrapper for a {@link Transformation} and corresponding optional {@link Predicate}
+ * which applies the transformation when the {@link Predicate} is true (or false, according to {@code negate}).
+ * If no {@link Predicate} is provided, the transformation will be unconditionally applied.
  * @param <R> The type of record (must be an implementation of {@link ConnectRecord})
  */
-public class PredicatedTransformation<R extends ConnectRecord<R>> implements Transformation<R> {
+public class TransformationStage<R extends ConnectRecord<R>> implements AutoCloseable {
 
     static final String PREDICATE_CONFIG = "predicate";
     static final String NEGATE_CONFIG = "negate";
-    final Predicate<R> predicate;
-    final Transformation<R> delegate;
-    final boolean negate;
+    private final Predicate<R> predicate;
+    private final Transformation<R> transformation;
+    private final boolean negate;
 
-    PredicatedTransformation(Predicate<R> predicate, boolean negate, Transformation<R> delegate) {
+    TransformationStage(Transformation<R> transformation) {
+        this(null, false, transformation);
+    }
+
+    TransformationStage(Predicate<R> predicate, boolean negate, Transformation<R> transformation) {
         this.predicate = predicate;
         this.negate = negate;
-        this.delegate = delegate;
+        this.transformation = transformation;
     }
 
-    @Override
-    public void configure(Map<String, ?> configs) {
-        throw new ConnectException(PredicatedTransformation.class.getName() + ".configure() " +
-                "should never be called directly.");
+    public Class<? extends Transformation<R>> transformClass() {
+        @SuppressWarnings("unchecked")
+        Class<? extends Transformation<R>> transformClass = (Class<? extends Transformation<R>>) transformation.getClass();
+        return transformClass;
     }
 
-    @Override
     public R apply(R record) {
-        if (negate ^ predicate.test(record)) {
-            return delegate.apply(record);
+        if (predicate == null || negate ^ predicate.test(record)) {
+            return transformation.apply(record);
         }
         return record;
     }
 
-    @Override
-    public ConfigDef config() {
-        throw new ConnectException(PredicatedTransformation.class.getName() + ".config() " +
-                "should never be called directly.");
-    }
-
     @Override
     public void close() {
-        Utils.closeQuietly(delegate, "predicated transformation");
+        Utils.closeQuietly(transformation, "transformation");
         Utils.closeQuietly(predicate, "predicate");
     }
 
     @Override
     public String toString() {
-        return "PredicatedTransformation{" +
+        return "TransformationStage{" +
                 "predicate=" + predicate +
-                ", delegate=" + delegate +
+                ", transformation=" + transformation +
                 ", negate=" + negate +
                 '}';
     }
diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/Worker.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/Worker.java
index f966ea4cb55..8f9f727e25e 100644
--- a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/Worker.java
+++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/Worker.java
@@ -1253,7 +1253,7 @@ public class Worker {
                            Class<? extends Connector> connectorClass,
                            RetryWithToleranceOperator retryWithToleranceOperator) {
 
-            TransformationChain<SinkRecord> transformationChain = new TransformationChain<>(connectorConfig.<SinkRecord>transformations(), retryWithToleranceOperator);
+            TransformationChain<SinkRecord> transformationChain = new TransformationChain<>(connectorConfig.<SinkRecord>transformationStages(), retryWithToleranceOperator);
             log.info("Initializing: {}", transformationChain);
             SinkConnectorConfig sinkConfig = new SinkConnectorConfig(plugins, connectorConfig.originalsStrings());
             retryWithToleranceOperator.reporters(sinkTaskReporters(id, sinkConfig, errorHandlingMetrics, connectorClass));
@@ -1297,7 +1297,7 @@ public class Worker {
             SourceConnectorConfig sourceConfig = new SourceConnectorConfig(plugins,
                     connectorConfig.originalsStrings(), config.topicCreationEnable());
             retryWithToleranceOperator.reporters(sourceTaskReporters(id, sourceConfig, errorHandlingMetrics));
-            TransformationChain<SourceRecord> transformationChain = new TransformationChain<>(sourceConfig.<SourceRecord>transformations(), retryWithToleranceOperator);
+            TransformationChain<SourceRecord> transformationChain = new TransformationChain<>(sourceConfig.<SourceRecord>transformationStages(), retryWithToleranceOperator);
             log.info("Initializing: {}", transformationChain);
 
             Map<String, Object> producerProps = baseProducerConfigs(id.connector(), "connector-producer-" + id, config, sourceConfig, connectorClass,
@@ -1365,7 +1365,7 @@ public class Worker {
             SourceConnectorConfig sourceConfig = new SourceConnectorConfig(plugins,
                     connectorConfig.originalsStrings(), config.topicCreationEnable());
             retryWithToleranceOperator.reporters(sourceTaskReporters(id, sourceConfig, errorHandlingMetrics));
-            TransformationChain<SourceRecord> transformationChain = new TransformationChain<>(sourceConfig.<SourceRecord>transformations(), retryWithToleranceOperator);
+            TransformationChain<SourceRecord> transformationChain = new TransformationChain<>(sourceConfig.<SourceRecord>transformationStages(), retryWithToleranceOperator);
             log.info("Initializing: {}", transformationChain);
 
             Map<String, Object> producerProps = exactlyOnceSourceTaskProducerConfigs(
diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/resources/ConnectorPluginsResource.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/resources/ConnectorPluginsResource.java
index 05b8375183c..ad8ff00cb96 100644
--- a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/resources/ConnectorPluginsResource.java
+++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/resources/ConnectorPluginsResource.java
@@ -20,7 +20,6 @@ import io.swagger.v3.oas.annotations.Operation;
 import io.swagger.v3.oas.annotations.Parameter;
 import org.apache.kafka.connect.runtime.ConnectorConfig;
 import org.apache.kafka.connect.runtime.Herder;
-import org.apache.kafka.connect.runtime.PredicatedTransformation;
 import org.apache.kafka.connect.runtime.isolation.PluginDesc;
 import org.apache.kafka.connect.runtime.isolation.PluginType;
 import org.apache.kafka.connect.runtime.rest.entities.ConfigInfos;
@@ -34,7 +33,6 @@ import org.apache.kafka.connect.tools.MockSourceConnector;
 import org.apache.kafka.connect.tools.SchemaSourceConnector;
 import org.apache.kafka.connect.tools.VerifiableSinkConnector;
 import org.apache.kafka.connect.tools.VerifiableSourceConnector;
-import org.apache.kafka.connect.transforms.Transformation;
 import org.apache.kafka.connect.util.FutureCallback;
 
 import javax.ws.rs.BadRequestException;
@@ -79,11 +77,6 @@ public class ConnectorPluginsResource implements ConnectResource {
             SchemaSourceConnector.class
     );
 
-    @SuppressWarnings({"unchecked", "rawtypes"})
-    static final List<Class<? extends Transformation<?>>> TRANSFORM_EXCLUDES = Collections.singletonList(
-            (Class) PredicatedTransformation.class
-    );
-
     public ConnectorPluginsResource(Herder herder) {
         this.herder = herder;
         this.connectorPlugins = new ArrayList<>();
@@ -92,7 +85,7 @@ public class ConnectorPluginsResource implements ConnectResource {
         // TODO: improve once plugins are allowed to be added/removed during runtime.
         addConnectorPlugins(herder.plugins().sinkConnectors(), SINK_CONNECTOR_EXCLUDES);
         addConnectorPlugins(herder.plugins().sourceConnectors(), SOURCE_CONNECTOR_EXCLUDES);
-        addConnectorPlugins(herder.plugins().transformations(), TRANSFORM_EXCLUDES);
+        addConnectorPlugins(herder.plugins().transformations(), Collections.emptySet());
         addConnectorPlugins(herder.plugins().predicates(), Collections.emptySet());
         addConnectorPlugins(herder.plugins().converters(), Collections.emptySet());
         addConnectorPlugins(herder.plugins().headerConverters(), Collections.emptySet());
diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/ConnectorConfigTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/ConnectorConfigTest.java
index 4abdbeaa4e5..d8c071e6c2f 100644
--- a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/ConnectorConfigTest.java
+++ b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/ConnectorConfigTest.java
@@ -22,6 +22,7 @@ import org.apache.kafka.connect.connector.ConnectRecord;
 import org.apache.kafka.connect.connector.Connector;
 import org.apache.kafka.connect.runtime.isolation.PluginDesc;
 import org.apache.kafka.connect.runtime.isolation.Plugins;
+import org.apache.kafka.connect.sink.SinkRecord;
 import org.apache.kafka.connect.transforms.Transformation;
 import org.apache.kafka.connect.transforms.predicates.Predicate;
 import org.junit.Test;
@@ -48,6 +49,8 @@ public class ConnectorConfigTest<R extends ConnectRecord<R>> {
         }
     };
 
+    private static final SinkRecord DUMMY_RECORD = new SinkRecord(null, 0, null, null, null, null, 0L);
+
     public static abstract class TestConnector extends Connector {
     }
 
@@ -62,7 +65,7 @@ public class ConnectorConfigTest<R extends ConnectRecord<R>> {
 
         @Override
         public R apply(R record) {
-            return null;
+            return record.newRecord(null, magicNumber, null, null, null, null, 0L);
         }
 
         @Override
@@ -147,10 +150,11 @@ public class ConnectorConfigTest<R extends ConnectRecord<R>> {
         props.put("transforms.a.type", SimpleTransformation.class.getName());
         props.put("transforms.a.magic.number", "42");
         final ConnectorConfig config = new ConnectorConfig(MOCK_PLUGINS, props);
-        final List<Transformation<R>> transformations = config.transformations();
-        assertEquals(1, transformations.size());
-        final SimpleTransformation<R> xform = (SimpleTransformation<R>) transformations.get(0);
-        assertEquals(42, xform.magicNumber);
+        final List<TransformationStage<SinkRecord>> transformationStages = config.transformationStages();
+        assertEquals(1, transformationStages.size());
+        final TransformationStage<SinkRecord> stage = transformationStages.get(0);
+        assertEquals(SimpleTransformation.class, stage.transformClass());
+        assertEquals(42, stage.apply(DUMMY_RECORD).kafkaPartition().intValue());
     }
 
     @Test
@@ -175,10 +179,10 @@ public class ConnectorConfigTest<R extends ConnectRecord<R>> {
         props.put("transforms.b.type", SimpleTransformation.class.getName());
         props.put("transforms.b.magic.number", "84");
         final ConnectorConfig config = new ConnectorConfig(MOCK_PLUGINS, props);
-        final List<Transformation<R>> transformations = config.transformations();
-        assertEquals(2, transformations.size());
-        assertEquals(42, ((SimpleTransformation<R>) transformations.get(0)).magicNumber);
-        assertEquals(84, ((SimpleTransformation<R>) transformations.get(1)).magicNumber);
+        final List<TransformationStage<SinkRecord>> transformationStages = config.transformationStages();
+        assertEquals(2, transformationStages.size());
+        assertEquals(42, transformationStages.get(0).apply(DUMMY_RECORD).kafkaPartition().intValue());
+        assertEquals(84, transformationStages.get(1).apply(DUMMY_RECORD).kafkaPartition().intValue());
     }
 
     @Test
@@ -246,7 +250,7 @@ public class ConnectorConfigTest<R extends ConnectRecord<R>> {
         props.put("predicates", "my-pred");
         props.put("predicates.my-pred.type", TestPredicate.class.getName());
         props.put("predicates.my-pred.int", "84");
-        assertPredicatedTransform(props, true);
+        assertTransformationStageWithPredicate(props, true);
     }
 
     @Test
@@ -261,7 +265,7 @@ public class ConnectorConfigTest<R extends ConnectRecord<R>> {
         props.put("predicates", "my-pred");
         props.put("predicates.my-pred.type", TestPredicate.class.getName());
         props.put("predicates.my-pred.int", "84");
-        assertPredicatedTransform(props, false);
+        assertTransformationStageWithPredicate(props, false);
     }
 
     @Test
@@ -280,25 +284,19 @@ public class ConnectorConfigTest<R extends ConnectRecord<R>> {
         assertTrue(e.getMessage().contains("Predicate is abstract and cannot be created"));
     }
 
-    private void assertPredicatedTransform(Map<String, String> props, boolean expectedNegated) {
+    private void assertTransformationStageWithPredicate(Map<String, String> props, boolean expectedNegated) {
         final ConnectorConfig config = new ConnectorConfig(MOCK_PLUGINS, props);
-        final List<Transformation<R>> transformations = config.transformations();
-        assertEquals(1, transformations.size());
-        assertTrue(transformations.get(0) instanceof PredicatedTransformation);
-        PredicatedTransformation<?> predicated = (PredicatedTransformation<?>) transformations.get(0);
-
-        assertEquals(expectedNegated, predicated.negate);
-
-        assertTrue(predicated.delegate instanceof ConnectorConfigTest.SimpleTransformation);
-        assertEquals(42, ((SimpleTransformation<?>) predicated.delegate).magicNumber);
+        final List<TransformationStage<SinkRecord>> transformationStages = config.transformationStages();
+        assertEquals(1, transformationStages.size());
+        TransformationStage<SinkRecord> stage = transformationStages.get(0);
 
-        assertTrue(predicated.predicate instanceof ConnectorConfigTest.TestPredicate);
-        assertEquals(84, ((TestPredicate<?>) predicated.predicate).param);
+        assertEquals(expectedNegated ? 42 : 0, stage.apply(DUMMY_RECORD).kafkaPartition().intValue());
 
-        predicated.close();
+        SinkRecord matchingRecord = DUMMY_RECORD.newRecord(null, 84, null, null, null, null, 0L);
+        assertEquals(expectedNegated ? 84 : 42, stage.apply(matchingRecord).kafkaPartition().intValue());
+        assertEquals(SimpleTransformation.class, stage.transformClass());
 
-        assertEquals(0, ((SimpleTransformation<?>) predicated.delegate).magicNumber);
-        assertEquals(0, ((TestPredicate<?>) predicated.predicate).param);
+        stage.close();
     }
 
     @Test
@@ -381,7 +379,7 @@ public class ConnectorConfigTest<R extends ConnectRecord<R>> {
 
         @Override
         public boolean test(R record) {
-            return false;
+            return record.kafkaPartition() == param;
         }
 
         @Override
@@ -445,8 +443,8 @@ public class ConnectorConfigTest<R extends ConnectRecord<R>> {
         props.put(prefix + "type", HasDuplicateConfigTransformation.class.getName());
         ConfigDef def = ConnectorConfig.enrich(MOCK_PLUGINS, new ConfigDef(), props, false);
         assertEnrichedConfigDef(def, prefix, HasDuplicateConfigTransformation.MUST_EXIST_KEY, ConfigDef.Type.BOOLEAN);
-        assertEnrichedConfigDef(def, prefix, PredicatedTransformation.PREDICATE_CONFIG, ConfigDef.Type.STRING);
-        assertEnrichedConfigDef(def, prefix, PredicatedTransformation.NEGATE_CONFIG, ConfigDef.Type.BOOLEAN);
+        assertEnrichedConfigDef(def, prefix, TransformationStage.PREDICATE_CONFIG, ConfigDef.Type.STRING);
+        assertEnrichedConfigDef(def, prefix, TransformationStage.NEGATE_CONFIG, ConfigDef.Type.BOOLEAN);
     }
 
     private static void assertEnrichedConfigDef(ConfigDef def, String prefix, String keyName, ConfigDef.Type expectedType) {
@@ -460,9 +458,9 @@ public class ConnectorConfigTest<R extends ConnectRecord<R>> {
         private static final String MUST_EXIST_KEY = "must.exist.key";
         private static final ConfigDef CONFIG_DEF = new ConfigDef()
                 // this configDef is duplicate. It should be removed automatically so as to avoid duplicate config error.
-                .define(PredicatedTransformation.PREDICATE_CONFIG, ConfigDef.Type.INT, ConfigDef.NO_DEFAULT_VALUE, ConfigDef.Importance.MEDIUM, "fake")
+                .define(TransformationStage.PREDICATE_CONFIG, ConfigDef.Type.INT, ConfigDef.NO_DEFAULT_VALUE, ConfigDef.Importance.MEDIUM, "fake")
                 // this configDef is duplicate. It should be removed automatically so as to avoid duplicate config error.
-                .define(PredicatedTransformation.NEGATE_CONFIG, ConfigDef.Type.INT, 123, ConfigDef.Importance.MEDIUM, "fake")
+                .define(TransformationStage.NEGATE_CONFIG, ConfigDef.Type.INT, 123, ConfigDef.Importance.MEDIUM, "fake")
                 // this configDef should appear if above duplicate configDef is removed without any error
                 .define(MUST_EXIST_KEY, ConfigDef.Type.BOOLEAN, true, ConfigDef.Importance.MEDIUM, "this key must exist");
 
diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/ErrorHandlingTaskTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/ErrorHandlingTaskTest.java
index 4de13128dbc..3a0090f2267 100644
--- a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/ErrorHandlingTaskTest.java
+++ b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/ErrorHandlingTaskTest.java
@@ -502,7 +502,7 @@ public class ErrorHandlingTaskTest {
         converter.configure(oo);
 
         TransformationChain<SinkRecord> sinkTransforms =
-                new TransformationChain<>(singletonList(new FaultyPassthrough<SinkRecord>()), retryWithToleranceOperator);
+                new TransformationChain<>(singletonList(new TransformationStage<>(new FaultyPassthrough<SinkRecord>())), retryWithToleranceOperator);
 
         workerSinkTask = new WorkerSinkTask(
             taskId, sinkTask, statusListener, initialState, workerConfig,
@@ -532,7 +532,7 @@ public class ErrorHandlingTaskTest {
     }
 
     private void createSourceTask(TargetState initialState, RetryWithToleranceOperator retryWithToleranceOperator, Converter converter) {
-        TransformationChain<SourceRecord> sourceTransforms = new TransformationChain<>(singletonList(new FaultyPassthrough<SourceRecord>()), retryWithToleranceOperator);
+        TransformationChain<SourceRecord> sourceTransforms = new TransformationChain<>(singletonList(new TransformationStage<>(new FaultyPassthrough<SourceRecord>())), retryWithToleranceOperator);
 
         workerSourceTask = spy(new WorkerSourceTask(
             taskId, sourceTask, statusListener, initialState, converter,
diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/PredicatedTransformationTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/TransformationStageTest.java
similarity index 63%
rename from connect/runtime/src/test/java/org/apache/kafka/connect/runtime/PredicatedTransformationTest.java
rename to connect/runtime/src/test/java/org/apache/kafka/connect/runtime/TransformationStageTest.java
index 8bce328817a..d31e8563f8c 100644
--- a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/PredicatedTransformationTest.java
+++ b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/TransformationStageTest.java
@@ -17,13 +17,18 @@
 package org.apache.kafka.connect.runtime;
 
 import org.apache.kafka.connect.source.SourceRecord;
+import org.apache.kafka.connect.transforms.Transformation;
+import org.apache.kafka.connect.transforms.predicates.Predicate;
 import org.junit.Test;
 
 import static java.util.Collections.singletonMap;
 import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertTrue;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
 
-public class PredicatedTransformationTest {
+public class TransformationStageTest {
 
     private final SourceRecord initial = new SourceRecord(singletonMap("initial", 1), null, null, null, null);
     private final SourceRecord transformed = new SourceRecord(singletonMap("transformed", 2), null, null, null, null);
@@ -39,17 +44,21 @@ public class PredicatedTransformationTest {
     private void applyAndAssert(boolean predicateResult, boolean negate,
                                 SourceRecord expectedResult) {
 
-        SamplePredicate predicate = new SamplePredicate(predicateResult);
-        SampleTransformation<SourceRecord> predicatedTransform = new SampleTransformation<>(transformed);
-        PredicatedTransformation<SourceRecord> pt = new PredicatedTransformation<>(
+        @SuppressWarnings("unchecked")
+        Predicate<SourceRecord> predicate = mock(Predicate.class);
+        when(predicate.test(any())).thenReturn(predicateResult);
+        @SuppressWarnings("unchecked")
+        Transformation<SourceRecord> transformation = mock(Transformation.class);
+        when(transformation.apply(any())).thenReturn(transformed);
+        TransformationStage<SourceRecord> stage = new TransformationStage<>(
                 predicate,
                 negate,
-                predicatedTransform);
+                transformation);
 
-        assertEquals(expectedResult, pt.apply(initial));
+        assertEquals(expectedResult, stage.apply(initial));
 
-        pt.close();
-        assertTrue(predicate.closed);
-        assertTrue(predicatedTransform.closed);
+        stage.close();
+        verify(predicate).close();
+        verify(transformation).close();
     }
 }
diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/rest/resources/ConnectorPluginsResourceTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/rest/resources/ConnectorPluginsResourceTest.java
index 59cf83ca9ae..63e7c27c92b 100644
--- a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/rest/resources/ConnectorPluginsResourceTest.java
+++ b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/rest/resources/ConnectorPluginsResourceTest.java
@@ -387,8 +387,7 @@ public class ConnectorPluginsResourceTest {
     public void testListAllPlugins() {
         Set<Class<?>> excludes = Stream.of(
                         ConnectorPluginsResource.SINK_CONNECTOR_EXCLUDES,
-                        ConnectorPluginsResource.SOURCE_CONNECTOR_EXCLUDES,
-                        ConnectorPluginsResource.TRANSFORM_EXCLUDES
+                        ConnectorPluginsResource.SOURCE_CONNECTOR_EXCLUDES
                 ).flatMap(Collection::stream)
                 .collect(Collectors.toSet());
         Set<PluginInfo> expectedConnectorPlugins = Stream.of(
diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/util/TopicCreationTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/util/TopicCreationTest.java
index feb0e5f5ac9..af11782a041 100644
--- a/connect/runtime/src/test/java/org/apache/kafka/connect/util/TopicCreationTest.java
+++ b/connect/runtime/src/test/java/org/apache/kafka/connect/util/TopicCreationTest.java
@@ -19,6 +19,7 @@ package org.apache.kafka.connect.util;
 
 import org.apache.kafka.clients.admin.NewTopic;
 import org.apache.kafka.connect.data.Schema;
+import org.apache.kafka.connect.runtime.TransformationStage;
 import org.apache.kafka.connect.runtime.SourceConnectorConfig;
 import org.apache.kafka.connect.runtime.WorkerConfig;
 import org.apache.kafka.connect.runtime.distributed.DistributedConfig;
@@ -26,7 +27,6 @@ import org.apache.kafka.connect.source.SourceRecord;
 import org.apache.kafka.connect.storage.StringConverter;
 import org.apache.kafka.connect.transforms.Cast;
 import org.apache.kafka.connect.transforms.RegexRouter;
-import org.apache.kafka.connect.transforms.Transformation;
 import org.junit.Before;
 import org.junit.Test;
 
@@ -516,9 +516,9 @@ public class TopicCreationTest {
         topicCreation.addTopic(FOO_TOPIC);
         assertFalse(topicCreation.isTopicCreationRequired(FOO_TOPIC));
 
-        List<Transformation<SourceRecord>> transformations = sourceConfig.transformations();
-        assertEquals(1, transformations.size());
-        Cast<SourceRecord> xform = (Cast<SourceRecord>) transformations.get(0);
+        List<TransformationStage<SourceRecord>> transformationStages = sourceConfig.transformationStages();
+        assertEquals(1, transformationStages.size());
+        TransformationStage<SourceRecord> xform = transformationStages.get(0);
         SourceRecord transformed = xform.apply(new SourceRecord(null, null, "topic", 0, null, null, Schema.INT8_SCHEMA, 42));
         assertEquals(Schema.Type.INT8, transformed.valueSchema().type());
         assertEquals((byte) 42, transformed.value());
@@ -623,15 +623,15 @@ public class TopicCreationTest {
         assertEquals(barPartitions, barTopicSpec.numPartitions());
         assertThat(barTopicSpec.configs(), is(barTopicProps));
 
-        List<Transformation<SourceRecord>> transformations = sourceConfig.transformations();
-        assertEquals(2, transformations.size());
+        List<TransformationStage<SourceRecord>> transformationStages = sourceConfig.transformationStages();
+        assertEquals(2, transformationStages.size());
 
-        Cast<SourceRecord> castXForm = (Cast<SourceRecord>) transformations.get(0);
+        TransformationStage<SourceRecord> castXForm = transformationStages.get(0);
         SourceRecord transformed = castXForm.apply(new SourceRecord(null, null, "topic", 0, null, null, Schema.INT8_SCHEMA, 42));
         assertEquals(Schema.Type.INT8, transformed.valueSchema().type());
         assertEquals((byte) 42, transformed.value());
 
-        RegexRouter<SourceRecord> regexRouterXForm = (RegexRouter<SourceRecord>) transformations.get(1);
+        TransformationStage<SourceRecord> regexRouterXForm = transformationStages.get(1);
         transformed = regexRouterXForm.apply(new SourceRecord(null, null, "topic", 0, null, null, Schema.INT8_SCHEMA, 42));
         assertEquals("prefix-topic", transformed.topic());
     }