You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@pinot.apache.org by ja...@apache.org on 2022/07/23 22:28:54 UTC

[pinot] branch master updated: Enhance upsert metadata handling (#9095)

This is an automated email from the ASF dual-hosted git repository.

jackie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new 8fe333119d Enhance upsert metadata handling (#9095)
8fe333119d is described below

commit 8fe333119de79218211a0aad68c31efc13de211e
Author: Xiaotian (Jackie) Jiang <17...@users.noreply.github.com>
AuthorDate: Sat Jul 23 15:28:48 2022 -0700

    Enhance upsert metadata handling (#9095)
    
    Make the following enhancement to the upsert metadata manager:
    - Add replace segment support
      - Log error and emit metric (`PARTIAL_UPSERT_ROWS_NOT_REPLACED`) for segment not fully replaced, which can potentially cause inconsistency between replicas for partial upsert table
      - Remove the remaining primary keys from the replaced segment immediately so that new consuming segment is not affected
    - Handle empty segment properly
    - Enhance the log to log the table name, partition id and primary key count
    - Clean up the code and move the upsert related logic into the metadata manager, such as creating the record info iterator
    - In `IndexSegment`, replace `getPrimaryKey()` with `getValue()` which is more general and can be used to read the comparison column as well
    - Fix the bug of assuming the first primary key column is the partition column when fetching the partition id of the segment
    - Fix the bug of using `byte[]` as primary key when the column type is `BYTES`
---
 .../apache/pinot/common/metrics/ServerMeter.java   |   1 +
 .../apache/pinot/common/utils/SegmentUtils.java    |  14 +-
 .../manager/realtime/RealtimeTableDataManager.java | 112 ++++-------
 ...adataAndDictionaryAggregationPlanMakerTest.java |   4 +-
 .../plan/maker/QueryOverrideWithHintsTest.java     |   4 +-
 .../indexsegment/immutable/EmptyIndexSegment.java  |   5 +-
 .../immutable/ImmutableSegmentImpl.java            |  10 +-
 .../indexsegment/mutable/IntermediateSegment.java  |  15 +-
 .../indexsegment/mutable/MutableSegmentImpl.java   |  45 +++--
 .../segment/readers/PinotSegmentRecordReader.java  |  15 +-
 .../upsert/PartitionUpsertMetadataManager.java     | 218 ++++++++++++++++-----
 .../local/upsert/TableUpsertMetadataManager.java   |  23 +--
 .../MutableSegmentImplUpsertComparisonColTest.java |  19 +-
 .../mutable/MutableSegmentImplUpsertTest.java      |  15 +-
 .../upsert/PartitionUpsertMetadataManagerTest.java | 156 +++++++--------
 .../org/apache/pinot/segment/spi/IndexSegment.java |   9 +-
 16 files changed, 364 insertions(+), 301 deletions(-)

diff --git a/pinot-common/src/main/java/org/apache/pinot/common/metrics/ServerMeter.java b/pinot-common/src/main/java/org/apache/pinot/common/metrics/ServerMeter.java
index 776a4e1aae..21133fddc8 100644
--- a/pinot-common/src/main/java/org/apache/pinot/common/metrics/ServerMeter.java
+++ b/pinot-common/src/main/java/org/apache/pinot/common/metrics/ServerMeter.java
@@ -42,6 +42,7 @@ public enum ServerMeter implements AbstractMetrics.Meter {
   REALTIME_PARTITION_MISMATCH("mismatch", false),
   REALTIME_DEDUP_DROPPED("rows", false),
   PARTIAL_UPSERT_OUT_OF_ORDER("rows", false),
+  PARTIAL_UPSERT_ROWS_NOT_REPLACED("rows", false),
   ROWS_WITH_ERRORS("rows", false),
   LLC_CONTROLLER_RESPONSE_NOT_SENT("messages", true),
   LLC_CONTROLLER_RESPONSE_COMMIT("messages", true),
diff --git a/pinot-common/src/main/java/org/apache/pinot/common/utils/SegmentUtils.java b/pinot-common/src/main/java/org/apache/pinot/common/utils/SegmentUtils.java
index ad0f49cc2b..b458f5b511 100644
--- a/pinot-common/src/main/java/org/apache/pinot/common/utils/SegmentUtils.java
+++ b/pinot-common/src/main/java/org/apache/pinot/common/utils/SegmentUtils.java
@@ -19,6 +19,7 @@
 package org.apache.pinot.common.utils;
 
 import com.google.common.base.Preconditions;
+import java.util.Map;
 import javax.annotation.Nullable;
 import org.apache.helix.HelixManager;
 import org.apache.pinot.common.metadata.ZKMetadataProvider;
@@ -37,7 +38,7 @@ public class SegmentUtils {
   // path.
   @Nullable
   public static Integer getRealtimeSegmentPartitionId(String segmentName, String realtimeTableName,
-      HelixManager helixManager, String partitionColumn) {
+      HelixManager helixManager, @Nullable String partitionColumn) {
     // A fast path if the segmentName is an LLC segment name: get the partition id from the name directly
     LLCSegmentName llcSegmentName = LLCSegmentName.of(segmentName);
     if (llcSegmentName != null) {
@@ -50,8 +51,15 @@ public class SegmentUtils {
         "Failed to find segment ZK metadata for segment: %s of table: %s", segmentName, realtimeTableName);
     SegmentPartitionMetadata segmentPartitionMetadata = segmentZKMetadata.getPartitionMetadata();
     if (segmentPartitionMetadata != null) {
-      ColumnPartitionMetadata columnPartitionMetadata =
-          segmentPartitionMetadata.getColumnPartitionMap().get(partitionColumn);
+      Map<String, ColumnPartitionMetadata> columnPartitionMap = segmentPartitionMetadata.getColumnPartitionMap();
+      ColumnPartitionMetadata columnPartitionMetadata = null;
+      if (partitionColumn != null) {
+        columnPartitionMetadata = columnPartitionMap.get(partitionColumn);
+      } else {
+        if (columnPartitionMap.size() == 1) {
+          columnPartitionMetadata = columnPartitionMap.values().iterator().next();
+        }
+      }
       if (columnPartitionMetadata != null && columnPartitionMetadata.getPartitions().size() == 1) {
         return columnPartitionMetadata.getPartitions().iterator().next();
       }
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/data/manager/realtime/RealtimeTableDataManager.java b/pinot-core/src/main/java/org/apache/pinot/core/data/manager/realtime/RealtimeTableDataManager.java
index 4fa98e5720..0a09ff4087 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/data/manager/realtime/RealtimeTableDataManager.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/data/manager/realtime/RealtimeTableDataManager.java
@@ -22,8 +22,6 @@ import com.google.common.base.Preconditions;
 import java.io.File;
 import java.io.IOException;
 import java.net.URI;
-import java.util.HashMap;
-import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
 import java.util.UUID;
@@ -48,6 +46,7 @@ import org.apache.pinot.common.utils.SegmentUtils;
 import org.apache.pinot.common.utils.TarGzCompressionUtils;
 import org.apache.pinot.common.utils.fetcher.SegmentFetcherFactory;
 import org.apache.pinot.core.data.manager.BaseTableDataManager;
+import org.apache.pinot.core.data.manager.offline.ImmutableSegmentDataManager;
 import org.apache.pinot.core.util.PeerServerSegmentFinder;
 import org.apache.pinot.segment.local.data.manager.SegmentDataManager;
 import org.apache.pinot.segment.local.dedup.PartitionDedupMetadataManager;
@@ -57,24 +56,20 @@ import org.apache.pinot.segment.local.indexsegment.immutable.ImmutableSegmentLoa
 import org.apache.pinot.segment.local.realtime.impl.RealtimeSegmentStatsHistory;
 import org.apache.pinot.segment.local.segment.index.loader.IndexLoadingConfig;
 import org.apache.pinot.segment.local.segment.index.loader.LoaderUtils;
-import org.apache.pinot.segment.local.segment.readers.PinotSegmentColumnReader;
 import org.apache.pinot.segment.local.segment.virtualcolumn.VirtualColumnProviderFactory;
 import org.apache.pinot.segment.local.upsert.PartialUpsertHandler;
 import org.apache.pinot.segment.local.upsert.PartitionUpsertMetadataManager;
 import org.apache.pinot.segment.local.upsert.TableUpsertMetadataManager;
-import org.apache.pinot.segment.local.utils.RecordInfo;
 import org.apache.pinot.segment.local.utils.SchemaUtils;
 import org.apache.pinot.segment.local.utils.tablestate.TableStateUtils;
 import org.apache.pinot.segment.spi.ImmutableSegment;
-import org.apache.pinot.segment.spi.index.mutable.ThreadSafeMutableRoaringBitmap;
+import org.apache.pinot.segment.spi.IndexSegment;
 import org.apache.pinot.spi.config.table.DedupConfig;
 import org.apache.pinot.spi.config.table.IndexingConfig;
 import org.apache.pinot.spi.config.table.TableConfig;
 import org.apache.pinot.spi.config.table.UpsertConfig;
 import org.apache.pinot.spi.data.FieldSpec;
 import org.apache.pinot.spi.data.Schema;
-import org.apache.pinot.spi.data.readers.PrimaryKey;
-import org.apache.pinot.spi.utils.ByteArray;
 import org.apache.pinot.spi.utils.CommonConstants;
 import org.apache.pinot.spi.utils.CommonConstants.Segment.Realtime.Status;
 
@@ -119,8 +114,6 @@ public class RealtimeTableDataManager extends BaseTableDataManager {
 
   private TableDedupMetadataManager _tableDedupMetadataManager;
   private TableUpsertMetadataManager _tableUpsertMetadataManager;
-  private List<String> _primaryKeyColumns;
-  private String _upsertComparisonColumn;
 
   public RealtimeTableDataManager(Semaphore segmentBuildSemaphore) {
     _segmentBuildSemaphore = segmentBuildSemaphore;
@@ -134,9 +127,8 @@ public class RealtimeTableDataManager extends BaseTableDataManager {
     try {
       _statsHistory = RealtimeSegmentStatsHistory.deserialzeFrom(statsFile);
     } catch (IOException | ClassNotFoundException e) {
-      _logger
-          .error("Error reading history object for table {} from {}", _tableNameWithType, statsFile.getAbsolutePath(),
-              e);
+      _logger.error("Error reading history object for table {} from {}", _tableNameWithType,
+          statsFile.getAbsolutePath(), e);
       File savedFile = new File(_tableDataDir, STATS_FILE_NAME + "." + UUID.randomUUID());
       try {
         FileUtils.moveFile(statsFile, savedFile);
@@ -182,10 +174,10 @@ public class RealtimeTableDataManager extends BaseTableDataManager {
       Schema schema = ZKMetadataProvider.getTableSchema(_propertyStore, _tableNameWithType);
       Preconditions.checkState(schema != null, "Failed to find schema for table: %s", _tableNameWithType);
 
-      _primaryKeyColumns = schema.getPrimaryKeyColumns();
-      Preconditions.checkState(!CollectionUtils.isEmpty(_primaryKeyColumns),
+      List<String> primaryKeyColumns = schema.getPrimaryKeyColumns();
+      Preconditions.checkState(!CollectionUtils.isEmpty(primaryKeyColumns),
           "Primary key columns must be configured for dedup");
-      _tableDedupMetadataManager = new TableDedupMetadataManager(_tableNameWithType, _primaryKeyColumns, _serverMetrics,
+      _tableDedupMetadataManager = new TableDedupMetadataManager(_tableNameWithType, primaryKeyColumns, _serverMetrics,
           dedupConfig.getHashFunction());
     }
 
@@ -196,24 +188,25 @@ public class RealtimeTableDataManager extends BaseTableDataManager {
       Schema schema = ZKMetadataProvider.getTableSchema(_propertyStore, _tableNameWithType);
       Preconditions.checkState(schema != null, "Failed to find schema for table: %s", _tableNameWithType);
 
-      _primaryKeyColumns = schema.getPrimaryKeyColumns();
-      Preconditions.checkState(!CollectionUtils.isEmpty(_primaryKeyColumns),
+      List<String> primaryKeyColumns = schema.getPrimaryKeyColumns();
+      Preconditions.checkState(!CollectionUtils.isEmpty(primaryKeyColumns),
           "Primary key columns must be configured for upsert");
 
       String comparisonColumn = upsertConfig.getComparisonColumn();
-      _upsertComparisonColumn =
-          comparisonColumn != null ? comparisonColumn : tableConfig.getValidationConfig().getTimeColumnName();
+      if (comparisonColumn == null) {
+        comparisonColumn = tableConfig.getValidationConfig().getTimeColumnName();
+      }
 
       PartialUpsertHandler partialUpsertHandler = null;
       if (upsertConfig.getMode() == UpsertConfig.Mode.PARTIAL) {
         assert upsertConfig.getPartialUpsertStrategies() != null;
         partialUpsertHandler = new PartialUpsertHandler(schema, upsertConfig.getPartialUpsertStrategies(),
-            upsertConfig.getDefaultPartialUpsertStrategy(), _upsertComparisonColumn);
+            upsertConfig.getDefaultPartialUpsertStrategy(), comparisonColumn);
       }
 
       _tableUpsertMetadataManager =
-          new TableUpsertMetadataManager(_tableNameWithType, _serverMetrics, partialUpsertHandler,
-              upsertConfig.getHashFunction(), _primaryKeyColumns);
+          new TableUpsertMetadataManager(_tableNameWithType, primaryKeyColumns, comparisonColumn,
+              upsertConfig.getHashFunction(), partialUpsertHandler, _serverMetrics);
     }
   }
 
@@ -390,9 +383,11 @@ public class RealtimeTableDataManager extends BaseTableDataManager {
   @Override
   public void addSegment(ImmutableSegment immutableSegment) {
     if (isUpsertEnabled()) {
-      handleUpsert((ImmutableSegmentImpl) immutableSegment);
+      handleUpsert(immutableSegment);
+      return;
     }
 
+    // TODO: Change dedup handling to handle segment replacement
     if (isDedupEnabled()) {
       buildDedupMeta((ImmutableSegmentImpl) immutableSegment);
     }
@@ -403,8 +398,7 @@ public class RealtimeTableDataManager extends BaseTableDataManager {
     // TODO(saurabh) refactor commons code with handleUpsert
     String segmentName = immutableSegment.getSegmentName();
     Integer partitionGroupId =
-        SegmentUtils.getRealtimeSegmentPartitionId(segmentName, _tableNameWithType, _helixManager,
-            _primaryKeyColumns.get(0));
+        SegmentUtils.getRealtimeSegmentPartitionId(segmentName, _tableNameWithType, _helixManager, null);
     Preconditions.checkNotNull(partitionGroupId,
         String.format("PartitionGroupId is not available for segment: '%s' (dedup-enabled table: %s)", segmentName,
             _tableNameWithType));
@@ -414,53 +408,33 @@ public class RealtimeTableDataManager extends BaseTableDataManager {
     partitionDedupMetadataManager.addSegment(immutableSegment);
   }
 
-  private void handleUpsert(ImmutableSegmentImpl immutableSegment) {
+  private void handleUpsert(ImmutableSegment immutableSegment) {
     String segmentName = immutableSegment.getSegmentName();
-    Integer partitionGroupId =
-        SegmentUtils.getRealtimeSegmentPartitionId(segmentName, _tableNameWithType, _helixManager,
-            _primaryKeyColumns.get(0));
-    Preconditions.checkNotNull(partitionGroupId,
-        String.format("PartitionGroupId is not available for segment: '%s' (upsert-enabled table: %s)", segmentName,
+    _logger.info("Adding immutable segment: {} to upsert-enabled table: {}", segmentName, _tableNameWithType);
+
+    Integer partitionId =
+        SegmentUtils.getRealtimeSegmentPartitionId(segmentName, _tableNameWithType, _helixManager, null);
+    Preconditions.checkNotNull(partitionId,
+        String.format("Failed to get partition id for segment: %s (upsert-enabled table: %s)", segmentName,
             _tableNameWithType));
     PartitionUpsertMetadataManager partitionUpsertMetadataManager =
-        _tableUpsertMetadataManager.getOrCreatePartitionManager(partitionGroupId);
-    ThreadSafeMutableRoaringBitmap validDocIds = new ThreadSafeMutableRoaringBitmap();
-    immutableSegment.enableUpsert(partitionUpsertMetadataManager, validDocIds);
+        _tableUpsertMetadataManager.getOrCreatePartitionManager(partitionId);
 
-    Map<String, PinotSegmentColumnReader> columnToReaderMap = new HashMap<>();
-    for (String primaryKeyColumn : _primaryKeyColumns) {
-      columnToReaderMap.put(primaryKeyColumn, new PinotSegmentColumnReader(immutableSegment, primaryKeyColumn));
+    _serverMetrics.addValueToTableGauge(_tableNameWithType, ServerGauge.DOCUMENT_COUNT,
+        immutableSegment.getSegmentMetadata().getTotalDocs());
+    _serverMetrics.addValueToTableGauge(_tableNameWithType, ServerGauge.SEGMENT_COUNT, 1L);
+    ImmutableSegmentDataManager newSegmentManager = new ImmutableSegmentDataManager(immutableSegment);
+    SegmentDataManager oldSegmentManager = _segmentDataManagerMap.put(segmentName, newSegmentManager);
+    if (oldSegmentManager == null) {
+      partitionUpsertMetadataManager.addSegment(immutableSegment);
+      _logger.info("Added new immutable segment: {} to upsert-enabled table: {}", segmentName, _tableNameWithType);
+    } else {
+      IndexSegment oldSegment = oldSegmentManager.getSegment();
+      partitionUpsertMetadataManager.replaceSegment(immutableSegment, oldSegment);
+      _logger.info("Replaced {} segment: {} of upsert-enabled table: {}",
+          oldSegment instanceof ImmutableSegment ? "immutable" : "mutable", segmentName, _tableNameWithType);
+      releaseSegment(oldSegmentManager);
     }
-    columnToReaderMap
-        .put(_upsertComparisonColumn, new PinotSegmentColumnReader(immutableSegment, _upsertComparisonColumn));
-    int numTotalDocs = immutableSegment.getSegmentMetadata().getTotalDocs();
-    int numPrimaryKeyColumns = _primaryKeyColumns.size();
-    Iterator<RecordInfo> recordInfoIterator = new Iterator<RecordInfo>() {
-      private int _docId = 0;
-
-      @Override
-      public boolean hasNext() {
-        return _docId < numTotalDocs;
-      }
-
-      @Override
-      public RecordInfo next() {
-        Object[] values = new Object[numPrimaryKeyColumns];
-        for (int i = 0; i < numPrimaryKeyColumns; i++) {
-          Object value = columnToReaderMap.get(_primaryKeyColumns.get(i)).getValue(_docId);
-          if (value instanceof byte[]) {
-            value = new ByteArray((byte[]) value);
-          }
-          values[i] = value;
-        }
-        PrimaryKey primaryKey = new PrimaryKey(values);
-        Object upsertComparisonValue = columnToReaderMap.get(_upsertComparisonColumn).getValue(_docId);
-        Preconditions.checkState(upsertComparisonValue instanceof Comparable,
-            "Upsert comparison column: %s must be comparable", _upsertComparisonColumn);
-        return new RecordInfo(primaryKey, _docId++, (Comparable) upsertComparisonValue);
-      }
-    };
-    partitionUpsertMetadataManager.addSegment(immutableSegment, recordInfoIterator);
   }
 
   @Override
@@ -538,8 +512,8 @@ public class RealtimeTableDataManager extends BaseTableDataManager {
   private boolean isPeerSegmentDownloadEnabled(TableConfig tableConfig) {
     return
         CommonConstants.HTTP_PROTOCOL.equalsIgnoreCase(tableConfig.getValidationConfig().getPeerSegmentDownloadScheme())
-            || CommonConstants.HTTPS_PROTOCOL
-            .equalsIgnoreCase(tableConfig.getValidationConfig().getPeerSegmentDownloadScheme());
+            || CommonConstants.HTTPS_PROTOCOL.equalsIgnoreCase(
+            tableConfig.getValidationConfig().getPeerSegmentDownloadScheme());
   }
 
   private void downloadSegmentFromPeer(String segmentName, String downloadScheme,
diff --git a/pinot-core/src/test/java/org/apache/pinot/core/plan/maker/MetadataAndDictionaryAggregationPlanMakerTest.java b/pinot-core/src/test/java/org/apache/pinot/core/plan/maker/MetadataAndDictionaryAggregationPlanMakerTest.java
index f6b55b08c3..4bfb67dfbe 100644
--- a/pinot-core/src/test/java/org/apache/pinot/core/plan/maker/MetadataAndDictionaryAggregationPlanMakerTest.java
+++ b/pinot-core/src/test/java/org/apache/pinot/core/plan/maker/MetadataAndDictionaryAggregationPlanMakerTest.java
@@ -125,8 +125,8 @@ public class MetadataAndDictionaryAggregationPlanMakerTest {
     ServerMetrics serverMetrics = Mockito.mock(ServerMetrics.class);
     _upsertIndexSegment = ImmutableSegmentLoader.load(new File(INDEX_DIR, SEGMENT_NAME), ReadMode.heap);
     ((ImmutableSegmentImpl) _upsertIndexSegment).enableUpsert(
-        new PartitionUpsertMetadataManager("testTable_REALTIME", 0, serverMetrics, null,
-            HashFunction.NONE, Collections.emptyList()), new ThreadSafeMutableRoaringBitmap());
+        new PartitionUpsertMetadataManager("testTable_REALTIME", 0, Collections.singletonList("column6"),
+            "daysSinceEpoch", HashFunction.NONE, null, serverMetrics), new ThreadSafeMutableRoaringBitmap());
   }
 
   @AfterClass
diff --git a/pinot-core/src/test/java/org/apache/pinot/core/plan/maker/QueryOverrideWithHintsTest.java b/pinot-core/src/test/java/org/apache/pinot/core/plan/maker/QueryOverrideWithHintsTest.java
index bffcd6a974..e725dfa4ff 100644
--- a/pinot-core/src/test/java/org/apache/pinot/core/plan/maker/QueryOverrideWithHintsTest.java
+++ b/pinot-core/src/test/java/org/apache/pinot/core/plan/maker/QueryOverrideWithHintsTest.java
@@ -42,7 +42,6 @@ import org.apache.pinot.segment.spi.datasource.DataSource;
 import org.apache.pinot.segment.spi.index.mutable.ThreadSafeMutableRoaringBitmap;
 import org.apache.pinot.segment.spi.index.startree.StarTreeV2;
 import org.apache.pinot.spi.data.readers.GenericRow;
-import org.apache.pinot.spi.data.readers.PrimaryKey;
 import org.apache.pinot.sql.parsers.CalciteSqlParser;
 import org.testng.annotations.Test;
 
@@ -95,7 +94,8 @@ public class QueryOverrideWithHintsTest {
     }
 
     @Override
-    public void getPrimaryKey(int docId, PrimaryKey reuse) {
+    public Object getValue(int docId, String column) {
+      return null;
     }
 
     @Override
diff --git a/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/indexsegment/immutable/EmptyIndexSegment.java b/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/indexsegment/immutable/EmptyIndexSegment.java
index db19e7d7a8..742a417a3e 100644
--- a/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/indexsegment/immutable/EmptyIndexSegment.java
+++ b/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/indexsegment/immutable/EmptyIndexSegment.java
@@ -33,7 +33,6 @@ import org.apache.pinot.segment.spi.index.reader.ForwardIndexReader;
 import org.apache.pinot.segment.spi.index.reader.InvertedIndexReader;
 import org.apache.pinot.segment.spi.index.startree.StarTreeV2;
 import org.apache.pinot.spi.data.readers.GenericRow;
-import org.apache.pinot.spi.data.readers.PrimaryKey;
 
 
 /**
@@ -96,8 +95,8 @@ public class EmptyIndexSegment implements ImmutableSegment {
   }
 
   @Override
-  public void getPrimaryKey(int docId, PrimaryKey reuse) {
-    throw new UnsupportedOperationException("Cannot read primary key from empty segment");
+  public Object getValue(int docId, String column) {
+    throw new UnsupportedOperationException("Cannot read value from empty segment");
   }
 
   @Override
diff --git a/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/indexsegment/immutable/ImmutableSegmentImpl.java b/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/indexsegment/immutable/ImmutableSegmentImpl.java
index 06f7bc252f..b4fafbbc23 100644
--- a/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/indexsegment/immutable/ImmutableSegmentImpl.java
+++ b/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/indexsegment/immutable/ImmutableSegmentImpl.java
@@ -44,7 +44,6 @@ import org.apache.pinot.segment.spi.index.reader.InvertedIndexReader;
 import org.apache.pinot.segment.spi.index.startree.StarTreeV2;
 import org.apache.pinot.segment.spi.store.SegmentDirectory;
 import org.apache.pinot.spi.data.readers.GenericRow;
-import org.apache.pinot.spi.data.readers.PrimaryKey;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -138,8 +137,7 @@ public class ImmutableSegmentImpl implements ImmutableSegment {
   public DataSource getDataSource(String column) {
     DataSource result = _dataSources.get(column);
     Preconditions.checkNotNull(result,
-        "DataSource for %s should not be null. Potentially invalid column name specified.",
-        column);
+        "DataSource for %s should not be null. Potentially invalid column name specified.", column);
     return result;
   }
 
@@ -236,15 +234,15 @@ public class ImmutableSegmentImpl implements ImmutableSegment {
   }
 
   @Override
-  public void getPrimaryKey(int docId, PrimaryKey reuse) {
+  public Object getValue(int docId, String column) {
     try {
       if (_pinotSegmentRecordReader == null) {
         _pinotSegmentRecordReader = new PinotSegmentRecordReader();
         _pinotSegmentRecordReader.init(this);
       }
-      _pinotSegmentRecordReader.getPrimaryKey(docId, _partitionUpsertMetadataManager.getPrimaryKeyColumns(), reuse);
+      return _pinotSegmentRecordReader.getValue(docId, column);
     } catch (Exception e) {
-      throw new RuntimeException("Failed to use PinotSegmentRecordReader to read primary key from immutable segment");
+      throw new RuntimeException("Failed to use PinotSegmentRecordReader to read value from immutable segment");
     }
   }
 }
diff --git a/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/indexsegment/mutable/IntermediateSegment.java b/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/indexsegment/mutable/IntermediateSegment.java
index 7f2e8b9356..fbf5c95804 100644
--- a/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/indexsegment/mutable/IntermediateSegment.java
+++ b/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/indexsegment/mutable/IntermediateSegment.java
@@ -55,7 +55,6 @@ import org.apache.pinot.spi.data.FieldSpec.DataType;
 import org.apache.pinot.spi.data.FieldSpec.FieldType;
 import org.apache.pinot.spi.data.Schema;
 import org.apache.pinot.spi.data.readers.GenericRow;
-import org.apache.pinot.spi.data.readers.PrimaryKey;
 import org.apache.pinot.spi.stream.RowMetadata;
 import org.apache.pinot.spi.utils.ByteArray;
 import org.slf4j.Logger;
@@ -226,16 +225,10 @@ public class IntermediateSegment implements MutableSegment {
   }
 
   @Override
-  public void getPrimaryKey(int docId, PrimaryKey reuse) {
-    int numPrimaryKeyColumns = _schema.getPrimaryKeyColumns().size();
-    Object[] values = reuse.getValues();
-    for (int i = 0; i < numPrimaryKeyColumns; i++) {
-      IntermediateIndexContainer indexContainer = _indexContainerMap.get(
-          _schema.getPrimaryKeyColumns().get(i));
-      Object value = getValue(docId, indexContainer.getForwardIndex(), indexContainer.getDictionary(),
-          indexContainer.getNumValuesInfo().getMaxNumValuesPerMVEntry());
-      values[i] = value;
-    }
+  public Object getValue(int docId, String column) {
+    IntermediateIndexContainer indexContainer = _indexContainerMap.get(column);
+    return getValue(docId, indexContainer.getForwardIndex(), indexContainer.getDictionary(),
+        indexContainer.getNumValuesInfo().getMaxNumValuesPerMVEntry());
   }
 
   @Override
diff --git a/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/indexsegment/mutable/MutableSegmentImpl.java b/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/indexsegment/mutable/MutableSegmentImpl.java
index 04eb437899..775fe06f5c 100644
--- a/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/indexsegment/mutable/MutableSegmentImpl.java
+++ b/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/indexsegment/mutable/MutableSegmentImpl.java
@@ -433,7 +433,7 @@ public class MutableSegmentImpl implements MutableSegment {
           || dataType == BYTES)) {
         _logger.info(
             "Aggregate metrics is enabled. Will create dictionary in consuming segment for column {} of type {}",
-            column, dataType.toString());
+            column, dataType);
         return false;
       }
       // So don't create dictionary if the column (1) is member of noDictionary, and (2) is single-value or multi-value
@@ -488,7 +488,7 @@ public class MutableSegmentImpl implements MutableSegment {
         }
       }
     }
-    _logger.info("Newly added columns: " + _newlyAddedColumnsFieldMap.toString());
+    _logger.info("Newly added columns: " + _newlyAddedColumnsFieldMap);
   }
 
   @Override
@@ -998,35 +998,34 @@ public class MutableSegmentImpl implements MutableSegment {
   @Override
   public GenericRow getRecord(int docId, GenericRow reuse) {
     for (Map.Entry<String, IndexContainer> entry : _indexContainerMap.entrySet()) {
+      String column = entry.getKey();
+      IndexContainer indexContainer = entry.getValue();
+      Object value;
       try {
-        String column = entry.getKey();
-        IndexContainer indexContainer = entry.getValue();
-        Object value = getValue(docId, indexContainer._forwardIndex, indexContainer._dictionary,
+        value = getValue(docId, indexContainer._forwardIndex, indexContainer._dictionary,
             indexContainer._numValuesInfo._maxNumValuesPerMVEntry);
-        if (_nullHandlingEnabled && indexContainer._nullValueVector.isNull(docId)) {
-          reuse.putDefaultNullValue(column, value);
-        } else {
-          reuse.putValue(column, value);
-        }
       } catch (Exception e) {
-        _logger.error("error encountered when getting record for {} on indexContainer: {}", docId, entry.getKey());
-        throw new RuntimeException("error encountered when getting record for " + docId + " on indexContainer: "
-            + entry.getKey(), e);
+        throw new RuntimeException(
+            String.format("Caught exception while reading value for docId: %d, column: %s", docId, column), e);
+      }
+      if (_nullHandlingEnabled && indexContainer._nullValueVector.isNull(docId)) {
+        reuse.putDefaultNullValue(column, value);
+      } else {
+        reuse.putValue(column, value);
       }
     }
     return reuse;
   }
 
   @Override
-  public void getPrimaryKey(int docId, PrimaryKey reuse) {
-    int numPrimaryKeyColumns = _partitionUpsertMetadataManager.getPrimaryKeyColumns().size();
-    Object[] values = reuse.getValues();
-    for (int i = 0; i < numPrimaryKeyColumns; i++) {
-      IndexContainer indexContainer = _indexContainerMap.get(
-          _partitionUpsertMetadataManager.getPrimaryKeyColumns().get(i));
-      Object value = getValue(docId, indexContainer._forwardIndex, indexContainer._dictionary,
+  public Object getValue(int docId, String column) {
+    try {
+      IndexContainer indexContainer = _indexContainerMap.get(column);
+      return getValue(docId, indexContainer._forwardIndex, indexContainer._dictionary,
           indexContainer._numValuesInfo._maxNumValuesPerMVEntry);
-      values[i] = value;
+    } catch (Exception e) {
+      throw new RuntimeException(
+          String.format("Caught exception while reading value for docId: %d, column: %s", docId, column), e);
     }
   }
 
@@ -1108,8 +1107,8 @@ public class MutableSegmentImpl implements MutableSegment {
             }
             return value;
           default:
-            throw new IllegalStateException("No support for MV no dictionary column of type "
-                + forwardIndex.getStoredType());
+            throw new IllegalStateException(
+                "No support for MV no dictionary column of type " + forwardIndex.getStoredType());
         }
       }
     }
diff --git a/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/segment/readers/PinotSegmentRecordReader.java b/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/segment/readers/PinotSegmentRecordReader.java
index 7173d9d241..feb0be4b7c 100644
--- a/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/segment/readers/PinotSegmentRecordReader.java
+++ b/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/segment/readers/PinotSegmentRecordReader.java
@@ -33,10 +33,8 @@ import org.apache.pinot.segment.spi.IndexSegment;
 import org.apache.pinot.segment.spi.MutableSegment;
 import org.apache.pinot.spi.data.Schema;
 import org.apache.pinot.spi.data.readers.GenericRow;
-import org.apache.pinot.spi.data.readers.PrimaryKey;
 import org.apache.pinot.spi.data.readers.RecordReader;
 import org.apache.pinot.spi.data.readers.RecordReaderConfig;
-import org.apache.pinot.spi.utils.ByteArray;
 import org.apache.pinot.spi.utils.ReadMode;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -232,17 +230,8 @@ public class PinotSegmentRecordReader implements RecordReader {
     }
   }
 
-  public void getPrimaryKey(int docId, List<String> primaryKeyColumns, PrimaryKey reuse) {
-    int numPrimaryKeyColumns = primaryKeyColumns.size();
-    Object[] values = reuse.getValues();
-    for (int i = 0; i < numPrimaryKeyColumns; i++) {
-      PinotSegmentColumnReader columnReader = _columnReaderMap.get(primaryKeyColumns.get(i));
-      Object value = columnReader.getValue(docId);
-      if (value instanceof byte[]) {
-        value = new ByteArray((byte[]) value);
-      }
-      values[i] = value;
-    }
+  public Object getValue(int docId, String column) {
+    return _columnReaderMap.get(column).getValue(docId);
   }
 
   @Override
diff --git a/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/upsert/PartitionUpsertMetadataManager.java b/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/upsert/PartitionUpsertMetadataManager.java
index c279511a4b..1a435fc3f0 100644
--- a/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/upsert/PartitionUpsertMetadataManager.java
+++ b/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/upsert/PartitionUpsertMetadataManager.java
@@ -19,6 +19,7 @@
 package org.apache.pinot.segment.local.upsert;
 
 import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Preconditions;
 import java.util.Iterator;
 import java.util.List;
 import java.util.Objects;
@@ -30,13 +31,18 @@ import org.apache.pinot.common.metrics.ServerGauge;
 import org.apache.pinot.common.metrics.ServerMeter;
 import org.apache.pinot.common.metrics.ServerMetrics;
 import org.apache.pinot.common.utils.LLCSegmentName;
+import org.apache.pinot.segment.local.indexsegment.immutable.EmptyIndexSegment;
+import org.apache.pinot.segment.local.indexsegment.immutable.ImmutableSegmentImpl;
 import org.apache.pinot.segment.local.utils.HashUtils;
 import org.apache.pinot.segment.local.utils.RecordInfo;
+import org.apache.pinot.segment.spi.ImmutableSegment;
 import org.apache.pinot.segment.spi.IndexSegment;
+import org.apache.pinot.segment.spi.MutableSegment;
 import org.apache.pinot.segment.spi.index.mutable.ThreadSafeMutableRoaringBitmap;
 import org.apache.pinot.spi.config.table.HashFunction;
 import org.apache.pinot.spi.data.readers.GenericRow;
 import org.apache.pinot.spi.data.readers.PrimaryKey;
+import org.apache.pinot.spi.utils.ByteArray;
 import org.roaringbitmap.PeekableIntIterator;
 import org.roaringbitmap.buffer.MutableRoaringBitmap;
 import org.slf4j.Logger;
@@ -67,19 +73,19 @@ import org.slf4j.LoggerFactory;
  *   </li>
  * </ul>
  */
-@SuppressWarnings("unchecked")
+@SuppressWarnings({"rawtypes", "unchecked"})
 @ThreadSafe
 public class PartitionUpsertMetadataManager {
-  private static final Logger LOGGER = LoggerFactory.getLogger(PartitionUpsertMetadataManager.class);
-
   private static final long OUT_OF_ORDER_EVENT_MIN_REPORT_INTERVAL_NS = TimeUnit.MINUTES.toNanos(1);
 
   private final String _tableNameWithType;
   private final int _partitionId;
-  private final ServerMetrics _serverMetrics;
-  private final PartialUpsertHandler _partialUpsertHandler;
-  private final HashFunction _hashFunction;
   private final List<String> _primaryKeyColumns;
+  private final String _comparisonColumn;
+  private final HashFunction _hashFunction;
+  private final PartialUpsertHandler _partialUpsertHandler;
+  private final ServerMetrics _serverMetrics;
+  private final Logger _logger;
 
   // TODO(upsert): consider an off-heap KV store to persist this mapping to improve the recovery speed.
   @VisibleForTesting
@@ -91,17 +97,22 @@ public class PartitionUpsertMetadataManager {
   private long _lastOutOfOrderEventReportTimeNs = Long.MIN_VALUE;
   private int _numOutOfOrderEvents = 0;
 
-  public PartitionUpsertMetadataManager(String tableNameWithType, int partitionId, ServerMetrics serverMetrics,
-      @Nullable PartialUpsertHandler partialUpsertHandler, HashFunction hashFunction,
-      List<String> primaryKeyColumns) {
+  public PartitionUpsertMetadataManager(String tableNameWithType, int partitionId, List<String> primaryKeyColumns,
+      String comparisonColumn, HashFunction hashFunction, @Nullable PartialUpsertHandler partialUpsertHandler,
+      ServerMetrics serverMetrics) {
     _tableNameWithType = tableNameWithType;
     _partitionId = partitionId;
-    _serverMetrics = serverMetrics;
-    _partialUpsertHandler = partialUpsertHandler;
-    _hashFunction = hashFunction;
     _primaryKeyColumns = primaryKeyColumns;
+    _comparisonColumn = comparisonColumn;
+    _hashFunction = hashFunction;
+    _partialUpsertHandler = partialUpsertHandler;
+    _serverMetrics = serverMetrics;
+    _logger = LoggerFactory.getLogger(tableNameWithType + "-" + partitionId + "-" + getClass().getSimpleName());
   }
 
+  /**
+   * Returns the primary key columns.
+   */
   public List<String> getPrimaryKeyColumns() {
     return _primaryKeyColumns;
   }
@@ -109,11 +120,70 @@ public class PartitionUpsertMetadataManager {
   /**
    * Initializes the upsert metadata for the given immutable segment.
    */
-  public void addSegment(IndexSegment segment, Iterator<RecordInfo> recordInfoIterator) {
+  public void addSegment(ImmutableSegment segment) {
     String segmentName = segment.getSegmentName();
-    LOGGER.info("Adding upsert metadata for segment: {}", segmentName);
+    _logger.info("Adding segment: {}, current primary key count: {}", segmentName,
+        _primaryKeyToRecordLocationMap.size());
 
-    ThreadSafeMutableRoaringBitmap validDocIds = Objects.requireNonNull(segment.getValidDocIds());
+    if (segment instanceof EmptyIndexSegment) {
+      _logger.info("Skip adding empty segment: {}", segmentName);
+      return;
+    }
+
+    Preconditions.checkArgument(segment instanceof ImmutableSegmentImpl,
+        "Got unsupported segment implementation: {} for segment: {}, table: {}", segment.getClass(), segmentName,
+        _tableNameWithType);
+    addSegment((ImmutableSegmentImpl) segment, new ThreadSafeMutableRoaringBitmap(), getRecordInfoIterator(segment));
+
+    // Update metrics
+    int numPrimaryKeys = _primaryKeyToRecordLocationMap.size();
+    _serverMetrics.setValueOfPartitionGauge(_tableNameWithType, _partitionId, ServerGauge.UPSERT_PRIMARY_KEYS_COUNT,
+        numPrimaryKeys);
+
+    _logger.info("Finished adding segment: {}, current primary key count: {}", segmentName, numPrimaryKeys);
+  }
+
+  private Iterator<RecordInfo> getRecordInfoIterator(ImmutableSegment segment) {
+    int numTotalDocs = segment.getSegmentMetadata().getTotalDocs();
+    return new Iterator<RecordInfo>() {
+      private int _docId = 0;
+
+      @Override
+      public boolean hasNext() {
+        return _docId < numTotalDocs;
+      }
+
+      @Override
+      public RecordInfo next() {
+        PrimaryKey primaryKey = new PrimaryKey(new Object[_primaryKeyColumns.size()]);
+        getPrimaryKey(segment, _docId, primaryKey);
+
+        Object comparisonValue = segment.getValue(_docId, _comparisonColumn);
+        if (comparisonValue instanceof byte[]) {
+          comparisonValue = new ByteArray((byte[]) comparisonValue);
+        }
+        return new RecordInfo(primaryKey, _docId++, (Comparable) comparisonValue);
+      }
+    };
+  }
+
+  private void getPrimaryKey(IndexSegment segment, int docId, PrimaryKey buffer) {
+    Object[] values = buffer.getValues();
+    int numPrimaryKeyColumns = values.length;
+    for (int i = 0; i < numPrimaryKeyColumns; i++) {
+      Object value = segment.getValue(docId, _primaryKeyColumns.get(i));
+      if (value instanceof byte[]) {
+        value = new ByteArray((byte[]) value);
+      }
+      values[i] = value;
+    }
+  }
+
+  @VisibleForTesting
+  void addSegment(ImmutableSegmentImpl segment, ThreadSafeMutableRoaringBitmap validDocIds,
+      Iterator<RecordInfo> recordInfoIterator) {
+    String segmentName = segment.getSegmentName();
+    segment.enableUpsert(this, validDocIds);
     while (recordInfoIterator.hasNext()) {
       RecordInfo recordInfo = recordInfoIterator.next();
       _primaryKeyToRecordLocationMap.compute(HashUtils.hashPrimaryKey(recordInfo.getPrimaryKey(), _hashFunction),
@@ -172,15 +242,12 @@ public class PartitionUpsertMetadataManager {
             }
           });
     }
-    // Update metrics
-    _serverMetrics.setValueOfPartitionGauge(_tableNameWithType, _partitionId, ServerGauge.UPSERT_PRIMARY_KEYS_COUNT,
-        _primaryKeyToRecordLocationMap.size());
   }
 
   /**
    * Updates the upsert metadata for a new consumed record in the given consuming segment.
    */
-  public void addRecord(IndexSegment segment, RecordInfo recordInfo) {
+  public void addRecord(MutableSegment segment, RecordInfo recordInfo) {
     ThreadSafeMutableRoaringBitmap validDocIds = Objects.requireNonNull(segment.getValidDocIds());
     _primaryKeyToRecordLocationMap.compute(HashUtils.hashPrimaryKey(recordInfo.getPrimaryKey(), _hashFunction),
         (primaryKey, currentRecordLocation) -> {
@@ -208,11 +275,86 @@ public class PartitionUpsertMetadataManager {
             return new RecordLocation(segment, recordInfo.getDocId(), recordInfo.getComparisonValue());
           }
         });
+
     // Update metrics
     _serverMetrics.setValueOfPartitionGauge(_tableNameWithType, _partitionId, ServerGauge.UPSERT_PRIMARY_KEYS_COUNT,
         _primaryKeyToRecordLocationMap.size());
   }
 
+  /**
+   * Replaces the upsert metadata for the old segment with the new immutable segment.
+   */
+  public void replaceSegment(ImmutableSegment newSegment, IndexSegment oldSegment) {
+    String segmentName = newSegment.getSegmentName();
+    Preconditions.checkArgument(segmentName.equals(oldSegment.getSegmentName()),
+        "Cannot replace segment with different name for table: {}, old segment: {}, new segment: {}",
+        _tableNameWithType, oldSegment.getSegmentName(), segmentName);
+    _logger.info("Replacing {} segment: {}", oldSegment instanceof ImmutableSegment ? "immutable" : "mutable",
+        segmentName);
+
+    addSegment(newSegment);
+
+    MutableRoaringBitmap validDocIds =
+        oldSegment.getValidDocIds() != null ? oldSegment.getValidDocIds().getMutableRoaringBitmap() : null;
+    if (validDocIds != null && !validDocIds.isEmpty()) {
+      int numDocsNotReplaced = validDocIds.getCardinality();
+      if (_partialUpsertHandler != null) {
+        // For partial-upsert table, because we do not restore the original record location when removing the primary
+        // keys not replaced, it can potentially cause inconsistency between replicas. This can happen when a consuming
+        // segment is replaced by a committed segment that is consumed from a different server with different records
+        // (some stream consumer cannot guarantee consuming the messages in the same order).
+        _logger.error("{} primary keys not replaced when replacing segment: {} for partial-upsert table. This can "
+            + "potentially cause inconsistency between replicas", numDocsNotReplaced, segmentName);
+        _serverMetrics.addMeteredTableValue(_tableNameWithType, ServerMeter.PARTIAL_UPSERT_ROWS_NOT_REPLACED,
+            numDocsNotReplaced);
+      } else {
+        _logger.info("{} primary keys not replaced when replacing segment: {}", numDocsNotReplaced, segmentName);
+      }
+      removeSegment(oldSegment);
+    }
+
+    _logger.info("Finished replacing segment: {}", segmentName);
+  }
+
+  /**
+   * Removes the upsert metadata for the given segment.
+   */
+  public void removeSegment(IndexSegment segment) {
+    String segmentName = segment.getSegmentName();
+    _logger.info("Removing {} segment: {}, current primary key count: {}",
+        segment instanceof ImmutableSegment ? "immutable" : "mutable", segmentName,
+        _primaryKeyToRecordLocationMap.size());
+
+    MutableRoaringBitmap validDocIds =
+        segment.getValidDocIds() != null ? segment.getValidDocIds().getMutableRoaringBitmap() : null;
+    if (validDocIds == null || validDocIds.isEmpty()) {
+      _logger.info("Skip removing segment without valid docs: {}", segmentName);
+      return;
+    }
+
+    _logger.info("Removing {} primary keys for segment: {}", validDocIds.getCardinality(), segmentName);
+    PrimaryKey primaryKey = new PrimaryKey(new Object[_primaryKeyColumns.size()]);
+    PeekableIntIterator iterator = validDocIds.getIntIterator();
+    while (iterator.hasNext()) {
+      int docId = iterator.next();
+      getPrimaryKey(segment, docId, primaryKey);
+      _primaryKeyToRecordLocationMap.computeIfPresent(HashUtils.hashPrimaryKey(primaryKey, _hashFunction),
+          (pk, recordLocation) -> {
+            if (recordLocation.getSegment() == segment) {
+              return null;
+            }
+            return recordLocation;
+          });
+    }
+
+    // Update metrics
+    int numPrimaryKeys = _primaryKeyToRecordLocationMap.size();
+    _serverMetrics.setValueOfPartitionGauge(_tableNameWithType, _partitionId, ServerGauge.UPSERT_PRIMARY_KEYS_COUNT,
+        numPrimaryKeys);
+
+    _logger.info("Finished removing segment: {}, current primary key count: {}", segmentName, numPrimaryKeys);
+  }
+
   /**
    * Returns the merged record when partial-upsert is enabled.
    */
@@ -236,10 +378,9 @@ public class PartitionUpsertMetadataManager {
         _numOutOfOrderEvents++;
         long currentTimeNs = System.nanoTime();
         if (currentTimeNs - _lastOutOfOrderEventReportTimeNs > OUT_OF_ORDER_EVENT_MIN_REPORT_INTERVAL_NS) {
-          LOGGER.warn("Skipped {} out-of-order events for partial-upsert table: {} "
-                  + "(the last event has current comparison value: {}, record comparison value: {})",
-              _numOutOfOrderEvents,
-              _tableNameWithType, currentRecordLocation.getComparisonValue(), recordInfo.getComparisonValue());
+          _logger.warn("Skipped {} out-of-order events for partial-upsert table (the last event has current comparison "
+                  + "value: {}, record comparison value: {})", _numOutOfOrderEvents,
+              currentRecordLocation.getComparisonValue(), recordInfo.getComparisonValue());
           _lastOutOfOrderEventReportTimeNs = currentTimeNs;
           _numOutOfOrderEvents = 0;
         }
@@ -250,35 +391,4 @@ public class PartitionUpsertMetadataManager {
       return record;
     }
   }
-
-  /**
-   * Removes the upsert metadata for the given immutable segment. No need to remove the upsert metadata for the
-   * consuming segment because it should be replaced by the committed segment.
-   */
-  public void removeSegment(IndexSegment segment) {
-    String segmentName = segment.getSegmentName();
-    LOGGER.info("Removing upsert metadata for segment: {}", segmentName);
-
-    MutableRoaringBitmap mutableRoaringBitmap =
-        Objects.requireNonNull(segment.getValidDocIds()).getMutableRoaringBitmap();
-
-    if (!mutableRoaringBitmap.isEmpty()) {
-      PrimaryKey reuse = new PrimaryKey(new Object[_primaryKeyColumns.size()]);
-      PeekableIntIterator iterator = mutableRoaringBitmap.getIntIterator();
-      while (iterator.hasNext()) {
-        int docId = iterator.next();
-        segment.getPrimaryKey(docId, reuse);
-        _primaryKeyToRecordLocationMap.computeIfPresent(HashUtils.hashPrimaryKey(reuse, _hashFunction),
-            (pk, recordLocation) -> {
-              if (recordLocation.getSegment() == segment) {
-                return null;
-              }
-              return recordLocation;
-        });
-      }
-    }
-    // Update metrics
-    _serverMetrics.setValueOfPartitionGauge(_tableNameWithType, _partitionId, ServerGauge.UPSERT_PRIMARY_KEYS_COUNT,
-        _primaryKeyToRecordLocationMap.size());
-  }
 }
diff --git a/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/upsert/TableUpsertMetadataManager.java b/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/upsert/TableUpsertMetadataManager.java
index 6e09192c56..108438e95e 100644
--- a/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/upsert/TableUpsertMetadataManager.java
+++ b/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/upsert/TableUpsertMetadataManager.java
@@ -34,25 +34,26 @@ import org.apache.pinot.spi.config.table.HashFunction;
 public class TableUpsertMetadataManager {
   private final Map<Integer, PartitionUpsertMetadataManager> _partitionMetadataManagerMap = new ConcurrentHashMap<>();
   private final String _tableNameWithType;
-  private final ServerMetrics _serverMetrics;
-  private final PartialUpsertHandler _partialUpsertHandler;
-  private final HashFunction _hashFunction;
   private final List<String> _primaryKeyColumns;
+  private final String _comparisonColumn;
+  private final HashFunction _hashFunction;
+  private final PartialUpsertHandler _partialUpsertHandler;
+  private final ServerMetrics _serverMetrics;
 
-  public TableUpsertMetadataManager(String tableNameWithType, ServerMetrics serverMetrics,
-      @Nullable PartialUpsertHandler partialUpsertHandler, HashFunction hashFunction,
-      List<String> primaryKeyColumns) {
+  public TableUpsertMetadataManager(String tableNameWithType, List<String> primaryKeyColumns, String comparisonColumn,
+      HashFunction hashFunction, @Nullable PartialUpsertHandler partialUpsertHandler, ServerMetrics serverMetrics) {
     _tableNameWithType = tableNameWithType;
-    _serverMetrics = serverMetrics;
-    _partialUpsertHandler = partialUpsertHandler;
-    _hashFunction = hashFunction;
     _primaryKeyColumns = primaryKeyColumns;
+    _comparisonColumn = comparisonColumn;
+    _hashFunction = hashFunction;
+    _partialUpsertHandler = partialUpsertHandler;
+    _serverMetrics = serverMetrics;
   }
 
   public PartitionUpsertMetadataManager getOrCreatePartitionManager(int partitionId) {
     return _partitionMetadataManagerMap.computeIfAbsent(partitionId,
-        k -> new PartitionUpsertMetadataManager(_tableNameWithType, k, _serverMetrics, _partialUpsertHandler,
-            _hashFunction, _primaryKeyColumns));
+        k -> new PartitionUpsertMetadataManager(_tableNameWithType, k, _primaryKeyColumns, _comparisonColumn,
+            _hashFunction, _partialUpsertHandler, _serverMetrics));
   }
 
   public boolean isPartialUpsertEnabled() {
diff --git a/pinot-segment-local/src/test/java/org/apache/pinot/segment/local/indexsegment/mutable/MutableSegmentImplUpsertComparisonColTest.java b/pinot-segment-local/src/test/java/org/apache/pinot/segment/local/indexsegment/mutable/MutableSegmentImplUpsertComparisonColTest.java
index bd02b851be..60e336b7ec 100644
--- a/pinot-segment-local/src/test/java/org/apache/pinot/segment/local/indexsegment/mutable/MutableSegmentImplUpsertComparisonColTest.java
+++ b/pinot-segment-local/src/test/java/org/apache/pinot/segment/local/indexsegment/mutable/MutableSegmentImplUpsertComparisonColTest.java
@@ -35,12 +35,13 @@ import org.apache.pinot.spi.data.readers.GenericRow;
 import org.apache.pinot.spi.data.readers.RecordReader;
 import org.apache.pinot.spi.data.readers.RecordReaderFactory;
 import org.apache.pinot.spi.utils.builder.TableConfigBuilder;
-import org.mockito.Mockito;
 import org.roaringbitmap.buffer.ImmutableRoaringBitmap;
 import org.testng.Assert;
 import org.testng.annotations.BeforeClass;
 import org.testng.annotations.Test;
 
+import static org.mockito.Mockito.mock;
+
 
 public class MutableSegmentImplUpsertComparisonColTest {
   private static final String SCHEMA_FILE_PATH = "data/test_upsert_comparison_col_schema.json";
@@ -62,15 +63,15 @@ public class MutableSegmentImplUpsertComparisonColTest {
     _recordTransformer = CompositeTransformer.getDefaultTransformer(_tableConfig, _schema);
     File jsonFile = new File(dataResourceUrl.getFile());
     _partitionUpsertMetadataManager =
-        new TableUpsertMetadataManager("testTable_REALTIME", Mockito.mock(ServerMetrics.class), null,
-            HashFunction.NONE, _schema.getPrimaryKeyColumns()).getOrCreatePartitionManager(0);
-    _mutableSegmentImpl = MutableSegmentImplTestUtils
-        .createMutableSegmentImpl(_schema, Collections.emptySet(), Collections.emptySet(), Collections.emptySet(),
-            false, true, new UpsertConfig(UpsertConfig.Mode.FULL, null, null, "offset", null), "secondsSinceEpoch",
-            _partitionUpsertMetadataManager, null);
+        new TableUpsertMetadataManager("testTable_REALTIME", _schema.getPrimaryKeyColumns(), "offset",
+            HashFunction.NONE, null, mock(ServerMetrics.class)).getOrCreatePartitionManager(0);
+    _mutableSegmentImpl =
+        MutableSegmentImplTestUtils.createMutableSegmentImpl(_schema, Collections.emptySet(), Collections.emptySet(),
+            Collections.emptySet(), false, true, new UpsertConfig(UpsertConfig.Mode.FULL, null, null, "offset", null),
+            "secondsSinceEpoch", _partitionUpsertMetadataManager, null);
     GenericRow reuse = new GenericRow();
-    try (RecordReader recordReader = RecordReaderFactory
-        .getRecordReader(FileFormat.JSON, jsonFile, _schema.getColumnNames(), null)) {
+    try (RecordReader recordReader = RecordReaderFactory.getRecordReader(FileFormat.JSON, jsonFile,
+        _schema.getColumnNames(), null)) {
       while (recordReader.hasNext()) {
         recordReader.next(reuse);
         GenericRow transformedRow = _recordTransformer.transform(reuse);
diff --git a/pinot-segment-local/src/test/java/org/apache/pinot/segment/local/indexsegment/mutable/MutableSegmentImplUpsertTest.java b/pinot-segment-local/src/test/java/org/apache/pinot/segment/local/indexsegment/mutable/MutableSegmentImplUpsertTest.java
index 07033e20b6..a603303bc8 100644
--- a/pinot-segment-local/src/test/java/org/apache/pinot/segment/local/indexsegment/mutable/MutableSegmentImplUpsertTest.java
+++ b/pinot-segment-local/src/test/java/org/apache/pinot/segment/local/indexsegment/mutable/MutableSegmentImplUpsertTest.java
@@ -35,11 +35,12 @@ import org.apache.pinot.spi.data.readers.GenericRow;
 import org.apache.pinot.spi.data.readers.RecordReader;
 import org.apache.pinot.spi.data.readers.RecordReaderFactory;
 import org.apache.pinot.spi.utils.builder.TableConfigBuilder;
-import org.mockito.Mockito;
 import org.roaringbitmap.buffer.ImmutableRoaringBitmap;
 import org.testng.Assert;
 import org.testng.annotations.Test;
 
+import static org.mockito.Mockito.mock;
+
 
 public class MutableSegmentImplUpsertTest {
   private static final String SCHEMA_FILE_PATH = "data/test_upsert_schema.json";
@@ -60,12 +61,12 @@ public class MutableSegmentImplUpsertTest {
     _recordTransformer = CompositeTransformer.getDefaultTransformer(_tableConfig, _schema);
     File jsonFile = new File(dataResourceUrl.getFile());
     _partitionUpsertMetadataManager =
-        new TableUpsertMetadataManager("testTable_REALTIME", Mockito.mock(ServerMetrics.class), null, hashFunction,
-            _schema.getPrimaryKeyColumns())
-            .getOrCreatePartitionManager(0);
-    _mutableSegmentImpl = MutableSegmentImplTestUtils
-        .createMutableSegmentImpl(_schema, Collections.emptySet(), Collections.emptySet(), Collections.emptySet(),
-            false, true, new UpsertConfig(UpsertConfig.Mode.FULL, null, null, null, hashFunction), "secondsSinceEpoch",
+        new TableUpsertMetadataManager("testTable_REALTIME", _schema.getPrimaryKeyColumns(), "secondsSinceEpoch",
+            hashFunction, null, mock(ServerMetrics.class)).getOrCreatePartitionManager(0);
+    _mutableSegmentImpl =
+        MutableSegmentImplTestUtils.createMutableSegmentImpl(_schema, Collections.emptySet(), Collections.emptySet(),
+            Collections.emptySet(), false, true,
+            new UpsertConfig(UpsertConfig.Mode.FULL, null, null, null, hashFunction), "secondsSinceEpoch",
             _partitionUpsertMetadataManager, null);
 
     GenericRow reuse = new GenericRow();
diff --git a/pinot-segment-local/src/test/java/org/apache/pinot/segment/local/upsert/PartitionUpsertMetadataManagerTest.java b/pinot-segment-local/src/test/java/org/apache/pinot/segment/local/upsert/PartitionUpsertMetadataManagerTest.java
index 702ecf1b2b..639f6321e4 100644
--- a/pinot-segment-local/src/test/java/org/apache/pinot/segment/local/upsert/PartitionUpsertMetadataManagerTest.java
+++ b/pinot-segment-local/src/test/java/org/apache/pinot/segment/local/upsert/PartitionUpsertMetadataManagerTest.java
@@ -24,24 +24,23 @@ import java.util.List;
 import java.util.Map;
 import org.apache.pinot.common.metrics.ServerMetrics;
 import org.apache.pinot.common.utils.LLCSegmentName;
+import org.apache.pinot.segment.local.indexsegment.immutable.EmptyIndexSegment;
 import org.apache.pinot.segment.local.indexsegment.immutable.ImmutableSegmentImpl;
 import org.apache.pinot.segment.local.utils.HashUtils;
 import org.apache.pinot.segment.local.utils.RecordInfo;
 import org.apache.pinot.segment.spi.IndexSegment;
+import org.apache.pinot.segment.spi.MutableSegment;
+import org.apache.pinot.segment.spi.index.metadata.SegmentMetadataImpl;
 import org.apache.pinot.segment.spi.index.mutable.ThreadSafeMutableRoaringBitmap;
 import org.apache.pinot.spi.config.table.HashFunction;
-import org.apache.pinot.spi.data.readers.GenericRow;
 import org.apache.pinot.spi.data.readers.PrimaryKey;
 import org.apache.pinot.spi.utils.ByteArray;
 import org.apache.pinot.spi.utils.BytesUtils;
 import org.apache.pinot.spi.utils.builder.TableNameBuilder;
-import org.mockito.ArgumentMatchers;
-import org.testng.Assert;
 import org.testng.annotations.Test;
 
-import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.anyInt;
-import static org.mockito.Mockito.doAnswer;
+import static org.mockito.ArgumentMatchers.anyString;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.when;
 import static org.testng.Assert.assertEquals;
@@ -63,20 +62,19 @@ public class PartitionUpsertMetadataManagerTest {
 
   private void verifyAddSegment(HashFunction hashFunction) {
     PartitionUpsertMetadataManager upsertMetadataManager =
-        new PartitionUpsertMetadataManager(REALTIME_TABLE_NAME, 0, mock(ServerMetrics.class), null, hashFunction,
-            Collections.emptyList());
+        new PartitionUpsertMetadataManager(REALTIME_TABLE_NAME, 0, Collections.singletonList("pk"), "timeCol",
+            hashFunction, null, mock(ServerMetrics.class));
     Map<Object, RecordLocation> recordLocationMap = upsertMetadataManager._primaryKeyToRecordLocationMap;
 
     // Add the first segment
-
     int numRecords = 6;
     int[] primaryKeys = new int[]{0, 1, 2, 0, 1, 0};
     int[] timestamps = new int[]{100, 100, 100, 80, 120, 100};
-    List<RecordInfo> recordInfoList1 =
-        getRecordInfoList(numRecords, primaryKeys, timestamps);
     ThreadSafeMutableRoaringBitmap validDocIds1 = new ThreadSafeMutableRoaringBitmap();
-    ImmutableSegmentImpl segment1 = mockSegment(1, validDocIds1, getPrimaryKeyList(numRecords, primaryKeys));
-    upsertMetadataManager.addSegment(segment1, recordInfoList1.iterator());
+    List<PrimaryKey> primaryKeys1 = getPrimaryKeyList(numRecords, primaryKeys);
+    ImmutableSegmentImpl segment1 = mockImmutableSegment(1, validDocIds1, primaryKeys1);
+    List<RecordInfo> recordInfoList1 = getRecordInfoList(numRecords, primaryKeys, timestamps);
+    upsertMetadataManager.addSegment(segment1, validDocIds1, recordInfoList1.iterator());
     // segment1: 0 -> {5, 100}, 1 -> {4, 120}, 2 -> {2, 100}
     checkRecordLocation(recordLocationMap, 0, segment1, 5, 100, hashFunction);
     checkRecordLocation(recordLocationMap, 1, segment1, 4, 120, hashFunction);
@@ -87,11 +85,21 @@ public class PartitionUpsertMetadataManagerTest {
     numRecords = 5;
     primaryKeys = new int[]{0, 1, 2, 3, 0};
     timestamps = new int[]{100, 100, 120, 80, 80};
-    List<RecordInfo> recordInfoList2 =
-        getRecordInfoList(numRecords, primaryKeys, timestamps);
     ThreadSafeMutableRoaringBitmap validDocIds2 = new ThreadSafeMutableRoaringBitmap();
-    ImmutableSegmentImpl segment2 = mockSegment(2, validDocIds2);
-    upsertMetadataManager.addSegment(segment2, recordInfoList2.iterator());
+    ImmutableSegmentImpl segment2 = mockImmutableSegment(2, validDocIds2, getPrimaryKeyList(numRecords, primaryKeys));
+    upsertMetadataManager.addSegment(segment2, validDocIds2,
+        getRecordInfoList(numRecords, primaryKeys, timestamps).iterator());
+    // segment1: 1 -> {4, 120}
+    // segment2: 0 -> {0, 100}, 2 -> {2, 120}, 3 -> {3, 80}
+    checkRecordLocation(recordLocationMap, 0, segment2, 0, 100, hashFunction);
+    checkRecordLocation(recordLocationMap, 1, segment1, 4, 120, hashFunction);
+    checkRecordLocation(recordLocationMap, 2, segment2, 2, 120, hashFunction);
+    checkRecordLocation(recordLocationMap, 3, segment2, 3, 80, hashFunction);
+    assertEquals(validDocIds1.getMutableRoaringBitmap().toArray(), new int[]{4});
+    assertEquals(validDocIds2.getMutableRoaringBitmap().toArray(), new int[]{0, 2, 3});
+
+    // Add an empty segment
+    upsertMetadataManager.addSegment(new EmptyIndexSegment(mock(SegmentMetadataImpl.class)));
     // segment1: 1 -> {4, 120}
     // segment2: 0 -> {0, 100}, 2 -> {2, 120}, 3 -> {3, 80}
     checkRecordLocation(recordLocationMap, 0, segment2, 0, 100, hashFunction);
@@ -103,8 +111,8 @@ public class PartitionUpsertMetadataManagerTest {
 
     // Replace (reload) the first segment
     ThreadSafeMutableRoaringBitmap newValidDocIds1 = new ThreadSafeMutableRoaringBitmap();
-    ImmutableSegmentImpl newSegment1 = mockSegment(1, newValidDocIds1);
-    upsertMetadataManager.addSegment(newSegment1, recordInfoList1.iterator());
+    ImmutableSegmentImpl newSegment1 = mockImmutableSegment(1, newValidDocIds1, primaryKeys1);
+    upsertMetadataManager.addSegment(newSegment1, newValidDocIds1, recordInfoList1.iterator());
     // original segment1: 1 -> {4, 120}
     // segment2: 0 -> {0, 100}, 2 -> {2, 120}, 3 -> {3, 80}
     // new segment1: 1 -> {4, 120}
@@ -115,8 +123,8 @@ public class PartitionUpsertMetadataManagerTest {
     assertEquals(validDocIds1.getMutableRoaringBitmap().toArray(), new int[]{4});
     assertEquals(validDocIds2.getMutableRoaringBitmap().toArray(), new int[]{0, 2, 3});
     assertEquals(newValidDocIds1.getMutableRoaringBitmap().toArray(), new int[]{4});
-    assertSame(recordLocationMap.get(HashUtils.hashPrimaryKey(makePrimaryKey(1), hashFunction))
-        .getSegment(), newSegment1);
+    assertSame(recordLocationMap.get(HashUtils.hashPrimaryKey(makePrimaryKey(1), hashFunction)).getSegment(),
+        newSegment1);
 
     // Remove the original segment1
     upsertMetadataManager.removeSegment(segment1);
@@ -128,16 +136,27 @@ public class PartitionUpsertMetadataManagerTest {
     checkRecordLocation(recordLocationMap, 3, segment2, 3, 80, hashFunction);
     assertEquals(validDocIds2.getMutableRoaringBitmap().toArray(), new int[]{0, 2, 3});
     assertEquals(newValidDocIds1.getMutableRoaringBitmap().toArray(), new int[]{4});
-    assertSame(recordLocationMap.get(HashUtils.hashPrimaryKey(makePrimaryKey(1), hashFunction))
-        .getSegment(), newSegment1);
+    assertSame(recordLocationMap.get(HashUtils.hashPrimaryKey(makePrimaryKey(1), hashFunction)).getSegment(),
+        newSegment1);
+
+    // Remove an empty segment
+    upsertMetadataManager.removeSegment(new EmptyIndexSegment(mock(SegmentMetadataImpl.class)));
+    // segment2: 0 -> {0, 100}, 2 -> {2, 120}, 3 -> {3, 80}
+    // new segment1: 1 -> {4, 120}
+    checkRecordLocation(recordLocationMap, 0, segment2, 0, 100, hashFunction);
+    checkRecordLocation(recordLocationMap, 1, newSegment1, 4, 120, hashFunction);
+    checkRecordLocation(recordLocationMap, 2, segment2, 2, 120, hashFunction);
+    checkRecordLocation(recordLocationMap, 3, segment2, 3, 80, hashFunction);
+    assertEquals(validDocIds2.getMutableRoaringBitmap().toArray(), new int[]{0, 2, 3});
+    assertEquals(newValidDocIds1.getMutableRoaringBitmap().toArray(), new int[]{4});
+    assertSame(recordLocationMap.get(HashUtils.hashPrimaryKey(makePrimaryKey(1), hashFunction)).getSegment(),
+        newSegment1);
   }
 
-  private List<RecordInfo> getRecordInfoList(int numRecords, int[] primaryKeys,
-      int[] timestamps) {
+  private List<RecordInfo> getRecordInfoList(int numRecords, int[] primaryKeys, int[] timestamps) {
     List<RecordInfo> recordInfoList = new ArrayList<>();
     for (int i = 0; i < numRecords; i++) {
-      recordInfoList.add(new RecordInfo(makePrimaryKey(primaryKeys[i]), i,
-          new IntWrapper(timestamps[i])));
+      recordInfoList.add(new RecordInfo(makePrimaryKey(primaryKeys[i]), i, new IntWrapper(timestamps[i])));
     }
     return recordInfoList;
   }
@@ -150,32 +169,20 @@ public class PartitionUpsertMetadataManagerTest {
     return primaryKeyList;
   }
 
-  private static ImmutableSegmentImpl mockSegment(int sequenceNumber, ThreadSafeMutableRoaringBitmap validDocIds) {
+  private static ImmutableSegmentImpl mockImmutableSegment(int sequenceNumber,
+      ThreadSafeMutableRoaringBitmap validDocIds, List<PrimaryKey> primaryKeys) {
     ImmutableSegmentImpl segment = mock(ImmutableSegmentImpl.class);
-    String segmentName = getSegmentName(sequenceNumber);
-    when(segment.getSegmentName()).thenReturn(segmentName);
+    when(segment.getSegmentName()).thenReturn(getSegmentName(sequenceNumber));
     when(segment.getValidDocIds()).thenReturn(validDocIds);
-    when(segment.getRecord(anyInt(), ArgumentMatchers.any(GenericRow.class))).thenReturn(new GenericRow());
+    when(segment.getValue(anyInt(), anyString())).thenAnswer(
+        invocation -> primaryKeys.get(invocation.getArgument(0)).getValues()[0]);
     return segment;
   }
 
-  private static ImmutableSegmentImpl mockSegment(int sequenceNumber, ThreadSafeMutableRoaringBitmap validDocIds,
-      List<PrimaryKey> primaryKeys) {
-    ImmutableSegmentImpl segment = mock(ImmutableSegmentImpl.class);
-
-    String segmentName = getSegmentName(sequenceNumber);
-    when(segment.getSegmentName()).thenReturn(segmentName);
+  private static MutableSegment mockMutableSegment(int sequenceNumber, ThreadSafeMutableRoaringBitmap validDocIds) {
+    MutableSegment segment = mock(MutableSegment.class);
+    when(segment.getSegmentName()).thenReturn(getSegmentName(sequenceNumber));
     when(segment.getValidDocIds()).thenReturn(validDocIds);
-    doAnswer((invocation) -> {
-      PrimaryKey pk = primaryKeys.get(invocation.getArgument(0));
-      PrimaryKey reuse = invocation.getArgument(1, PrimaryKey.class);
-      Object[] reuseValues = reuse.getValues();
-      for (int i = 0; i < reuseValues.length; i++) {
-        reuseValues[i] = pk.getValues()[i];
-      }
-        return null;
-      }).when(segment).getPrimaryKey(anyInt(), any(PrimaryKey.class));
-
     return segment;
   }
 
@@ -206,8 +213,8 @@ public class PartitionUpsertMetadataManagerTest {
 
   private void verifyAddRecord(HashFunction hashFunction) {
     PartitionUpsertMetadataManager upsertMetadataManager =
-        new PartitionUpsertMetadataManager(REALTIME_TABLE_NAME, 0, mock(ServerMetrics.class), null, hashFunction,
-            Collections.emptyList());
+        new PartitionUpsertMetadataManager(REALTIME_TABLE_NAME, 0, Collections.singletonList("pk"), "timeCol",
+            hashFunction, null, mock(ServerMetrics.class));
     Map<Object, RecordLocation> recordLocationMap = upsertMetadataManager._primaryKeyToRecordLocationMap;
 
     // Add the first segment
@@ -215,18 +222,15 @@ public class PartitionUpsertMetadataManagerTest {
     int numRecords = 3;
     int[] primaryKeys = new int[]{0, 1, 2};
     int[] timestamps = new int[]{100, 120, 100};
-    List<RecordInfo> recordInfoList1 =
-        getRecordInfoList(numRecords, primaryKeys, timestamps);
     ThreadSafeMutableRoaringBitmap validDocIds1 = new ThreadSafeMutableRoaringBitmap();
-    ImmutableSegmentImpl segment1 = mockSegment(1, validDocIds1);
-    upsertMetadataManager.addSegment(segment1, recordInfoList1.iterator());
+    ImmutableSegmentImpl segment1 = mockImmutableSegment(1, validDocIds1, getPrimaryKeyList(numRecords, primaryKeys));
+    upsertMetadataManager.addSegment(segment1, validDocIds1,
+        getRecordInfoList(numRecords, primaryKeys, timestamps).iterator());
 
     // Update records from the second segment
     ThreadSafeMutableRoaringBitmap validDocIds2 = new ThreadSafeMutableRoaringBitmap();
-    IndexSegment segment2 = mockSegment(1, validDocIds2);
-
-    upsertMetadataManager.addRecord(segment2,
-        new RecordInfo(makePrimaryKey(3), 0, new IntWrapper(100)));
+    MutableSegment segment2 = mockMutableSegment(1, validDocIds2);
+    upsertMetadataManager.addRecord(segment2, new RecordInfo(makePrimaryKey(3), 0, new IntWrapper(100)));
 
     // segment1: 0 -> {0, 100}, 1 -> {1, 120}, 2 -> {2, 100}
     // segment2: 3 -> {0, 100}
@@ -237,8 +241,7 @@ public class PartitionUpsertMetadataManagerTest {
     assertEquals(validDocIds1.getMutableRoaringBitmap().toArray(), new int[]{0, 1, 2});
     assertEquals(validDocIds2.getMutableRoaringBitmap().toArray(), new int[]{0});
 
-    upsertMetadataManager.addRecord(segment2,
-        new RecordInfo(makePrimaryKey(2), 1, new IntWrapper(120)));
+    upsertMetadataManager.addRecord(segment2, new RecordInfo(makePrimaryKey(2), 1, new IntWrapper(120)));
     // segment1: 0 -> {0, 100}, 1 -> {1, 120}
     // segment2: 2 -> {1, 120}, 3 -> {0, 100}
     checkRecordLocation(recordLocationMap, 0, segment1, 0, 100, hashFunction);
@@ -248,8 +251,7 @@ public class PartitionUpsertMetadataManagerTest {
     assertEquals(validDocIds1.getMutableRoaringBitmap().toArray(), new int[]{0, 1});
     assertEquals(validDocIds2.getMutableRoaringBitmap().toArray(), new int[]{0, 1});
 
-    upsertMetadataManager.addRecord(segment2,
-        new RecordInfo(makePrimaryKey(1), 2, new IntWrapper(100)));
+    upsertMetadataManager.addRecord(segment2, new RecordInfo(makePrimaryKey(1), 2, new IntWrapper(100)));
     // segment1: 0 -> {0, 100}, 1 -> {1, 120}
     // segment2: 2 -> {1, 120}, 3 -> {0, 100}
     checkRecordLocation(recordLocationMap, 0, segment1, 0, 100, hashFunction);
@@ -259,8 +261,7 @@ public class PartitionUpsertMetadataManagerTest {
     assertEquals(validDocIds1.getMutableRoaringBitmap().toArray(), new int[]{0, 1});
     assertEquals(validDocIds2.getMutableRoaringBitmap().toArray(), new int[]{0, 1});
 
-    upsertMetadataManager.addRecord(segment2,
-        new RecordInfo(makePrimaryKey(0), 3, new IntWrapper(100)));
+    upsertMetadataManager.addRecord(segment2, new RecordInfo(makePrimaryKey(0), 3, new IntWrapper(100)));
     // segment1: 1 -> {1, 120}
     // segment2: 0 -> {3, 100}, 2 -> {1, 120}, 3 -> {0, 100}
     checkRecordLocation(recordLocationMap, 0, segment2, 3, 100, hashFunction);
@@ -280,8 +281,8 @@ public class PartitionUpsertMetadataManagerTest {
 
   private void verifyRemoveSegment(HashFunction hashFunction) {
     PartitionUpsertMetadataManager upsertMetadataManager =
-        new PartitionUpsertMetadataManager(REALTIME_TABLE_NAME, 0, mock(ServerMetrics.class), null, hashFunction,
-            Collections.singletonList("primaryKey"));
+        new PartitionUpsertMetadataManager(REALTIME_TABLE_NAME, 0, Collections.singletonList("pk"), "timeCol",
+            hashFunction, null, mock(ServerMetrics.class));
     Map<Object, RecordLocation> recordLocationMap = upsertMetadataManager._primaryKeyToRecordLocationMap;
 
     // Add 2 segments
@@ -290,18 +291,16 @@ public class PartitionUpsertMetadataManagerTest {
     int numRecords = 2;
     int[] primaryKeys = new int[]{0, 1};
     int[] timestamps = new int[]{100, 100};
-    List<RecordInfo> recordInfoList1 =
-        getRecordInfoList(numRecords, primaryKeys, timestamps);
     ThreadSafeMutableRoaringBitmap validDocIds1 = new ThreadSafeMutableRoaringBitmap();
-    ImmutableSegmentImpl segment1 = mockSegment(1, validDocIds1, getPrimaryKeyList(numRecords, primaryKeys));
-    upsertMetadataManager.addSegment(segment1, recordInfoList1.iterator());
+    ImmutableSegmentImpl segment1 = mockImmutableSegment(1, validDocIds1, getPrimaryKeyList(numRecords, primaryKeys));
+    upsertMetadataManager.addSegment(segment1, validDocIds1,
+        getRecordInfoList(numRecords, primaryKeys, timestamps).iterator());
 
     primaryKeys = new int[]{2, 3};
-    List<RecordInfo> recordInfoList2 =
-        getRecordInfoList(numRecords, primaryKeys, timestamps);
     ThreadSafeMutableRoaringBitmap validDocIds2 = new ThreadSafeMutableRoaringBitmap();
-    ImmutableSegmentImpl segment2 = mockSegment(2, validDocIds2, getPrimaryKeyList(numRecords, primaryKeys));
-    upsertMetadataManager.addSegment(segment2, recordInfoList2.iterator());
+    ImmutableSegmentImpl segment2 = mockImmutableSegment(2, validDocIds2, getPrimaryKeyList(numRecords, primaryKeys));
+    upsertMetadataManager.addSegment(segment2, validDocIds2,
+        getRecordInfoList(numRecords, primaryKeys, timestamps).iterator());
 
     // Remove the first segment
     upsertMetadataManager.removeSegment(segment1);
@@ -316,20 +315,15 @@ public class PartitionUpsertMetadataManagerTest {
   @Test
   public void testHashPrimaryKey() {
     PrimaryKey pk = new PrimaryKey(new Object[]{"uuid-1", "uuid-2", "uuid-3"});
-    Assert.assertEquals(BytesUtils.toHexString(
-            ((ByteArray) HashUtils.hashPrimaryKey(pk, HashFunction.MD5)).
-                getBytes()),
+    assertEquals(BytesUtils.toHexString(((ByteArray) HashUtils.hashPrimaryKey(pk, HashFunction.MD5)).getBytes()),
         "58de44997505014e02982846a4d1cbbd");
-    Assert.assertEquals(BytesUtils.toHexString(
-        ((ByteArray) HashUtils.hashPrimaryKey(pk, HashFunction.MURMUR3)).getBytes()),
+    assertEquals(BytesUtils.toHexString(((ByteArray) HashUtils.hashPrimaryKey(pk, HashFunction.MURMUR3)).getBytes()),
         "7e6b4a98296292a4012225fff037fa8c");
     // reorder
     pk = new PrimaryKey(new Object[]{"uuid-3", "uuid-2", "uuid-1"});
-    Assert.assertEquals(BytesUtils.toHexString(
-        ((ByteArray) HashUtils.hashPrimaryKey(pk, HashFunction.MD5)).getBytes()),
+    assertEquals(BytesUtils.toHexString(((ByteArray) HashUtils.hashPrimaryKey(pk, HashFunction.MD5)).getBytes()),
         "d2df12c6dea7b83f965613614eee58e2");
-    Assert.assertEquals(BytesUtils.toHexString(
-        ((ByteArray) HashUtils.hashPrimaryKey(pk, HashFunction.MURMUR3)).getBytes()),
+    assertEquals(BytesUtils.toHexString(((ByteArray) HashUtils.hashPrimaryKey(pk, HashFunction.MURMUR3)).getBytes()),
         "8d68b314cc0c8de4dbd55f4dad3c3e66");
   }
 
diff --git a/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/IndexSegment.java b/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/IndexSegment.java
index 9984e89405..bbd4af2545 100644
--- a/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/IndexSegment.java
+++ b/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/IndexSegment.java
@@ -26,7 +26,6 @@ import org.apache.pinot.segment.spi.index.mutable.ThreadSafeMutableRoaringBitmap
 import org.apache.pinot.segment.spi.index.startree.StarTreeV2;
 import org.apache.pinot.spi.annotations.InterfaceAudience;
 import org.apache.pinot.spi.data.readers.GenericRow;
-import org.apache.pinot.spi.data.readers.PrimaryKey;
 
 
 @InterfaceAudience.Private
@@ -88,13 +87,9 @@ public interface IndexSegment {
   GenericRow getRecord(int docId, GenericRow reuse);
 
   /**
-   * Returns the primaryKey for a given docId
-   *
-   * @param docId Document Id
-   * @param reuse Reusable buffer for the primary key
-   * @return Primary key for the given document Id
+   * Returns the value for the column at the document id. Returns byte[] for BYTES data type.
    */
-  void getPrimaryKey(int docId, PrimaryKey reuse);
+  Object getValue(int docId, String column);
 
   /**
    * Hints the segment to begin prefetching buffers for specified columns.


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@pinot.apache.org
For additional commands, e-mail: commits-help@pinot.apache.org