You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@pinot.apache.org by ja...@apache.org on 2022/04/21 22:44:50 UTC

[pinot] branch master updated: Added multi column partitioning for offline table (#8255)

This is an automated email from the ASF dual-hosted git repository.

jackie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new f200ec5156 Added multi column partitioning for offline table (#8255)
f200ec5156 is described below

commit f200ec515630442f95f6fc55f568b00f7cb04f4f
Author: Mohemmad Zaid Khan <za...@gmail.com>
AuthorDate: Fri Apr 22 04:14:44 2022 +0530

    Added multi column partitioning for offline table (#8255)
    
    Adds support for multi column partitioning for minion tasks and broker segment pruner.
---
 ...ava => MultiPartitionColumnsSegmentPruner.java} | 140 ++++++++++++--------
 .../segmentpruner/SegmentPrunerFactory.java        |  31 +++--
 ...ava => SinglePartitionColumnSegmentPruner.java} |  10 +-
 .../routing/segmentpruner/SegmentPrunerTest.java   | 143 +++++++++++++++------
 .../tests/BaseClusterIntegrationTest.java          |   7 +-
 .../MergeRollupMinionClusterIntegrationTest.java   |  22 +++-
 ...fflineSegmentsMinionClusterIntegrationTest.java |  44 ++++++-
 .../pinot/plugin/minion/tasks/MergeTaskUtils.java  |  28 ++--
 .../mergerollup/MergeRollupTaskGenerator.java      |  41 +++---
 .../plugin/minion/tasks/MergeTaskUtilsTest.java    |  14 ++
 .../mergerollup/MergeRollupTaskGeneratorTest.java  |  56 ++++++--
 .../indexsegment/mutable/IntermediateSegment.java  |  23 +---
 .../mutable/IntermediateSegmentTest.java           |  32 ++++-
 13 files changed, 410 insertions(+), 181 deletions(-)

diff --git a/pinot-broker/src/main/java/org/apache/pinot/broker/routing/segmentpruner/PartitionSegmentPruner.java b/pinot-broker/src/main/java/org/apache/pinot/broker/routing/segmentpruner/MultiPartitionColumnsSegmentPruner.java
similarity index 59%
copy from pinot-broker/src/main/java/org/apache/pinot/broker/routing/segmentpruner/PartitionSegmentPruner.java
copy to pinot-broker/src/main/java/org/apache/pinot/broker/routing/segmentpruner/MultiPartitionColumnsSegmentPruner.java
index 87755897d1..65b626c322 100644
--- a/pinot-broker/src/main/java/org/apache/pinot/broker/routing/segmentpruner/PartitionSegmentPruner.java
+++ b/pinot-broker/src/main/java/org/apache/pinot/broker/routing/segmentpruner/MultiPartitionColumnsSegmentPruner.java
@@ -18,7 +18,10 @@
  */
 package org.apache.pinot.broker.routing.segmentpruner;
 
+import com.google.common.annotations.VisibleForTesting;
 import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
@@ -49,23 +52,23 @@ import org.slf4j.LoggerFactory;
 
 
 /**
- * The {@code PartitionSegmentPruner} prunes segments based on the their partition metadata stored in ZK. The pruner
- * supports queries with filter (or nested filter) of EQUALITY and IN predicates.
+ * The {@code MultiPartitionColumnsSegmentPruner} prunes segments based on their partition metadata stored in ZK. The
+ * pruner supports queries with filter (or nested filter) of EQUALITY and IN predicates.
  */
-public class PartitionSegmentPruner implements SegmentPruner {
-  private static final Logger LOGGER = LoggerFactory.getLogger(PartitionSegmentPruner.class);
-  private static final PartitionInfo INVALID_PARTITION_INFO = new PartitionInfo(null, null);
+public class MultiPartitionColumnsSegmentPruner implements SegmentPruner {
+  private static final Logger LOGGER = LoggerFactory.getLogger(MultiPartitionColumnsSegmentPruner.class);
+  private static final Map<String, PartitionInfo> INVALID_COLUMN_PARTITION_INFO_MAP = Collections.emptyMap();
 
   private final String _tableNameWithType;
-  private final String _partitionColumn;
+  private final Set<String> _partitionColumns;
   private final ZkHelixPropertyStore<ZNRecord> _propertyStore;
   private final String _segmentZKMetadataPathPrefix;
-  private final Map<String, PartitionInfo> _partitionInfoMap = new ConcurrentHashMap<>();
+  private final Map<String, Map<String, PartitionInfo>> _segmentColumnPartitionInfoMap = new ConcurrentHashMap<>();
 
-  public PartitionSegmentPruner(String tableNameWithType, String partitionColumn,
+  public MultiPartitionColumnsSegmentPruner(String tableNameWithType, Set<String> partitionColumns,
       ZkHelixPropertyStore<ZNRecord> propertyStore) {
     _tableNameWithType = tableNameWithType;
-    _partitionColumn = partitionColumn;
+    _partitionColumns = partitionColumns;
     _propertyStore = propertyStore;
     _segmentZKMetadataPathPrefix = ZKMetadataProvider.constructPropertyStorePathForResource(tableNameWithType) + "/";
   }
@@ -83,20 +86,22 @@ public class PartitionSegmentPruner implements SegmentPruner {
     List<ZNRecord> znRecords = _propertyStore.get(segmentZKMetadataPaths, null, AccessOption.PERSISTENT, false);
     for (int i = 0; i < numSegments; i++) {
       String segment = segments.get(i);
-      PartitionInfo partitionInfo = extractPartitionInfoFromSegmentZKMetadataZNRecord(segment, znRecords.get(i));
-      if (partitionInfo != null) {
-        _partitionInfoMap.put(segment, partitionInfo);
+      Map<String, PartitionInfo> columnPartitionInfoMap =
+          extractColumnPartitionInfoMapFromSegmentZKMetadataZNRecord(segment, znRecords.get(i));
+      if (columnPartitionInfoMap != null) {
+        _segmentColumnPartitionInfoMap.put(segment, columnPartitionInfoMap);
       }
     }
   }
 
   /**
    * NOTE: Returns {@code null} when the ZNRecord is missing (could be transient Helix issue). Returns
-   *       {@link #INVALID_PARTITION_INFO} when the segment does not have valid partition metadata in its ZK metadata,
-   *       in which case we won't retry later.
+   *       {@link #INVALID_COLUMN_PARTITION_INFO_MAP} when the segment does not have valid partition metadata in its ZK
+   *       metadata, in which case we won't retry later.
    */
   @Nullable
-  private PartitionInfo extractPartitionInfoFromSegmentZKMetadataZNRecord(String segment, @Nullable ZNRecord znRecord) {
+  private Map<String, PartitionInfo> extractColumnPartitionInfoMapFromSegmentZKMetadataZNRecord(String segment,
+      @Nullable ZNRecord znRecord) {
     if (znRecord == null) {
       LOGGER.warn("Failed to find segment ZK metadata for segment: {}, table: {}", segment, _tableNameWithType);
       return null;
@@ -105,7 +110,7 @@ public class PartitionSegmentPruner implements SegmentPruner {
     String partitionMetadataJson = znRecord.getSimpleField(Segment.PARTITION_METADATA);
     if (partitionMetadataJson == null) {
       LOGGER.warn("Failed to find segment partition metadata for segment: {}, table: {}", segment, _tableNameWithType);
-      return INVALID_PARTITION_INFO;
+      return INVALID_COLUMN_PARTITION_INFO_MAP;
     }
 
     SegmentPartitionMetadata segmentPartitionMetadata;
@@ -114,20 +119,29 @@ public class PartitionSegmentPruner implements SegmentPruner {
     } catch (Exception e) {
       LOGGER.warn("Caught exception while extracting segment partition metadata for segment: {}, table: {}", segment,
           _tableNameWithType, e);
-      return INVALID_PARTITION_INFO;
+      return INVALID_COLUMN_PARTITION_INFO_MAP;
     }
 
-    ColumnPartitionMetadata columnPartitionMetadata =
-        segmentPartitionMetadata.getColumnPartitionMap().get(_partitionColumn);
-    if (columnPartitionMetadata == null) {
-      LOGGER.warn("Failed to find column partition metadata for column: {}, segment: {}, table: {}", _partitionColumn,
-          segment, _tableNameWithType);
-      return INVALID_PARTITION_INFO;
+    Map<String, PartitionInfo> columnPartitionInfoMap = new HashMap<>();
+    for (String partitionColumn : _partitionColumns) {
+      ColumnPartitionMetadata columnPartitionMetadata =
+          segmentPartitionMetadata.getColumnPartitionMap().get(partitionColumn);
+      if (columnPartitionMetadata == null) {
+        LOGGER.warn("Failed to find column partition metadata for column: {}, segment: {}, table: {}", partitionColumn,
+            segment, _tableNameWithType);
+        continue;
+      }
+      PartitionInfo partitionInfo = new PartitionInfo(
+          PartitionFunctionFactory.getPartitionFunction(columnPartitionMetadata.getFunctionName(),
+              columnPartitionMetadata.getNumPartitions(), columnPartitionMetadata.getFunctionConfig()),
+          columnPartitionMetadata.getPartitions());
+      columnPartitionInfoMap.put(partitionColumn, partitionInfo);
     }
-
-    return new PartitionInfo(PartitionFunctionFactory.getPartitionFunction(columnPartitionMetadata.getFunctionName(),
-        columnPartitionMetadata.getNumPartitions(), columnPartitionMetadata.getFunctionConfig()),
-        columnPartitionMetadata.getPartitions());
+    if (columnPartitionInfoMap.size() == 1) {
+      String partitionColumn = columnPartitionInfoMap.keySet().iterator().next();
+      return Collections.singletonMap(partitionColumn, columnPartitionInfoMap.get(partitionColumn));
+    }
+    return columnPartitionInfoMap.isEmpty() ? INVALID_COLUMN_PARTITION_INFO_MAP : columnPartitionInfoMap;
   }
 
   @Override
@@ -136,20 +150,21 @@ public class PartitionSegmentPruner implements SegmentPruner {
     // NOTE: We don't update all the segment ZK metadata for every external view change, but only the new added/removed
     //       ones. The refreshed segment ZK metadata change won't be picked up.
     for (String segment : onlineSegments) {
-      _partitionInfoMap.computeIfAbsent(segment, k -> extractPartitionInfoFromSegmentZKMetadataZNRecord(k,
-          _propertyStore.get(_segmentZKMetadataPathPrefix + k, null, AccessOption.PERSISTENT)));
+      _segmentColumnPartitionInfoMap.computeIfAbsent(segment,
+          k -> extractColumnPartitionInfoMapFromSegmentZKMetadataZNRecord(k,
+              _propertyStore.get(_segmentZKMetadataPathPrefix + k, null, AccessOption.PERSISTENT)));
     }
-    _partitionInfoMap.keySet().retainAll(onlineSegments);
+    _segmentColumnPartitionInfoMap.keySet().retainAll(onlineSegments);
   }
 
   @Override
   public synchronized void refreshSegment(String segment) {
-    PartitionInfo partitionInfo = extractPartitionInfoFromSegmentZKMetadataZNRecord(segment,
+    Map<String, PartitionInfo> columnPartitionInfo = extractColumnPartitionInfoMapFromSegmentZKMetadataZNRecord(segment,
         _propertyStore.get(_segmentZKMetadataPathPrefix + segment, null, AccessOption.PERSISTENT));
-    if (partitionInfo != null) {
-      _partitionInfoMap.put(segment, partitionInfo);
+    if (columnPartitionInfo != null) {
+      _segmentColumnPartitionInfoMap.put(segment, columnPartitionInfo);
     } else {
-      _partitionInfoMap.remove(segment);
+      _segmentColumnPartitionInfoMap.remove(segment);
     }
   }
 
@@ -165,9 +180,9 @@ public class PartitionSegmentPruner implements SegmentPruner {
       }
       Set<String> selectedSegments = new HashSet<>();
       for (String segment : segments) {
-        PartitionInfo partitionInfo = _partitionInfoMap.get(segment);
-        if (partitionInfo == null || partitionInfo == INVALID_PARTITION_INFO || isPartitionMatch(filterExpression,
-            partitionInfo)) {
+        Map<String, PartitionInfo> columnPartitionInfoMap = _segmentColumnPartitionInfoMap.get(segment);
+        if (columnPartitionInfoMap == null || columnPartitionInfoMap == INVALID_COLUMN_PARTITION_INFO_MAP
+            || isPartitionMatch(filterExpression, columnPartitionInfoMap)) {
           selectedSegments.add(segment);
         }
       }
@@ -180,9 +195,9 @@ public class PartitionSegmentPruner implements SegmentPruner {
       }
       Set<String> selectedSegments = new HashSet<>();
       for (String segment : segments) {
-        PartitionInfo partitionInfo = _partitionInfoMap.get(segment);
-        if (partitionInfo == null || partitionInfo == INVALID_PARTITION_INFO || isPartitionMatch(filterQueryTree,
-            partitionInfo)) {
+        Map<String, PartitionInfo> columnPartitionInfo = _segmentColumnPartitionInfoMap.get(segment);
+        if (columnPartitionInfo == null || columnPartitionInfo == INVALID_COLUMN_PARTITION_INFO_MAP || isPartitionMatch(
+            filterQueryTree, columnPartitionInfo)) {
           selectedSegments.add(segment);
         }
       }
@@ -190,37 +205,47 @@ public class PartitionSegmentPruner implements SegmentPruner {
     }
   }
 
-  private boolean isPartitionMatch(Expression filterExpression, PartitionInfo partitionInfo) {
+  @VisibleForTesting
+  public Set<String> getPartitionColumns() {
+    return _partitionColumns;
+  }
+
+  private boolean isPartitionMatch(Expression filterExpression, Map<String, PartitionInfo> columnPartitionInfoMap) {
     Function function = filterExpression.getFunctionCall();
     FilterKind filterKind = FilterKind.valueOf(function.getOperator());
     List<Expression> operands = function.getOperands();
     switch (filterKind) {
       case AND:
         for (Expression child : operands) {
-          if (!isPartitionMatch(child, partitionInfo)) {
+          if (!isPartitionMatch(child, columnPartitionInfoMap)) {
             return false;
           }
         }
         return true;
       case OR:
         for (Expression child : operands) {
-          if (isPartitionMatch(child, partitionInfo)) {
+          if (isPartitionMatch(child, columnPartitionInfoMap)) {
             return true;
           }
         }
         return false;
       case EQUALS: {
         Identifier identifier = operands.get(0).getIdentifier();
-        if (identifier != null && identifier.getName().equals(_partitionColumn)) {
-          return partitionInfo._partitions.contains(
-              partitionInfo._partitionFunction.getPartition(operands.get(1).getLiteral().getFieldValue().toString()));
+        if (identifier != null) {
+          PartitionInfo partitionInfo = columnPartitionInfoMap.get(identifier.getName());
+          return partitionInfo == null || partitionInfo._partitions.contains(
+              partitionInfo._partitionFunction.getPartition(operands.get(1).getLiteral().getFieldValue()));
         } else {
           return true;
         }
       }
       case IN: {
         Identifier identifier = operands.get(0).getIdentifier();
-        if (identifier != null && identifier.getName().equals(_partitionColumn)) {
+        if (identifier != null) {
+          PartitionInfo partitionInfo = columnPartitionInfoMap.get(identifier.getName());
+          if (partitionInfo == null) {
+            return true;
+          }
           int numOperands = operands.size();
           for (int i = 1; i < numOperands; i++) {
             if (partitionInfo._partitions.contains(partitionInfo._partitionFunction.getPartition(
@@ -239,33 +264,34 @@ public class PartitionSegmentPruner implements SegmentPruner {
   }
 
   @Deprecated
-  private boolean isPartitionMatch(FilterQueryTree filterQueryTree, PartitionInfo partitionInfo) {
+  private boolean isPartitionMatch(FilterQueryTree filterQueryTree, Map<String, PartitionInfo> columnPartitionInfoMap) {
     switch (filterQueryTree.getOperator()) {
       case AND:
         for (FilterQueryTree child : filterQueryTree.getChildren()) {
-          if (!isPartitionMatch(child, partitionInfo)) {
+          if (!isPartitionMatch(child, columnPartitionInfoMap)) {
             return false;
           }
         }
         return true;
       case OR:
         for (FilterQueryTree child : filterQueryTree.getChildren()) {
-          if (isPartitionMatch(child, partitionInfo)) {
+          if (isPartitionMatch(child, columnPartitionInfoMap)) {
             return true;
           }
         }
         return false;
       case EQUALITY:
       case IN:
-        if (filterQueryTree.getColumn().equals(_partitionColumn)) {
-          for (String value : filterQueryTree.getValue()) {
-            if (partitionInfo._partitions.contains(partitionInfo._partitionFunction.getPartition(value))) {
-              return true;
-            }
+        PartitionInfo partitionInfo = columnPartitionInfoMap.get(filterQueryTree.getColumn());
+        if (partitionInfo == null) {
+          return true;
+        }
+        for (String value : filterQueryTree.getValue()) {
+          if (partitionInfo._partitions.contains(partitionInfo._partitionFunction.getPartition(value))) {
+            return true;
           }
-          return false;
         }
-        return true;
+        return false;
       default:
         return true;
     }
diff --git a/pinot-broker/src/main/java/org/apache/pinot/broker/routing/segmentpruner/SegmentPrunerFactory.java b/pinot-broker/src/main/java/org/apache/pinot/broker/routing/segmentpruner/SegmentPrunerFactory.java
index b5adeba324..1f82d8f636 100644
--- a/pinot-broker/src/main/java/org/apache/pinot/broker/routing/segmentpruner/SegmentPrunerFactory.java
+++ b/pinot-broker/src/main/java/org/apache/pinot/broker/routing/segmentpruner/SegmentPrunerFactory.java
@@ -21,7 +21,9 @@ package org.apache.pinot.broker.routing.segmentpruner;
 import java.util.ArrayList;
 import java.util.List;
 import java.util.Map;
+import java.util.Set;
 import javax.annotation.Nullable;
+import org.apache.commons.collections.MapUtils;
 import org.apache.helix.ZNRecord;
 import org.apache.helix.store.zk.ZkHelixPropertyStore;
 import org.apache.pinot.segment.local.utils.TableConfigUtils;
@@ -60,7 +62,7 @@ public class SegmentPrunerFactory {
         List<SegmentPruner> configuredSegmentPruners = new ArrayList<>(segmentPrunerTypes.size());
         for (String segmentPrunerType : segmentPrunerTypes) {
           if (RoutingConfig.PARTITION_SEGMENT_PRUNER_TYPE.equalsIgnoreCase(segmentPrunerType)) {
-            PartitionSegmentPruner partitionSegmentPruner = getPartitionSegmentPruner(tableConfig, propertyStore);
+            SegmentPruner partitionSegmentPruner = getPartitionSegmentPruner(tableConfig, propertyStore);
             if (partitionSegmentPruner != null) {
               configuredSegmentPruners.add(partitionSegmentPruner);
             }
@@ -83,9 +85,9 @@ public class SegmentPrunerFactory {
         if ((tableType == TableType.OFFLINE && LEGACY_PARTITION_AWARE_OFFLINE_ROUTING.equalsIgnoreCase(
             routingTableBuilderName)) || (tableType == TableType.REALTIME
             && LEGACY_PARTITION_AWARE_REALTIME_ROUTING.equalsIgnoreCase(routingTableBuilderName))) {
-          PartitionSegmentPruner partitionSegmentPruner = getPartitionSegmentPruner(tableConfig, propertyStore);
+          SegmentPruner partitionSegmentPruner = getPartitionSegmentPruner(tableConfig, propertyStore);
           if (partitionSegmentPruner != null) {
-            segmentPruners.add(getPartitionSegmentPruner(tableConfig, propertyStore));
+            segmentPruners.add(partitionSegmentPruner);
           }
         }
       }
@@ -94,7 +96,7 @@ public class SegmentPrunerFactory {
   }
 
   @Nullable
-  private static PartitionSegmentPruner getPartitionSegmentPruner(TableConfig tableConfig,
+  private static SegmentPruner getPartitionSegmentPruner(TableConfig tableConfig,
       ZkHelixPropertyStore<ZNRecord> propertyStore) {
     String tableNameWithType = tableConfig.getTableName();
     SegmentPartitionConfig segmentPartitionConfig = tableConfig.getIndexingConfig().getSegmentPartitionConfig();
@@ -102,17 +104,17 @@ public class SegmentPrunerFactory {
       LOGGER.warn("Cannot enable partition pruning without segment partition config for table: {}", tableNameWithType);
       return null;
     }
-    Map<String, ColumnPartitionConfig> columnPartitionMap = segmentPartitionConfig.getColumnPartitionMap();
-    if (columnPartitionMap.size() != 1) {
-      LOGGER.warn("Cannot enable partition pruning with other than exact one partition column for table: {}",
-          tableNameWithType);
+    if (MapUtils.isEmpty(segmentPartitionConfig.getColumnPartitionMap())) {
+      LOGGER.warn("Cannot enable partition pruning without column partition config for table: {}", tableNameWithType);
       return null;
-    } else {
-      String partitionColumn = columnPartitionMap.keySet().iterator().next();
-      LOGGER.info("Using PartitionSegmentPruner on partition column: {} for table: {}", partitionColumn,
-          tableNameWithType);
-      return new PartitionSegmentPruner(tableNameWithType, partitionColumn, propertyStore);
     }
+    Map<String, ColumnPartitionConfig> columnPartitionMap = segmentPartitionConfig.getColumnPartitionMap();
+    Set<String> partitionColumns = columnPartitionMap.keySet();
+    LOGGER.info("Using PartitionSegmentPruner on partition columns: {} for table: {}", partitionColumns,
+        tableNameWithType);
+    return partitionColumns.size() == 1 ? new SinglePartitionColumnSegmentPruner(tableNameWithType,
+        partitionColumns.iterator().next(), propertyStore)
+        : new MultiPartitionColumnsSegmentPruner(tableNameWithType, partitionColumns, propertyStore);
   }
 
   @Nullable
@@ -151,7 +153,8 @@ public class SegmentPrunerFactory {
       }
     }
     for (SegmentPruner pruner : pruners) {
-      if (pruner instanceof PartitionSegmentPruner) {
+      if (pruner instanceof SinglePartitionColumnSegmentPruner
+          || pruner instanceof MultiPartitionColumnsSegmentPruner) {
         sortedPruners.add(pruner);
       }
     }
diff --git a/pinot-broker/src/main/java/org/apache/pinot/broker/routing/segmentpruner/PartitionSegmentPruner.java b/pinot-broker/src/main/java/org/apache/pinot/broker/routing/segmentpruner/SinglePartitionColumnSegmentPruner.java
similarity index 95%
rename from pinot-broker/src/main/java/org/apache/pinot/broker/routing/segmentpruner/PartitionSegmentPruner.java
rename to pinot-broker/src/main/java/org/apache/pinot/broker/routing/segmentpruner/SinglePartitionColumnSegmentPruner.java
index 87755897d1..f695f2b225 100644
--- a/pinot-broker/src/main/java/org/apache/pinot/broker/routing/segmentpruner/PartitionSegmentPruner.java
+++ b/pinot-broker/src/main/java/org/apache/pinot/broker/routing/segmentpruner/SinglePartitionColumnSegmentPruner.java
@@ -49,11 +49,11 @@ import org.slf4j.LoggerFactory;
 
 
 /**
- * The {@code PartitionSegmentPruner} prunes segments based on the their partition metadata stored in ZK. The pruner
- * supports queries with filter (or nested filter) of EQUALITY and IN predicates.
+ * The {@code SinglePartitionColumnSegmentPruner} prunes segments based on their partition metadata stored in ZK. The
+ * pruner supports queries with filter (or nested filter) of EQUALITY and IN predicates.
  */
-public class PartitionSegmentPruner implements SegmentPruner {
-  private static final Logger LOGGER = LoggerFactory.getLogger(PartitionSegmentPruner.class);
+public class SinglePartitionColumnSegmentPruner implements SegmentPruner {
+  private static final Logger LOGGER = LoggerFactory.getLogger(SinglePartitionColumnSegmentPruner.class);
   private static final PartitionInfo INVALID_PARTITION_INFO = new PartitionInfo(null, null);
 
   private final String _tableNameWithType;
@@ -62,7 +62,7 @@ public class PartitionSegmentPruner implements SegmentPruner {
   private final String _segmentZKMetadataPathPrefix;
   private final Map<String, PartitionInfo> _partitionInfoMap = new ConcurrentHashMap<>();
 
-  public PartitionSegmentPruner(String tableNameWithType, String partitionColumn,
+  public SinglePartitionColumnSegmentPruner(String tableNameWithType, String partitionColumn,
       ZkHelixPropertyStore<ZNRecord> propertyStore) {
     _tableNameWithType = tableNameWithType;
     _partitionColumn = partitionColumn;
diff --git a/pinot-broker/src/test/java/org/apache/pinot/broker/routing/segmentpruner/SegmentPrunerTest.java b/pinot-broker/src/test/java/org/apache/pinot/broker/routing/segmentpruner/SegmentPrunerTest.java
index ad0f995376..0bd5ddca5c 100644
--- a/pinot-broker/src/test/java/org/apache/pinot/broker/routing/segmentpruner/SegmentPrunerTest.java
+++ b/pinot-broker/src/test/java/org/apache/pinot/broker/routing/segmentpruner/SegmentPrunerTest.java
@@ -26,6 +26,8 @@ import java.util.List;
 import java.util.Map;
 import java.util.Set;
 import java.util.concurrent.TimeUnit;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
 import org.apache.helix.ZNRecord;
 import org.apache.helix.manager.zk.ZNRecordSerializer;
 import org.apache.helix.manager.zk.ZkBaseDataAccessor;
@@ -72,13 +74,15 @@ public class SegmentPrunerTest extends ControllerTest {
   private static final String RAW_TABLE_NAME = "testTable";
   private static final String OFFLINE_TABLE_NAME = "testTable_OFFLINE";
   private static final String REALTIME_TABLE_NAME = "testTable_REALTIME";
-  private static final String PARTITION_COLUMN = "memberId";
+  private static final String PARTITION_COLUMN_1 = "memberId";
+  private static final String PARTITION_COLUMN_2 = "memberName";
   private static final String TIME_COLUMN = "timeColumn";
   private static final String SDF_PATTERN = "yyyyMMdd";
 
   private static final String QUERY_1 = "SELECT * FROM testTable";
   private static final String QUERY_2 = "SELECT * FROM testTable where memberId = 0";
   private static final String QUERY_3 = "SELECT * FROM testTable where memberId IN (1, 2)";
+  private static final String QUERY_4 = "SELECT * FROM testTable where memberId = 0 AND memberName='xyz'";
 
   private static final String TIME_QUERY_1 = "SELECT * FROM testTable where timeColumn = 40";
   private static final String TIME_QUERY_2 = "SELECT * FROM testTable where timeColumn BETWEEN 20 AND 30";
@@ -151,18 +155,23 @@ public class SegmentPrunerTest extends ControllerTest {
     assertEquals(segmentPruners.size(), 0);
 
     // Partition-aware segment pruner should be returned
-    columnPartitionConfigMap.put(PARTITION_COLUMN, new ColumnPartitionConfig("Modulo", 5));
+    columnPartitionConfigMap.put(PARTITION_COLUMN_1, new ColumnPartitionConfig("Modulo", 5));
     segmentPruners = SegmentPrunerFactory.getSegmentPruners(tableConfig, _propertyStore);
     assertEquals(segmentPruners.size(), 1);
-    assertTrue(segmentPruners.get(0) instanceof PartitionSegmentPruner);
+    assertTrue(segmentPruners.get(0) instanceof SinglePartitionColumnSegmentPruner);
 
-    // Do not allow multiple partition columns
-    columnPartitionConfigMap.put("anotherPartitionColumn", new ColumnPartitionConfig("Modulo", 5));
+    // Multiple partition columns
+    columnPartitionConfigMap.put(PARTITION_COLUMN_2, new ColumnPartitionConfig("Modulo", 5));
     segmentPruners = SegmentPrunerFactory.getSegmentPruners(tableConfig, _propertyStore);
-    assertEquals(segmentPruners.size(), 0);
+    assertEquals(segmentPruners.size(), 1);
+    assertTrue(segmentPruners.get(0) instanceof MultiPartitionColumnsSegmentPruner);
+    MultiPartitionColumnsSegmentPruner partitionSegmentPruner =
+        (MultiPartitionColumnsSegmentPruner) segmentPruners.get(0);
+    assertEquals(partitionSegmentPruner.getPartitionColumns(),
+        Stream.of(PARTITION_COLUMN_1, PARTITION_COLUMN_2).collect(Collectors.toSet()));
 
     // Should be backward-compatible with legacy config
-    columnPartitionConfigMap.remove("anotherPartitionColumn");
+    columnPartitionConfigMap.remove(PARTITION_COLUMN_1);
     when(routingConfig.getSegmentPrunerTypes()).thenReturn(null);
     segmentPruners = SegmentPrunerFactory.getSegmentPruners(tableConfig, _propertyStore);
     assertEquals(segmentPruners.size(), 0);
@@ -170,13 +179,13 @@ public class SegmentPrunerTest extends ControllerTest {
     when(routingConfig.getRoutingTableBuilderName()).thenReturn(
         SegmentPrunerFactory.LEGACY_PARTITION_AWARE_OFFLINE_ROUTING);
     segmentPruners = SegmentPrunerFactory.getSegmentPruners(tableConfig, _propertyStore);
-    assertTrue(segmentPruners.get(0) instanceof PartitionSegmentPruner);
+    assertTrue(segmentPruners.get(0) instanceof SinglePartitionColumnSegmentPruner);
     when(tableConfig.getTableType()).thenReturn(TableType.REALTIME);
     when(routingConfig.getRoutingTableBuilderName()).thenReturn(
         SegmentPrunerFactory.LEGACY_PARTITION_AWARE_REALTIME_ROUTING);
     segmentPruners = SegmentPrunerFactory.getSegmentPruners(tableConfig, _propertyStore);
     assertEquals(segmentPruners.size(), 1);
-    assertTrue(segmentPruners.get(0) instanceof PartitionSegmentPruner);
+    assertTrue(segmentPruners.get(0) instanceof SinglePartitionColumnSegmentPruner);
   }
 
   @Test
@@ -256,25 +265,29 @@ public class SegmentPrunerTest extends ControllerTest {
     BrokerRequest brokerRequest1 = compiler.compileToBrokerRequest(QUERY_1);
     BrokerRequest brokerRequest2 = compiler.compileToBrokerRequest(QUERY_2);
     BrokerRequest brokerRequest3 = compiler.compileToBrokerRequest(QUERY_3);
+    BrokerRequest brokerRequest4 = compiler.compileToBrokerRequest(QUERY_4);
     // NOTE: Ideal state and external view are not used in the current implementation
     IdealState idealState = Mockito.mock(IdealState.class);
     ExternalView externalView = Mockito.mock(ExternalView.class);
 
-    PartitionSegmentPruner segmentPruner =
-        new PartitionSegmentPruner(OFFLINE_TABLE_NAME, PARTITION_COLUMN, _propertyStore);
+    SinglePartitionColumnSegmentPruner singlePartitionColumnSegmentPruner =
+        new SinglePartitionColumnSegmentPruner(OFFLINE_TABLE_NAME, PARTITION_COLUMN_1, _propertyStore);
     Set<String> onlineSegments = new HashSet<>();
-    segmentPruner.init(idealState, externalView, onlineSegments);
-    assertEquals(segmentPruner.prune(brokerRequest1, Collections.emptySet()), Collections.emptySet());
-    assertEquals(segmentPruner.prune(brokerRequest2, Collections.emptySet()), Collections.emptySet());
-    assertEquals(segmentPruner.prune(brokerRequest3, Collections.emptySet()), Collections.emptySet());
+    singlePartitionColumnSegmentPruner.init(idealState, externalView, onlineSegments);
+    assertEquals(singlePartitionColumnSegmentPruner.prune(brokerRequest1, Collections.emptySet()),
+        Collections.emptySet());
+    assertEquals(singlePartitionColumnSegmentPruner.prune(brokerRequest2, Collections.emptySet()),
+        Collections.emptySet());
+    assertEquals(singlePartitionColumnSegmentPruner.prune(brokerRequest3, Collections.emptySet()),
+        Collections.emptySet());
 
     // Segments without metadata (not updated yet) should not be pruned
     String newSegment = "newSegment";
-    assertEquals(segmentPruner.prune(brokerRequest1, Collections.singleton(newSegment)),
+    assertEquals(singlePartitionColumnSegmentPruner.prune(brokerRequest1, Collections.singleton(newSegment)),
         Collections.singletonList(newSegment));
-    assertEquals(segmentPruner.prune(brokerRequest2, Collections.singleton(newSegment)),
+    assertEquals(singlePartitionColumnSegmentPruner.prune(brokerRequest2, Collections.singleton(newSegment)),
         Collections.singletonList(newSegment));
-    assertEquals(segmentPruner.prune(brokerRequest3, Collections.singleton(newSegment)),
+    assertEquals(singlePartitionColumnSegmentPruner.prune(brokerRequest3, Collections.singleton(newSegment)),
         Collections.singletonList(newSegment));
 
     // Segments without partition metadata should not be pruned
@@ -284,15 +297,15 @@ public class SegmentPrunerTest extends ControllerTest {
         new SegmentZKMetadata(segmentWithoutPartitionMetadata);
     ZKMetadataProvider.setSegmentZKMetadata(_propertyStore, OFFLINE_TABLE_NAME,
         segmentZKMetadataWithoutPartitionMetadata);
-    segmentPruner.onAssignmentChange(idealState, externalView, onlineSegments);
-    assertEquals(
-        segmentPruner.prune(brokerRequest1, new HashSet<>(Collections.singletonList(segmentWithoutPartitionMetadata))),
+    singlePartitionColumnSegmentPruner.onAssignmentChange(idealState, externalView, onlineSegments);
+    assertEquals(singlePartitionColumnSegmentPruner.prune(brokerRequest1,
+            new HashSet<>(Collections.singletonList(segmentWithoutPartitionMetadata))),
         Collections.singletonList(segmentWithoutPartitionMetadata));
-    assertEquals(
-        segmentPruner.prune(brokerRequest2, new HashSet<>(Collections.singletonList(segmentWithoutPartitionMetadata))),
+    assertEquals(singlePartitionColumnSegmentPruner.prune(brokerRequest2,
+            new HashSet<>(Collections.singletonList(segmentWithoutPartitionMetadata))),
         Collections.singletonList(segmentWithoutPartitionMetadata));
-    assertEquals(
-        segmentPruner.prune(brokerRequest3, new HashSet<>(Collections.singletonList(segmentWithoutPartitionMetadata))),
+    assertEquals(singlePartitionColumnSegmentPruner.prune(brokerRequest3,
+            new HashSet<>(Collections.singletonList(segmentWithoutPartitionMetadata))),
         Collections.singletonList(segmentWithoutPartitionMetadata));
 
     // Test different partition functions and number of partitions
@@ -304,32 +317,79 @@ public class SegmentPrunerTest extends ControllerTest {
     String segment1 = "segment1";
     onlineSegments.add(segment1);
     setSegmentZKPartitionMetadata(OFFLINE_TABLE_NAME, segment1, "Murmur", 4, 0);
-    segmentPruner.onAssignmentChange(idealState, externalView, onlineSegments);
-    assertEquals(segmentPruner.prune(brokerRequest1, new HashSet<>(Arrays.asList(segment0, segment1))),
+    singlePartitionColumnSegmentPruner.onAssignmentChange(idealState, externalView, onlineSegments);
+    assertEquals(
+        singlePartitionColumnSegmentPruner.prune(brokerRequest1, new HashSet<>(Arrays.asList(segment0, segment1))),
         new HashSet<>(Arrays.asList(segment0, segment1)));
-    assertEquals(segmentPruner.prune(brokerRequest2, new HashSet<>(Arrays.asList(segment0, segment1))),
+    assertEquals(
+        singlePartitionColumnSegmentPruner.prune(brokerRequest2, new HashSet<>(Arrays.asList(segment0, segment1))),
         new HashSet<>(Arrays.asList(segment0, segment1)));
-    assertEquals(segmentPruner.prune(brokerRequest3, new HashSet<>(Arrays.asList(segment0, segment1))),
+    assertEquals(
+        singlePartitionColumnSegmentPruner.prune(brokerRequest3, new HashSet<>(Arrays.asList(segment0, segment1))),
         new HashSet<>(Collections.singletonList(segment1)));
 
     // Update partition metadata without refreshing should have no effect
     setSegmentZKPartitionMetadata(OFFLINE_TABLE_NAME, segment0, "Modulo", 4, 1);
-    segmentPruner.onAssignmentChange(idealState, externalView, onlineSegments);
-    assertEquals(segmentPruner.prune(brokerRequest1, new HashSet<>(Arrays.asList(segment0, segment1))),
+    singlePartitionColumnSegmentPruner.onAssignmentChange(idealState, externalView, onlineSegments);
+    assertEquals(
+        singlePartitionColumnSegmentPruner.prune(brokerRequest1, new HashSet<>(Arrays.asList(segment0, segment1))),
         new HashSet<>(Arrays.asList(segment0, segment1)));
-    assertEquals(segmentPruner.prune(brokerRequest2, new HashSet<>(Arrays.asList(segment0, segment1))),
+    assertEquals(
+        singlePartitionColumnSegmentPruner.prune(brokerRequest2, new HashSet<>(Arrays.asList(segment0, segment1))),
         new HashSet<>(Arrays.asList(segment0, segment1)));
-    assertEquals(segmentPruner.prune(brokerRequest3, new HashSet<>(Arrays.asList(segment0, segment1))),
+    assertEquals(
+        singlePartitionColumnSegmentPruner.prune(brokerRequest3, new HashSet<>(Arrays.asList(segment0, segment1))),
         new HashSet<>(Collections.singletonList(segment1)));
 
     // Refresh the changed segment should update the segment pruner
-    segmentPruner.refreshSegment(segment0);
-    assertEquals(segmentPruner.prune(brokerRequest1, new HashSet<>(Arrays.asList(segment0, segment1))),
+    singlePartitionColumnSegmentPruner.refreshSegment(segment0);
+    assertEquals(
+        singlePartitionColumnSegmentPruner.prune(brokerRequest1, new HashSet<>(Arrays.asList(segment0, segment1))),
         new HashSet<>(Arrays.asList(segment0, segment1)));
-    assertEquals(segmentPruner.prune(brokerRequest2, new HashSet<>(Arrays.asList(segment0, segment1))),
+    assertEquals(
+        singlePartitionColumnSegmentPruner.prune(brokerRequest2, new HashSet<>(Arrays.asList(segment0, segment1))),
         new HashSet<>(Collections.singletonList(segment1)));
-    assertEquals(segmentPruner.prune(brokerRequest3, new HashSet<>(Arrays.asList(segment0, segment1))),
+    assertEquals(
+        singlePartitionColumnSegmentPruner.prune(brokerRequest3, new HashSet<>(Arrays.asList(segment0, segment1))),
         new HashSet<>(Arrays.asList(segment0, segment1)));
+
+    // Multi-column partitioned segment.
+    MultiPartitionColumnsSegmentPruner multiPartitionColumnsSegmentPruner =
+        new MultiPartitionColumnsSegmentPruner(OFFLINE_TABLE_NAME,
+            Stream.of(PARTITION_COLUMN_1, PARTITION_COLUMN_2).collect(Collectors.toSet()), _propertyStore);
+    multiPartitionColumnsSegmentPruner.init(idealState, externalView, onlineSegments);
+    assertEquals(multiPartitionColumnsSegmentPruner.prune(brokerRequest1, Collections.emptySet()),
+        Collections.emptySet());
+    assertEquals(multiPartitionColumnsSegmentPruner.prune(brokerRequest2, Collections.emptySet()),
+        Collections.emptySet());
+    assertEquals(multiPartitionColumnsSegmentPruner.prune(brokerRequest3, Collections.emptySet()),
+        Collections.emptySet());
+    assertEquals(multiPartitionColumnsSegmentPruner.prune(brokerRequest4, Collections.emptySet()),
+        Collections.emptySet());
+
+    String segment2 = "segment2";
+    onlineSegments.add(segment2);
+    Map<String, ColumnPartitionMetadata> columnPartitionMetadataMap = new HashMap<>();
+    columnPartitionMetadataMap.put(PARTITION_COLUMN_1,
+        new ColumnPartitionMetadata("Modulo", 4, Collections.singleton(0), null));
+    Map<String, String> partitionColumn2FunctionConfig = new HashMap<>();
+    partitionColumn2FunctionConfig.put("columnValues", "xyz|abc");
+    partitionColumn2FunctionConfig.put("columnValuesDelimiter", "|");
+    columnPartitionMetadataMap.put(PARTITION_COLUMN_2,
+        new ColumnPartitionMetadata("BoundedColumnValue", 3, Collections.singleton(1), partitionColumn2FunctionConfig));
+    setSegmentZKPartitionMetadata(OFFLINE_TABLE_NAME, segment2, columnPartitionMetadataMap);
+    multiPartitionColumnsSegmentPruner.onAssignmentChange(idealState, externalView, onlineSegments);
+    assertEquals(
+        multiPartitionColumnsSegmentPruner.prune(brokerRequest1, new HashSet<>(Arrays.asList(segment0, segment1))),
+        new HashSet<>(Arrays.asList(segment0, segment1)));
+    assertEquals(
+        multiPartitionColumnsSegmentPruner.prune(brokerRequest2, new HashSet<>(Arrays.asList(segment0, segment1))),
+        new HashSet<>(Collections.singletonList(segment1)));
+    assertEquals(
+        multiPartitionColumnsSegmentPruner.prune(brokerRequest3, new HashSet<>(Arrays.asList(segment0, segment1))),
+        new HashSet<>(Arrays.asList(segment0, segment1)));
+    assertEquals(multiPartitionColumnsSegmentPruner.prune(brokerRequest4,
+        new HashSet<>(Arrays.asList(segment0, segment1, segment2))), new HashSet<>(Arrays.asList(segment1, segment2)));
   }
 
   @Test(dataProvider = "compilerProvider")
@@ -680,11 +740,18 @@ public class SegmentPrunerTest extends ControllerTest {
   private void setSegmentZKPartitionMetadata(String tableNameWithType, String segment, String partitionFunction,
       int numPartitions, int partitionId) {
     SegmentZKMetadata segmentZKMetadata = new SegmentZKMetadata(segment);
-    segmentZKMetadata.setPartitionMetadata(new SegmentPartitionMetadata(Collections.singletonMap(PARTITION_COLUMN,
+    segmentZKMetadata.setPartitionMetadata(new SegmentPartitionMetadata(Collections.singletonMap(PARTITION_COLUMN_1,
         new ColumnPartitionMetadata(partitionFunction, numPartitions, Collections.singleton(partitionId), null))));
     ZKMetadataProvider.setSegmentZKMetadata(_propertyStore, tableNameWithType, segmentZKMetadata);
   }
 
+  private void setSegmentZKPartitionMetadata(String tableNameWithType, String segment,
+      Map<String, ColumnPartitionMetadata> columnPartitionMap) {
+    SegmentZKMetadata segmentZKMetadata = new SegmentZKMetadata(segment);
+    segmentZKMetadata.setPartitionMetadata(new SegmentPartitionMetadata(columnPartitionMap));
+    ZKMetadataProvider.setSegmentZKMetadata(_propertyStore, tableNameWithType, segmentZKMetadata);
+  }
+
   private void setSegmentZKTimeRangeMetadata(String tableNameWithType, String segment, long startTime, long endTime,
       TimeUnit unit) {
     SegmentZKMetadata segmentZKMetadata = new SegmentZKMetadata(segment);
diff --git a/pinot-integration-test-base/src/test/java/org/apache/pinot/integration/tests/BaseClusterIntegrationTest.java b/pinot-integration-test-base/src/test/java/org/apache/pinot/integration/tests/BaseClusterIntegrationTest.java
index 970765d3ea..09ae9ef662 100644
--- a/pinot-integration-test-base/src/test/java/org/apache/pinot/integration/tests/BaseClusterIntegrationTest.java
+++ b/pinot-integration-test-base/src/test/java/org/apache/pinot/integration/tests/BaseClusterIntegrationTest.java
@@ -260,6 +260,11 @@ public abstract class BaseClusterIntegrationTest extends ClusterTest {
     return DEFAULT_NULL_HANDLING_ENABLED;
   }
 
+  @Nullable
+  protected SegmentPartitionConfig getSegmentPartitionConfig() {
+    return null;
+  }
+
   /**
    * The following methods are based on the getters. Override the getters for non-default settings before calling these
    * methods.
@@ -294,7 +299,7 @@ public abstract class BaseClusterIntegrationTest extends ClusterTest {
         .setFieldConfigList(getFieldConfigs()).setNumReplicas(getNumReplicas()).setSegmentVersion(getSegmentVersion())
         .setLoadMode(getLoadMode()).setTaskConfig(getTaskConfig()).setBrokerTenant(getBrokerTenant())
         .setServerTenant(getServerTenant()).setIngestionConfig(getIngestionConfig())
-        .setNullHandlingEnabled(getNullHandlingEnabled()).build();
+        .setNullHandlingEnabled(getNullHandlingEnabled()).setSegmentPartitionConfig(getSegmentPartitionConfig()).build();
   }
 
   /**
diff --git a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/MergeRollupMinionClusterIntegrationTest.java b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/MergeRollupMinionClusterIntegrationTest.java
index 1f3d18910b..70a063153c 100644
--- a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/MergeRollupMinionClusterIntegrationTest.java
+++ b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/MergeRollupMinionClusterIntegrationTest.java
@@ -28,6 +28,7 @@ import java.util.Map;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
 import java.util.concurrent.Future;
+import javax.annotation.Nullable;
 import org.apache.commons.io.FileUtils;
 import org.apache.helix.task.TaskState;
 import org.apache.pinot.common.lineage.SegmentLineageAccessHelper;
@@ -43,6 +44,8 @@ import org.apache.pinot.core.common.MinionConstants;
 import org.apache.pinot.segment.local.segment.creator.impl.SegmentIndexCreationDriverImpl;
 import org.apache.pinot.segment.spi.creator.SegmentGeneratorConfig;
 import org.apache.pinot.segment.spi.creator.SegmentIndexCreationDriver;
+import org.apache.pinot.spi.config.table.ColumnPartitionConfig;
+import org.apache.pinot.spi.config.table.SegmentPartitionConfig;
 import org.apache.pinot.spi.config.table.TableConfig;
 import org.apache.pinot.spi.config.table.TableTaskConfig;
 import org.apache.pinot.spi.config.table.TableType;
@@ -97,7 +100,8 @@ public class MergeRollupMinionClusterIntegrationTest extends BaseClusterIntegrat
     TableConfig singleLevelConcatTableConfig =
         createOfflineTableConfig(SINGLE_LEVEL_CONCAT_TEST_TABLE, getSingleLevelConcatTaskConfig());
     TableConfig singleLevelRollupTableConfig =
-        createOfflineTableConfig(SINGLE_LEVEL_ROLLUP_TEST_TABLE, getSingleLevelRollupTaskConfig());
+        createOfflineTableConfig(SINGLE_LEVEL_ROLLUP_TEST_TABLE, getSingleLevelRollupTaskConfig(),
+            getMultiColumnsSegmentPartitionConfig());
     TableConfig multiLevelConcatTableConfig =
         createOfflineTableConfig(MULTI_LEVEL_CONCAT_TEST_TABLE, getMultiLevelConcatTaskConfig());
     addTableConfig(singleLevelConcatTableConfig);
@@ -131,6 +135,11 @@ public class MergeRollupMinionClusterIntegrationTest extends BaseClusterIntegrat
   }
 
   private TableConfig createOfflineTableConfig(String tableName, TableTaskConfig taskConfig) {
+    return createOfflineTableConfig(tableName, taskConfig, null);
+  }
+
+  private TableConfig createOfflineTableConfig(String tableName, TableTaskConfig taskConfig,
+      @Nullable SegmentPartitionConfig partitionConfig) {
     return new TableConfigBuilder(TableType.OFFLINE).setTableName(tableName).setSchemaName(getSchemaName())
         .setTimeColumnName(getTimeColumnName()).setSortedColumn(getSortedColumn())
         .setInvertedIndexColumns(getInvertedIndexColumns()).setNoDictionaryColumns(getNoDictionaryColumns())
@@ -138,7 +147,7 @@ public class MergeRollupMinionClusterIntegrationTest extends BaseClusterIntegrat
         .setFieldConfigList(getFieldConfigs()).setNumReplicas(getNumReplicas()).setSegmentVersion(getSegmentVersion())
         .setLoadMode(getLoadMode()).setTaskConfig(taskConfig).setBrokerTenant(getBrokerTenant())
         .setServerTenant(getServerTenant()).setIngestionConfig(getIngestionConfig())
-        .setNullHandlingEnabled(getNullHandlingEnabled()).build();
+        .setNullHandlingEnabled(getNullHandlingEnabled()).setSegmentPartitionConfig(partitionConfig).build();
   }
 
   private TableTaskConfig getSingleLevelConcatTaskConfig() {
@@ -178,6 +187,15 @@ public class MergeRollupMinionClusterIntegrationTest extends BaseClusterIntegrat
     return new TableTaskConfig(Collections.singletonMap(MinionConstants.MergeRollupTask.TASK_TYPE, tableTaskConfigs));
   }
 
+  private SegmentPartitionConfig getMultiColumnsSegmentPartitionConfig() {
+    Map<String, ColumnPartitionConfig> columnPartitionConfigMap = new HashMap<>();
+    ColumnPartitionConfig columnOneConfig = new ColumnPartitionConfig("murmur", 1);
+    columnPartitionConfigMap.put("AirlineID", columnOneConfig);
+    ColumnPartitionConfig columnTwoConfig = new ColumnPartitionConfig("murmur", 1);
+    columnPartitionConfigMap.put("Month", columnTwoConfig);
+    return new SegmentPartitionConfig(columnPartitionConfigMap);
+  }
+
   private static void buildSegmentsFromAvroWithPostfix(List<File> avroFiles, TableConfig tableConfig,
       org.apache.pinot.spi.data.Schema schema, int baseSegmentIndex, File segmentDir, File tarDir, String postfix)
       throws Exception {
diff --git a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/RealtimeToOfflineSegmentsMinionClusterIntegrationTest.java b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/RealtimeToOfflineSegmentsMinionClusterIntegrationTest.java
index 242195d681..4379dd88ec 100644
--- a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/RealtimeToOfflineSegmentsMinionClusterIntegrationTest.java
+++ b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/RealtimeToOfflineSegmentsMinionClusterIntegrationTest.java
@@ -22,6 +22,8 @@ import java.io.IOException;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
+import java.util.Map;
+import javax.annotation.Nullable;
 import org.apache.commons.io.FileUtils;
 import org.apache.helix.ZNRecord;
 import org.apache.helix.task.TaskState;
@@ -32,6 +34,8 @@ import org.apache.pinot.controller.helix.core.PinotHelixResourceManager;
 import org.apache.pinot.controller.helix.core.minion.PinotHelixTaskResourceManager;
 import org.apache.pinot.controller.helix.core.minion.PinotTaskManager;
 import org.apache.pinot.core.common.MinionConstants;
+import org.apache.pinot.spi.config.table.ColumnPartitionConfig;
+import org.apache.pinot.spi.config.table.SegmentPartitionConfig;
 import org.apache.pinot.spi.config.table.TableTaskConfig;
 import org.apache.pinot.spi.utils.CommonConstants;
 import org.apache.pinot.spi.utils.builder.TableNameBuilder;
@@ -99,8 +103,16 @@ public class RealtimeToOfflineSegmentsMinionClusterIntegrationTest extends Realt
     List<SegmentZKMetadata> segmentsZKMetadata = _pinotHelixResourceManager.getSegmentsZKMetadata(_offlineTableName);
     Assert.assertTrue(segmentsZKMetadata.isEmpty());
 
+    // The number of offline segments would be equal to the product of number of partitions for all the
+    // partition columns if segment partitioning is configured.
+    SegmentPartitionConfig segmentPartitionConfig =
+        getOfflineTableConfig().getIndexingConfig().getSegmentPartitionConfig();
+    int numOfflineSegmentsPerTask =
+        segmentPartitionConfig != null ? segmentPartitionConfig.getColumnPartitionMap().values().stream()
+            .map(ColumnPartitionConfig::getNumPartitions).reduce((a, b) -> a * b)
+            .orElseThrow(() -> new RuntimeException("Expected accumulated result but not found.")) : 1;
+
     long expectedWatermark = _dataSmallestTimeMs + 86400000;
-    int numOfflineSegments = 0;
     for (int i = 0; i < 3; i++) {
       // Schedule task
       Assert.assertNotNull(_taskManager.scheduleTasks().get(MinionConstants.RealtimeToOfflineSegmentsTask.TASK_TYPE));
@@ -113,12 +125,21 @@ public class RealtimeToOfflineSegmentsMinionClusterIntegrationTest extends Realt
       waitForTaskToComplete(expectedWatermark);
       // check segment is in offline
       segmentsZKMetadata = _pinotHelixResourceManager.getSegmentsZKMetadata(_offlineTableName);
-      numOfflineSegments++;
-      Assert.assertEquals(segmentsZKMetadata.size(), numOfflineSegments);
-      long expectedOfflineSegmentTimeMs = expectedWatermark - 86400000;
-      Assert.assertEquals(segmentsZKMetadata.get(i).getStartTimeMs(), expectedOfflineSegmentTimeMs);
-      Assert.assertEquals(segmentsZKMetadata.get(i).getEndTimeMs(), expectedOfflineSegmentTimeMs);
+      Assert.assertEquals(segmentsZKMetadata.size(), (numOfflineSegmentsPerTask * (i + 1)));
 
+      long expectedOfflineSegmentTimeMs = expectedWatermark - 86400000;
+      for (int j = (numOfflineSegmentsPerTask * i); j < segmentsZKMetadata.size(); j++) {
+        SegmentZKMetadata segmentZKMetadata = segmentsZKMetadata.get(j);
+        Assert.assertEquals(segmentZKMetadata.getStartTimeMs(), expectedOfflineSegmentTimeMs);
+        Assert.assertEquals(segmentZKMetadata.getEndTimeMs(), expectedOfflineSegmentTimeMs);
+        if (segmentPartitionConfig != null) {
+          Assert.assertEquals(segmentZKMetadata.getPartitionMetadata().getColumnPartitionMap().keySet(),
+              segmentPartitionConfig.getColumnPartitionMap().keySet());
+          for (String partitionColumn : segmentPartitionConfig.getColumnPartitionMap().keySet()) {
+            Assert.assertEquals(segmentZKMetadata.getPartitionMetadata().getPartitions(partitionColumn).size(), 1);
+          }
+        }
+      }
       expectedWatermark += 86400000;
     }
     this.testHardcodedQueries();
@@ -130,6 +151,17 @@ public class RealtimeToOfflineSegmentsMinionClusterIntegrationTest extends Realt
     verifyTableDelete(_realtimeTableName);
   }
 
+  @Nullable
+  @Override
+  protected SegmentPartitionConfig getSegmentPartitionConfig() {
+    Map<String, ColumnPartitionConfig> columnPartitionConfigMap = new HashMap<>();
+    ColumnPartitionConfig columnOneConfig = new ColumnPartitionConfig("murmur", 3);
+    columnPartitionConfigMap.put("AirlineID", columnOneConfig);
+    ColumnPartitionConfig columnTwoConfig = new ColumnPartitionConfig("hashcode", 2);
+    columnPartitionConfigMap.put("OriginAirportID", columnTwoConfig);
+    return new SegmentPartitionConfig(columnPartitionConfigMap);
+  }
+
   protected void verifyTableDelete(String tableNameWithType) {
     TestUtils.waitForCondition(input -> {
       // Check if the task metadata is cleaned up
diff --git a/pinot-plugins/pinot-minion-tasks/pinot-minion-builtin-tasks/src/main/java/org/apache/pinot/plugin/minion/tasks/MergeTaskUtils.java b/pinot-plugins/pinot-minion-tasks/pinot-minion-builtin-tasks/src/main/java/org/apache/pinot/plugin/minion/tasks/MergeTaskUtils.java
index 340b143763..7c027af227 100644
--- a/pinot-plugins/pinot-minion-tasks/pinot-minion-builtin-tasks/src/main/java/org/apache/pinot/plugin/minion/tasks/MergeTaskUtils.java
+++ b/pinot-plugins/pinot-minion-tasks/pinot-minion-builtin-tasks/src/main/java/org/apache/pinot/plugin/minion/tasks/MergeTaskUtils.java
@@ -19,6 +19,7 @@
 package org.apache.pinot.plugin.minion.tasks;
 
 import com.google.common.base.Preconditions;
+import java.util.ArrayList;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
@@ -63,9 +64,8 @@ public class MergeTaskUtils {
       return null;
     }
     DateTimeFieldSpec fieldSpec = schema.getSpecForTimeColumn(timeColumn);
-    Preconditions
-        .checkState(fieldSpec != null, "No valid spec found for time column: %s in schema for table: %s", timeColumn,
-            tableConfig.getTableName());
+    Preconditions.checkState(fieldSpec != null, "No valid spec found for time column: %s in schema for table: %s",
+        timeColumn, tableConfig.getTableName());
 
     TimeHandlerConfig.Builder timeHandlerConfigBuilder = new TimeHandlerConfig.Builder(TimeHandler.Type.EPOCH);
 
@@ -97,17 +97,19 @@ public class MergeTaskUtils {
     if (segmentPartitionConfig == null) {
       return Collections.emptyList();
     }
+    List<PartitionerConfig> partitionerConfigs = new ArrayList<>();
     Map<String, ColumnPartitionConfig> columnPartitionMap = segmentPartitionConfig.getColumnPartitionMap();
-    Preconditions.checkState(columnPartitionMap.size() == 1, "Cannot partition on multiple columns for table: %s",
-        tableConfig.getTableName());
-    Map.Entry<String, ColumnPartitionConfig> entry = columnPartitionMap.entrySet().iterator().next();
-    String partitionColumn = entry.getKey();
-    Preconditions.checkState(schema.hasColumn(partitionColumn),
-        "Partition column: %s does not exist in the schema for table: %s", partitionColumn, tableConfig.getTableName());
-    PartitionerConfig partitionerConfig =
-        new PartitionerConfig.Builder().setPartitionerType(PartitionerFactory.PartitionerType.TABLE_PARTITION_CONFIG)
-            .setColumnName(partitionColumn).setColumnPartitionConfig(entry.getValue()).build();
-    return Collections.singletonList(partitionerConfig);
+    for (Map.Entry<String, ColumnPartitionConfig> entry : columnPartitionMap.entrySet()) {
+      String partitionColumn = entry.getKey();
+      Preconditions.checkState(schema.hasColumn(partitionColumn),
+          "Partition column: %s does not exist in the schema for table: %s", partitionColumn,
+          tableConfig.getTableName());
+      PartitionerConfig partitionerConfig =
+          new PartitionerConfig.Builder().setPartitionerType(PartitionerFactory.PartitionerType.TABLE_PARTITION_CONFIG)
+              .setColumnName(partitionColumn).setColumnPartitionConfig(entry.getValue()).build();
+      partitionerConfigs.add(partitionerConfig);
+    }
+    return partitionerConfigs;
   }
 
   /**
diff --git a/pinot-plugins/pinot-minion-tasks/pinot-minion-builtin-tasks/src/main/java/org/apache/pinot/plugin/minion/tasks/mergerollup/MergeRollupTaskGenerator.java b/pinot-plugins/pinot-minion-tasks/pinot-minion-builtin-tasks/src/main/java/org/apache/pinot/plugin/minion/tasks/mergerollup/MergeRollupTaskGenerator.java
index de05a9c908..8b5b5baf3d 100644
--- a/pinot-plugins/pinot-minion-tasks/pinot-minion-builtin-tasks/src/main/java/org/apache/pinot/plugin/minion/tasks/mergerollup/MergeRollupTaskGenerator.java
+++ b/pinot-plugins/pinot-minion-tasks/pinot-minion-builtin-tasks/src/main/java/org/apache/pinot/plugin/minion/tasks/mergerollup/MergeRollupTaskGenerator.java
@@ -18,7 +18,6 @@
  */
 package org.apache.pinot.plugin.minion.tasks.mergerollup;
 
-import com.google.common.base.Preconditions;
 import java.util.ArrayList;
 import java.util.Comparator;
 import java.util.HashMap;
@@ -353,33 +352,41 @@ public class MergeRollupTaskGenerator extends BaseTaskGenerator {
                     mergeConfigs, taskConfigs));
           }
         } else {
-          // For partitioned table, schedule separate tasks for each partition
+          // For partitioned table, schedule separate tasks for each partitionId (partitionId is constructed from
+          // partitions of all partition columns. There should be exact match between partition columns of segment and
+          // partition columns of table configuration, and there is only partition per column in segment metadata).
+          // Other segments which do not meet these conditions are considered as outlier segments, and additional tasks
+          // are generated for them.
           Map<String, ColumnPartitionConfig> columnPartitionMap = segmentPartitionConfig.getColumnPartitionMap();
-          Preconditions.checkState(columnPartitionMap.size() == 1, "Cannot partition on multiple columns for table: %s",
-              tableConfig.getTableName());
-          Map.Entry<String, ColumnPartitionConfig> partitionEntry = columnPartitionMap.entrySet().iterator().next();
-          String partitionColumn = partitionEntry.getKey();
-
+          List<String> partitionColumns = new ArrayList<>(columnPartitionMap.keySet());
           for (List<SegmentZKMetadata> selectedSegmentsPerBucket : selectedSegmentsForAllBuckets) {
-            Map<Integer, List<SegmentZKMetadata>> partitionToSegments = new HashMap<>();
-            // Handle segments that have multiple partitions or no partition info
+            Map<List<Integer>, List<SegmentZKMetadata>> partitionToSegments = new HashMap<>();
             List<SegmentZKMetadata> outlierSegments = new ArrayList<>();
             for (SegmentZKMetadata selectedSegment : selectedSegmentsPerBucket) {
               SegmentPartitionMetadata segmentPartitionMetadata = selectedSegment.getPartitionMetadata();
-              if (segmentPartitionMetadata == null
-                  || segmentPartitionMetadata.getPartitions(partitionColumn).size() != 1) {
+              List<Integer> partitions = new ArrayList<>();
+              if (segmentPartitionMetadata != null && columnPartitionMap.keySet()
+                  .equals(segmentPartitionMetadata.getColumnPartitionMap().keySet())) {
+                for (String partitionColumn : partitionColumns) {
+                  if (segmentPartitionMetadata.getPartitions(partitionColumn).size() == 1) {
+                    partitions.add(segmentPartitionMetadata.getPartitions(partitionColumn).iterator().next());
+                  } else {
+                    partitions.clear();
+                    break;
+                  }
+                }
+              }
+              if (partitions.isEmpty()) {
                 outlierSegments.add(selectedSegment);
               } else {
-                int partition = segmentPartitionMetadata.getPartitions(partitionColumn).iterator().next();
-                partitionToSegments.computeIfAbsent(partition, k -> new ArrayList<>()).add(selectedSegment);
+                partitionToSegments.computeIfAbsent(partitions, k -> new ArrayList<>()).add(selectedSegment);
               }
             }
 
-            for (Map.Entry<Integer, List<SegmentZKMetadata>> partitionToSegmentsEntry
-                : partitionToSegments.entrySet()) {
+            for (List<SegmentZKMetadata> partitionedSegments : partitionToSegments.values()) {
               pinotTaskConfigsForTable.addAll(
-                  createPinotTaskConfigs(partitionToSegmentsEntry.getValue(), offlineTableName, maxNumRecordsPerTask,
-                      mergeLevel, mergeConfigs, taskConfigs));
+                  createPinotTaskConfigs(partitionedSegments, offlineTableName, maxNumRecordsPerTask, mergeLevel,
+                      mergeConfigs, taskConfigs));
             }
 
             if (!outlierSegments.isEmpty()) {
diff --git a/pinot-plugins/pinot-minion-tasks/pinot-minion-builtin-tasks/src/test/java/org/apache/pinot/plugin/minion/tasks/MergeTaskUtilsTest.java b/pinot-plugins/pinot-minion-tasks/pinot-minion-builtin-tasks/src/test/java/org/apache/pinot/plugin/minion/tasks/MergeTaskUtilsTest.java
index 61be5686af..bc6b897211 100644
--- a/pinot-plugins/pinot-minion-tasks/pinot-minion-builtin-tasks/src/test/java/org/apache/pinot/plugin/minion/tasks/MergeTaskUtilsTest.java
+++ b/pinot-plugins/pinot-minion-tasks/pinot-minion-builtin-tasks/src/test/java/org/apache/pinot/plugin/minion/tasks/MergeTaskUtilsTest.java
@@ -102,6 +102,20 @@ public class MergeTaskUtilsTest {
     assertEquals(columnPartitionConfig.getFunctionName(), "murmur");
     assertEquals(columnPartitionConfig.getNumPartitions(), 10);
 
+    // Table with multiple partition columns.
+    Map<String, ColumnPartitionConfig> columnPartitionConfigMap = new HashMap<>();
+    columnPartitionConfigMap.put("memberId", new ColumnPartitionConfig("murmur", 10));
+    columnPartitionConfigMap.put("memberName", new ColumnPartitionConfig("HashCode", 5));
+    TableConfig tableConfigWithMultiplePartitionColumns =
+        new TableConfigBuilder(TableType.OFFLINE).setTableName("myTable")
+            .setSegmentPartitionConfig(new SegmentPartitionConfig(columnPartitionConfigMap)).build();
+    Schema schemaWithMultipleColumns = new Schema.SchemaBuilder().addSingleValueDimension("memberId", DataType.LONG)
+        .addSingleValueDimension("memberName", DataType.STRING).build();
+    partitionerConfigs =
+        MergeTaskUtils.getPartitionerConfigs(tableConfigWithMultiplePartitionColumns, schemaWithMultipleColumns,
+            taskConfig);
+    assertEquals(partitionerConfigs.size(), 2);
+
     // No partition column in table config
     TableConfig tableConfigWithoutPartitionColumn =
         new TableConfigBuilder(TableType.OFFLINE).setTableName("myTable").build();
diff --git a/pinot-plugins/pinot-minion-tasks/pinot-minion-builtin-tasks/src/test/java/org/apache/pinot/plugin/minion/tasks/mergerollup/MergeRollupTaskGeneratorTest.java b/pinot-plugins/pinot-minion-tasks/pinot-minion-builtin-tasks/src/test/java/org/apache/pinot/plugin/minion/tasks/mergerollup/MergeRollupTaskGeneratorTest.java
index 5c34c1b8d8..5e503729fc 100644
--- a/pinot-plugins/pinot-minion-tasks/pinot-minion-builtin-tasks/src/test/java/org/apache/pinot/plugin/minion/tasks/mergerollup/MergeRollupTaskGeneratorTest.java
+++ b/pinot-plugins/pinot-minion-tasks/pinot-minion-builtin-tasks/src/test/java/org/apache/pinot/plugin/minion/tasks/mergerollup/MergeRollupTaskGeneratorTest.java
@@ -107,8 +107,15 @@ public class MergeRollupTaskGeneratorTest {
   private void checkPinotTaskConfig(Map<String, String> pinotTaskConfig, String segments, String mergeLevel,
       String mergeType, String partitionBucketTimePeriod, String roundBucketTimePeriod,
       String maxNumRecordsPerSegments) {
-    assertEquals(pinotTaskConfig.get(MinionConstants.TABLE_NAME_KEY), OFFLINE_TABLE_NAME);
     assertEquals(pinotTaskConfig.get(MinionConstants.SEGMENT_NAME_KEY), segments);
+    checkPinotTaskConfig(pinotTaskConfig, mergeLevel, mergeType, partitionBucketTimePeriod, roundBucketTimePeriod,
+        maxNumRecordsPerSegments);
+  }
+
+  private void checkPinotTaskConfig(Map<String, String> pinotTaskConfig, String mergeLevel,
+      String mergeType, String partitionBucketTimePeriod, String roundBucketTimePeriod,
+      String maxNumRecordsPerSegments) {
+    assertEquals(pinotTaskConfig.get(MinionConstants.TABLE_NAME_KEY), OFFLINE_TABLE_NAME);
     assertTrue("true".equalsIgnoreCase(pinotTaskConfig.get(MinionConstants.ENABLE_REPLACE_SEGMENTS_KEY)));
     assertEquals(pinotTaskConfig.get(MinionConstants.MergeRollupTask.MERGE_LEVEL_KEY), mergeLevel);
     assertEquals(pinotTaskConfig.get(MinionConstants.MergeRollupTask.MERGE_TYPE_KEY), mergeType);
@@ -433,10 +440,24 @@ public class MergeRollupTaskGeneratorTest {
     generator.init(mockClusterInfoProvide);
     List<PinotTaskConfig> pinotTaskConfigs = generator.generateTasks(Lists.newArrayList(offlineTableConfig));
     assertEquals(pinotTaskConfigs.size(), 2);
-    checkPinotTaskConfig(pinotTaskConfigs.get(0).getConfigs(), segmentName1 + "," + segmentName2, DAILY, "concat", "1d",
-        null, "1000000");
-    checkPinotTaskConfig(pinotTaskConfigs.get(1).getConfigs(), segmentName3 + "," + segmentName4, DAILY, "concat", "1d",
-        null, "1000000");
+
+    String partitionedSegmentsGroup1 = segmentName1 + "," + segmentName2;
+    String partitionedSegmentsGroup2 = segmentName3 + "," + segmentName4;
+    boolean isPartitionedSegmentsGroup1Seen = false;
+    boolean isPartitionedSegmentsGroup2Seen = false;
+    for (PinotTaskConfig pinotTaskConfig : pinotTaskConfigs) {
+      if (!isPartitionedSegmentsGroup1Seen) {
+        isPartitionedSegmentsGroup1Seen =
+            pinotTaskConfig.getConfigs().get(MinionConstants.SEGMENT_NAME_KEY).equals(partitionedSegmentsGroup1);
+      }
+      if (!isPartitionedSegmentsGroup2Seen) {
+        isPartitionedSegmentsGroup2Seen =
+            pinotTaskConfig.getConfigs().get(MinionConstants.SEGMENT_NAME_KEY).equals(partitionedSegmentsGroup2);
+      }
+      assertTrue(isPartitionedSegmentsGroup1Seen || isPartitionedSegmentsGroup2Seen);
+      checkPinotTaskConfig(pinotTaskConfigs.get(0).getConfigs(), DAILY, "concat", "1d", null, "1000000");
+    }
+    assertTrue(isPartitionedSegmentsGroup1Seen && isPartitionedSegmentsGroup2Seen);
 
     // With numMaxRecordsPerTask constraints
     tableTaskConfigs.put("daily.maxNumRecordsPerTask", "5000000");
@@ -447,10 +468,27 @@ public class MergeRollupTaskGeneratorTest {
 
     pinotTaskConfigs = generator.generateTasks(Lists.newArrayList(offlineTableConfig));
     assertEquals(pinotTaskConfigs.size(), 3);
-    checkPinotTaskConfig(pinotTaskConfigs.get(0).getConfigs(), segmentName1 + "," + segmentName2, DAILY, "concat", "1d",
-        null, "1000000");
-    checkPinotTaskConfig(pinotTaskConfigs.get(1).getConfigs(), segmentName3, DAILY, "concat", "1d", null, "1000000");
-    checkPinotTaskConfig(pinotTaskConfigs.get(2).getConfigs(), segmentName4, DAILY, "concat", "1d", null, "1000000");
+
+    isPartitionedSegmentsGroup1Seen = false;
+    isPartitionedSegmentsGroup2Seen = false;
+    boolean isPartitionedSegmentsGroup3Seen = false;
+    for (PinotTaskConfig pinotTaskConfig : pinotTaskConfigs) {
+      if (!isPartitionedSegmentsGroup1Seen) {
+        isPartitionedSegmentsGroup1Seen =
+            pinotTaskConfig.getConfigs().get(MinionConstants.SEGMENT_NAME_KEY).equals(partitionedSegmentsGroup1);
+      }
+      if (!isPartitionedSegmentsGroup2Seen) {
+        isPartitionedSegmentsGroup2Seen =
+            pinotTaskConfig.getConfigs().get(MinionConstants.SEGMENT_NAME_KEY).equals(segmentName3);
+      }
+      if (!isPartitionedSegmentsGroup3Seen) {
+        isPartitionedSegmentsGroup3Seen =
+            pinotTaskConfig.getConfigs().get(MinionConstants.SEGMENT_NAME_KEY).equals(segmentName4);
+      }
+      assertTrue(isPartitionedSegmentsGroup1Seen || isPartitionedSegmentsGroup2Seen || isPartitionedSegmentsGroup3Seen);
+      checkPinotTaskConfig(pinotTaskConfigs.get(1).getConfigs(), DAILY, "concat", "1d", null, "1000000");
+    }
+    assertTrue(isPartitionedSegmentsGroup1Seen && isPartitionedSegmentsGroup2Seen && isPartitionedSegmentsGroup3Seen);
   }
 
   /**
diff --git a/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/indexsegment/mutable/IntermediateSegment.java b/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/indexsegment/mutable/IntermediateSegment.java
index 10d8959727..c5f013a3d8 100644
--- a/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/indexsegment/mutable/IntermediateSegment.java
+++ b/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/indexsegment/mutable/IntermediateSegment.java
@@ -48,7 +48,6 @@ import org.apache.pinot.segment.spi.index.startree.StarTreeV2;
 import org.apache.pinot.segment.spi.memory.PinotDataBufferMemoryManager;
 import org.apache.pinot.segment.spi.partition.PartitionFunction;
 import org.apache.pinot.segment.spi.partition.PartitionFunctionFactory;
-import org.apache.pinot.spi.config.table.ColumnPartitionConfig;
 import org.apache.pinot.spi.config.table.SegmentPartitionConfig;
 import org.apache.pinot.spi.config.table.TableConfig;
 import org.apache.pinot.spi.data.FieldSpec;
@@ -80,8 +79,6 @@ public class IntermediateSegment implements MutableSegment {
   private final Schema _schema;
   private final TableConfig _tableConfig;
   private final String _segmentName;
-  private final PartitionFunction _partitionFunction;
-  private final String _partitionColumn;
   private final Map<String, IntermediateIndexContainer> _indexContainerMap = new HashMap<>();
   private final PinotDataBufferMemoryManager _memoryManager;
   private final File _mmapDir;
@@ -103,19 +100,6 @@ public class IntermediateSegment implements MutableSegment {
       }
     }
 
-    SegmentPartitionConfig segmentPartitionConfig = segmentGeneratorConfig.getSegmentPartitionConfig();
-    if (segmentPartitionConfig != null) {
-      Map<String, ColumnPartitionConfig> segmentPartitionConfigColumnPartitionMap =
-          segmentPartitionConfig.getColumnPartitionMap();
-      _partitionColumn = segmentPartitionConfigColumnPartitionMap.keySet().iterator().next();
-      _partitionFunction = PartitionFunctionFactory
-          .getPartitionFunction(segmentPartitionConfig.getFunctionName(_partitionColumn),
-              segmentPartitionConfig.getNumPartitions(_partitionColumn),
-              segmentPartitionConfig.getFunctionConfig(_partitionColumn));
-    } else {
-      _partitionColumn = null;
-      _partitionFunction = null;
-    }
     String outputDir = segmentGeneratorConfig.getOutDir();
     _mmapDir = new File(outputDir, _segmentName + "_mmap_" + UUID.randomUUID().toString());
     _mmapDir.mkdir();
@@ -127,10 +111,13 @@ public class IntermediateSegment implements MutableSegment {
       String column = fieldSpec.getName();
 
       // Partition info
+      SegmentPartitionConfig segmentPartitionConfig = segmentGeneratorConfig.getSegmentPartitionConfig();
       PartitionFunction partitionFunction = null;
       Set<Integer> partitions = null;
-      if (column.equals(_partitionColumn)) {
-        partitionFunction = _partitionFunction;
+      if (segmentPartitionConfig != null && segmentPartitionConfig.getColumnPartitionMap().containsKey(column)) {
+        partitionFunction =
+            PartitionFunctionFactory.getPartitionFunction(segmentPartitionConfig.getFunctionName(column),
+                segmentPartitionConfig.getNumPartitions(column), segmentPartitionConfig.getFunctionConfig(column));
         partitions = new HashSet<>();
         partitions.add(segmentGeneratorConfig.getSequenceId());
       }
diff --git a/pinot-segment-local/src/test/java/org/apache/pinot/segment/local/indexsegment/mutable/IntermediateSegmentTest.java b/pinot-segment-local/src/test/java/org/apache/pinot/segment/local/indexsegment/mutable/IntermediateSegmentTest.java
index ddc20bc0c0..acf7887574 100644
--- a/pinot-segment-local/src/test/java/org/apache/pinot/segment/local/indexsegment/mutable/IntermediateSegmentTest.java
+++ b/pinot-segment-local/src/test/java/org/apache/pinot/segment/local/indexsegment/mutable/IntermediateSegmentTest.java
@@ -22,6 +22,8 @@ import java.io.File;
 import java.io.IOException;
 import java.net.URL;
 import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Map;
 import java.util.Set;
 import java.util.concurrent.TimeUnit;
 import org.apache.commons.io.FileUtils;
@@ -30,12 +32,15 @@ import org.apache.pinot.segment.local.indexsegment.immutable.ImmutableSegmentLoa
 import org.apache.pinot.segment.local.segment.creator.SegmentTestUtils;
 import org.apache.pinot.segment.local.segment.creator.impl.SegmentIndexCreationDriverImpl;
 import org.apache.pinot.segment.local.segment.readers.IntermediateSegmentRecordReader;
+import org.apache.pinot.segment.spi.ColumnMetadata;
 import org.apache.pinot.segment.spi.IndexSegment;
 import org.apache.pinot.segment.spi.creator.SegmentGeneratorConfig;
 import org.apache.pinot.segment.spi.creator.SegmentIndexCreationDriver;
 import org.apache.pinot.segment.spi.datasource.DataSource;
 import org.apache.pinot.segment.spi.index.reader.Dictionary;
 import org.apache.pinot.segment.spi.index.reader.InvertedIndexReader;
+import org.apache.pinot.spi.config.table.ColumnPartitionConfig;
+import org.apache.pinot.spi.config.table.SegmentPartitionConfig;
 import org.apache.pinot.spi.config.table.TableConfig;
 import org.apache.pinot.spi.config.table.TableType;
 import org.apache.pinot.spi.data.FieldSpec;
@@ -130,6 +135,21 @@ public class IntermediateSegmentTest {
           assertEquals(actualInvertedIndexReader.getDocIds(j), expectedInvertedIndexReader.getDocIds(j));
         }
       }
+
+      // Check for Partition Metadata.
+      SegmentPartitionConfig segmentPartitionConfig = tableConfig.getIndexingConfig().getSegmentPartitionConfig();
+      if (segmentPartitionConfig != null && segmentPartitionConfig.getColumnPartitionMap().containsKey(column)) {
+        ColumnMetadata columnMetadata =
+            segmentFromIntermediateSegment.getSegmentMetadata().getColumnMetadataFor(column);
+        assertNotNull(columnMetadata.getPartitionFunction());
+        assertEquals(columnMetadata.getPartitionFunction().getName(), segmentPartitionConfig.getFunctionName(column));
+        assertEquals(columnMetadata.getPartitionFunction().getNumPartitions(),
+            segmentPartitionConfig.getNumPartitions(column));
+        assertEquals(columnMetadata.getPartitionFunction().getFunctionConfig(),
+            segmentPartitionConfig.getFunctionConfig(column));
+        assertNotNull(columnMetadata.getPartitions());
+        assertEquals(columnMetadata.getPartitions().size(), 1);
+      }
     }
   }
 
@@ -211,10 +231,20 @@ public class IntermediateSegmentTest {
     if (AVRO_DATA_SV.equals(inputFile)) {
       tableConfig =
           new TableConfigBuilder(TableType.OFFLINE).setTableName("testTable").setTimeColumnName("daysSinceEpoch")
-              .setInvertedIndexColumns(Arrays.asList("column6", "column7", "column11", "column17", "column18")).build();
+              .setInvertedIndexColumns(Arrays.asList("column6", "column7", "column11", "column17", "column18"))
+              .setSegmentPartitionConfig(getSegmentPartitionConfig()).build();
     } else {
       tableConfig = new TableConfigBuilder(TableType.OFFLINE).setTableName("testTable").build();
     }
     return tableConfig;
   }
+
+  private static SegmentPartitionConfig getSegmentPartitionConfig() {
+    Map<String, ColumnPartitionConfig> columnPartitionConfigMap = new HashMap<>();
+    ColumnPartitionConfig columnOneConfig = new ColumnPartitionConfig("Murmur", 1);
+    columnPartitionConfigMap.put("column7", columnOneConfig);
+    ColumnPartitionConfig columnTwoConfig = new ColumnPartitionConfig("HashCode", 1);
+    columnPartitionConfigMap.put("column11", columnTwoConfig);
+    return new SegmentPartitionConfig(columnPartitionConfigMap);
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@pinot.apache.org
For additional commands, e-mail: commits-help@pinot.apache.org