You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@pinot.apache.org by kh...@apache.org on 2022/07/16 09:34:27 UTC

[pinot] branch master updated: Remove segments using valid doc ids instead of primary key (#8674)

This is an automated email from the ASF dual-hosted git repository.

kharekartik pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new 83410f69b4 Remove segments using valid doc ids instead of primary key (#8674)
83410f69b4 is described below

commit 83410f69b49988f4a1c1f0f373c8bafcd25404ba
Author: Kartik Khare <kh...@gmail.com>
AuthorDate: Sat Jul 16 15:04:19 2022 +0530

    Remove segments using valid doc ids instead of primary key (#8674)
    
    * Remove segments using valid doc ids instead of primary key
    
    * Handle concurrent updates
    
    * Refactor
    
    * Reuse primary key
    
    * Hash primary key before checking
    
    * Reuse primary key
    
    * Refactor: move methods for readability and cache values in local variable
    
    Co-authored-by: Kartik Khare <kh...@Kartiks-MacBook-Pro.local>
---
 .../manager/realtime/RealtimeTableDataManager.java |  17 +--
 ...adataAndDictionaryAggregationPlanMakerTest.java |   3 +-
 .../plan/maker/QueryOverrideWithHintsTest.java     |   5 +
 .../indexsegment/immutable/EmptyIndexSegment.java  |   6 +
 .../immutable/ImmutableSegmentImpl.java            |  14 +++
 .../indexsegment/mutable/IntermediateSegment.java  |  14 +++
 .../indexsegment/mutable/MutableSegmentImpl.java   |  13 ++
 .../segment/readers/PinotSegmentRecordReader.java  |  15 +++
 .../upsert/PartitionUpsertMetadataManager.java     |  38 ++++--
 .../local/upsert/TableUpsertMetadataManager.java   |   8 +-
 .../MutableSegmentImplUpsertComparisonColTest.java |   2 +-
 .../mutable/MutableSegmentImplUpsertTest.java      |   8 +-
 .../upsert/PartitionUpsertMetadataManagerTest.java | 132 ++++++++++++++-------
 .../org/apache/pinot/segment/spi/IndexSegment.java |  10 ++
 14 files changed, 222 insertions(+), 63 deletions(-)

diff --git a/pinot-core/src/main/java/org/apache/pinot/core/data/manager/realtime/RealtimeTableDataManager.java b/pinot-core/src/main/java/org/apache/pinot/core/data/manager/realtime/RealtimeTableDataManager.java
index 195e5e023e..4fa98e5720 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/data/manager/realtime/RealtimeTableDataManager.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/data/manager/realtime/RealtimeTableDataManager.java
@@ -134,8 +134,9 @@ public class RealtimeTableDataManager extends BaseTableDataManager {
     try {
       _statsHistory = RealtimeSegmentStatsHistory.deserialzeFrom(statsFile);
     } catch (IOException | ClassNotFoundException e) {
-      _logger.error("Error reading history object for table {} from {}", _tableNameWithType,
-          statsFile.getAbsolutePath(), e);
+      _logger
+          .error("Error reading history object for table {} from {}", _tableNameWithType, statsFile.getAbsolutePath(),
+              e);
       File savedFile = new File(_tableDataDir, STATS_FILE_NAME + "." + UUID.randomUUID());
       try {
         FileUtils.moveFile(statsFile, savedFile);
@@ -198,6 +199,7 @@ public class RealtimeTableDataManager extends BaseTableDataManager {
       _primaryKeyColumns = schema.getPrimaryKeyColumns();
       Preconditions.checkState(!CollectionUtils.isEmpty(_primaryKeyColumns),
           "Primary key columns must be configured for upsert");
+
       String comparisonColumn = upsertConfig.getComparisonColumn();
       _upsertComparisonColumn =
           comparisonColumn != null ? comparisonColumn : tableConfig.getValidationConfig().getTimeColumnName();
@@ -211,7 +213,7 @@ public class RealtimeTableDataManager extends BaseTableDataManager {
 
       _tableUpsertMetadataManager =
           new TableUpsertMetadataManager(_tableNameWithType, _serverMetrics, partialUpsertHandler,
-              upsertConfig.getHashFunction());
+              upsertConfig.getHashFunction(), _primaryKeyColumns);
     }
   }
 
@@ -369,6 +371,7 @@ public class RealtimeTableDataManager extends BaseTableDataManager {
           }
         }
       }
+
       segmentDataManager =
           new LLRealtimeSegmentDataManager(segmentZKMetadata, tableConfig, this, _indexDir.getAbsolutePath(),
               indexLoadingConfig, schema, llcSegmentName, semaphore, _serverMetrics, partitionUpsertMetadataManager,
@@ -428,8 +431,8 @@ public class RealtimeTableDataManager extends BaseTableDataManager {
     for (String primaryKeyColumn : _primaryKeyColumns) {
       columnToReaderMap.put(primaryKeyColumn, new PinotSegmentColumnReader(immutableSegment, primaryKeyColumn));
     }
-    columnToReaderMap.put(_upsertComparisonColumn,
-        new PinotSegmentColumnReader(immutableSegment, _upsertComparisonColumn));
+    columnToReaderMap
+        .put(_upsertComparisonColumn, new PinotSegmentColumnReader(immutableSegment, _upsertComparisonColumn));
     int numTotalDocs = immutableSegment.getSegmentMetadata().getTotalDocs();
     int numPrimaryKeyColumns = _primaryKeyColumns.size();
     Iterator<RecordInfo> recordInfoIterator = new Iterator<RecordInfo>() {
@@ -535,8 +538,8 @@ public class RealtimeTableDataManager extends BaseTableDataManager {
   private boolean isPeerSegmentDownloadEnabled(TableConfig tableConfig) {
     return
         CommonConstants.HTTP_PROTOCOL.equalsIgnoreCase(tableConfig.getValidationConfig().getPeerSegmentDownloadScheme())
-            || CommonConstants.HTTPS_PROTOCOL.equalsIgnoreCase(
-            tableConfig.getValidationConfig().getPeerSegmentDownloadScheme());
+            || CommonConstants.HTTPS_PROTOCOL
+            .equalsIgnoreCase(tableConfig.getValidationConfig().getPeerSegmentDownloadScheme());
   }
 
   private void downloadSegmentFromPeer(String segmentName, String downloadScheme,
diff --git a/pinot-core/src/test/java/org/apache/pinot/core/plan/maker/MetadataAndDictionaryAggregationPlanMakerTest.java b/pinot-core/src/test/java/org/apache/pinot/core/plan/maker/MetadataAndDictionaryAggregationPlanMakerTest.java
index 612e686134..f6b55b08c3 100644
--- a/pinot-core/src/test/java/org/apache/pinot/core/plan/maker/MetadataAndDictionaryAggregationPlanMakerTest.java
+++ b/pinot-core/src/test/java/org/apache/pinot/core/plan/maker/MetadataAndDictionaryAggregationPlanMakerTest.java
@@ -22,6 +22,7 @@ import java.io.File;
 import java.net.URL;
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.Collections;
 import java.util.List;
 import java.util.concurrent.TimeUnit;
 import org.apache.commons.io.FileUtils;
@@ -125,7 +126,7 @@ public class MetadataAndDictionaryAggregationPlanMakerTest {
     _upsertIndexSegment = ImmutableSegmentLoader.load(new File(INDEX_DIR, SEGMENT_NAME), ReadMode.heap);
     ((ImmutableSegmentImpl) _upsertIndexSegment).enableUpsert(
         new PartitionUpsertMetadataManager("testTable_REALTIME", 0, serverMetrics, null,
-            HashFunction.NONE), new ThreadSafeMutableRoaringBitmap());
+            HashFunction.NONE, Collections.emptyList()), new ThreadSafeMutableRoaringBitmap());
   }
 
   @AfterClass
diff --git a/pinot-core/src/test/java/org/apache/pinot/core/plan/maker/QueryOverrideWithHintsTest.java b/pinot-core/src/test/java/org/apache/pinot/core/plan/maker/QueryOverrideWithHintsTest.java
index 31d65af0da..bffcd6a974 100644
--- a/pinot-core/src/test/java/org/apache/pinot/core/plan/maker/QueryOverrideWithHintsTest.java
+++ b/pinot-core/src/test/java/org/apache/pinot/core/plan/maker/QueryOverrideWithHintsTest.java
@@ -42,6 +42,7 @@ import org.apache.pinot.segment.spi.datasource.DataSource;
 import org.apache.pinot.segment.spi.index.mutable.ThreadSafeMutableRoaringBitmap;
 import org.apache.pinot.segment.spi.index.startree.StarTreeV2;
 import org.apache.pinot.spi.data.readers.GenericRow;
+import org.apache.pinot.spi.data.readers.PrimaryKey;
 import org.apache.pinot.sql.parsers.CalciteSqlParser;
 import org.testng.annotations.Test;
 
@@ -93,6 +94,10 @@ public class QueryOverrideWithHintsTest {
       return null;
     }
 
+    @Override
+    public void getPrimaryKey(int docId, PrimaryKey reuse) {
+    }
+
     @Override
     public void destroy() {
     }
diff --git a/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/indexsegment/immutable/EmptyIndexSegment.java b/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/indexsegment/immutable/EmptyIndexSegment.java
index 7ef97295f6..db19e7d7a8 100644
--- a/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/indexsegment/immutable/EmptyIndexSegment.java
+++ b/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/indexsegment/immutable/EmptyIndexSegment.java
@@ -33,6 +33,7 @@ import org.apache.pinot.segment.spi.index.reader.ForwardIndexReader;
 import org.apache.pinot.segment.spi.index.reader.InvertedIndexReader;
 import org.apache.pinot.segment.spi.index.startree.StarTreeV2;
 import org.apache.pinot.spi.data.readers.GenericRow;
+import org.apache.pinot.spi.data.readers.PrimaryKey;
 
 
 /**
@@ -94,6 +95,11 @@ public class EmptyIndexSegment implements ImmutableSegment {
     throw new UnsupportedOperationException("Cannot read record from empty segment");
   }
 
+  @Override
+  public void getPrimaryKey(int docId, PrimaryKey reuse) {
+    throw new UnsupportedOperationException("Cannot read primary key from empty segment");
+  }
+
   @Override
   public Dictionary getDictionary(String column) {
     return null;
diff --git a/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/indexsegment/immutable/ImmutableSegmentImpl.java b/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/indexsegment/immutable/ImmutableSegmentImpl.java
index 9f37d81b6b..06f7bc252f 100644
--- a/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/indexsegment/immutable/ImmutableSegmentImpl.java
+++ b/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/indexsegment/immutable/ImmutableSegmentImpl.java
@@ -44,6 +44,7 @@ import org.apache.pinot.segment.spi.index.reader.InvertedIndexReader;
 import org.apache.pinot.segment.spi.index.startree.StarTreeV2;
 import org.apache.pinot.segment.spi.store.SegmentDirectory;
 import org.apache.pinot.spi.data.readers.GenericRow;
+import org.apache.pinot.spi.data.readers.PrimaryKey;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -233,4 +234,17 @@ public class ImmutableSegmentImpl implements ImmutableSegment {
       throw new RuntimeException("Failed to use PinotSegmentRecordReader to read immutable segment");
     }
   }
+
+  @Override
+  public void getPrimaryKey(int docId, PrimaryKey reuse) {
+    try {
+      if (_pinotSegmentRecordReader == null) {
+        _pinotSegmentRecordReader = new PinotSegmentRecordReader();
+        _pinotSegmentRecordReader.init(this);
+      }
+      _pinotSegmentRecordReader.getPrimaryKey(docId, _partitionUpsertMetadataManager.getPrimaryKeyColumns(), reuse);
+    } catch (Exception e) {
+      throw new RuntimeException("Failed to use PinotSegmentRecordReader to read primary key from immutable segment");
+    }
+  }
 }
diff --git a/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/indexsegment/mutable/IntermediateSegment.java b/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/indexsegment/mutable/IntermediateSegment.java
index 481829c86e..7f2e8b9356 100644
--- a/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/indexsegment/mutable/IntermediateSegment.java
+++ b/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/indexsegment/mutable/IntermediateSegment.java
@@ -55,6 +55,7 @@ import org.apache.pinot.spi.data.FieldSpec.DataType;
 import org.apache.pinot.spi.data.FieldSpec.FieldType;
 import org.apache.pinot.spi.data.Schema;
 import org.apache.pinot.spi.data.readers.GenericRow;
+import org.apache.pinot.spi.data.readers.PrimaryKey;
 import org.apache.pinot.spi.stream.RowMetadata;
 import org.apache.pinot.spi.utils.ByteArray;
 import org.slf4j.Logger;
@@ -224,6 +225,19 @@ public class IntermediateSegment implements MutableSegment {
     return reuse;
   }
 
+  @Override
+  public void getPrimaryKey(int docId, PrimaryKey reuse) {
+    int numPrimaryKeyColumns = _schema.getPrimaryKeyColumns().size();
+    Object[] values = reuse.getValues();
+    for (int i = 0; i < numPrimaryKeyColumns; i++) {
+      IntermediateIndexContainer indexContainer = _indexContainerMap.get(
+          _schema.getPrimaryKeyColumns().get(i));
+      Object value = getValue(docId, indexContainer.getForwardIndex(), indexContainer.getDictionary(),
+          indexContainer.getNumValuesInfo().getMaxNumValuesPerMVEntry());
+      values[i] = value;
+    }
+  }
+
   @Override
   public void destroy() {
     String segmentName = getSegmentName();
diff --git a/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/indexsegment/mutable/MutableSegmentImpl.java b/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/indexsegment/mutable/MutableSegmentImpl.java
index 1b0c755d5c..04eb437899 100644
--- a/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/indexsegment/mutable/MutableSegmentImpl.java
+++ b/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/indexsegment/mutable/MutableSegmentImpl.java
@@ -1017,6 +1017,19 @@ public class MutableSegmentImpl implements MutableSegment {
     return reuse;
   }
 
+  @Override
+  public void getPrimaryKey(int docId, PrimaryKey reuse) {
+    int numPrimaryKeyColumns = _partitionUpsertMetadataManager.getPrimaryKeyColumns().size();
+    Object[] values = reuse.getValues();
+    for (int i = 0; i < numPrimaryKeyColumns; i++) {
+      IndexContainer indexContainer = _indexContainerMap.get(
+          _partitionUpsertMetadataManager.getPrimaryKeyColumns().get(i));
+      Object value = getValue(docId, indexContainer._forwardIndex, indexContainer._dictionary,
+          indexContainer._numValuesInfo._maxNumValuesPerMVEntry);
+      values[i] = value;
+    }
+  }
+
   /**
    * Helper method to read the value for the given document id.
    */
diff --git a/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/segment/readers/PinotSegmentRecordReader.java b/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/segment/readers/PinotSegmentRecordReader.java
index 1d3d438785..7173d9d241 100644
--- a/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/segment/readers/PinotSegmentRecordReader.java
+++ b/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/segment/readers/PinotSegmentRecordReader.java
@@ -33,8 +33,10 @@ import org.apache.pinot.segment.spi.IndexSegment;
 import org.apache.pinot.segment.spi.MutableSegment;
 import org.apache.pinot.spi.data.Schema;
 import org.apache.pinot.spi.data.readers.GenericRow;
+import org.apache.pinot.spi.data.readers.PrimaryKey;
 import org.apache.pinot.spi.data.readers.RecordReader;
 import org.apache.pinot.spi.data.readers.RecordReaderConfig;
+import org.apache.pinot.spi.utils.ByteArray;
 import org.apache.pinot.spi.utils.ReadMode;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -230,6 +232,19 @@ public class PinotSegmentRecordReader implements RecordReader {
     }
   }
 
+  public void getPrimaryKey(int docId, List<String> primaryKeyColumns, PrimaryKey reuse) {
+    int numPrimaryKeyColumns = primaryKeyColumns.size();
+    Object[] values = reuse.getValues();
+    for (int i = 0; i < numPrimaryKeyColumns; i++) {
+      PinotSegmentColumnReader columnReader = _columnReaderMap.get(primaryKeyColumns.get(i));
+      Object value = columnReader.getValue(docId);
+      if (value instanceof byte[]) {
+        value = new ByteArray((byte[]) value);
+      }
+      values[i] = value;
+    }
+  }
+
   @Override
   public void rewind() {
     _nextDocId = 0;
diff --git a/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/upsert/PartitionUpsertMetadataManager.java b/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/upsert/PartitionUpsertMetadataManager.java
index 8214162fe9..c279511a4b 100644
--- a/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/upsert/PartitionUpsertMetadataManager.java
+++ b/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/upsert/PartitionUpsertMetadataManager.java
@@ -20,6 +20,7 @@ package org.apache.pinot.segment.local.upsert;
 
 import com.google.common.annotations.VisibleForTesting;
 import java.util.Iterator;
+import java.util.List;
 import java.util.Objects;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.TimeUnit;
@@ -35,6 +36,9 @@ import org.apache.pinot.segment.spi.IndexSegment;
 import org.apache.pinot.segment.spi.index.mutable.ThreadSafeMutableRoaringBitmap;
 import org.apache.pinot.spi.config.table.HashFunction;
 import org.apache.pinot.spi.data.readers.GenericRow;
+import org.apache.pinot.spi.data.readers.PrimaryKey;
+import org.roaringbitmap.PeekableIntIterator;
+import org.roaringbitmap.buffer.MutableRoaringBitmap;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -75,6 +79,7 @@ public class PartitionUpsertMetadataManager {
   private final ServerMetrics _serverMetrics;
   private final PartialUpsertHandler _partialUpsertHandler;
   private final HashFunction _hashFunction;
+  private final List<String> _primaryKeyColumns;
 
   // TODO(upsert): consider an off-heap KV store to persist this mapping to improve the recovery speed.
   @VisibleForTesting
@@ -87,12 +92,18 @@ public class PartitionUpsertMetadataManager {
   private int _numOutOfOrderEvents = 0;
 
   public PartitionUpsertMetadataManager(String tableNameWithType, int partitionId, ServerMetrics serverMetrics,
-      @Nullable PartialUpsertHandler partialUpsertHandler, HashFunction hashFunction) {
+      @Nullable PartialUpsertHandler partialUpsertHandler, HashFunction hashFunction,
+      List<String> primaryKeyColumns) {
     _tableNameWithType = tableNameWithType;
     _partitionId = partitionId;
     _serverMetrics = serverMetrics;
     _partialUpsertHandler = partialUpsertHandler;
     _hashFunction = hashFunction;
+    _primaryKeyColumns = primaryKeyColumns;
+  }
+
+  public List<String> getPrimaryKeyColumns() {
+    return _primaryKeyColumns;
   }
 
   /**
@@ -248,14 +259,23 @@ public class PartitionUpsertMetadataManager {
     String segmentName = segment.getSegmentName();
     LOGGER.info("Removing upsert metadata for segment: {}", segmentName);
 
-    if (!Objects.requireNonNull(segment.getValidDocIds()).getMutableRoaringBitmap().isEmpty()) {
-      // Remove all the record locations that point to the removed segment
-      _primaryKeyToRecordLocationMap.forEach((primaryKey, recordLocation) -> {
-        if (recordLocation.getSegment() == segment) {
-          // Check and remove to prevent removing the key that is just updated
-          _primaryKeyToRecordLocationMap.remove(primaryKey, recordLocation);
-        }
-      });
+    MutableRoaringBitmap mutableRoaringBitmap =
+        Objects.requireNonNull(segment.getValidDocIds()).getMutableRoaringBitmap();
+
+    if (!mutableRoaringBitmap.isEmpty()) {
+      PrimaryKey reuse = new PrimaryKey(new Object[_primaryKeyColumns.size()]);
+      PeekableIntIterator iterator = mutableRoaringBitmap.getIntIterator();
+      while (iterator.hasNext()) {
+        int docId = iterator.next();
+        segment.getPrimaryKey(docId, reuse);
+        _primaryKeyToRecordLocationMap.computeIfPresent(HashUtils.hashPrimaryKey(reuse, _hashFunction),
+            (pk, recordLocation) -> {
+              if (recordLocation.getSegment() == segment) {
+                return null;
+              }
+              return recordLocation;
+        });
+      }
     }
     // Update metrics
     _serverMetrics.setValueOfPartitionGauge(_tableNameWithType, _partitionId, ServerGauge.UPSERT_PRIMARY_KEYS_COUNT,
diff --git a/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/upsert/TableUpsertMetadataManager.java b/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/upsert/TableUpsertMetadataManager.java
index 384268a494..6e09192c56 100644
--- a/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/upsert/TableUpsertMetadataManager.java
+++ b/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/upsert/TableUpsertMetadataManager.java
@@ -18,6 +18,7 @@
  */
 package org.apache.pinot.segment.local.upsert;
 
+import java.util.List;
 import java.util.Map;
 import java.util.concurrent.ConcurrentHashMap;
 import javax.annotation.Nullable;
@@ -36,19 +37,22 @@ public class TableUpsertMetadataManager {
   private final ServerMetrics _serverMetrics;
   private final PartialUpsertHandler _partialUpsertHandler;
   private final HashFunction _hashFunction;
+  private final List<String> _primaryKeyColumns;
 
   public TableUpsertMetadataManager(String tableNameWithType, ServerMetrics serverMetrics,
-      @Nullable PartialUpsertHandler partialUpsertHandler, HashFunction hashFunction) {
+      @Nullable PartialUpsertHandler partialUpsertHandler, HashFunction hashFunction,
+      List<String> primaryKeyColumns) {
     _tableNameWithType = tableNameWithType;
     _serverMetrics = serverMetrics;
     _partialUpsertHandler = partialUpsertHandler;
     _hashFunction = hashFunction;
+    _primaryKeyColumns = primaryKeyColumns;
   }
 
   public PartitionUpsertMetadataManager getOrCreatePartitionManager(int partitionId) {
     return _partitionMetadataManagerMap.computeIfAbsent(partitionId,
         k -> new PartitionUpsertMetadataManager(_tableNameWithType, k, _serverMetrics, _partialUpsertHandler,
-            _hashFunction));
+            _hashFunction, _primaryKeyColumns));
   }
 
   public boolean isPartialUpsertEnabled() {
diff --git a/pinot-segment-local/src/test/java/org/apache/pinot/segment/local/indexsegment/mutable/MutableSegmentImplUpsertComparisonColTest.java b/pinot-segment-local/src/test/java/org/apache/pinot/segment/local/indexsegment/mutable/MutableSegmentImplUpsertComparisonColTest.java
index de521a8821..bd02b851be 100644
--- a/pinot-segment-local/src/test/java/org/apache/pinot/segment/local/indexsegment/mutable/MutableSegmentImplUpsertComparisonColTest.java
+++ b/pinot-segment-local/src/test/java/org/apache/pinot/segment/local/indexsegment/mutable/MutableSegmentImplUpsertComparisonColTest.java
@@ -63,7 +63,7 @@ public class MutableSegmentImplUpsertComparisonColTest {
     File jsonFile = new File(dataResourceUrl.getFile());
     _partitionUpsertMetadataManager =
         new TableUpsertMetadataManager("testTable_REALTIME", Mockito.mock(ServerMetrics.class), null,
-            HashFunction.NONE).getOrCreatePartitionManager(0);
+            HashFunction.NONE, _schema.getPrimaryKeyColumns()).getOrCreatePartitionManager(0);
     _mutableSegmentImpl = MutableSegmentImplTestUtils
         .createMutableSegmentImpl(_schema, Collections.emptySet(), Collections.emptySet(), Collections.emptySet(),
             false, true, new UpsertConfig(UpsertConfig.Mode.FULL, null, null, "offset", null), "secondsSinceEpoch",
diff --git a/pinot-segment-local/src/test/java/org/apache/pinot/segment/local/indexsegment/mutable/MutableSegmentImplUpsertTest.java b/pinot-segment-local/src/test/java/org/apache/pinot/segment/local/indexsegment/mutable/MutableSegmentImplUpsertTest.java
index 7f0de8b9fe..07033e20b6 100644
--- a/pinot-segment-local/src/test/java/org/apache/pinot/segment/local/indexsegment/mutable/MutableSegmentImplUpsertTest.java
+++ b/pinot-segment-local/src/test/java/org/apache/pinot/segment/local/indexsegment/mutable/MutableSegmentImplUpsertTest.java
@@ -60,15 +60,17 @@ public class MutableSegmentImplUpsertTest {
     _recordTransformer = CompositeTransformer.getDefaultTransformer(_tableConfig, _schema);
     File jsonFile = new File(dataResourceUrl.getFile());
     _partitionUpsertMetadataManager =
-        new TableUpsertMetadataManager("testTable_REALTIME", Mockito.mock(ServerMetrics.class), null, hashFunction)
+        new TableUpsertMetadataManager("testTable_REALTIME", Mockito.mock(ServerMetrics.class), null, hashFunction,
+            _schema.getPrimaryKeyColumns())
             .getOrCreatePartitionManager(0);
     _mutableSegmentImpl = MutableSegmentImplTestUtils
         .createMutableSegmentImpl(_schema, Collections.emptySet(), Collections.emptySet(), Collections.emptySet(),
             false, true, new UpsertConfig(UpsertConfig.Mode.FULL, null, null, null, hashFunction), "secondsSinceEpoch",
             _partitionUpsertMetadataManager, null);
+
     GenericRow reuse = new GenericRow();
-    try (RecordReader recordReader = RecordReaderFactory
-        .getRecordReader(FileFormat.JSON, jsonFile, _schema.getColumnNames(), null)) {
+    try (RecordReader recordReader = RecordReaderFactory.getRecordReader(FileFormat.JSON, jsonFile,
+        _schema.getColumnNames(), null)) {
       while (recordReader.hasNext()) {
         recordReader.next(reuse);
         GenericRow transformedRow = _recordTransformer.transform(reuse);
diff --git a/pinot-segment-local/src/test/java/org/apache/pinot/segment/local/upsert/PartitionUpsertMetadataManagerTest.java b/pinot-segment-local/src/test/java/org/apache/pinot/segment/local/upsert/PartitionUpsertMetadataManagerTest.java
index bbb15b0132..702ecf1b2b 100644
--- a/pinot-segment-local/src/test/java/org/apache/pinot/segment/local/upsert/PartitionUpsertMetadataManagerTest.java
+++ b/pinot-segment-local/src/test/java/org/apache/pinot/segment/local/upsert/PartitionUpsertMetadataManagerTest.java
@@ -19,6 +19,7 @@
 package org.apache.pinot.segment.local.upsert;
 
 import java.util.ArrayList;
+import java.util.Collections;
 import java.util.List;
 import java.util.Map;
 import org.apache.pinot.common.metrics.ServerMetrics;
@@ -29,13 +30,18 @@ import org.apache.pinot.segment.local.utils.RecordInfo;
 import org.apache.pinot.segment.spi.IndexSegment;
 import org.apache.pinot.segment.spi.index.mutable.ThreadSafeMutableRoaringBitmap;
 import org.apache.pinot.spi.config.table.HashFunction;
+import org.apache.pinot.spi.data.readers.GenericRow;
 import org.apache.pinot.spi.data.readers.PrimaryKey;
 import org.apache.pinot.spi.utils.ByteArray;
 import org.apache.pinot.spi.utils.BytesUtils;
 import org.apache.pinot.spi.utils.builder.TableNameBuilder;
+import org.mockito.ArgumentMatchers;
 import org.testng.Assert;
 import org.testng.annotations.Test;
 
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyInt;
+import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.when;
 import static org.testng.Assert.assertEquals;
@@ -57,19 +63,19 @@ public class PartitionUpsertMetadataManagerTest {
 
   private void verifyAddSegment(HashFunction hashFunction) {
     PartitionUpsertMetadataManager upsertMetadataManager =
-        new PartitionUpsertMetadataManager(REALTIME_TABLE_NAME, 0, mock(ServerMetrics.class), null, hashFunction);
+        new PartitionUpsertMetadataManager(REALTIME_TABLE_NAME, 0, mock(ServerMetrics.class), null, hashFunction,
+            Collections.emptyList());
     Map<Object, RecordLocation> recordLocationMap = upsertMetadataManager._primaryKeyToRecordLocationMap;
 
     // Add the first segment
-    List<RecordInfo> recordInfoList1 = new ArrayList<>();
-    recordInfoList1.add(new RecordInfo(getPrimaryKey(0), 0, new IntWrapper(100)));
-    recordInfoList1.add(new RecordInfo(getPrimaryKey(1), 1, new IntWrapper(100)));
-    recordInfoList1.add(new RecordInfo(getPrimaryKey(2), 2, new IntWrapper(100)));
-    recordInfoList1.add(new RecordInfo(getPrimaryKey(0), 3, new IntWrapper(80)));
-    recordInfoList1.add(new RecordInfo(getPrimaryKey(1), 4, new IntWrapper(120)));
-    recordInfoList1.add(new RecordInfo(getPrimaryKey(0), 5, new IntWrapper(100)));
+
+    int numRecords = 6;
+    int[] primaryKeys = new int[]{0, 1, 2, 0, 1, 0};
+    int[] timestamps = new int[]{100, 100, 100, 80, 120, 100};
+    List<RecordInfo> recordInfoList1 =
+        getRecordInfoList(numRecords, primaryKeys, timestamps);
     ThreadSafeMutableRoaringBitmap validDocIds1 = new ThreadSafeMutableRoaringBitmap();
-    ImmutableSegmentImpl segment1 = mockSegment(1, validDocIds1);
+    ImmutableSegmentImpl segment1 = mockSegment(1, validDocIds1, getPrimaryKeyList(numRecords, primaryKeys));
     upsertMetadataManager.addSegment(segment1, recordInfoList1.iterator());
     // segment1: 0 -> {5, 100}, 1 -> {4, 120}, 2 -> {2, 100}
     checkRecordLocation(recordLocationMap, 0, segment1, 5, 100, hashFunction);
@@ -78,12 +84,11 @@ public class PartitionUpsertMetadataManagerTest {
     assertEquals(validDocIds1.getMutableRoaringBitmap().toArray(), new int[]{2, 4, 5});
 
     // Add the second segment
-    ArrayList<RecordInfo> recordInfoList2 = new ArrayList<>();
-    recordInfoList2.add(new RecordInfo(getPrimaryKey(0), 0, new IntWrapper(100)));
-    recordInfoList2.add(new RecordInfo(getPrimaryKey(1), 1, new IntWrapper(100)));
-    recordInfoList2.add(new RecordInfo(getPrimaryKey(2), 2, new IntWrapper(120)));
-    recordInfoList2.add(new RecordInfo(getPrimaryKey(3), 3, new IntWrapper(80)));
-    recordInfoList2.add(new RecordInfo(getPrimaryKey(0), 4, new IntWrapper(80)));
+    numRecords = 5;
+    primaryKeys = new int[]{0, 1, 2, 3, 0};
+    timestamps = new int[]{100, 100, 120, 80, 80};
+    List<RecordInfo> recordInfoList2 =
+        getRecordInfoList(numRecords, primaryKeys, timestamps);
     ThreadSafeMutableRoaringBitmap validDocIds2 = new ThreadSafeMutableRoaringBitmap();
     ImmutableSegmentImpl segment2 = mockSegment(2, validDocIds2);
     upsertMetadataManager.addSegment(segment2, recordInfoList2.iterator());
@@ -110,7 +115,7 @@ public class PartitionUpsertMetadataManagerTest {
     assertEquals(validDocIds1.getMutableRoaringBitmap().toArray(), new int[]{4});
     assertEquals(validDocIds2.getMutableRoaringBitmap().toArray(), new int[]{0, 2, 3});
     assertEquals(newValidDocIds1.getMutableRoaringBitmap().toArray(), new int[]{4});
-    assertSame(recordLocationMap.get(HashUtils.hashPrimaryKey(getPrimaryKey(1), hashFunction))
+    assertSame(recordLocationMap.get(HashUtils.hashPrimaryKey(makePrimaryKey(1), hashFunction))
         .getSegment(), newSegment1);
 
     // Remove the original segment1
@@ -123,15 +128,54 @@ public class PartitionUpsertMetadataManagerTest {
     checkRecordLocation(recordLocationMap, 3, segment2, 3, 80, hashFunction);
     assertEquals(validDocIds2.getMutableRoaringBitmap().toArray(), new int[]{0, 2, 3});
     assertEquals(newValidDocIds1.getMutableRoaringBitmap().toArray(), new int[]{4});
-    assertSame(recordLocationMap.get(HashUtils.hashPrimaryKey(getPrimaryKey(1), hashFunction))
+    assertSame(recordLocationMap.get(HashUtils.hashPrimaryKey(makePrimaryKey(1), hashFunction))
         .getSegment(), newSegment1);
   }
 
+  private List<RecordInfo> getRecordInfoList(int numRecords, int[] primaryKeys,
+      int[] timestamps) {
+    List<RecordInfo> recordInfoList = new ArrayList<>();
+    for (int i = 0; i < numRecords; i++) {
+      recordInfoList.add(new RecordInfo(makePrimaryKey(primaryKeys[i]), i,
+          new IntWrapper(timestamps[i])));
+    }
+    return recordInfoList;
+  }
+
+  private List<PrimaryKey> getPrimaryKeyList(int numRecords, int[] primaryKeys) {
+    List<PrimaryKey> primaryKeyList = new ArrayList<>();
+    for (int i = 0; i < numRecords; i++) {
+      primaryKeyList.add(makePrimaryKey(primaryKeys[i]));
+    }
+    return primaryKeyList;
+  }
+
   private static ImmutableSegmentImpl mockSegment(int sequenceNumber, ThreadSafeMutableRoaringBitmap validDocIds) {
     ImmutableSegmentImpl segment = mock(ImmutableSegmentImpl.class);
     String segmentName = getSegmentName(sequenceNumber);
     when(segment.getSegmentName()).thenReturn(segmentName);
     when(segment.getValidDocIds()).thenReturn(validDocIds);
+    when(segment.getRecord(anyInt(), ArgumentMatchers.any(GenericRow.class))).thenReturn(new GenericRow());
+    return segment;
+  }
+
+  private static ImmutableSegmentImpl mockSegment(int sequenceNumber, ThreadSafeMutableRoaringBitmap validDocIds,
+      List<PrimaryKey> primaryKeys) {
+    ImmutableSegmentImpl segment = mock(ImmutableSegmentImpl.class);
+
+    String segmentName = getSegmentName(sequenceNumber);
+    when(segment.getSegmentName()).thenReturn(segmentName);
+    when(segment.getValidDocIds()).thenReturn(validDocIds);
+    doAnswer((invocation) -> {
+      PrimaryKey pk = primaryKeys.get(invocation.getArgument(0));
+      PrimaryKey reuse = invocation.getArgument(1, PrimaryKey.class);
+      Object[] reuseValues = reuse.getValues();
+      for (int i = 0; i < reuseValues.length; i++) {
+        reuseValues[i] = pk.getValues()[i];
+      }
+        return null;
+      }).when(segment).getPrimaryKey(anyInt(), any(PrimaryKey.class));
+
     return segment;
   }
 
@@ -139,14 +183,14 @@ public class PartitionUpsertMetadataManagerTest {
     return new LLCSegmentName(RAW_TABLE_NAME, 0, sequenceNumber, System.currentTimeMillis()).toString();
   }
 
-  private static PrimaryKey getPrimaryKey(int value) {
+  private static PrimaryKey makePrimaryKey(int value) {
     return new PrimaryKey(new Object[]{value});
   }
 
   private static void checkRecordLocation(Map<Object, RecordLocation> recordLocationMap, int keyValue,
       IndexSegment segment, int docId, int comparisonValue, HashFunction hashFunction) {
     RecordLocation recordLocation =
-        recordLocationMap.get(HashUtils.hashPrimaryKey(getPrimaryKey(keyValue), hashFunction));
+        recordLocationMap.get(HashUtils.hashPrimaryKey(makePrimaryKey(keyValue), hashFunction));
     assertNotNull(recordLocation);
     assertSame(recordLocation.getSegment(), segment);
     assertEquals(recordLocation.getDocId(), docId);
@@ -162,15 +206,17 @@ public class PartitionUpsertMetadataManagerTest {
 
   private void verifyAddRecord(HashFunction hashFunction) {
     PartitionUpsertMetadataManager upsertMetadataManager =
-        new PartitionUpsertMetadataManager(REALTIME_TABLE_NAME, 0, mock(ServerMetrics.class), null, hashFunction);
+        new PartitionUpsertMetadataManager(REALTIME_TABLE_NAME, 0, mock(ServerMetrics.class), null, hashFunction,
+            Collections.emptyList());
     Map<Object, RecordLocation> recordLocationMap = upsertMetadataManager._primaryKeyToRecordLocationMap;
 
     // Add the first segment
     // segment1: 0 -> {0, 100}, 1 -> {1, 120}, 2 -> {2, 100}
-    List<RecordInfo> recordInfoList1 = new ArrayList<>();
-    recordInfoList1.add(new RecordInfo(getPrimaryKey(0), 0, new IntWrapper(100)));
-    recordInfoList1.add(new RecordInfo(getPrimaryKey(1), 1, new IntWrapper(120)));
-    recordInfoList1.add(new RecordInfo(getPrimaryKey(2), 2, new IntWrapper(100)));
+    int numRecords = 3;
+    int[] primaryKeys = new int[]{0, 1, 2};
+    int[] timestamps = new int[]{100, 120, 100};
+    List<RecordInfo> recordInfoList1 =
+        getRecordInfoList(numRecords, primaryKeys, timestamps);
     ThreadSafeMutableRoaringBitmap validDocIds1 = new ThreadSafeMutableRoaringBitmap();
     ImmutableSegmentImpl segment1 = mockSegment(1, validDocIds1);
     upsertMetadataManager.addSegment(segment1, recordInfoList1.iterator());
@@ -180,7 +226,8 @@ public class PartitionUpsertMetadataManagerTest {
     IndexSegment segment2 = mockSegment(1, validDocIds2);
 
     upsertMetadataManager.addRecord(segment2,
-        new RecordInfo(getPrimaryKey(3), 0, new IntWrapper(100)));
+        new RecordInfo(makePrimaryKey(3), 0, new IntWrapper(100)));
+
     // segment1: 0 -> {0, 100}, 1 -> {1, 120}, 2 -> {2, 100}
     // segment2: 3 -> {0, 100}
     checkRecordLocation(recordLocationMap, 0, segment1, 0, 100, hashFunction);
@@ -191,7 +238,7 @@ public class PartitionUpsertMetadataManagerTest {
     assertEquals(validDocIds2.getMutableRoaringBitmap().toArray(), new int[]{0});
 
     upsertMetadataManager.addRecord(segment2,
-        new RecordInfo(getPrimaryKey(2), 1, new IntWrapper(120)));
+        new RecordInfo(makePrimaryKey(2), 1, new IntWrapper(120)));
     // segment1: 0 -> {0, 100}, 1 -> {1, 120}
     // segment2: 2 -> {1, 120}, 3 -> {0, 100}
     checkRecordLocation(recordLocationMap, 0, segment1, 0, 100, hashFunction);
@@ -202,7 +249,7 @@ public class PartitionUpsertMetadataManagerTest {
     assertEquals(validDocIds2.getMutableRoaringBitmap().toArray(), new int[]{0, 1});
 
     upsertMetadataManager.addRecord(segment2,
-        new RecordInfo(getPrimaryKey(1), 2, new IntWrapper(100)));
+        new RecordInfo(makePrimaryKey(1), 2, new IntWrapper(100)));
     // segment1: 0 -> {0, 100}, 1 -> {1, 120}
     // segment2: 2 -> {1, 120}, 3 -> {0, 100}
     checkRecordLocation(recordLocationMap, 0, segment1, 0, 100, hashFunction);
@@ -213,7 +260,7 @@ public class PartitionUpsertMetadataManagerTest {
     assertEquals(validDocIds2.getMutableRoaringBitmap().toArray(), new int[]{0, 1});
 
     upsertMetadataManager.addRecord(segment2,
-        new RecordInfo(getPrimaryKey(0), 3, new IntWrapper(100)));
+        new RecordInfo(makePrimaryKey(0), 3, new IntWrapper(100)));
     // segment1: 1 -> {1, 120}
     // segment2: 0 -> {3, 100}, 2 -> {1, 120}, 3 -> {0, 100}
     checkRecordLocation(recordLocationMap, 0, segment2, 3, 100, hashFunction);
@@ -233,30 +280,34 @@ public class PartitionUpsertMetadataManagerTest {
 
   private void verifyRemoveSegment(HashFunction hashFunction) {
     PartitionUpsertMetadataManager upsertMetadataManager =
-        new PartitionUpsertMetadataManager(REALTIME_TABLE_NAME, 0, mock(ServerMetrics.class), null, hashFunction);
+        new PartitionUpsertMetadataManager(REALTIME_TABLE_NAME, 0, mock(ServerMetrics.class), null, hashFunction,
+            Collections.singletonList("primaryKey"));
     Map<Object, RecordLocation> recordLocationMap = upsertMetadataManager._primaryKeyToRecordLocationMap;
 
     // Add 2 segments
     // segment1: 0 -> {0, 100}, 1 -> {1, 100}
     // segment2: 2 -> {0, 100}, 3 -> {0, 100}
-    List<RecordInfo> recordInfoList1 = new ArrayList<>();
-    recordInfoList1.add(new RecordInfo(getPrimaryKey(0), 0, new IntWrapper(100)));
-    recordInfoList1.add(new RecordInfo(getPrimaryKey(1), 1, new IntWrapper(100)));
+    int numRecords = 2;
+    int[] primaryKeys = new int[]{0, 1};
+    int[] timestamps = new int[]{100, 100};
+    List<RecordInfo> recordInfoList1 =
+        getRecordInfoList(numRecords, primaryKeys, timestamps);
     ThreadSafeMutableRoaringBitmap validDocIds1 = new ThreadSafeMutableRoaringBitmap();
-    ImmutableSegmentImpl segment1 = mockSegment(1, validDocIds1);
+    ImmutableSegmentImpl segment1 = mockSegment(1, validDocIds1, getPrimaryKeyList(numRecords, primaryKeys));
     upsertMetadataManager.addSegment(segment1, recordInfoList1.iterator());
-    List<RecordInfo> recordInfoList2 = new ArrayList<>();
-    recordInfoList2.add(new RecordInfo(getPrimaryKey(2), 0, new IntWrapper(100)));
-    recordInfoList2.add(new RecordInfo(getPrimaryKey(3), 1, new IntWrapper(100)));
+
+    primaryKeys = new int[]{2, 3};
+    List<RecordInfo> recordInfoList2 =
+        getRecordInfoList(numRecords, primaryKeys, timestamps);
     ThreadSafeMutableRoaringBitmap validDocIds2 = new ThreadSafeMutableRoaringBitmap();
-    ImmutableSegmentImpl segment2 = mockSegment(2, validDocIds2);
+    ImmutableSegmentImpl segment2 = mockSegment(2, validDocIds2, getPrimaryKeyList(numRecords, primaryKeys));
     upsertMetadataManager.addSegment(segment2, recordInfoList2.iterator());
 
     // Remove the first segment
     upsertMetadataManager.removeSegment(segment1);
     // segment2: 2 -> {0, 100}, 3 -> {0, 100}
-    assertNull(recordLocationMap.get(getPrimaryKey(0)));
-    assertNull(recordLocationMap.get(getPrimaryKey(1)));
+    assertNull(recordLocationMap.get(makePrimaryKey(0)));
+    assertNull(recordLocationMap.get(makePrimaryKey(1)));
     checkRecordLocation(recordLocationMap, 2, segment2, 0, 100, hashFunction);
     checkRecordLocation(recordLocationMap, 3, segment2, 1, 100, hashFunction);
     assertEquals(validDocIds2.getMutableRoaringBitmap().toArray(), new int[]{0, 1});
@@ -266,7 +317,8 @@ public class PartitionUpsertMetadataManagerTest {
   public void testHashPrimaryKey() {
     PrimaryKey pk = new PrimaryKey(new Object[]{"uuid-1", "uuid-2", "uuid-3"});
     Assert.assertEquals(BytesUtils.toHexString(
-        ((ByteArray) HashUtils.hashPrimaryKey(pk, HashFunction.MD5)).getBytes()),
+            ((ByteArray) HashUtils.hashPrimaryKey(pk, HashFunction.MD5)).
+                getBytes()),
         "58de44997505014e02982846a4d1cbbd");
     Assert.assertEquals(BytesUtils.toHexString(
         ((ByteArray) HashUtils.hashPrimaryKey(pk, HashFunction.MURMUR3)).getBytes()),
diff --git a/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/IndexSegment.java b/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/IndexSegment.java
index a28f6a5dcc..9984e89405 100644
--- a/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/IndexSegment.java
+++ b/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/IndexSegment.java
@@ -26,6 +26,7 @@ import org.apache.pinot.segment.spi.index.mutable.ThreadSafeMutableRoaringBitmap
 import org.apache.pinot.segment.spi.index.startree.StarTreeV2;
 import org.apache.pinot.spi.annotations.InterfaceAudience;
 import org.apache.pinot.spi.data.readers.GenericRow;
+import org.apache.pinot.spi.data.readers.PrimaryKey;
 
 
 @InterfaceAudience.Private
@@ -86,6 +87,15 @@ public interface IndexSegment {
    */
   GenericRow getRecord(int docId, GenericRow reuse);
 
+  /**
+   * Returns the primaryKey for a given docId
+   *
+   * @param docId Document Id
+   * @param reuse Reusable buffer for the primary key
+   * @return Primary key for the given document Id
+   */
+  void getPrimaryKey(int docId, PrimaryKey reuse);
+
   /**
    * Hints the segment to begin prefetching buffers for specified columns.
    * Typically, this should be an async call made before operating on the segment.


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@pinot.apache.org
For additional commands, e-mail: commits-help@pinot.apache.org