You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@druid.apache.org by gi...@apache.org on 2023/03/08 22:19:53 UTC

[druid] branch master updated: Sort-merge join and hash shuffles for MSQ. (#13506)

This is an automated email from the ASF dual-hosted git repository.

gian pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/druid.git


The following commit(s) were added to refs/heads/master by this push:
     new 82f7a56475 Sort-merge join and hash shuffles for MSQ. (#13506)
82f7a56475 is described below

commit 82f7a5647571ab38659739e0fce806e3f0cfb688
Author: Gian Merlino <gi...@gmail.com>
AuthorDate: Wed Mar 8 14:19:39 2023 -0800

    Sort-merge join and hash shuffles for MSQ. (#13506)
    
    * Sort-merge join and hash shuffles for MSQ.
    
    The main changes are in the processing, multi-stage-query, and sql modules.
    
    processing module:
    
    1) Rename SortColumn to KeyColumn, replace boolean descending with KeyOrder.
       This makes it nicer to model hash keys, which use KeyOrder.NONE.
    
    2) Add nullability checkers to the FieldReader interface, and an
       "isPartiallyNullKey" method to FrameComparisonWidget. The join
       processor uses this to detect null keys.
    
    3) Add WritableFrameChannel.isClosed and OutputChannel.isReadableChannelReady
       so callers can tell which OutputChannels are ready for reading and which
       aren't.
    
    4) Specialize FrameProcessors.makeCursor to return FrameCursor, a random-access
       implementation. The join processor uses this to rewind when it needs to
       replay a set of rows with a particular key.
    
    5) Add MemoryAllocatorFactory, which is embedded inside FrameWriterFactory
       instead of a particular MemoryAllocator. This allows FrameWriterFactory
       to be shared in more scenarios.
    
    multi-stage-query module:
    
    1) ShuffleSpec: Add hash-based shuffles. New enum ShuffleKind helps callers
       figure out what kind of shuffle is happening. The change from SortColumn
       to KeyColumn allows ClusterBy to be used for both hash-based and sort-based
       shuffling.
    
    2) WorkerImpl: Add ability to handle hash-based shuffles. Refactor the logic
       to be more readable by moving the work-order-running code to the inner
       class RunWorkOrder, and the shuffle-pipeline-building code to the inner
       class ShufflePipelineBuilder.
    
    3) Add SortMergeJoinFrameProcessor and factory.
    
    4) WorkerMemoryParameters: Adjust logic to reserve space for output frames
       for hash partitioning. (We need one frame per partition.)
    
    sql module:
    
    1) Add sqlJoinAlgorithm context parameter; can be "broadcast" or
       "sortMerge". With native, it must always be "broadcast", or it's a
       validation error. MSQ supports both. Default is "broadcast" in
       both engines.
    
    2) Validate that MSQs do not use broadcast join with RIGHT or FULL join,
       as results are not correct for broadcast join with those types. Allow
       this in native for two reasons: legacy (the docs caution against it,
       but it's always been allowed), and the fact that it actually *does*
       generate correct results in native when the join is processed on the
       Broker. It is much less likely that MSQ will plan in such a way that
       generates correct results.
    
    3) Remove subquery penalty in DruidJoinQueryRel when using sort-merge
       join, because subqueries are always required, so there's no reason
       to penalize them.
    
    4) Move previously-disabled join reordering and manipulation rules to
       FANCY_JOIN_RULES, and enable them when using sort-merge join. Helps
       get to better plans where projections and filters are pushed down.
    
    * Work around compiler problem.
    
    * Updates from static analysis.
    
    * Fix @param tag.
    
    * Fix declared exception.
    
    * Fix spelling.
    
    * Minor adjustments.
    
    * wip
    
    * Merge fixups
    
    * fixes
    
    * Fix CalciteSelectQueryMSQTest
    
    * Empty keys are sortable.
    
    * Address comments from code review. Rename mux -> mix.
    
    * Restore inspection config.
    
    * Restore original doc.
    
    * Reorder imports.
    
    * Adjustments
    
    * Fix.
    
    * Fix imports.
    
    * Adjustments from review.
    
    * Update header.
    
    * Adjust docs.
---
 docs/multi-stage-query/reference.md                |   89 ++
 docs/querying/datasource.md                        |   16 +-
 docs/querying/joins.md                             |    4 +-
 .../org/apache/druid/msq/exec/ControllerImpl.java  |   23 +-
 .../java/org/apache/druid/msq/exec/Limits.java     |    6 +
 .../org/apache/druid/msq/exec/WorkerContext.java   |    6 +
 .../java/org/apache/druid/msq/exec/WorkerImpl.java | 1322 +++++++++++++-------
 .../druid/msq/exec/WorkerMemoryParameters.java     |   78 +-
 .../apache/druid/msq/exec/WorkerSketchFetcher.java |    1 -
 .../apache/druid/msq/guice/MSQIndexingModule.java  |    4 +
 .../msq/indexing/CountingWritableFrameChannel.java |    6 +
 .../druid/msq/indexing/InputChannelsImpl.java      |   13 +-
 .../error/BroadcastTablesTooLargeFault.java        |   13 +-
 ...Fault.java => TooManyRowsWithSameKeyFault.java} |   53 +-
 .../org/apache/druid/msq/input/InputSlices.java    |   82 ++
 .../org/apache/druid/msq/input/ReadableInput.java  |   81 +-
 .../org/apache/druid/msq/input/ReadableInputs.java |    2 +
 .../msq/input/table/SegmentWithDescriptor.java     |   26 +-
 ...pec.java => GlobalSortMaxCountShuffleSpec.java} |   68 +-
 ...ShuffleSpec.java => GlobalSortShuffleSpec.java} |   37 +-
 ...c.java => GlobalSortTargetSizeShuffleSpec.java} |   36 +-
 .../apache/druid/msq/kernel/HashShuffleSpec.java   |   74 ++
 .../apache/druid/msq/kernel/MixShuffleSpec.java    |   85 ++
 .../druid/msq/kernel/QueryDefinitionBuilder.java   |   12 +
 .../org/apache/druid/msq/kernel/ShuffleKind.java   |   87 ++
 .../org/apache/druid/msq/kernel/ShuffleSpec.java   |   48 +-
 .../apache/druid/msq/kernel/StageDefinition.java   |  101 +-
 .../druid/msq/kernel/StageDefinitionBuilder.java   |    5 +
 .../kernel/controller/ControllerStageTracker.java  |   54 +-
 .../druid/msq/kernel/worker/WorkerStageKernel.java |    4 +-
 .../querykit/BaseLeafFrameProcessorFactory.java    |   18 +-
 .../apache/druid/msq/querykit/DataSourcePlan.java  |  175 ++-
 .../apache/druid/msq/querykit/QueryKitUtils.java   |   13 +-
 .../druid/msq/querykit/ShuffleSpecFactories.java   |   19 +-
 .../druid/msq/querykit/ShuffleSpecFactory.java     |    2 +-
 .../querykit/common/OffsetLimitFrameProcessor.java |   26 +-
 .../common/OffsetLimitFrameProcessorFactory.java   |    3 +
 .../common/SortMergeJoinFrameProcessor.java        | 1075 ++++++++++++++++
 .../common/SortMergeJoinFrameProcessorFactory.java |  277 ++++
 .../groupby/GroupByPostShuffleFrameProcessor.java  |   22 +-
 .../GroupByPostShuffleFrameProcessorFactory.java   |   10 +-
 .../GroupByPreShuffleFrameProcessorFactory.java    |   20 +-
 .../msq/querykit/groupby/GroupByQueryKit.java      |   20 +-
 .../scan/ScanQueryFrameProcessorFactory.java       |   20 +-
 .../druid/msq/querykit/scan/ScanQueryKit.java      |   22 +-
 .../shuffle/DurableStorageInputChannelFactory.java |    1 -
 .../DurableStorageOutputChannelFactory.java        |    5 +-
 .../java/org/apache/druid/msq/sql/MSQMode.java     |    1 -
 .../apache/druid/msq/sql/MSQTaskQueryMaker.java    |    1 -
 .../org/apache/druid/msq/sql/MSQTaskSqlEngine.java |    1 +
 .../org/apache/druid/msq/exec/MSQSelectTest.java   |  196 +--
 .../druid/msq/exec/WorkerMemoryParametersTest.java |   94 +-
 .../msq/indexing/error/MSQFaultSerdeTest.java      |    2 +
 .../msq/indexing/report/MSQTaskReportTest.java     |    9 +-
 .../druid/msq/kernel/QueryDefinitionTest.java      |    7 +-
 .../druid/msq/kernel/StageDefinitionTest.java      |   24 +-
 .../controller/MockQueryDefinitionBuilder.java     |    9 +-
 .../common/SortMergeJoinFrameProcessorTest.java    | 1080 ++++++++++++++++
 .../querykit/scan/ScanQueryFrameProcessorTest.java |   77 +-
 .../DurableStorageOutputChannelFactoryTest.java    |    3 +-
 .../ClusterByStatisticsCollectorImplTest.java      |   13 +-
 .../statistics/DelegateOrMinKeyCollectorTest.java  |    5 +-
 .../msq/statistics/DistinctKeyCollectorTest.java   |    5 +-
 .../msq/statistics/KeyCollectorTestUtils.java      |    7 +-
 .../QuantilesSketchKeyCollectorTest.java           |   26 +-
 .../druid/msq/test/CalciteSelectQueryTestMSQ.java  |    5 +-
 .../druid/msq/test/LimitedFrameWriterFactory.java  |  114 ++
 .../org/apache/druid/msq/test/MSQTestBase.java     |    1 +
 .../ArenaMemoryAllocatorFactory.java}              |   30 +-
 .../MemoryAllocatorFactory.java}                   |   18 +-
 .../apache/druid/frame/allocation/MemoryRange.java |    8 +-
 .../allocation/SingleMemoryAllocatorFactory.java   |   59 +
 .../frame/channel/BlockingQueueFrameChannel.java   |   15 +-
 .../channel/ComposingWritableFrameChannel.java     |    7 +
 .../druid/frame/channel/WritableFrameChannel.java  |    5 +
 .../frame/channel/WritableFrameFileChannel.java    |    8 +
 .../druid/frame/field/ComplexFieldReader.java      |    6 +
 .../druid/frame/field/DoubleFieldReader.java       |    6 +
 .../org/apache/druid/frame/field/FieldReader.java  |    5 +
 .../apache/druid/frame/field/FloatFieldReader.java |    6 +
 .../apache/druid/frame/field/LongFieldReader.java  |    6 +
 .../druid/frame/field/StringFieldReader.java       |   10 +
 .../druid/frame/key/ByteRowKeyComparator.java      |   19 +-
 .../java/org/apache/druid/frame/key/ClusterBy.java |   44 +-
 .../druid/frame/key/FrameComparisonWidget.java     |    7 +-
 .../druid/frame/key/FrameComparisonWidgetImpl.java |   75 +-
 .../frame/key/{SortColumn.java => KeyColumn.java}  |   26 +-
 .../java/org/apache/druid/frame/key/KeyOrder.java  |   61 +
 .../apache/druid/frame/key/RowKeyComparator.java   |    2 +-
 .../BlockingQueueOutputChannelFactory.java         |    4 +-
 .../processor/FrameChannelHashPartitioner.java     |  348 ++++++
 .../druid/frame/processor/FrameChannelMerger.java  |   41 +-
 ...ameChannelMuxer.java => FrameChannelMixer.java} |   53 +-
 .../druid/frame/processor/FrameProcessor.java      |    2 +-
 .../druid/frame/processor/FrameProcessors.java     |   62 +-
 .../processor/MultiColumnSelectorFactory.java      |   30 +-
 .../druid/frame/processor/OutputChannel.java       |   79 +-
 .../druid/frame/processor/OutputChannels.java      |   20 +
 .../druid/frame/processor/ReturnOrAwait.java       |    9 +-
 .../apache/druid/frame/processor/SuperSorter.java  |   23 +-
 .../org/apache/druid/frame/read/FrameReader.java   |   19 +-
 .../apache/druid/frame/segment/FrameCursor.java    |   28 +-
 .../druid/frame/segment/FrameFilteredOffset.java   |   21 +-
 .../frame/segment/columnar/FrameCursorFactory.java |    7 +-
 .../frame/segment/row/FrameCursorFactory.java      |    6 +-
 .../org/apache/druid/frame/write/FrameSort.java    |   12 +-
 .../druid/frame/write/FrameWriterFactory.java      |    6 +
 .../apache/druid/frame/write/FrameWriterUtils.java |   28 +-
 .../org/apache/druid/frame/write/FrameWriters.java |   27 +-
 .../druid/frame/write/RowBasedFrameWriter.java     |    6 +-
 .../frame/write/RowBasedFrameWriterFactory.java    |   30 +-
 .../frame/write/columnar/ColumnarFrameWriter.java  |   10 +-
 .../write/columnar/ColumnarFrameWriterFactory.java |   43 +-
 .../apache/druid/segment/join/JoinPrefixUtils.java |   15 +
 .../apache/druid/segment/join/JoinableClause.java  |    7 +-
 .../java/org/apache/druid/frame/FrameTest.java     |    9 +-
 .../druid/frame/field/ComplexFieldReaderTest.java  |   14 +
 .../druid/frame/field/DoubleFieldReaderTest.java   |   14 +
 .../druid/frame/field/FloatFieldReaderTest.java    |   14 +
 .../druid/frame/field/LongFieldReaderTest.java     |   14 +
 .../druid/frame/field/StringFieldReaderTest.java   |   33 +
 .../druid/frame/key/ByteRowKeyComparatorTest.java  |   82 +-
 .../org/apache/druid/frame/key/ClusterByTest.java  |   53 +-
 .../frame/key/FrameComparisonWidgetImplTest.java   |   99 +-
 .../{SortColumnTest.java => KeyColumnTest.java}    |    4 +-
 .../org/apache/druid/frame/key/KeyTestUtils.java   |   11 +-
 .../druid/frame/key/RowKeyComparatorTest.java      |   82 +-
 .../ComposingOutputChannelFactoryTest.java         |   26 +-
 .../processor/FrameProcessorExecutorTest.java      |    2 +-
 .../druid/frame/processor/OutputChannelTest.java   |   53 +-
 .../druid/frame/processor/OutputChannelsTest.java  |    7 +-
 .../druid/frame/processor/SuperSorterTest.java     |   81 +-
 .../druid/frame/testutil/FrameSequenceBuilder.java |   18 +-
 .../apache/druid/frame/write/FrameWriterTest.java  |   60 +-
 .../druid/frame/write/FrameWriterTestData.java     |   12 +-
 .../apache/druid/frame/write/FrameWritersTest.java |   31 +-
 .../apache/druid/segment/join/JoinTestHelper.java  |   30 +-
 .../src/test/resources/wikipedia/regions.json      |    1 +
 .../sql/calcite/planner/CalciteRulesManager.java   |   41 +-
 .../druid/sql/calcite/planner/IngestHandler.java   |   12 -
 .../druid/sql/calcite/planner/JoinAlgorithm.java   |   97 ++
 .../druid/sql/calcite/planner/PlannerContext.java  |   45 +-
 .../druid/sql/calcite/planner/QueryHandler.java    |   11 +-
 .../sql/calcite/planner/QueryValidations.java      |   89 ++
 .../sql/calcite/planner/SqlStatementHandler.java   |    2 -
 .../sql/calcite/rel/DruidCorrelateUnnestRel.java   |    4 +-
 .../druid/sql/calcite/rel/DruidJoinQueryRel.java   |   44 +-
 .../druid/sql/calcite/rule/DruidJoinRule.java      |   31 +-
 .../druid/sql/calcite/run/EngineFeature.java       |   11 +-
 .../druid/sql/calcite/run/NativeSqlEngine.java     |   23 +
 .../apache/druid/sql/calcite/CalciteQueryTest.java |    1 -
 .../druid/sql/calcite/IngestionTestSqlEngine.java  |    1 +
 .../druid/sql/calcite/rule/DruidJoinRuleTest.java  |    2 +
 153 files changed, 6679 insertions(+), 1688 deletions(-)

diff --git a/docs/multi-stage-query/reference.md b/docs/multi-stage-query/reference.md
index 5c05b68bf3..25eb827efc 100644
--- a/docs/multi-stage-query/reference.md
+++ b/docs/multi-stage-query/reference.md
@@ -592,6 +592,7 @@ The following table lists the context parameters for the MSQ task engine:
 | `maxNumTasks` | SELECT, INSERT, REPLACE<br /><br />The maximum total number of tasks to launch, including the controller task. The lowest possible value for this setting is 2: one controller and one worker. All tasks must be able to launch simultaneously. If they cannot, the query returns a `TaskStartTimeout` error code after approximately 10 minutes.<br /><br />May also be provided as `numTasks`. If both are present, `maxNumTasks` takes priority.| 2 |
 | `taskAssignment` | SELECT, INSERT, REPLACE<br /><br />Determines how many tasks to use. Possible values include: <ul><li>`max`: Uses as many tasks as possible, up to `maxNumTasks`.</li><li>`auto`: When file sizes can be determined through directory listing (for example: local files, S3, GCS, HDFS) uses as few tasks as possible without exceeding 10 GiB or 10,000 files per task, unless exceeding these limits is necessary to stay within `maxNumTasks`. When file sizes cannot be determined  [...]
 | `finalizeAggregations` | SELECT, INSERT, REPLACE<br /><br />Determines the type of aggregation to return. If true, Druid finalizes the results of complex aggregations that directly appear in query results. If false, Druid returns the aggregation's intermediate type rather than finalized type. This parameter is useful during ingestion, where it enables storing sketches directly in Druid tables. For more information about aggregations, see [SQL aggregation functions](../querying/sql-aggr [...]
+| `sqlJoinAlgorithm` | SELECT, INSERT, REPLACE<br /><br />Algorithm to use for JOIN. Use `broadcast` (the default) for broadcast hash join or `sortMerge` for sort-merge join. Affects all JOIN operations in the query. See [Joins](#joins) for more details. | `broadcast` |
 | `rowsInMemory` | INSERT or REPLACE<br /><br />Maximum number of rows to store in memory at once before flushing to disk during the segment generation process. Ignored for non-INSERT queries. In most cases, use the default value. You may need to override the default if you run into one of the [known issues](./known-issues.md) around memory usage. | 100,000 |
 | `segmentSortOrder` | INSERT or REPLACE<br /><br />Normally, Druid sorts rows in individual segments using `__time` first, followed by the [CLUSTERED BY](#clustered-by) clause. When you set `segmentSortOrder`, Druid sorts rows in segments using this column list first, followed by the CLUSTERED BY order.<br /><br />You provide the column list as comma-separated values or as a JSON array in string form. If your query includes `__time`, then this list must begin with `__time`. For example, [...]
 | `maxParseExceptions`| SELECT, INSERT, REPLACE<br /><br />Maximum number of parse exceptions that are ignored while executing the query before it stops with `TooManyWarningsFault`. To ignore all the parse exceptions, set the value to -1.| 0 |
@@ -604,6 +605,92 @@ The following table lists the context parameters for the MSQ task engine:
 | `intermediateSuperSorterStorageMaxLocalBytes` | SELECT, INSERT, REPLACE<br /><br /> Whether to enable a byte limit on local storage for sorting's intermediate data. If that limit is crossed, the task fails with `ResourceLimitExceededException`.| `9223372036854775807` |
 | `maxInputBytesPerWorker` | Should be used in conjunction with taskAssignment `auto` mode. When dividing the input of a stage among the workers, this parameter determines the maximum size in bytes that are given to a single worker before the next worker is chosen. This parameter is only used as a guideline during input slicing, and does not guarantee that a the input cannot be larger. For example, we have 3 files. 3, 7, 12 GB each. then we would end up using 2 worker: worker 1 -> 3, 7 a [...]
 
+## Joins
+
+Joins in multi-stage queries use one of two algorithms, based on the [context parameter](#context-parameters)
+`sqlJoinAlgorithm`. This context parameter applies to the entire SQL statement, so it is not possible to mix different
+join algorithms in the same query.
+
+### Broadcast
+
+Set `sqlJoinAlgorithm` to `broadcast`.
+
+The default join algorithm for multi-stage queries is a broadcast hash join, which is similar to how
+[joins are executed with native queries](../querying/query-execution.md#join). First, any adjacent joins are flattened
+into a structure with a "base" input (the bottom-leftmost one) and other leaf inputs (the rest). Next, any subqueries
+that are inputs the join (either base or other leafs) are planned into independent stages. Then, the non-base leaf
+inputs are all connected as broadcast inputs to the "base" stage.
+
+Together, all of these non-base leaf inputs must not exceed the [limit on broadcast table footprint](#limits). There
+is no limit on the size of the base (leftmost) input.
+
+Only LEFT JOIN, INNER JOIN, and CROSS JOIN are supported with with `broadcast`.
+
+Join conditions, if present, must be equalities. It is not necessary to include a join condition; for example,
+`CROSS JOIN` and comma join do not require join conditions.
+
+As an example, the following statement has a single join chain where `orders` is the base input, and `products` and
+`customers` are non-base leaf inputs. The query will first read `products` and `customers`, then broadcast both to
+the stage that reads `orders`. That stage loads the broadcast inputs (`products` and `customers`) in memory, and walks
+through `orders` row by row. The results are then aggregated and written to the table `orders_enriched`. The broadcast
+inputs (`products` and `customers`) must fall under the limit on broadcast table footprint, but the base `orders` input
+can be unlimited in size.
+
+```
+REPLACE INTO orders_enriched
+OVERWRITE ALL
+SELECT
+  orders.__time,
+  products.name AS product_name,
+  customers.name AS customer_name,
+  SUM(orders.amount) AS amount
+FROM orders
+LEFT JOIN products ON orders.product_id = products.id
+LEFT JOIN customers ON orders.customer_id = customers.id
+GROUP BY 1, 2
+PARTITIONED BY HOUR
+CLUSTERED BY product_name
+```
+
+### Sort-merge
+
+Set `sqlJoinAlgorithm` to `sortMerge`.
+
+Multi-stage queries can use a sort-merge join algorithm. With this algorithm, each pairwise join is planned into its own
+stage with two inputs. The two inputs are partitioned and sorted using a hash partitioning on the same key. This
+approach is generally less performant, but more scalable, than `broadcast`. There are various scenarios where broadcast
+join would return a [`BroadcastTablesTooLarge`](#errors) error, but a sort-merge join would succeed.
+
+There is no limit on the overall size of either input, so sort-merge is a good choice for performing a join of two large
+inputs, or for performing a self-join of a large input with itself.
+
+There is a limit on the amount of data associated with each individual key. If _both_ sides of the join exceed this
+limit, the query returns a [`TooManyRowsWithSameKey`](#errors) error. If only one side exceeds the limit, the query
+does not return this error.
+
+Join conditions, if present, must be equalities. It is not necessary to include a join condition; for example,
+`CROSS JOIN` and comma join do not require join conditions.
+
+All join types are supported with `sortMerge`: LEFT, RIGHT, INNER, FULL, and CROSS.
+
+As an example, the following statement runs using a single sort-merge join stage that receives `eventstream`
+(partitioned on `user_id`) and `users` (partitioned on `id`) as inputs. There is no limit on the size of either input.
+
+```
+REPLACE INTO eventstream_enriched
+OVERWRITE ALL
+SELECT
+  eventstream.__time,
+  eventstream.user_id,
+  eventstream.event_type,
+  eventstream.event_details,
+  users.signup_date AS user_signup_date
+FROM eventstream
+LEFT JOIN users ON eventstream.user_id = users.id
+PARTITIONED BY HOUR
+CLUSTERED BY user
+```
+
 ## Sketch Merging Mode
 This section details the advantages and performance of various Cluster By Statistics Merge Modes.
 
@@ -656,6 +743,7 @@ The following table lists query limits:
 | Number of cluster by columns that can appear in a stage | 1,500 | [`TooManyClusteredByColumns`](#error_TooManyClusteredByColumns) |
 | Number of workers for any one stage. | Hard limit is 1,000. Memory-dependent soft limit may be lower. | [`TooManyWorkers`](#error_TooManyWorkers) |
 | Maximum memory occupied by broadcasted tables. | 30% of each [processor memory bundle](concepts.md#memory-usage). | [`BroadcastTablesTooLarge`](#error_BroadcastTablesTooLarge) |
+| Maximum memory occupied by buffered data during sort-merge join. Only relevant when `sqlJoinAlgorithm` is `sortMerge`. | 10 MB | `TooManyRowsWithSameKey` |
 | Maximum relaunch attempts per worker. Initial run is not a relaunch. The worker will be spawned 1 + `workerRelaunchLimit` times before the job fails. | 2 | `TooManyAttemptsForWorker` |
 | Maximum relaunch attempts for a job across all workers. | 100 | `TooManyAttemptsForJob` |
 <a name="errors"></a>
@@ -687,6 +775,7 @@ The following table describes error codes you may encounter in the `multiStageQu
 | <a name="error_TooManyInputFiles">`TooManyInputFiles`</a> | Exceeded the maximum number of input files or segments per worker (10,000 files or segments).<br /><br />If you encounter this limit, consider adding more workers, or breaking up your query into smaller queries that process fewer files or segments per query. | `numInputFiles`: The total number of input files/segments for the stage.<br /><br />`maxInputFiles`: The maximum number of input files/segments per worker per stage.<br  [...]
 | <a name="error_TooManyPartitions">`TooManyPartitions`</a> | Exceeded the maximum number of partitions for a stage (25,000 partitions).<br /><br />This can occur with INSERT or REPLACE statements that generate large numbers of segments, since each segment is associated with a partition. If you encounter this limit, consider breaking up your INSERT or REPLACE statement into smaller statements that process less data per statement. | `maxPartitions`: The limit on partitions which was exceeded |
 | <a name="error_TooManyClusteredByColumns">`TooManyClusteredByColumns`</a>  | Exceeded the maximum number of clustering columns for a stage (1,500 columns).<br /><br />This can occur with `CLUSTERED BY`, `ORDER BY`, or `GROUP BY` with a large number of columns. | `numColumns`: The number of columns requested.<br /><br />`maxColumns`: The limit on columns which was exceeded.`stage`: The stage number exceeding the limit<br /><br /> |
+| <a name="error_TooManyRowsWithSameKey">`TooManyRowsWithSameKey`</a> | The number of rows for a given key exceeded the maximum number of buffered bytes on both sides of a join. See the [Limits](#limits) table for the specific limit. Only occurs when `sqlJoinAlgorithm` is `sortMerge`. | `key`: The key that had a large number of rows.<br /><br />`numBytes`: Number of bytes buffered, which may include other keys.<br /><br />`maxBytes`: Maximum number of bytes buffered. |
 | <a name="error_TooManyColumns">`TooManyColumns`</a> | Exceeded the maximum number of columns for a stage (2,000 columns). | `numColumns`: The number of columns requested.<br /><br />`maxColumns`: The limit on columns which was exceeded. |
 | <a name="error_TooManyWarnings">`TooManyWarnings`</a> | Exceeded the maximum allowed number of warnings of a particular type. | `rootErrorCode`: The error code corresponding to the exception that exceeded the required limit. <br /><br />`maxWarnings`: Maximum number of warnings that are allowed for the corresponding `rootErrorCode`. |
 | <a name="error_TooManyWorkers">`TooManyWorkers`</a> | Exceeded the maximum number of simultaneously-running workers. See the [Limits](#limits) table for more details. | `workers`: The number of simultaneously running workers that exceeded a hard or soft limit. This may be larger than the number of workers in any one stage if multiple stages are running simultaneously. <br /><br />`maxWorkers`: The hard or soft limit on workers that was exceeded. If this is lower than the hard limit (1, [...]
diff --git a/docs/querying/datasource.md b/docs/querying/datasource.md
index 0a16e1d16b..5419b7f0c3 100644
--- a/docs/querying/datasource.md
+++ b/docs/querying/datasource.md
@@ -289,10 +289,10 @@ GROUP BY
 Join datasources allow you to do a SQL-style join of two datasources. Stacking joins on top of each other allows
 you to join arbitrarily many datasources.
 
-In Druid {{DRUIDVERSION}}, joins are implemented with a broadcast hash-join algorithm. This means that all datasources
-other than the leftmost "base" datasource must fit in memory. It also means that the join condition must be an equality. This
-feature is intended mainly to allow joining regular Druid tables with [lookup](#lookup), [inline](#inline), and
-[query](#query) datasources.
+In Druid {{DRUIDVERSION}}, joins in native queries are implemented with a broadcast hash-join algorithm. This means
+that all datasources other than the leftmost "base" datasource must fit in memory. It also means that the join condition
+must be an equality. This feature is intended mainly to allow joining regular Druid tables with [lookup](#lookup),
+[inline](#inline), and [query](#query) datasources.
 
 Refer to the [Query execution](query-execution.md#join) page for more details on how queries are executed when you
 use join datasources.
@@ -362,13 +362,11 @@ Also, as a result of this, comma joins should be avoided.
 Joins are an area of active development in Druid. The following features are missing today but may appear in
 future versions:
 
-- Reordering of predicates and filters (pushing up and/or pushing down) to get the most performant plan.
+- Reordering of join operations to get the most performant plan.
 - Preloaded dimension tables that are wider than lookups (i.e. supporting more than a single key and single value).
-- RIGHT OUTER and FULL OUTER joins. Currently, they are partially implemented. Queries will run but results will not
-always be correct.
+- RIGHT OUTER and FULL OUTER joins in the native query engine. Currently, they are partially implemented. Queries run
+  but results are not always correct.
 - Performance-related optimizations as mentioned in the [previous section](#join-performance).
-- Join algorithms other than broadcast hash-joins.
-- Join condition on a column compared to a constant value.
 - Join conditions on a column containing a multi-value dimension.
 
 ### `unnest`
diff --git a/docs/querying/joins.md b/docs/querying/joins.md
index 380ea16bf2..d200b757e4 100644
--- a/docs/querying/joins.md
+++ b/docs/querying/joins.md
@@ -26,7 +26,9 @@ Apache Druid has two features related to joining of data:
 
 1. [Join](datasource.md#join) operators. These are available using a [join datasource](datasource.md#join) in native
 queries, or using the [JOIN operator](sql.md) in Druid SQL. Refer to the
-[join datasource](datasource.md#join) documentation for information about how joins work in Druid.
+[join datasource](datasource.md#join) documentation for information about how joins work in Druid native queries,
+or the [multi-stage query join documentation](../multi-stage-query/reference.md#joins) for information about how joins
+work in multi-stage query tasks.
 2. [Query-time lookups](lookups.md), simple key-to-value mappings. These are preloaded on all servers that are involved
 in queries and can be accessed with or without an explicit join operator. Refer to the [lookups](lookups.md)
 documentation for more details.
diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java
index 180323caab..9a928ffa34 100644
--- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java
+++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java
@@ -47,9 +47,10 @@ import org.apache.druid.frame.channel.FrameChannelSequence;
 import org.apache.druid.frame.key.ClusterBy;
 import org.apache.druid.frame.key.ClusterByPartition;
 import org.apache.druid.frame.key.ClusterByPartitions;
+import org.apache.druid.frame.key.KeyColumn;
+import org.apache.druid.frame.key.KeyOrder;
 import org.apache.druid.frame.key.RowKey;
 import org.apache.druid.frame.key.RowKeyReader;
-import org.apache.druid.frame.key.SortColumn;
 import org.apache.druid.frame.processor.FrameProcessorExecutor;
 import org.apache.druid.frame.processor.FrameProcessors;
 import org.apache.druid.frame.util.DurableStorageUtils;
@@ -132,12 +133,12 @@ import org.apache.druid.msq.input.stage.StageInputSpec;
 import org.apache.druid.msq.input.stage.StageInputSpecSlicer;
 import org.apache.druid.msq.input.table.TableInputSpec;
 import org.apache.druid.msq.input.table.TableInputSpecSlicer;
+import org.apache.druid.msq.kernel.GlobalSortTargetSizeShuffleSpec;
 import org.apache.druid.msq.kernel.QueryDefinition;
 import org.apache.druid.msq.kernel.QueryDefinitionBuilder;
 import org.apache.druid.msq.kernel.StageDefinition;
 import org.apache.druid.msq.kernel.StageId;
 import org.apache.druid.msq.kernel.StagePartition;
-import org.apache.druid.msq.kernel.TargetSizeShuffleSpec;
 import org.apache.druid.msq.kernel.WorkOrder;
 import org.apache.druid.msq.kernel.controller.ControllerQueryKernel;
 import org.apache.druid.msq.kernel.controller.ControllerStagePhase;
@@ -698,8 +699,8 @@ public class ControllerImpl implements Controller
           final StageDefinition stageDef = queryKernel.getStageDefinition(stageId);
           final ObjectMapper mapper = MSQTasks.decorateObjectMapperForKeyCollectorSnapshot(
               context.jsonMapper(),
-              stageDef.getShuffleSpec().get().getClusterBy(),
-              stageDef.getShuffleSpec().get().doesAggregateByClusterKey()
+              stageDef.getShuffleSpec().clusterBy(),
+              stageDef.getShuffleSpec().doesAggregate()
           );
 
           final PartialKeyStatisticsInformation partialKeyStatisticsInformation;
@@ -1502,7 +1503,7 @@ public class ControllerImpl implements Controller
 
     if (MSQControllerTask.isIngestion(querySpec)) {
       shuffleSpecFactory = (clusterBy, aggregate) ->
-          new TargetSizeShuffleSpec(
+          new GlobalSortTargetSizeShuffleSpec(
               clusterBy,
               tuningConfig.getRowsPerSegment(),
               aggregate
@@ -1728,7 +1729,7 @@ public class ControllerImpl implements Controller
       final ColumnMappings columnMappings
   )
   {
-    final List<SortColumn> clusterByColumns = clusterBy.getColumns();
+    final List<KeyColumn> clusterByColumns = clusterBy.getColumns();
     final List<String> shardColumns = new ArrayList<>();
     final boolean boosted = isClusterByBoosted(clusterBy);
     final int numShardColumns = clusterByColumns.size() - clusterBy.getBucketByCount() - (boosted ? 1 : 0);
@@ -1738,11 +1739,11 @@ public class ControllerImpl implements Controller
     }
 
     for (int i = clusterBy.getBucketByCount(); i < clusterBy.getBucketByCount() + numShardColumns; i++) {
-      final SortColumn column = clusterByColumns.get(i);
+      final KeyColumn column = clusterByColumns.get(i);
       final List<String> outputColumns = columnMappings.getOutputColumnsForQueryColumn(column.columnName());
 
       // DimensionRangeShardSpec only handles ascending order.
-      if (column.descending()) {
+      if (column.order() != KeyOrder.ASCENDING) {
         return Collections.emptyList();
       }
 
@@ -1824,8 +1825,8 @@ public class ControllerImpl implements Controller
     // Note: this doesn't work when CLUSTERED BY specifies an expression that is not being selected.
     // Such fields in CLUSTERED BY still control partitioning as expected, but do not affect sort order of rows
     // within an individual segment.
-    for (final SortColumn clusterByColumn : queryClusterBy.getColumns()) {
-      if (clusterByColumn.descending()) {
+    for (final KeyColumn clusterByColumn : queryClusterBy.getColumns()) {
+      if (clusterByColumn.order() == KeyOrder.DESCENDING) {
         throw new MSQException(new InsertCannotOrderByDescendingFault(clusterByColumn.columnName()));
       }
 
@@ -2400,7 +2401,7 @@ public class ControllerImpl implements Controller
           segmentsToGenerate = generateSegmentIdsWithShardSpecs(
               (DataSourceMSQDestination) task.getQuerySpec().getDestination(),
               queryKernel.getStageDefinition(shuffleStageId).getSignature(),
-              queryKernel.getStageDefinition(shuffleStageId).getShuffleSpec().get().getClusterBy(),
+              queryKernel.getStageDefinition(shuffleStageId).getClusterBy(),
               partitionBoundaries,
               mayHaveMultiValuedClusterByFields
           );
diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/Limits.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/Limits.java
index 8bce9a7ba5..c946fd796c 100644
--- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/Limits.java
+++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/Limits.java
@@ -69,6 +69,12 @@ public class Limits
    */
   public static final int MAX_KERNEL_MANIPULATION_QUEUE_SIZE = 100_000;
 
+  /**
+   * Maximum number of bytes buffered for each side of a
+   * {@link org.apache.druid.msq.querykit.common.SortMergeJoinFrameProcessor}, not counting the most recent frame read.
+   */
+  public static final int MAX_BUFFERED_BYTES_FOR_SORT_MERGE_JOIN = 10_000_000;
+
   /**
    * Maximum relaunches across all workers.
    */
diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerContext.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerContext.java
index 6b4a387b8d..d017feb099 100644
--- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerContext.java
+++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerContext.java
@@ -22,6 +22,7 @@ package org.apache.druid.msq.exec;
 import com.fasterxml.jackson.databind.ObjectMapper;
 import com.google.inject.Injector;
 import org.apache.druid.frame.processor.Bouncer;
+import org.apache.druid.java.util.common.StringUtils;
 import org.apache.druid.java.util.common.io.Closer;
 import org.apache.druid.msq.kernel.FrameContext;
 import org.apache.druid.msq.kernel.QueryDefinition;
@@ -73,4 +74,9 @@ public interface WorkerContext
   DruidNode selfNode();
 
   Bouncer processorBouncer();
+
+  default File tempDir(int stageNumber, String id)
+  {
+    return new File(StringUtils.format("%s/stage_%02d/%s", tempDir(), stageNumber, id));
+  }
 }
diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerImpl.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerImpl.java
index fa29cd2c01..6bfea2fc3d 100644
--- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerImpl.java
+++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerImpl.java
@@ -25,41 +25,50 @@ import com.google.common.base.Suppliers;
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableMap;
 import com.google.common.collect.ImmutableSet;
+import com.google.common.collect.Iterables;
+import com.google.common.util.concurrent.AsyncFunction;
 import com.google.common.util.concurrent.FutureCallback;
 import com.google.common.util.concurrent.Futures;
 import com.google.common.util.concurrent.ListenableFuture;
 import com.google.common.util.concurrent.ListeningExecutorService;
 import com.google.common.util.concurrent.SettableFuture;
 import it.unimi.dsi.fastutil.bytes.ByteArrays;
+import org.apache.druid.common.guava.FutureUtils;
+import org.apache.druid.frame.FrameType;
 import org.apache.druid.frame.allocation.ArenaMemoryAllocator;
+import org.apache.druid.frame.allocation.ArenaMemoryAllocatorFactory;
 import org.apache.druid.frame.channel.BlockingQueueFrameChannel;
 import org.apache.druid.frame.channel.ByteTracker;
+import org.apache.druid.frame.channel.FrameWithPartition;
 import org.apache.druid.frame.channel.ReadableFileFrameChannel;
 import org.apache.druid.frame.channel.ReadableFrameChannel;
 import org.apache.druid.frame.channel.ReadableNilFrameChannel;
 import org.apache.druid.frame.file.FrameFile;
 import org.apache.druid.frame.file.FrameFileWriter;
-import org.apache.druid.frame.key.ClusterBy;
 import org.apache.druid.frame.key.ClusterByPartitions;
 import org.apache.druid.frame.processor.BlockingQueueOutputChannelFactory;
 import org.apache.druid.frame.processor.Bouncer;
 import org.apache.druid.frame.processor.ComposingOutputChannelFactory;
-import org.apache.druid.frame.processor.DurableStorageOutputChannelFactory;
 import org.apache.druid.frame.processor.FileOutputChannelFactory;
-import org.apache.druid.frame.processor.FrameChannelMuxer;
+import org.apache.druid.frame.processor.FrameChannelHashPartitioner;
+import org.apache.druid.frame.processor.FrameChannelMixer;
 import org.apache.druid.frame.processor.FrameProcessor;
 import org.apache.druid.frame.processor.FrameProcessorExecutor;
 import org.apache.druid.frame.processor.OutputChannel;
 import org.apache.druid.frame.processor.OutputChannelFactory;
 import org.apache.druid.frame.processor.OutputChannels;
+import org.apache.druid.frame.processor.PartitionedOutputChannel;
 import org.apache.druid.frame.processor.SuperSorter;
 import org.apache.druid.frame.processor.SuperSorterProgressTracker;
 import org.apache.druid.frame.util.DurableStorageUtils;
+import org.apache.druid.frame.write.FrameWriters;
 import org.apache.druid.indexer.TaskStatus;
 import org.apache.druid.java.util.common.FileUtils;
 import org.apache.druid.java.util.common.ISE;
 import org.apache.druid.java.util.common.Pair;
+import org.apache.druid.java.util.common.RE;
 import org.apache.druid.java.util.common.StringUtils;
+import org.apache.druid.java.util.common.UOE;
 import org.apache.druid.java.util.common.guava.Sequence;
 import org.apache.druid.java.util.common.guava.Sequences;
 import org.apache.druid.java.util.common.io.Closer;
@@ -98,15 +107,15 @@ import org.apache.druid.msq.input.table.SegmentsInputSliceReader;
 import org.apache.druid.msq.kernel.FrameContext;
 import org.apache.druid.msq.kernel.FrameProcessorFactory;
 import org.apache.druid.msq.kernel.ProcessorsAndChannels;
-import org.apache.druid.msq.kernel.QueryDefinition;
+import org.apache.druid.msq.kernel.ShuffleSpec;
 import org.apache.druid.msq.kernel.StageDefinition;
 import org.apache.druid.msq.kernel.StageId;
 import org.apache.druid.msq.kernel.StagePartition;
 import org.apache.druid.msq.kernel.WorkOrder;
 import org.apache.druid.msq.kernel.worker.WorkerStageKernel;
 import org.apache.druid.msq.kernel.worker.WorkerStagePhase;
-import org.apache.druid.msq.querykit.DataSegmentProvider;
 import org.apache.druid.msq.shuffle.DurableStorageInputChannelFactory;
+import org.apache.druid.msq.shuffle.DurableStorageOutputChannelFactory;
 import org.apache.druid.msq.shuffle.WorkerInputChannelFactory;
 import org.apache.druid.msq.statistics.ClusterByStatisticsCollector;
 import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot;
@@ -137,7 +146,6 @@ import java.util.List;
 import java.util.Map;
 import java.util.Optional;
 import java.util.Set;
-import java.util.UUID;
 import java.util.concurrent.BlockingQueue;
 import java.util.concurrent.Callable;
 import java.util.concurrent.ConcurrentHashMap;
@@ -357,17 +365,22 @@ public class WorkerImpl implements Worker
 
           // Start working on this stage immediately.
           kernel.startReading();
+
+          final RunWorkOrder runWorkOrder = new RunWorkOrder(
+              kernel,
+              inputChannelFactory,
+              stageCounters.computeIfAbsent(stageDefinition.getId(), ignored -> new CounterTracker()),
+              workerExec,
+              cancellationId,
+              context.threadCount(),
+              stageFrameContexts.get(stageDefinition.getId()),
+              msqWarningReportPublisher
+          );
+
+          runWorkOrder.start();
+
           final SettableFuture<ClusterByPartitions> partitionBoundariesFuture =
-              startWorkOrder(
-                  kernel,
-                  inputChannelFactory,
-                  stageCounters.computeIfAbsent(stageDefinition.getId(), ignored -> new CounterTracker()),
-                  workerExec,
-                  cancellationId,
-                  context.threadCount(),
-                  stageFrameContexts.get(stageDefinition.getId()),
-                  msqWarningReportPublisher
-              );
+              runWorkOrder.getStagePartitionBoundariesFuture();
 
           if (partitionBoundariesFuture != null) {
             if (partitionBoundariesFutureMap.put(stageDefinition.getId(), partitionBoundariesFuture) != null) {
@@ -408,7 +421,10 @@ public class WorkerImpl implements Worker
         }
 
         if (kernel.getPhase() == WorkerStagePhase.RESULTS_READY
-            && kernel.addPostedResultsComplete(Pair.of(stageDefinition.getId(), kernel.getWorkOrder().getWorkerNumber()))) {
+            && kernel.addPostedResultsComplete(Pair.of(
+            stageDefinition.getId(),
+            kernel.getWorkOrder().getWorkerNumber()
+        ))) {
           if (controllerAlive) {
             controllerClient.postResultsComplete(
                 stageDefinition.getId(),
@@ -711,8 +727,8 @@ public class WorkerImpl implements Worker
     final FileOutputChannelFactory fileOutputChannelFactory =
         new FileOutputChannelFactory(fileChannelDirectory, frameSize, intermediateSuperSorterLocalStorageTracker);
 
-    if (MultiStageQueryContext.isComposedIntermediateSuperSorterStorageEnabled(QueryContext.of(task.getContext())) &&
-        durableStageStorageEnabled) {
+    if (MultiStageQueryContext.isComposedIntermediateSuperSorterStorageEnabled(QueryContext.of(task.getContext()))
+        && durableStageStorageEnabled) {
       return new ComposingOutputChannelFactory(
           ImmutableList.of(
               fileOutputChannelFactory,
@@ -844,458 +860,6 @@ public class WorkerImpl implements Worker
     }
   }
 
-  @SuppressWarnings({"rawtypes", "unchecked"})
-  @Nullable
-  private SettableFuture<ClusterByPartitions> startWorkOrder(
-      final WorkerStageKernel kernel,
-      final InputChannelFactory inputChannelFactory,
-      final CounterTracker counters,
-      final FrameProcessorExecutor exec,
-      final String cancellationId,
-      final int parallelism,
-      final FrameContext frameContext,
-      final MSQWarningReportPublisher MSQWarningReportPublisher
-  ) throws IOException
-  {
-    final WorkOrder workOrder = kernel.getWorkOrder();
-    final int workerNumber = workOrder.getWorkerNumber();
-    final StageDefinition stageDef = workOrder.getStageDefinition();
-
-    final InputChannels inputChannels =
-        new InputChannelsImpl(
-            workOrder.getQueryDefinition(),
-            InputSlices.allReadablePartitions(workOrder.getInputs()),
-            inputChannelFactory,
-            () -> ArenaMemoryAllocator.createOnHeap(frameContext.memoryParameters().getStandardFrameSize()),
-            exec,
-            cancellationId
-        );
-
-    final InputSliceReader inputSliceReader = makeInputSliceReader(
-        workOrder.getQueryDefinition(),
-        inputChannels,
-        frameContext.tempDir(),
-        frameContext.dataSegmentProvider()
-    );
-
-    final OutputChannelFactory workerOutputChannelFactory;
-
-    if (stageDef.doesShuffle()) {
-      // Writing to a consumer in the same JVM (which will be set up later on in this method). Use the large frame
-      // size, since we may be writing to a SuperSorter, and we'll generate fewer temp files if we use larger frames.
-      // Note: it's not *guaranteed* that we're writing to a SuperSorter, but it's harmless to use large frames
-      // even if not.
-      workerOutputChannelFactory =
-          new BlockingQueueOutputChannelFactory(frameContext.memoryParameters().getLargeFrameSize());
-    } else {
-      // Writing stage output.
-      workerOutputChannelFactory = makeStageOutputChannelFactory(frameContext, stageDef.getStageNumber());
-    }
-
-    final ResultAndChannels<?> workerResultAndOutputChannels =
-        makeAndRunWorkers(
-            workerNumber,
-            workOrder.getStageDefinition().getProcessorFactory(),
-            workOrder.getExtraInfo(),
-            new CountingOutputChannelFactory(
-                workerOutputChannelFactory,
-                counters.channel(CounterNames.outputChannel())
-            ),
-            stageDef,
-            workOrder.getInputs(),
-            inputSliceReader,
-            frameContext,
-            exec,
-            cancellationId,
-            parallelism,
-            processorBouncer,
-            counters,
-            MSQWarningReportPublisher
-        );
-
-    final ListenableFuture<ClusterByPartitions> stagePartitionBoundariesFuture;
-    final ListenableFuture<OutputChannels> outputChannelsFuture;
-
-    if (stageDef.doesShuffle()) {
-      final ClusterBy clusterBy = workOrder.getStageDefinition().getShuffleSpec().get().getClusterBy();
-
-      final CountingOutputChannelFactory shuffleOutputChannelFactory =
-          new CountingOutputChannelFactory(
-              makeStageOutputChannelFactory(frameContext, stageDef.getStageNumber()),
-              counters.channel(CounterNames.shuffleChannel())
-          );
-
-      if (stageDef.doesSortDuringShuffle()) {
-        if (stageDef.mustGatherResultKeyStatistics()) {
-          stagePartitionBoundariesFuture = SettableFuture.create();
-        } else {
-          stagePartitionBoundariesFuture = Futures.immediateFuture(kernel.getResultPartitionBoundaries());
-        }
-
-        final File sorterTmpDir = new File(context.tempDir(), "super-sort-" + UUID.randomUUID());
-        FileUtils.mkdirp(sorterTmpDir);
-        if (!sorterTmpDir.isDirectory()) {
-          throw new IOException("Cannot create directory: " + sorterTmpDir);
-        }
-
-        outputChannelsFuture = superSortOutputChannels(
-            workOrder.getStageDefinition(),
-            clusterBy,
-            workerResultAndOutputChannels.getOutputChannels(),
-            stagePartitionBoundariesFuture,
-            shuffleOutputChannelFactory,
-            makeSuperSorterIntermediateOutputChannelFactory(frameContext, stageDef.getStageNumber(), sorterTmpDir),
-            exec,
-            cancellationId,
-            frameContext.memoryParameters(),
-            kernelManipulationQueue,
-            counters.sortProgress()
-        );
-      } else {
-        // No sorting, just combining all outputs into one big partition. Use a muxer to get everything into one file.
-        // Note: even if there is only one output channel, we'll run it through the muxer anyway, to ensure the data
-        // gets written to a file. (httpGetChannelData requires files.)
-        final OutputChannel outputChannel = shuffleOutputChannelFactory.openChannel(0);
-
-        final FrameChannelMuxer muxer =
-            new FrameChannelMuxer(
-                workerResultAndOutputChannels.getOutputChannels()
-                                             .getAllChannels()
-                                             .stream()
-                                             .map(OutputChannel::getReadableChannel)
-                                             .collect(Collectors.toList()),
-                outputChannel.getWritableChannel()
-            );
-
-        //noinspection unchecked, rawtypes
-        outputChannelsFuture = Futures.transform(
-            exec.runFully(muxer, cancellationId),
-            (Function) ignored -> OutputChannels.wrap(Collections.singletonList(outputChannel.readOnly()))
-        );
-
-        stagePartitionBoundariesFuture = null;
-      }
-    } else {
-      stagePartitionBoundariesFuture = null;
-
-      // Retain read-only versions to reduce memory footprint.
-      outputChannelsFuture = Futures.immediateFuture(workerResultAndOutputChannels.getOutputChannels().readOnly());
-    }
-
-    // Output channels and future are all constructed. Sanity check, record them, and set up callbacks.
-    Futures.addCallback(
-        Futures.allAsList(
-            Arrays.asList(
-                workerResultAndOutputChannels.getResultFuture(),
-                Futures.transform(
-                    outputChannelsFuture,
-                    new Function<OutputChannels, OutputChannels>()
-                    {
-                      @Override
-                      public OutputChannels apply(final OutputChannels channels)
-                      {
-                        sanityCheckOutputChannels(channels);
-                        return channels;
-                      }
-                    }
-                )
-            )
-        ),
-        new FutureCallback<List<Object>>()
-        {
-          @Override
-          public void onSuccess(final List<Object> workerResultAndOutputChannelsResolved)
-          {
-            Object resultObject = workerResultAndOutputChannelsResolved.get(0);
-            final OutputChannels outputChannels = (OutputChannels) workerResultAndOutputChannelsResolved.get(1);
-
-            for (OutputChannel channel : outputChannels.getAllChannels()) {
-              stageOutputs.computeIfAbsent(stageDef.getId(), ignored1 -> new ConcurrentHashMap<>())
-                          .computeIfAbsent(channel.getPartitionNumber(), ignored2 -> channel.getReadableChannel());
-
-            }
-
-            if (durableStageStorageEnabled) {
-              // Once the outputs channels have been resolved and are ready for reading, the worker appends the filename
-              // with a special marker flag and adds it to the
-              DurableStorageOutputChannelFactory durableStorageOutputChannelFactory =
-                  DurableStorageOutputChannelFactory.createStandardImplementation(
-                      task.getControllerTaskId(),
-                      task().getWorkerNumber(),
-                      stageDef.getStageNumber(),
-                      task().getId(),
-                      frameContext.memoryParameters().getStandardFrameSize(),
-                      MSQTasks.makeStorageConnector(context.injector()),
-                      context.tempDir()
-                  );
-              try {
-                durableStorageOutputChannelFactory.createSuccessFile(task.getId());
-              }
-              catch (IOException e) {
-                throw new ISE(
-                    e,
-                    "Unable to create the success file [%s] at the location [%s]",
-                    DurableStorageUtils.SUCCESS_MARKER_FILENAME,
-                    DurableStorageUtils.getSuccessFilePath(
-                        task.getControllerTaskId(),
-                        stageDef.getStageNumber(),
-                        task().getWorkerNumber()
-                    )
-                );
-              }
-            }
-
-            kernelManipulationQueue.add(holder -> holder.getStageKernelMap()
-                                                        .get(stageDef.getId())
-                                                        .setResultsComplete(resultObject));
-          }
-
-          @Override
-          public void onFailure(final Throwable t)
-          {
-            kernelManipulationQueue.add(
-                kernelHolder ->
-                    kernelHolder.getStageKernelMap().get(stageDef.getId()).fail(t)
-            );
-          }
-        }
-    );
-
-    // Return settable result-key-statistics future, so callers can set it and unblock the supersorter if needed.
-    return stageDef.mustGatherResultKeyStatistics()
-           ? (SettableFuture<ClusterByPartitions>) stagePartitionBoundariesFuture
-           : null;
-  }
-
-  private static <FactoryType extends FrameProcessorFactory<I, WorkerClass, T, R>, I, WorkerClass extends FrameProcessor<T>, T, R> ResultAndChannels<R> makeAndRunWorkers(
-      final int workerNumber,
-      final FactoryType processorFactory,
-      final I processorFactoryExtraInfo,
-      final OutputChannelFactory outputChannelFactory,
-      final StageDefinition stageDefinition,
-      final List<InputSlice> inputSlices,
-      final InputSliceReader inputSliceReader,
-      final FrameContext frameContext,
-      final FrameProcessorExecutor exec,
-      final String cancellationId,
-      final int parallelism,
-      final Bouncer processorBouncer,
-      final CounterTracker counters,
-      final MSQWarningReportPublisher warningPublisher
-  ) throws IOException
-  {
-    final ProcessorsAndChannels<WorkerClass, T> processors =
-        processorFactory.makeProcessors(
-            stageDefinition,
-            workerNumber,
-            inputSlices,
-            inputSliceReader,
-            processorFactoryExtraInfo,
-            outputChannelFactory,
-            frameContext,
-            parallelism,
-            counters,
-            e -> warningPublisher.publishException(stageDefinition.getStageNumber(), e)
-        );
-
-    final Sequence<WorkerClass> processorSequence = processors.processors();
-
-    final int maxOutstandingProcessors;
-
-    if (processors.getOutputChannels().getAllChannels().isEmpty()) {
-      // No output channels: run up to "parallelism" processors at once.
-      maxOutstandingProcessors = Math.max(1, parallelism);
-    } else {
-      // If there are output channels, that acts as a ceiling on the number of processors that can run at once.
-      maxOutstandingProcessors =
-          Math.max(1, Math.min(parallelism, processors.getOutputChannels().getAllChannels().size()));
-    }
-
-    final ListenableFuture<R> workResultFuture = exec.runAllFully(
-        processorSequence,
-        processorFactory.newAccumulatedResult(),
-        processorFactory::accumulateResult,
-        maxOutstandingProcessors,
-        processorBouncer,
-        cancellationId
-    );
-
-    return new ResultAndChannels<>(workResultFuture, processors.getOutputChannels());
-  }
-
-  private static InputSliceReader makeInputSliceReader(
-      final QueryDefinition queryDef,
-      final InputChannels inputChannels,
-      final File temporaryDirectory,
-      final DataSegmentProvider segmentProvider
-  )
-  {
-    return new MapInputSliceReader(
-        ImmutableMap.<Class<? extends InputSlice>, InputSliceReader>builder()
-                    .put(NilInputSlice.class, NilInputSliceReader.INSTANCE)
-                    .put(StageInputSlice.class, new StageInputSliceReader(queryDef.getQueryId(), inputChannels))
-                    .put(ExternalInputSlice.class, new ExternalInputSliceReader(temporaryDirectory))
-                    .put(SegmentsInputSlice.class, new SegmentsInputSliceReader(segmentProvider))
-                    .build()
-    );
-  }
-
-  private static ListenableFuture<OutputChannels> superSortOutputChannels(
-      final StageDefinition stageDefinition,
-      final ClusterBy clusterBy,
-      final OutputChannels processorOutputChannels,
-      final ListenableFuture<ClusterByPartitions> stagePartitionBoundariesFuture,
-      final OutputChannelFactory outputChannelFactory,
-      final OutputChannelFactory intermediateOutputChannelFactory,
-      final FrameProcessorExecutor exec,
-      final String cancellationId,
-      final WorkerMemoryParameters memoryParameters,
-      final BlockingQueue<Consumer<KernelHolder>> kernelManipulationQueue,
-      final SuperSorterProgressTracker superSorterProgressTracker
-  ) throws IOException
-  {
-    if (!stageDefinition.doesShuffle()) {
-      throw new ISE("Output channels do not need shuffling");
-    }
-
-    final List<ReadableFrameChannel> channelsToSuperSort;
-
-    if (processorOutputChannels.getAllChannels().isEmpty()) {
-      // No data coming out of this processor. Report empty statistics, if the kernel is expecting statistics.
-      if (stageDefinition.mustGatherResultKeyStatistics()) {
-        kernelManipulationQueue.add(
-            holder ->
-                holder.getStageKernelMap().get(stageDefinition.getId())
-                      .setResultKeyStatisticsSnapshot(ClusterByStatisticsSnapshot.empty())
-        );
-      }
-
-      // Process one empty channel so the SuperSorter has something to do.
-      final BlockingQueueFrameChannel channel = BlockingQueueFrameChannel.minimal();
-      channel.writable().close();
-      channelsToSuperSort = Collections.singletonList(channel.readable());
-    } else if (stageDefinition.mustGatherResultKeyStatistics()) {
-      channelsToSuperSort = collectKeyStatistics(
-          stageDefinition,
-          clusterBy,
-          processorOutputChannels,
-          exec,
-          memoryParameters.getPartitionStatisticsMaxRetainedBytes(),
-          cancellationId,
-          kernelManipulationQueue
-      );
-    } else {
-      channelsToSuperSort = processorOutputChannels.getAllChannels()
-                                                   .stream()
-                                                   .map(OutputChannel::getReadableChannel)
-                                                   .collect(Collectors.toList());
-    }
-
-    final SuperSorter sorter = new SuperSorter(
-        channelsToSuperSort,
-        stageDefinition.getFrameReader(),
-        clusterBy,
-        stagePartitionBoundariesFuture,
-        exec,
-        outputChannelFactory,
-        intermediateOutputChannelFactory,
-        memoryParameters.getSuperSorterMaxActiveProcessors(),
-        memoryParameters.getSuperSorterMaxChannelsPerProcessor(),
-        -1,
-        cancellationId,
-        superSorterProgressTracker
-    );
-
-    return sorter.run();
-  }
-
-  private static List<ReadableFrameChannel> collectKeyStatistics(
-      final StageDefinition stageDefinition,
-      final ClusterBy clusterBy,
-      final OutputChannels processorOutputChannels,
-      final FrameProcessorExecutor exec,
-      final int partitionStatisticsMaxRetainedBytes,
-      final String cancellationId,
-      final BlockingQueue<Consumer<KernelHolder>> kernelManipulationQueue
-  )
-  {
-    final List<ReadableFrameChannel> retVal = new ArrayList<>();
-    final List<KeyStatisticsCollectionProcessor> processors = new ArrayList<>();
-
-    for (final OutputChannel outputChannel : processorOutputChannels.getAllChannels()) {
-      final BlockingQueueFrameChannel channel = BlockingQueueFrameChannel.minimal();
-      retVal.add(channel.readable());
-
-      processors.add(
-          new KeyStatisticsCollectionProcessor(
-              outputChannel.getReadableChannel(),
-              channel.writable(),
-              stageDefinition.getFrameReader(),
-              clusterBy,
-              stageDefinition.createResultKeyStatisticsCollector(partitionStatisticsMaxRetainedBytes)
-          )
-      );
-    }
-
-    final ListenableFuture<ClusterByStatisticsCollector> clusterByStatisticsCollectorFuture =
-        exec.runAllFully(
-            Sequences.simple(processors),
-            stageDefinition.createResultKeyStatisticsCollector(partitionStatisticsMaxRetainedBytes),
-            ClusterByStatisticsCollector::addAll,
-            // Run all processors simultaneously. They are lightweight and this keeps things moving.
-            processors.size(),
-            Bouncer.unlimited(),
-            cancellationId
-        );
-
-    Futures.addCallback(
-        clusterByStatisticsCollectorFuture,
-        new FutureCallback<ClusterByStatisticsCollector>()
-        {
-          @Override
-          public void onSuccess(final ClusterByStatisticsCollector result)
-          {
-            kernelManipulationQueue.add(
-                holder ->
-                    holder.getStageKernelMap().get(stageDefinition.getId())
-                          .setResultKeyStatisticsSnapshot(result.snapshot())
-            );
-          }
-
-          @Override
-          public void onFailure(Throwable t)
-          {
-            kernelManipulationQueue.add(
-                holder -> {
-                  log.noStackTrace()
-                     .warn(t, "Failed to gather clusterBy statistics for stage [%s]", stageDefinition.getId());
-                  holder.getStageKernelMap().get(stageDefinition.getId()).fail(t);
-                }
-            );
-          }
-        }
-    );
-
-    return retVal;
-  }
-
-  private static void sanityCheckOutputChannels(final OutputChannels outputChannels)
-  {
-    // Verify there is exactly one channel per partition.
-    for (int partitionNumber : outputChannels.getPartitionNumbers()) {
-      final List<OutputChannel> outputChannelsForPartition =
-          outputChannels.getChannelsForPartition(partitionNumber);
-
-      Preconditions.checkState(partitionNumber >= 0, "Expected partitionNumber >= 0, but got [%s]", partitionNumber);
-      Preconditions.checkState(
-          outputChannelsForPartition.size() == 1,
-          "Expected one channel for partition [%s], but got [%s]",
-          partitionNumber,
-          outputChannelsForPartition.size()
-      );
-    }
-  }
-
   /**
    * Log (at DEBUG level) a string explaining the status of all work assigned to this worker.
    */
@@ -1384,6 +948,817 @@ public class WorkerImpl implements Worker
     }
   }
 
+  /**
+   * Main worker logic for executing a {@link WorkOrder}.
+   */
+  private class RunWorkOrder
+  {
+    private final WorkerStageKernel kernel;
+    private final InputChannelFactory inputChannelFactory;
+    private final CounterTracker counterTracker;
+    private final FrameProcessorExecutor exec;
+    private final String cancellationId;
+    private final int parallelism;
+    private final FrameContext frameContext;
+    private final MSQWarningReportPublisher warningPublisher;
+
+    private InputSliceReader inputSliceReader;
+    private OutputChannelFactory workOutputChannelFactory;
+    private OutputChannelFactory shuffleOutputChannelFactory;
+    private ResultAndChannels<?> workResultAndOutputChannels;
+    private SettableFuture<ClusterByPartitions> stagePartitionBoundariesFuture;
+    private ListenableFuture<OutputChannels> shuffleOutputChannelsFuture;
+
+    public RunWorkOrder(
+        final WorkerStageKernel kernel,
+        final InputChannelFactory inputChannelFactory,
+        final CounterTracker counterTracker,
+        final FrameProcessorExecutor exec,
+        final String cancellationId,
+        final int parallelism,
+        final FrameContext frameContext,
+        final MSQWarningReportPublisher warningPublisher
+    )
+    {
+      this.kernel = kernel;
+      this.inputChannelFactory = inputChannelFactory;
+      this.counterTracker = counterTracker;
+      this.exec = exec;
+      this.cancellationId = cancellationId;
+      this.parallelism = parallelism;
+      this.frameContext = frameContext;
+      this.warningPublisher = warningPublisher;
+    }
+
+    private void start() throws IOException
+    {
+      final WorkOrder workOrder = kernel.getWorkOrder();
+      final StageDefinition stageDef = workOrder.getStageDefinition();
+
+      makeInputSliceReader();
+      makeWorkOutputChannelFactory();
+      makeShuffleOutputChannelFactory();
+      makeAndRunWorkProcessors();
+
+      if (stageDef.doesShuffle()) {
+        makeAndRunShuffleProcessors();
+      } else {
+        // No shuffling: work output _is_ shuffle output. Retain read-only versions to reduce memory footprint.
+        shuffleOutputChannelsFuture =
+            Futures.immediateFuture(workResultAndOutputChannels.getOutputChannels().readOnly());
+      }
+
+      setUpCompletionCallbacks();
+    }
+
+    /**
+     * Settable {@link ClusterByPartitions} future for global sort. Necessary because we don't know ahead of time
+     * what the boundaries will be. The controller decides based on statistics from all workers. Once the controller
+     * decides, its decision is written to this future, which allows sorting on workers to proceed.
+     */
+    @Nullable
+    public SettableFuture<ClusterByPartitions> getStagePartitionBoundariesFuture()
+    {
+      return stagePartitionBoundariesFuture;
+    }
+
+    private void makeInputSliceReader()
+    {
+      if (inputSliceReader != null) {
+        throw new ISE("inputSliceReader already created");
+      }
+
+      final WorkOrder workOrder = kernel.getWorkOrder();
+      final String queryId = workOrder.getQueryDefinition().getQueryId();
+
+      final InputChannels inputChannels =
+          new InputChannelsImpl(
+              workOrder.getQueryDefinition(),
+              InputSlices.allReadablePartitions(workOrder.getInputs()),
+              inputChannelFactory,
+              () -> ArenaMemoryAllocator.createOnHeap(frameContext.memoryParameters().getStandardFrameSize()),
+              exec,
+              cancellationId
+          );
+
+      inputSliceReader = new MapInputSliceReader(
+          ImmutableMap.<Class<? extends InputSlice>, InputSliceReader>builder()
+                      .put(NilInputSlice.class, NilInputSliceReader.INSTANCE)
+                      .put(StageInputSlice.class, new StageInputSliceReader(queryId, inputChannels))
+                      .put(ExternalInputSlice.class, new ExternalInputSliceReader(frameContext.tempDir()))
+                      .put(SegmentsInputSlice.class, new SegmentsInputSliceReader(frameContext.dataSegmentProvider()))
+                      .build()
+      );
+    }
+
+    private void makeWorkOutputChannelFactory()
+    {
+      if (workOutputChannelFactory != null) {
+        throw new ISE("processorOutputChannelFactory already created");
+      }
+
+      final OutputChannelFactory baseOutputChannelFactory;
+
+      if (kernel.getStageDefinition().doesShuffle()) {
+        // Writing to a consumer in the same JVM (which will be set up later on in this method). Use the large frame
+        // size if we're writing to a SuperSorter, since we'll generate fewer temp files if we use larger frames.
+        // Otherwise, use the standard frame size.
+        final int frameSize;
+
+        if (kernel.getStageDefinition().getShuffleSpec().kind().isSort()) {
+          frameSize = frameContext.memoryParameters().getLargeFrameSize();
+        } else {
+          frameSize = frameContext.memoryParameters().getStandardFrameSize();
+        }
+
+        baseOutputChannelFactory = new BlockingQueueOutputChannelFactory(frameSize);
+      } else {
+        // Writing stage output.
+        baseOutputChannelFactory =
+            makeStageOutputChannelFactory(frameContext, kernel.getStageDefinition().getStageNumber());
+      }
+
+      workOutputChannelFactory = new CountingOutputChannelFactory(
+          baseOutputChannelFactory,
+          counterTracker.channel(CounterNames.outputChannel())
+      );
+    }
+
+    private void makeShuffleOutputChannelFactory()
+    {
+      shuffleOutputChannelFactory =
+          new CountingOutputChannelFactory(
+              makeStageOutputChannelFactory(frameContext, kernel.getStageDefinition().getStageNumber()),
+              counterTracker.channel(CounterNames.shuffleChannel())
+          );
+    }
+
+    private <FactoryType extends FrameProcessorFactory<I, WorkerClass, T, R>, I, WorkerClass extends FrameProcessor<T>, T, R> void makeAndRunWorkProcessors()
+        throws IOException
+    {
+      if (workResultAndOutputChannels != null) {
+        throw new ISE("workResultAndOutputChannels already set");
+      }
+
+      @SuppressWarnings("unchecked")
+      final FactoryType processorFactory = (FactoryType) kernel.getStageDefinition().getProcessorFactory();
+
+      @SuppressWarnings("unchecked")
+      final ProcessorsAndChannels<WorkerClass, T> processors =
+          processorFactory.makeProcessors(
+              kernel.getStageDefinition(),
+              kernel.getWorkOrder().getWorkerNumber(),
+              kernel.getWorkOrder().getInputs(),
+              inputSliceReader,
+              (I) kernel.getWorkOrder().getExtraInfo(),
+              workOutputChannelFactory,
+              frameContext,
+              parallelism,
+              counterTracker,
+              e -> warningPublisher.publishException(kernel.getStageDefinition().getStageNumber(), e)
+          );
+
+      final Sequence<WorkerClass> processorSequence = processors.processors();
+
+      final int maxOutstandingProcessors;
+
+      if (processors.getOutputChannels().getAllChannels().isEmpty()) {
+        // No output channels: run up to "parallelism" processors at once.
+        maxOutstandingProcessors = Math.max(1, parallelism);
+      } else {
+        // If there are output channels, that acts as a ceiling on the number of processors that can run at once.
+        maxOutstandingProcessors =
+            Math.max(1, Math.min(parallelism, processors.getOutputChannels().getAllChannels().size()));
+      }
+
+      final ListenableFuture<R> workResultFuture = exec.runAllFully(
+          processorSequence,
+          processorFactory.newAccumulatedResult(),
+          processorFactory::accumulateResult,
+          maxOutstandingProcessors,
+          processorBouncer,
+          cancellationId
+      );
+
+      workResultAndOutputChannels = new ResultAndChannels<>(workResultFuture, processors.getOutputChannels());
+    }
+
+    private void makeAndRunShuffleProcessors()
+    {
+      if (shuffleOutputChannelsFuture != null) {
+        throw new ISE("shuffleOutputChannelsFuture already set");
+      }
+
+      final ShuffleSpec shuffleSpec = kernel.getWorkOrder().getStageDefinition().getShuffleSpec();
+
+      final ShufflePipelineBuilder shufflePipeline = new ShufflePipelineBuilder(
+          kernel,
+          counterTracker,
+          exec,
+          cancellationId,
+          frameContext
+      );
+
+      shufflePipeline.initialize(workResultAndOutputChannels);
+
+      switch (shuffleSpec.kind()) {
+        case MIX:
+          shufflePipeline.mix(shuffleOutputChannelFactory);
+          break;
+
+        case HASH:
+          shufflePipeline.hashPartition(shuffleOutputChannelFactory);
+          break;
+
+        case HASH_LOCAL_SORT:
+          final OutputChannelFactory hashOutputChannelFactory;
+
+          if (shuffleSpec.partitionCount() == 1) {
+            // Single partition; no need to write temporary files.
+            hashOutputChannelFactory =
+                new BlockingQueueOutputChannelFactory(frameContext.memoryParameters().getStandardFrameSize());
+          } else {
+            // Multi-partition; write temporary files and then sort each one file-by-file.
+            hashOutputChannelFactory =
+                new FileOutputChannelFactory(
+                    context.tempDir(kernel.getStageDefinition().getStageNumber(), "hash-parts"),
+                    frameContext.memoryParameters().getStandardFrameSize(),
+                    null
+                );
+          }
+
+          shufflePipeline.hashPartition(hashOutputChannelFactory);
+          shufflePipeline.localSort(shuffleOutputChannelFactory);
+          break;
+
+        case GLOBAL_SORT:
+          shufflePipeline.gatherResultKeyStatisticsIfNeeded();
+          shufflePipeline.globalSort(shuffleOutputChannelFactory, makeGlobalSortPartitionBoundariesFuture());
+          break;
+
+        default:
+          throw new UOE("Cannot handle shuffle kind [%s]", shuffleSpec.kind());
+      }
+
+      shuffleOutputChannelsFuture = shufflePipeline.build();
+    }
+
+    private ListenableFuture<ClusterByPartitions> makeGlobalSortPartitionBoundariesFuture()
+    {
+      if (kernel.getStageDefinition().mustGatherResultKeyStatistics()) {
+        if (stagePartitionBoundariesFuture != null) {
+          throw new ISE("Cannot call 'makeGlobalSortPartitionBoundariesFuture' twice");
+        }
+
+        return (stagePartitionBoundariesFuture = SettableFuture.create());
+      } else {
+        return Futures.immediateFuture(kernel.getResultPartitionBoundaries());
+      }
+    }
+
+    private void setUpCompletionCallbacks()
+    {
+      final StageDefinition stageDef = kernel.getStageDefinition();
+
+      Futures.addCallback(
+          Futures.allAsList(
+              Arrays.asList(
+                  workResultAndOutputChannels.getResultFuture(),
+                  shuffleOutputChannelsFuture
+              )
+          ),
+          new FutureCallback<List<Object>>()
+          {
+            @Override
+            public void onSuccess(final List<Object> workerResultAndOutputChannelsResolved)
+            {
+              final Object resultObject = workerResultAndOutputChannelsResolved.get(0);
+              final OutputChannels outputChannels = (OutputChannels) workerResultAndOutputChannelsResolved.get(1);
+
+              for (OutputChannel channel : outputChannels.getAllChannels()) {
+                try {
+                  stageOutputs.computeIfAbsent(stageDef.getId(), ignored1 -> new ConcurrentHashMap<>())
+                              .computeIfAbsent(channel.getPartitionNumber(), ignored2 -> channel.getReadableChannel());
+                }
+                catch (Exception e) {
+                  kernelManipulationQueue.add(holder -> {
+                    throw new RE(e, "Worker completion callback error for stage [%s]", stageDef.getId());
+                  });
+
+                  // Don't make the "setResultsComplete" call below.
+                  return;
+                }
+              }
+
+              // Once the outputs channels have been resolved and are ready for reading, write success file, if
+              // using durable storage.
+              writeDurableStorageSuccessFileIfNeeded(stageDef.getStageNumber());
+
+              kernelManipulationQueue.add(holder -> holder.getStageKernelMap()
+                                                          .get(stageDef.getId())
+                                                          .setResultsComplete(resultObject));
+            }
+
+            @Override
+            public void onFailure(final Throwable t)
+            {
+              kernelManipulationQueue.add(
+                  kernelHolder ->
+                      kernelHolder.getStageKernelMap().get(stageDef.getId()).fail(t)
+              );
+            }
+          }
+      );
+    }
+
+    /**
+     * Write {@link DurableStorageUtils#SUCCESS_MARKER_FILENAME} for a particular stage, if durable storage is enabled.
+     */
+    private void writeDurableStorageSuccessFileIfNeeded(final int stageNumber)
+    {
+      if (!durableStageStorageEnabled) {
+        return;
+      }
+
+      DurableStorageOutputChannelFactory durableStorageOutputChannelFactory =
+          DurableStorageOutputChannelFactory.createStandardImplementation(
+              task.getControllerTaskId(),
+              task().getWorkerNumber(),
+              stageNumber,
+              task().getId(),
+              frameContext.memoryParameters().getStandardFrameSize(),
+              MSQTasks.makeStorageConnector(context.injector()),
+              context.tempDir()
+          );
+      try {
+        durableStorageOutputChannelFactory.createSuccessFile(task.getId());
+      }
+      catch (IOException e) {
+        throw new ISE(
+            e,
+            "Unable to create the success file [%s] at the location [%s]",
+            DurableStorageUtils.SUCCESS_MARKER_FILENAME,
+            DurableStorageUtils.getSuccessFilePath(
+                task.getControllerTaskId(),
+                stageNumber,
+                task().getWorkerNumber()
+            )
+        );
+      }
+    }
+  }
+
+  /**
+   * Helper for {@link RunWorkOrder#makeAndRunShuffleProcessors()}. Builds a {@link FrameProcessor} pipeline to
+   * handle the shuffle.
+   */
+  private class ShufflePipelineBuilder
+  {
+    private final WorkerStageKernel kernel;
+    private final CounterTracker counterTracker;
+    private final FrameProcessorExecutor exec;
+    private final String cancellationId;
+    private final FrameContext frameContext;
+
+    // Current state of the pipeline. It's a future to allow pipeline construction to be deferred if necessary.
+    private ListenableFuture<ResultAndChannels<?>> pipelineFuture;
+
+    public ShufflePipelineBuilder(
+        final WorkerStageKernel kernel,
+        final CounterTracker counterTracker,
+        final FrameProcessorExecutor exec,
+        final String cancellationId,
+        final FrameContext frameContext
+    )
+    {
+      this.kernel = kernel;
+      this.counterTracker = counterTracker;
+      this.exec = exec;
+      this.cancellationId = cancellationId;
+      this.frameContext = frameContext;
+    }
+
+    /**
+     * Start the pipeline with the outputs of the main processor.
+     */
+    public void initialize(final ResultAndChannels<?> resultAndChannels)
+    {
+      if (pipelineFuture != null) {
+        throw new ISE("already initialized");
+      }
+
+      pipelineFuture = Futures.immediateFuture(resultAndChannels);
+    }
+
+    /**
+     * Add {@link FrameChannelMixer}, which mixes all current outputs into a single channel from the provided factory.
+     */
+    public void mix(final OutputChannelFactory outputChannelFactory)
+    {
+      // No sorting or statistics gathering, just combining all outputs into one big partition. Use a mixer to get
+      // everything into one file. Note: even if there is only one output channel, we'll run it through the mixer
+      // anyway, to ensure the data gets written to a file. (httpGetChannelData requires files.)
+
+      push(
+          resultAndChannels -> {
+            final OutputChannel outputChannel = outputChannelFactory.openChannel(0);
+
+            final FrameChannelMixer mixer =
+                new FrameChannelMixer(
+                    resultAndChannels.getOutputChannels().getAllReadableChannels(),
+                    outputChannel.getWritableChannel()
+                );
+
+            return new ResultAndChannels<>(
+                exec.runFully(mixer, cancellationId),
+                OutputChannels.wrap(Collections.singletonList(outputChannel.readOnly()))
+            );
+          }
+      );
+    }
+
+    /**
+     * Add {@link KeyStatisticsCollectionProcessor} if {@link StageDefinition#mustGatherResultKeyStatistics()}.
+     */
+    public void gatherResultKeyStatisticsIfNeeded()
+    {
+      push(
+          resultAndChannels -> {
+            final StageDefinition stageDefinition = kernel.getStageDefinition();
+            final OutputChannels channels = resultAndChannels.getOutputChannels();
+
+            if (channels.getAllChannels().isEmpty()) {
+              // No data coming out of this processor. Report empty statistics, if the kernel is expecting statistics.
+              if (stageDefinition.mustGatherResultKeyStatistics()) {
+                kernelManipulationQueue.add(
+                    holder ->
+                        holder.getStageKernelMap().get(stageDefinition.getId())
+                              .setResultKeyStatisticsSnapshot(ClusterByStatisticsSnapshot.empty())
+                );
+              }
+
+              // Generate one empty channel so the SuperSorter has something to do.
+              final BlockingQueueFrameChannel channel = BlockingQueueFrameChannel.minimal();
+              channel.writable().close();
+
+              final OutputChannel outputChannel = OutputChannel.readOnly(
+                  channel.readable(),
+                  FrameWithPartition.NO_PARTITION
+              );
+
+              return new ResultAndChannels<>(
+                  Futures.immediateFuture(null),
+                  OutputChannels.wrap(Collections.singletonList(outputChannel))
+              );
+            } else if (stageDefinition.mustGatherResultKeyStatistics()) {
+              return gatherResultKeyStatistics(channels);
+            } else {
+              return resultAndChannels;
+            }
+          }
+      );
+    }
+
+    /**
+     * Add a {@link SuperSorter} using {@link StageDefinition#getSortKey()} and partition boundaries
+     * from {@code partitionBoundariesFuture}.
+     */
+    public void globalSort(
+        final OutputChannelFactory outputChannelFactory,
+        final ListenableFuture<ClusterByPartitions> partitionBoundariesFuture
+    )
+    {
+      pushAsync(
+          resultAndChannels -> {
+            final StageDefinition stageDefinition = kernel.getStageDefinition();
+
+            final File sorterTmpDir = context.tempDir(stageDefinition.getStageNumber(), "super-sort");
+            FileUtils.mkdirp(sorterTmpDir);
+            if (!sorterTmpDir.isDirectory()) {
+              throw new IOException("Cannot create directory: " + sorterTmpDir);
+            }
+
+            final WorkerMemoryParameters memoryParameters = frameContext.memoryParameters();
+            final SuperSorter sorter = new SuperSorter(
+                resultAndChannels.getOutputChannels().getAllReadableChannels(),
+                stageDefinition.getFrameReader(),
+                stageDefinition.getSortKey(),
+                partitionBoundariesFuture,
+                exec,
+                outputChannelFactory,
+                makeSuperSorterIntermediateOutputChannelFactory(
+                    frameContext,
+                    stageDefinition.getStageNumber(),
+                    sorterTmpDir
+                ),
+                memoryParameters.getSuperSorterMaxActiveProcessors(),
+                memoryParameters.getSuperSorterMaxChannelsPerProcessor(),
+                -1,
+                cancellationId,
+                counterTracker.sortProgress()
+            );
+
+            return FutureUtils.transform(
+                sorter.run(),
+                sortedChannels -> new ResultAndChannels<>(Futures.immediateFuture(null), sortedChannels)
+            );
+          }
+      );
+    }
+
+    /**
+     * Add a {@link FrameChannelHashPartitioner} using {@link StageDefinition#getSortKey()}.
+     */
+    public void hashPartition(final OutputChannelFactory outputChannelFactory)
+    {
+      pushAsync(
+          resultAndChannels -> {
+            final ShuffleSpec shuffleSpec = kernel.getStageDefinition().getShuffleSpec();
+            final int partitions = shuffleSpec.partitionCount();
+
+            final List<OutputChannel> outputChannels = new ArrayList<>();
+
+            for (int i = 0; i < partitions; i++) {
+              outputChannels.add(outputChannelFactory.openChannel(i));
+            }
+
+            final FrameChannelHashPartitioner partitioner = new FrameChannelHashPartitioner(
+                resultAndChannels.getOutputChannels().getAllReadableChannels(),
+                outputChannels.stream().map(OutputChannel::getWritableChannel).collect(Collectors.toList()),
+                kernel.getStageDefinition().getFrameReader(),
+                kernel.getStageDefinition().getClusterBy().getColumns().size(),
+                FrameWriters.makeFrameWriterFactory(
+                    FrameType.ROW_BASED,
+                    new ArenaMemoryAllocatorFactory(frameContext.memoryParameters().getStandardFrameSize()),
+                    kernel.getStageDefinition().getSignature(),
+                    kernel.getStageDefinition().getSortKey()
+                )
+            );
+
+            final ListenableFuture<Long> partitionerFuture = exec.runFully(partitioner, cancellationId);
+
+            final ResultAndChannels<Long> retVal =
+                new ResultAndChannels<>(partitionerFuture, OutputChannels.wrap(outputChannels));
+
+            if (retVal.getOutputChannels().areReadableChannelsReady()) {
+              return Futures.immediateFuture(retVal);
+            } else {
+              return FutureUtils.transform(partitionerFuture, ignored -> retVal);
+            }
+          }
+      );
+    }
+
+    /**
+     * Add a sequence of {@link SuperSorter}, operating on each current output channel in order, one at a time.
+     */
+    public void localSort(final OutputChannelFactory outputChannelFactory)
+    {
+      pushAsync(
+          resultAndChannels -> {
+            final StageDefinition stageDefinition = kernel.getStageDefinition();
+            final OutputChannels channels = resultAndChannels.getOutputChannels();
+            final List<ListenableFuture<OutputChannel>> sortedChannelFutures = new ArrayList<>();
+
+            ListenableFuture<OutputChannel> nextFuture = Futures.immediateFuture(null);
+
+            for (final OutputChannel channel : channels.getAllChannels()) {
+              final File sorterTmpDir = context.tempDir(
+                  stageDefinition.getStageNumber(),
+                  StringUtils.format("hash-parts-super-sort-%06d", channel.getPartitionNumber())
+              );
+
+              FileUtils.mkdirp(sorterTmpDir);
+
+              // SuperSorter will try to write to output partition zero; we remap it to the correct partition number.
+              final OutputChannelFactory partitionOverrideOutputChannelFactory = new OutputChannelFactory()
+              {
+                @Override
+                public OutputChannel openChannel(int expectedZero) throws IOException
+                {
+                  if (expectedZero != 0) {
+                    throw new ISE("Unexpected part [%s]", expectedZero);
+                  }
+
+                  return outputChannelFactory.openChannel(channel.getPartitionNumber());
+                }
+
+                @Override
+                public PartitionedOutputChannel openPartitionedChannel(String name, boolean deleteAfterRead)
+                {
+                  throw new UnsupportedOperationException();
+                }
+
+                @Override
+                public OutputChannel openNilChannel(int expectedZero)
+                {
+                  if (expectedZero != 0) {
+                    throw new ISE("Unexpected part [%s]", expectedZero);
+                  }
+
+                  return outputChannelFactory.openNilChannel(channel.getPartitionNumber());
+                }
+              };
+
+              // Chain futures so we only sort one partition at a time.
+              nextFuture = Futures.transform(
+                  nextFuture,
+                  (AsyncFunction<OutputChannel, OutputChannel>) ignored -> {
+                    final SuperSorter sorter = new SuperSorter(
+                        Collections.singletonList(channel.getReadableChannel()),
+                        stageDefinition.getFrameReader(),
+                        stageDefinition.getSortKey(),
+                        Futures.immediateFuture(ClusterByPartitions.oneUniversalPartition()),
+                        exec,
+                        partitionOverrideOutputChannelFactory,
+                        makeSuperSorterIntermediateOutputChannelFactory(
+                            frameContext,
+                            stageDefinition.getStageNumber(),
+                            sorterTmpDir
+                        ),
+                        1,
+                        2,
+                        -1,
+                        cancellationId,
+
+                        // Tracker is not actually tracked, since it doesn't quite fit into the way we report counters.
+                        // There's a single SuperSorterProgressTrackerCounter per worker, but workers that do local
+                        // sorting have a SuperSorter per partition.
+                        new SuperSorterProgressTracker()
+                    );
+
+                    return FutureUtils.transform(sorter.run(), r -> Iterables.getOnlyElement(r.getAllChannels()));
+                  }
+              );
+
+              sortedChannelFutures.add(nextFuture);
+            }
+
+            return FutureUtils.transform(
+                Futures.allAsList(sortedChannelFutures),
+                sortedChannels -> new ResultAndChannels<>(
+                    Futures.immediateFuture(null),
+                    OutputChannels.wrap(sortedChannels)
+                )
+            );
+          }
+      );
+    }
+
+    /**
+     * Return the (future) output channels for this pipeline.
+     */
+    public ListenableFuture<OutputChannels> build()
+    {
+      if (pipelineFuture == null) {
+        throw new ISE("Not initialized");
+      }
+
+      return Futures.transform(
+          pipelineFuture,
+          (AsyncFunction<ResultAndChannels<?>, OutputChannels>) resultAndChannels ->
+              Futures.transform(
+                  resultAndChannels.getResultFuture(),
+                  (Function<Object, OutputChannels>) input -> {
+                    sanityCheckOutputChannels(resultAndChannels.getOutputChannels());
+                    return resultAndChannels.getOutputChannels();
+                  }
+              )
+      );
+    }
+
+    /**
+     * Adds {@link KeyStatisticsCollectionProcessor}. Called by {@link #gatherResultKeyStatisticsIfNeeded()}.
+     */
+    private ResultAndChannels<?> gatherResultKeyStatistics(final OutputChannels channels)
+    {
+      final StageDefinition stageDefinition = kernel.getStageDefinition();
+      final List<OutputChannel> retVal = new ArrayList<>();
+      final List<KeyStatisticsCollectionProcessor> processors = new ArrayList<>();
+
+      for (final OutputChannel outputChannel : channels.getAllChannels()) {
+        final BlockingQueueFrameChannel channel = BlockingQueueFrameChannel.minimal();
+        retVal.add(OutputChannel.readOnly(channel.readable(), outputChannel.getPartitionNumber()));
+
+        processors.add(
+            new KeyStatisticsCollectionProcessor(
+                outputChannel.getReadableChannel(),
+                channel.writable(),
+                stageDefinition.getFrameReader(),
+                stageDefinition.getClusterBy(),
+                stageDefinition.createResultKeyStatisticsCollector(
+                    frameContext.memoryParameters().getPartitionStatisticsMaxRetainedBytes()
+                )
+            )
+        );
+      }
+
+      final ListenableFuture<ClusterByStatisticsCollector> clusterByStatisticsCollectorFuture =
+          exec.runAllFully(
+              Sequences.simple(processors),
+              stageDefinition.createResultKeyStatisticsCollector(
+                  frameContext.memoryParameters().getPartitionStatisticsMaxRetainedBytes()
+              ),
+              ClusterByStatisticsCollector::addAll,
+              // Run all processors simultaneously. They are lightweight and this keeps things moving.
+              processors.size(),
+              Bouncer.unlimited(),
+              cancellationId
+          );
+
+      Futures.addCallback(
+          clusterByStatisticsCollectorFuture,
+          new FutureCallback<ClusterByStatisticsCollector>()
+          {
+            @Override
+            public void onSuccess(final ClusterByStatisticsCollector result)
+            {
+              kernelManipulationQueue.add(
+                  holder ->
+                      holder.getStageKernelMap().get(stageDefinition.getId())
+                            .setResultKeyStatisticsSnapshot(result.snapshot())
+              );
+            }
+
+            @Override
+            public void onFailure(Throwable t)
+            {
+              kernelManipulationQueue.add(
+                  holder -> {
+                    log.noStackTrace()
+                       .warn(t, "Failed to gather clusterBy statistics for stage [%s]", stageDefinition.getId());
+                    holder.getStageKernelMap().get(stageDefinition.getId()).fail(t);
+                  }
+              );
+            }
+          }
+      );
+
+      return new ResultAndChannels<>(
+          clusterByStatisticsCollectorFuture,
+          OutputChannels.wrap(retVal)
+      );
+    }
+
+    /**
+     * Update the {@link #pipelineFuture}.
+     */
+    private void push(final ExceptionalFunction<ResultAndChannels<?>, ResultAndChannels<?>> fn)
+    {
+      pushAsync(
+          channels ->
+              Futures.immediateFuture(fn.apply(channels))
+      );
+    }
+
+    /**
+     * Update the {@link #pipelineFuture} asynchronously.
+     */
+    private void pushAsync(final ExceptionalFunction<ResultAndChannels<?>, ListenableFuture<ResultAndChannels<?>>> fn)
+    {
+      if (pipelineFuture == null) {
+        throw new ISE("Not initialized");
+      }
+
+      pipelineFuture = FutureUtils.transform(
+          Futures.transform(
+              pipelineFuture,
+              new AsyncFunction<ResultAndChannels<?>, ResultAndChannels<?>>()
+              {
+                @Override
+                public ListenableFuture<ResultAndChannels<?>> apply(ResultAndChannels<?> t) throws Exception
+                {
+                  return fn.apply(t);
+                }
+              }
+          ),
+          resultAndChannels -> new ResultAndChannels<>(
+              resultAndChannels.getResultFuture(),
+              resultAndChannels.getOutputChannels().readOnly()
+          )
+      );
+    }
+
+    /**
+     * Verifies there is exactly one channel per partition.
+     */
+    private void sanityCheckOutputChannels(final OutputChannels outputChannels)
+    {
+      for (int partitionNumber : outputChannels.getPartitionNumbers()) {
+        final List<OutputChannel> outputChannelsForPartition =
+            outputChannels.getChannelsForPartition(partitionNumber);
+
+        Preconditions.checkState(partitionNumber >= 0, "Expected partitionNumber >= 0, but got [%s]", partitionNumber);
+        Preconditions.checkState(
+            outputChannelsForPartition.size() == 1,
+            "Expected one channel for partition [%s], but got [%s]",
+            partitionNumber,
+            outputChannelsForPartition.size()
+        );
+      }
+    }
+  }
+
   private class KernelHolder
   {
     private boolean done = false;
@@ -1428,4 +1803,9 @@ public class WorkerImpl implements Worker
       return outputChannels;
     }
   }
+
+  private interface ExceptionalFunction<T, R>
+  {
+    R apply(T t) throws Exception;
+  }
 }
diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerMemoryParameters.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerMemoryParameters.java
index 4c038f921d..058877f97e 100644
--- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerMemoryParameters.java
+++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerMemoryParameters.java
@@ -31,6 +31,7 @@ import org.apache.druid.msq.indexing.error.NotEnoughMemoryFault;
 import org.apache.druid.msq.indexing.error.TooManyWorkersFault;
 import org.apache.druid.msq.input.InputSpecs;
 import org.apache.druid.msq.kernel.QueryDefinition;
+import org.apache.druid.msq.kernel.StageDefinition;
 import org.apache.druid.msq.statistics.ClusterByStatisticsCollectorImpl;
 import org.apache.druid.query.lookup.LookupExtractor;
 import org.apache.druid.query.lookup.LookupExtractorFactoryContainer;
@@ -51,7 +52,7 @@ import java.util.Objects;
  * entirely on server configuration; this makes the calculation robust to different queries running simultaneously in
  * the same JVM.
  *
- * Then, we split up the resources for each bundle in two different ways: one assuming it'll be used for a
+ * Within each bundle, we split up memory in two different ways: one assuming it'll be used for a
  * {@link org.apache.druid.frame.processor.SuperSorter}, and one assuming it'll be used for a regular
  * processor. Callers can then use whichever set of allocations makes sense. (We assume no single bundle
  * will be used for both purposes.)
@@ -166,6 +167,7 @@ public class WorkerMemoryParameters
         computeNumWorkersInJvm(injector),
         computeNumProcessorsInJvm(injector),
         0,
+        0,
         totalLookupFootprint
     );
   }
@@ -179,19 +181,27 @@ public class WorkerMemoryParameters
       final int stageNumber
   )
   {
-    final IntSet inputStageNumbers =
-        InputSpecs.getStageNumbers(queryDef.getStageDefinition(stageNumber).getInputSpecs());
+    final StageDefinition stageDef = queryDef.getStageDefinition(stageNumber);
+    final IntSet inputStageNumbers = InputSpecs.getStageNumbers(stageDef.getInputSpecs());
     final int numInputWorkers =
         inputStageNumbers.intStream()
                          .map(inputStageNumber -> queryDef.getStageDefinition(inputStageNumber).getMaxWorkerCount())
                          .sum();
     long totalLookupFootprint = computeTotalLookupFootprint(injector);
 
+    final int numHashOutputPartitions;
+    if (stageDef.doesShuffle() && stageDef.getShuffleSpec().kind().isHash()) {
+      numHashOutputPartitions = stageDef.getShuffleSpec().partitionCount();
+    } else {
+      numHashOutputPartitions = 0;
+    }
+
     return createInstance(
         Runtime.getRuntime().maxMemory(),
         computeNumWorkersInJvm(injector),
         computeNumProcessorsInJvm(injector),
         numInputWorkers,
+        numHashOutputPartitions,
         totalLookupFootprint
     );
   }
@@ -206,15 +216,18 @@ public class WorkerMemoryParameters
    * @param numWorkersInJvm           number of workers that can run concurrently in this JVM. Generally equal to
    *                                  the task capacity.
    * @param numProcessingThreadsInJvm size of the processing thread pool in the JVM.
-   * @param numInputWorkers           number of workers across input stages that need to be merged together.
-   * @param totalLookUpFootprint      estimated size of the lookups loaded by the process.
+   * @param numInputWorkers           total number of workers across all input stages.
+   * @param numHashOutputPartitions   total number of output partitions, if using hash partitioning; zero if not using
+   *                                  hash partitioning.
+   * @param totalLookupFootprint      estimated size of the lookups loaded by the process.
    */
   public static WorkerMemoryParameters createInstance(
       final long maxMemoryInJvm,
       final int numWorkersInJvm,
       final int numProcessingThreadsInJvm,
       final int numInputWorkers,
-      final long totalLookUpFootprint
+      final int numHashOutputPartitions,
+      final long totalLookupFootprint
   )
   {
     Preconditions.checkArgument(maxMemoryInJvm > 0, "Max memory passed: [%s] should be > 0", maxMemoryInJvm);
@@ -226,18 +239,25 @@ public class WorkerMemoryParameters
     );
     Preconditions.checkArgument(numInputWorkers >= 0, "Number of input workers: [%s] should be >=0", numInputWorkers);
     Preconditions.checkArgument(
-        totalLookUpFootprint >= 0,
+        totalLookupFootprint >= 0,
         "Lookup memory footprint: [%s] should be >= 0",
-        totalLookUpFootprint
+        totalLookupFootprint
     );
-    final long usableMemoryInJvm = computeUsableMemoryInJvm(maxMemoryInJvm, totalLookUpFootprint);
+    final long usableMemoryInJvm = computeUsableMemoryInJvm(maxMemoryInJvm, totalLookupFootprint);
     final long workerMemory = memoryPerWorker(usableMemoryInJvm, numWorkersInJvm);
     final long bundleMemory = memoryPerBundle(usableMemoryInJvm, numWorkersInJvm, numProcessingThreadsInJvm);
     final long bundleMemoryForInputChannels = memoryNeededForInputChannels(numInputWorkers);
-    final long bundleMemoryForProcessing = bundleMemory - bundleMemoryForInputChannels;
+    final long bundleMemoryForHashPartitioning = memoryNeededForHashPartitioning(numHashOutputPartitions);
+    final long bundleMemoryForProcessing =
+        bundleMemory - bundleMemoryForInputChannels - bundleMemoryForHashPartitioning;
 
     if (bundleMemoryForProcessing < PROCESSING_MINIMUM_BYTES) {
-      final int maxWorkers = computeMaxWorkers(usableMemoryInJvm, numWorkersInJvm, numProcessingThreadsInJvm);
+      final int maxWorkers = computeMaxWorkers(
+          usableMemoryInJvm,
+          numWorkersInJvm,
+          numProcessingThreadsInJvm,
+          numHashOutputPartitions
+      );
 
       if (maxWorkers > 0) {
         throw new MSQException(new TooManyWorkersFault(numInputWorkers, Math.min(Limits.MAX_WORKERS, maxWorkers)));
@@ -250,7 +270,7 @@ public class WorkerMemoryParameters
                         numWorkersInJvm,
                         numProcessingThreadsInJvm,
                         PROCESSING_MINIMUM_BYTES + BUFFER_BYTES_FOR_ESTIMATION + bundleMemoryForInputChannels
-                    ), totalLookUpFootprint),
+                    ), totalLookupFootprint),
                 maxMemoryInJvm,
                 usableMemoryInJvm,
                 numWorkersInJvm,
@@ -271,7 +291,7 @@ public class WorkerMemoryParameters
                       numWorkersInJvm,
                       (MIN_SUPER_SORTER_FRAMES + BUFFER_BYTES_FOR_ESTIMATION) * LARGE_FRAME_SIZE
                   ),
-                  totalLookUpFootprint
+                  totalLookupFootprint
               ),
               maxMemoryInJvm,
               usableMemoryInJvm,
@@ -393,13 +413,19 @@ public class WorkerMemoryParameters
   static int computeMaxWorkers(
       final long usableMemoryInJvm,
       final int numWorkersInJvm,
-      final int numProcessingThreadsInJvm
+      final int numProcessingThreadsInJvm,
+      final int numHashOutputPartitions
   )
   {
     final long bundleMemory = memoryPerBundle(usableMemoryInJvm, numWorkersInJvm, numProcessingThreadsInJvm);
 
-    // Inverse of memoryNeededForInputChannels.
-    return Math.max(0, Ints.checkedCast((bundleMemory - PROCESSING_MINIMUM_BYTES) / STANDARD_FRAME_SIZE - 1));
+    // Compute number of workers that gives us PROCESSING_MINIMUM_BYTES of memory per bundle, while accounting for
+    // memoryNeededForInputChannels + memoryNeededForHashPartitioning.
+    final int isHashing = numHashOutputPartitions > 0 ? 1 : 0;
+    return Math.max(
+        0,
+        Ints.checkedCast((bundleMemory - PROCESSING_MINIMUM_BYTES) / ((long) STANDARD_FRAME_SIZE * (1 + isHashing)) - 1)
+    );
   }
 
   /**
@@ -499,17 +525,29 @@ public class WorkerMemoryParameters
     return (long) STANDARD_FRAME_SIZE * (numInputWorkers + 1);
   }
 
+  private static long memoryNeededForHashPartitioning(final int numOutputPartitions)
+  {
+    // One standard frame for each processor output.
+    // May be zero, since numOutputPartitions is zero if not using hash partitioning.
+    return (long) STANDARD_FRAME_SIZE * numOutputPartitions;
+  }
+
   /**
-   * Amount of heap memory available for our usage. Any computation changes done to this method should also be done in its corresponding method {@link WorkerMemoryParameters#calculateSuggestedMinMemoryFromUsableMemory}
+   * Amount of heap memory available for our usage. Any computation changes done to this method should also be done in
+   * its corresponding method {@link WorkerMemoryParameters#calculateSuggestedMinMemoryFromUsableMemory}
    */
   private static long computeUsableMemoryInJvm(final long maxMemory, final long totalLookupFootprint)
   {
-    // since lookups are essentially in memory hashmap's, the object overhead is trivial hence its subtracted prior to usable memory calculations.
-    return (long) ((maxMemory - totalLookupFootprint) * USABLE_MEMORY_FRACTION);
+    // Always report at least one byte, to simplify the math in createInstance.
+    return Math.max(
+        1,
+        (long) ((maxMemory - totalLookupFootprint) * USABLE_MEMORY_FRACTION)
+    );
   }
 
   /**
-   * Estimate amount of heap memory for the given workload to use in case usable memory is provided. This method is used for better exception messages when {@link NotEnoughMemoryFault} is thrown.
+   * Estimate amount of heap memory for the given workload to use in case usable memory is provided. This method is used
+   * for better exception messages when {@link NotEnoughMemoryFault} is thrown.
    */
   private static long calculateSuggestedMinMemoryFromUsableMemory(long usuableMemeory, final long totalLookupFootprint)
   {
diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerSketchFetcher.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerSketchFetcher.java
index 5a8200c403..c1c6219f5f 100644
--- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerSketchFetcher.java
+++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerSketchFetcher.java
@@ -267,7 +267,6 @@ public class WorkerSketchFetcher implements AutoCloseable
                 ),
                 retryOperation
             );
-
           });
         }
       }
diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/guice/MSQIndexingModule.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/guice/MSQIndexingModule.java
index dccd42bad9..ecaa6c9f10 100644
--- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/guice/MSQIndexingModule.java
+++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/guice/MSQIndexingModule.java
@@ -61,6 +61,7 @@ import org.apache.druid.msq.indexing.error.TooManyClusteredByColumnsFault;
 import org.apache.druid.msq.indexing.error.TooManyColumnsFault;
 import org.apache.druid.msq.indexing.error.TooManyInputFilesFault;
 import org.apache.druid.msq.indexing.error.TooManyPartitionsFault;
+import org.apache.druid.msq.indexing.error.TooManyRowsWithSameKeyFault;
 import org.apache.druid.msq.indexing.error.TooManyWarningsFault;
 import org.apache.druid.msq.indexing.error.TooManyWorkersFault;
 import org.apache.druid.msq.indexing.error.UnknownFault;
@@ -78,6 +79,7 @@ import org.apache.druid.msq.input.table.TableInputSpec;
 import org.apache.druid.msq.kernel.NilExtraInfoHolder;
 import org.apache.druid.msq.querykit.InputNumberDataSource;
 import org.apache.druid.msq.querykit.common.OffsetLimitFrameProcessorFactory;
+import org.apache.druid.msq.querykit.common.SortMergeJoinFrameProcessorFactory;
 import org.apache.druid.msq.querykit.groupby.GroupByPostShuffleFrameProcessorFactory;
 import org.apache.druid.msq.querykit.groupby.GroupByPreShuffleFrameProcessorFactory;
 import org.apache.druid.msq.querykit.scan.ScanQueryFrameProcessorFactory;
@@ -118,6 +120,7 @@ public class MSQIndexingModule implements DruidModule
       TooManyColumnsFault.class,
       TooManyInputFilesFault.class,
       TooManyPartitionsFault.class,
+      TooManyRowsWithSameKeyFault.class,
       TooManyWarningsFault.class,
       TooManyWorkersFault.class,
       TooManyAttemptsForJob.class,
@@ -150,6 +153,7 @@ public class MSQIndexingModule implements DruidModule
         ScanQueryFrameProcessorFactory.class,
         GroupByPreShuffleFrameProcessorFactory.class,
         GroupByPostShuffleFrameProcessorFactory.class,
+        SortMergeJoinFrameProcessorFactory.class,
         OffsetLimitFrameProcessorFactory.class,
         NilExtraInfoHolder.class,
 
diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/CountingWritableFrameChannel.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/CountingWritableFrameChannel.java
index c0df510732..1b4963ac67 100644
--- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/CountingWritableFrameChannel.java
+++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/CountingWritableFrameChannel.java
@@ -63,6 +63,12 @@ public class CountingWritableFrameChannel implements WritableFrameChannel
     baseChannel.close();
   }
 
+  @Override
+  public boolean isClosed()
+  {
+    return baseChannel.isClosed();
+  }
+
   @Override
   public ListenableFuture<?> writabilityFuture()
   {
diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/InputChannelsImpl.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/InputChannelsImpl.java
index 1c01d480e3..3414d37a54 100644
--- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/InputChannelsImpl.java
+++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/InputChannelsImpl.java
@@ -22,11 +22,12 @@ package org.apache.druid.msq.indexing;
 import com.google.common.collect.Iterables;
 import org.apache.druid.frame.FrameType;
 import org.apache.druid.frame.allocation.MemoryAllocator;
+import org.apache.druid.frame.allocation.SingleMemoryAllocatorFactory;
 import org.apache.druid.frame.channel.BlockingQueueFrameChannel;
 import org.apache.druid.frame.channel.ReadableFrameChannel;
 import org.apache.druid.frame.key.ClusterBy;
 import org.apache.druid.frame.processor.FrameChannelMerger;
-import org.apache.druid.frame.processor.FrameChannelMuxer;
+import org.apache.druid.frame.processor.FrameChannelMixer;
 import org.apache.druid.frame.processor.FrameProcessorExecutor;
 import org.apache.druid.frame.read.FrameReader;
 import org.apache.druid.frame.write.FrameWriters;
@@ -90,8 +91,8 @@ public class InputChannelsImpl implements InputChannels
   {
     final StageDefinition stageDef = queryDefinition.getStageDefinition(stagePartition.getStageId());
     final ReadablePartition readablePartition = readablePartitionMap.get(stagePartition);
-    final ClusterBy inputClusterBy = stageDef.getClusterBy();
-    final boolean isSorted = inputClusterBy.getBucketByCount() != inputClusterBy.getColumns().size();
+    final ClusterBy clusterBy = stageDef.getClusterBy();
+    final boolean isSorted = clusterBy.sortable() && (clusterBy.getColumns().size() - clusterBy.getBucketByCount() > 0);
 
     if (isSorted) {
       return openSorted(stageDef, readablePartition);
@@ -129,13 +130,13 @@ public class InputChannelsImpl implements InputChannels
           queueChannel.writable(),
           FrameWriters.makeFrameWriterFactory(
               FrameType.ROW_BASED,
-              allocatorMaker.get(),
+              new SingleMemoryAllocatorFactory(allocatorMaker.get()),
               stageDefinition.getFrameReader().signature(),
 
               // No sortColumns, because FrameChannelMerger generates frames that are sorted all on its own
               Collections.emptyList()
           ),
-          stageDefinition.getClusterBy(),
+          stageDefinition.getSortKey(),
           null,
           -1
       );
@@ -163,7 +164,7 @@ public class InputChannelsImpl implements InputChannels
       return Iterables.getOnlyElement(channels);
     } else {
       final BlockingQueueFrameChannel queueChannel = BlockingQueueFrameChannel.minimal();
-      final FrameChannelMuxer muxer = new FrameChannelMuxer(channels, queueChannel.writable());
+      final FrameChannelMixer muxer = new FrameChannelMixer(channels, queueChannel.writable());
 
       // Discard future, since there is no need to keep it. We aren't interested in its return value. If it fails,
       // downstream processors are notified through fail(e) on in-memory channels. If we need to cancel it, we use
diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/BroadcastTablesTooLargeFault.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/BroadcastTablesTooLargeFault.java
index edc33865cf..1a598547a9 100644
--- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/BroadcastTablesTooLargeFault.java
+++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/BroadcastTablesTooLargeFault.java
@@ -22,6 +22,8 @@ package org.apache.druid.msq.indexing.error;
 import com.fasterxml.jackson.annotation.JsonCreator;
 import com.fasterxml.jackson.annotation.JsonProperty;
 import com.fasterxml.jackson.annotation.JsonTypeName;
+import org.apache.druid.sql.calcite.planner.JoinAlgorithm;
+import org.apache.druid.sql.calcite.planner.PlannerContext;
 
 import java.util.Objects;
 
@@ -35,9 +37,14 @@ public class BroadcastTablesTooLargeFault extends BaseMSQFault
   @JsonCreator
   public BroadcastTablesTooLargeFault(@JsonProperty("maxBroadcastTablesSize") final long maxBroadcastTablesSize)
   {
-    super(CODE,
-          "Size of the broadcast tables exceed the memory reserved for them (memory reserved for broadcast tables = %d bytes)",
-          maxBroadcastTablesSize
+    super(
+        CODE,
+        "Size of broadcast tables in JOIN exceeds reserved memory limit "
+        + "(memory reserved for broadcast tables = %d bytes). "
+        + "Increase available memory, or set %s: %s in query context to use a shuffle-based join.",
+        maxBroadcastTablesSize,
+        PlannerContext.CTX_SQL_JOIN_ALGORITHM,
+        JoinAlgorithm.SORT_MERGE.toString()
     );
     this.maxBroadcastTablesSize = maxBroadcastTablesSize;
   }
diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/BroadcastTablesTooLargeFault.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/TooManyRowsWithSameKeyFault.java
similarity index 53%
copy from extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/BroadcastTablesTooLargeFault.java
copy to extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/TooManyRowsWithSameKeyFault.java
index edc33865cf..21fa363af8 100644
--- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/BroadcastTablesTooLargeFault.java
+++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/TooManyRowsWithSameKeyFault.java
@@ -23,29 +23,54 @@ import com.fasterxml.jackson.annotation.JsonCreator;
 import com.fasterxml.jackson.annotation.JsonProperty;
 import com.fasterxml.jackson.annotation.JsonTypeName;
 
+import java.util.List;
 import java.util.Objects;
 
-@JsonTypeName(BroadcastTablesTooLargeFault.CODE)
-public class BroadcastTablesTooLargeFault extends BaseMSQFault
+@JsonTypeName(TooManyRowsWithSameKeyFault.CODE)
+public class TooManyRowsWithSameKeyFault extends BaseMSQFault
 {
-  static final String CODE = "BroadcastTablesTooLarge";
+  static final String CODE = "TooManyRowsWithSameKey";
 
-  private final long maxBroadcastTablesSize;
+  private final List<Object> key;
+  private final long numBytes;
+  private final long maxBytes;
 
   @JsonCreator
-  public BroadcastTablesTooLargeFault(@JsonProperty("maxBroadcastTablesSize") final long maxBroadcastTablesSize)
+  public TooManyRowsWithSameKeyFault(
+      @JsonProperty("key") final List<Object> key,
+      @JsonProperty("numBytes") final long numBytes,
+      @JsonProperty("maxBytes") final long maxBytes
+  )
   {
-    super(CODE,
-          "Size of the broadcast tables exceed the memory reserved for them (memory reserved for broadcast tables = %d bytes)",
-          maxBroadcastTablesSize
+    super(
+        CODE,
+        "Too many rows with the same key during sort-merge join (bytes buffered = %,d; limit = %,d). Key: %s",
+        numBytes,
+        maxBytes,
+        key
     );
-    this.maxBroadcastTablesSize = maxBroadcastTablesSize;
+
+    this.key = key;
+    this.numBytes = numBytes;
+    this.maxBytes = maxBytes;
+  }
+
+  @JsonProperty
+  public List<Object> getKey()
+  {
+    return key;
+  }
+
+  @JsonProperty
+  public long getNumBytes()
+  {
+    return numBytes;
   }
 
   @JsonProperty
-  public long getMaxBroadcastTablesSize()
+  public long getMaxBytes()
   {
-    return maxBroadcastTablesSize;
+    return maxBytes;
   }
 
   @Override
@@ -60,13 +85,13 @@ public class BroadcastTablesTooLargeFault extends BaseMSQFault
     if (!super.equals(o)) {
       return false;
     }
-    BroadcastTablesTooLargeFault that = (BroadcastTablesTooLargeFault) o;
-    return maxBroadcastTablesSize == that.maxBroadcastTablesSize;
+    TooManyRowsWithSameKeyFault that = (TooManyRowsWithSameKeyFault) o;
+    return numBytes == that.numBytes && maxBytes == that.maxBytes && Objects.equals(key, that.key);
   }
 
   @Override
   public int hashCode()
   {
-    return Objects.hash(super.hashCode(), maxBroadcastTablesSize);
+    return Objects.hash(super.hashCode(), key, numBytes, maxBytes);
   }
 }
diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/InputSlices.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/InputSlices.java
index 98e28b0754..028f1b5bd4 100644
--- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/InputSlices.java
+++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/InputSlices.java
@@ -19,12 +19,20 @@
 
 package org.apache.druid.msq.input;
 
+import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
+import it.unimi.dsi.fastutil.ints.Int2ObjectRBTreeMap;
 import it.unimi.dsi.fastutil.ints.IntSet;
+import org.apache.druid.frame.channel.ReadableNilFrameChannel;
+import org.apache.druid.java.util.common.ISE;
+import org.apache.druid.msq.counters.CounterTracker;
 import org.apache.druid.msq.input.stage.ReadablePartitions;
 import org.apache.druid.msq.input.stage.StageInputSlice;
+import org.apache.druid.msq.kernel.StagePartition;
 
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.List;
+import java.util.function.Consumer;
 
 public class InputSlices
 {
@@ -33,6 +41,10 @@ public class InputSlices
     // No instantiation.
   }
 
+  /**
+   * Combines all {@link StageInputSlice#getPartitions()} from the input slices that are {@link StageInputSlice}.
+   * Ignores other types of input slices.
+   */
   public static ReadablePartitions allReadablePartitions(final List<InputSlice> slices)
   {
     final List<ReadablePartitions> partitionsList = new ArrayList<>();
@@ -46,6 +58,10 @@ public class InputSlices
     return ReadablePartitions.combine(partitionsList);
   }
 
+  /**
+   * Sum of {@link InputSliceReader#numReadableInputs(InputSlice)} across all input slices that are _not_ present
+   * in "broadcastInputs".
+   */
   public static int getNumNonBroadcastReadableInputs(
       final List<InputSlice> slices,
       final InputSliceReader reader,
@@ -62,4 +78,70 @@ public class InputSlices
 
     return numInputs;
   }
+
+  /**
+   * Calls {@link InputSliceReader#attach} on all "slices", which must all be {@link NilInputSlice} or
+   * {@link StageInputSlice}, and collects like-numbered partitions.
+   *
+   * The returned map is keyed by partition number. Each value is a list of inputs of the
+   * same length as "slices", and in the same order. i.e., the first ReadableInput in each list corresponds to the
+   * first provided {@link InputSlice}.
+   *
+   * "Missing" partitions -- which occur when one slice has no data for a given partition -- are replaced with
+   * {@link ReadableInput} based on {@link ReadableNilFrameChannel}, with no {@link StagePartition}.
+   *
+   * @throws IllegalStateException if any slices are not {@link StageInputSlice}
+   */
+  public static Int2ObjectMap<List<ReadableInput>> attachAndCollectPartitions(
+      final List<InputSlice> slices,
+      final InputSliceReader reader,
+      final CounterTracker counters,
+      final Consumer<Throwable> warningPublisher
+  )
+  {
+    // Input number -> ReadableInputs.
+    final List<ReadableInputs> inputsByInputNumber = new ArrayList<>();
+
+    for (final InputSlice slice : slices) {
+      if (slice instanceof NilInputSlice) {
+        inputsByInputNumber.add(null);
+      } else if (slice instanceof StageInputSlice) {
+        final ReadableInputs inputs = reader.attach(inputsByInputNumber.size(), slice, counters, warningPublisher);
+        inputsByInputNumber.add(inputs);
+      } else {
+        throw new ISE("Slice [%s] is not a 'stage' slice", slice);
+      }
+    }
+
+    // Populate the result map.
+    final Int2ObjectMap<List<ReadableInput>> retVal = new Int2ObjectRBTreeMap<>();
+
+    for (int inputNumber = 0; inputNumber < slices.size(); inputNumber++) {
+      for (final ReadableInput input : inputsByInputNumber.get(inputNumber)) {
+        if (input != null) {
+          final int partitionNumber = input.getStagePartition().getPartitionNumber();
+          retVal.computeIfAbsent(partitionNumber, ignored -> Arrays.asList(new ReadableInput[slices.size()]))
+                .set(inputNumber, input);
+        }
+      }
+    }
+
+    // Fill in all nulls with NilInputSlice.
+    for (Int2ObjectMap.Entry<List<ReadableInput>> entry : retVal.int2ObjectEntrySet()) {
+      for (int inputNumber = 0; inputNumber < entry.getValue().size(); inputNumber++) {
+        if (entry.getValue().get(inputNumber) == null) {
+          entry.getValue().set(
+              inputNumber,
+              ReadableInput.channel(
+                  ReadableNilFrameChannel.INSTANCE,
+                  inputsByInputNumber.get(inputNumber).frameReader(),
+                  null
+              )
+          );
+        }
+      }
+    }
+
+    return retVal;
+  }
 }
diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/ReadableInput.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/ReadableInput.java
index ada4b71915..b125dcfe8f 100644
--- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/ReadableInput.java
+++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/ReadableInput.java
@@ -40,7 +40,7 @@ public class ReadableInput
   private final SegmentWithDescriptor segment;
 
   @Nullable
-  private final ReadableFrameChannel inputChannel;
+  private final ReadableFrameChannel channel;
 
   @Nullable
   private final FrameReader frameReader;
@@ -56,7 +56,7 @@ public class ReadableInput
   )
   {
     this.segment = segment;
-    this.inputChannel = channel;
+    this.channel = channel;
     this.frameReader = frameReader;
     this.stagePartition = stagePartition;
 
@@ -65,48 +65,107 @@ public class ReadableInput
     }
   }
 
+  /**
+   * Create an input associated with a physical segment.
+   *
+   * @param segment the segment
+   */
   public static ReadableInput segment(final SegmentWithDescriptor segment)
   {
-    return new ReadableInput(segment, null, null, null);
+    return new ReadableInput(Preconditions.checkNotNull(segment, "segment"), null, null, null);
   }
 
+  /**
+   * Create an input associated with a channel.
+   *
+   * @param channel        the channel
+   * @param frameReader    reader for the channel
+   * @param stagePartition stage-partition associated with the channel, if meaningful. May be null if this channel
+   *                       does not correspond to any one particular stage-partition.
+   */
   public static ReadableInput channel(
-      final ReadableFrameChannel inputChannel,
+      final ReadableFrameChannel channel,
       final FrameReader frameReader,
-      final StagePartition stagePartition
+      @Nullable final StagePartition stagePartition
   )
   {
-    return new ReadableInput(null, inputChannel, frameReader, stagePartition);
+    return new ReadableInput(
+        null,
+        Preconditions.checkNotNull(channel, "channel"),
+        Preconditions.checkNotNull(frameReader, "frameReader"),
+        stagePartition
+    );
   }
 
+  /**
+   * Whether this input is a segment (from {@link #segment(SegmentWithDescriptor)}.
+   */
   public boolean hasSegment()
   {
     return segment != null;
   }
 
+  /**
+   * Whether this input is a channel (from {@link #channel(ReadableFrameChannel, FrameReader, StagePartition)}.
+   */
   public boolean hasChannel()
   {
-    return inputChannel != null;
+    return channel != null;
   }
 
+  /**
+   * The segment for this input. Only valid if {@link #hasSegment()}.
+   */
   public SegmentWithDescriptor getSegment()
   {
-    return Preconditions.checkNotNull(segment, "segment");
+    checkIsSegment();
+    return segment;
   }
 
+  /**
+   * The channel for this input. Only valid if {@link #hasChannel()}.
+   */
   public ReadableFrameChannel getChannel()
   {
-    return Preconditions.checkNotNull(inputChannel, "channel");
+    checkIsChannel();
+    return channel;
   }
 
+  /**
+   * The frame reader for this input. Only valid if {@link #hasChannel()}.
+   */
   public FrameReader getChannelFrameReader()
   {
-    return Preconditions.checkNotNull(frameReader, "frameReader");
+    checkIsChannel();
+    return frameReader;
   }
 
-  @Nullable
+  /**
+   * The stage-partition this input. Only valid if {@link #hasChannel()}, and if a stage-partition was provided
+   * during construction. Throws {@link IllegalStateException} if no stage-partition was provided during construction.
+   */
   public StagePartition getStagePartition()
   {
+    checkIsChannel();
+
+    if (stagePartition == null) {
+      throw new ISE("Stage-partition is not set for this channel");
+    }
+
     return stagePartition;
   }
+
+  private void checkIsSegment()
+  {
+    if (!hasSegment()) {
+      throw new ISE("Not a channel input; cannot call this method");
+    }
+  }
+
+  private void checkIsChannel()
+  {
+    if (!hasChannel()) {
+      throw new ISE("Not a channel input; cannot call this method");
+    }
+  }
 }
diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/ReadableInputs.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/ReadableInputs.java
index 449b2b4295..3496ac3f3f 100644
--- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/ReadableInputs.java
+++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/ReadableInputs.java
@@ -62,6 +62,8 @@ public class ReadableInputs implements Iterable<ReadableInput>
 
   /**
    * Returns the {@link ReadableInput} as an Iterator.
+   *
+   * When this instance is channel-based ({@link #isChannelBased()}), inputs are returned in order of partition number.
    */
   @Override
   public Iterator<ReadableInput> iterator()
diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/table/SegmentWithDescriptor.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/table/SegmentWithDescriptor.java
index 1a94efa1c5..94109bc4a7 100644
--- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/table/SegmentWithDescriptor.java
+++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/table/SegmentWithDescriptor.java
@@ -27,6 +27,9 @@ import org.apache.druid.segment.Segment;
 import java.io.Closeable;
 import java.util.Objects;
 
+/**
+ * A holder for a physical segment.
+ */
 public class SegmentWithDescriptor implements Closeable
 {
   private final ResourceHolder<? extends Segment> segmentHolder;
@@ -41,20 +44,33 @@ public class SegmentWithDescriptor implements Closeable
     this.descriptor = Preconditions.checkNotNull(descriptor, "descriptor");
   }
 
+  /**
+   * The physical segment.
+   *
+   * Named "getOrLoad" because the segment may be held by an eager or lazy resource holder (i.e.
+   * {@link org.apache.druid.msq.querykit.LazyResourceHolder}). If the resource holder is lazy, the segment is acquired
+   * as part of the call to this method.
+   */
   public Segment getOrLoadSegment()
   {
     return segmentHolder.get();
   }
 
-  @Override
-  public void close()
+  /**
+   * The segment descriptor associated with this physical segment.
+   */
+  public SegmentDescriptor getDescriptor()
   {
-    segmentHolder.close();
+    return descriptor;
   }
 
-  public SegmentDescriptor getDescriptor()
+  /**
+   * Release resources used by the physical segment.
+   */
+  @Override
+  public void close()
   {
-    return descriptor;
+    segmentHolder.close();
   }
 
   @Override
diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/MaxCountShuffleSpec.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/GlobalSortMaxCountShuffleSpec.java
similarity index 65%
rename from extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/MaxCountShuffleSpec.java
rename to extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/GlobalSortMaxCountShuffleSpec.java
index f10d3b4ea8..e773fcb87a 100644
--- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/MaxCountShuffleSpec.java
+++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/GlobalSortMaxCountShuffleSpec.java
@@ -27,6 +27,7 @@ import org.apache.druid.frame.key.ClusterBy;
 import org.apache.druid.frame.key.ClusterByPartitions;
 import org.apache.druid.java.util.common.Either;
 import org.apache.druid.java.util.common.IAE;
+import org.apache.druid.java.util.common.ISE;
 import org.apache.druid.msq.statistics.ClusterByStatisticsCollector;
 
 import javax.annotation.Nullable;
@@ -35,54 +36,71 @@ import java.util.Objects;
 /**
  * Shuffle spec that generates up to a certain number of output partitions. Commonly used for shuffles between stages.
  */
-public class MaxCountShuffleSpec implements ShuffleSpec
+public class GlobalSortMaxCountShuffleSpec implements GlobalSortShuffleSpec
 {
+  public static final String TYPE = "maxCount";
+
   private final ClusterBy clusterBy;
-  private final int partitions;
+  private final int maxPartitions;
   private final boolean aggregate;
 
   @JsonCreator
-  public MaxCountShuffleSpec(
+  public GlobalSortMaxCountShuffleSpec(
       @JsonProperty("clusterBy") final ClusterBy clusterBy,
-      @JsonProperty("partitions") final int partitions,
+      @JsonProperty("partitions") final int maxPartitions,
       @JsonProperty("aggregate") final boolean aggregate
   )
   {
     this.clusterBy = Preconditions.checkNotNull(clusterBy, "clusterBy");
-    this.partitions = partitions;
+    this.maxPartitions = maxPartitions;
     this.aggregate = aggregate;
 
-    if (partitions < 1) {
+    if (maxPartitions < 1) {
       throw new IAE("Partition count must be at least 1");
     }
+
+    if (!clusterBy.sortable()) {
+      throw new IAE("ClusterBy key must be sortable");
+    }
+
+    if (clusterBy.getBucketByCount() > 0) {
+      // Only GlobalSortTargetSizeShuffleSpec supports bucket-by.
+      throw new IAE("Cannot bucket with %s partitioning", TYPE);
+    }
+  }
+
+  @Override
+  public ShuffleKind kind()
+  {
+    return ShuffleKind.GLOBAL_SORT;
   }
 
   @Override
   @JsonProperty("aggregate")
   @JsonInclude(JsonInclude.Include.NON_DEFAULT)
-  public boolean doesAggregateByClusterKey()
+  public boolean doesAggregate()
   {
     return aggregate;
   }
 
   @Override
-  public boolean needsStatistics()
+  public boolean mustGatherResultKeyStatistics()
   {
-    return partitions > 1 || clusterBy.getBucketByCount() > 0;
+    return maxPartitions > 1 || clusterBy.getBucketByCount() > 0;
   }
 
   @Override
-  public Either<Long, ClusterByPartitions> generatePartitions(
+  public Either<Long, ClusterByPartitions> generatePartitionsForGlobalSort(
       @Nullable final ClusterByStatisticsCollector collector,
       final int maxNumPartitions
   )
   {
-    if (!needsStatistics()) {
+    if (!mustGatherResultKeyStatistics()) {
       return Either.value(ClusterByPartitions.oneUniversalPartition());
-    } else if (partitions > maxNumPartitions) {
-      return Either.error((long) partitions);
+    } else if (maxPartitions > maxNumPartitions) {
+      return Either.error((long) maxPartitions);
     } else {
-      final ClusterByPartitions generatedPartitions = collector.generatePartitionsWithMaxCount(partitions);
+      final ClusterByPartitions generatedPartitions = collector.generatePartitionsWithMaxCount(maxPartitions);
       if (generatedPartitions.size() <= maxNumPartitions) {
         return Either.value(generatedPartitions);
       } else {
@@ -93,15 +111,21 @@ public class MaxCountShuffleSpec implements ShuffleSpec
 
   @Override
   @JsonProperty
-  public ClusterBy getClusterBy()
+  public ClusterBy clusterBy()
   {
     return clusterBy;
   }
 
-  @JsonProperty
-  int getPartitions()
+  @Override
+  public int partitionCount()
+  {
+    throw new ISE("Number of partitions not known for [%s].", kind());
+  }
+
+  @JsonProperty("partitions")
+  public int getMaxPartitions()
   {
-    return partitions;
+    return maxPartitions;
   }
 
   @Override
@@ -113,8 +137,8 @@ public class MaxCountShuffleSpec implements ShuffleSpec
     if (o == null || getClass() != o.getClass()) {
       return false;
     }
-    MaxCountShuffleSpec that = (MaxCountShuffleSpec) o;
-    return partitions == that.partitions
+    GlobalSortMaxCountShuffleSpec that = (GlobalSortMaxCountShuffleSpec) o;
+    return maxPartitions == that.maxPartitions
            && aggregate == that.aggregate
            && Objects.equals(clusterBy, that.clusterBy);
   }
@@ -122,7 +146,7 @@ public class MaxCountShuffleSpec implements ShuffleSpec
   @Override
   public int hashCode()
   {
-    return Objects.hash(clusterBy, partitions, aggregate);
+    return Objects.hash(clusterBy, maxPartitions, aggregate);
   }
 
   @Override
@@ -130,7 +154,7 @@ public class MaxCountShuffleSpec implements ShuffleSpec
   {
     return "MaxCountShuffleSpec{" +
            "clusterBy=" + clusterBy +
-           ", partitions=" + partitions +
+           ", partitions=" + maxPartitions +
            ", aggregate=" + aggregate +
            '}';
   }
diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/ShuffleSpec.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/GlobalSortShuffleSpec.java
similarity index 55%
copy from extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/ShuffleSpec.java
copy to extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/GlobalSortShuffleSpec.java
index 319e85850c..4d608c4fbe 100644
--- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/ShuffleSpec.java
+++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/GlobalSortShuffleSpec.java
@@ -19,9 +19,6 @@
 
 package org.apache.druid.msq.kernel;
 
-import com.fasterxml.jackson.annotation.JsonSubTypes;
-import com.fasterxml.jackson.annotation.JsonTypeInfo;
-import org.apache.druid.frame.key.ClusterBy;
 import org.apache.druid.frame.key.ClusterByPartitions;
 import org.apache.druid.java.util.common.Either;
 import org.apache.druid.msq.statistics.ClusterByStatisticsCollector;
@@ -29,43 +26,29 @@ import org.apache.druid.msq.statistics.ClusterByStatisticsCollector;
 import javax.annotation.Nullable;
 
 /**
- * Describes how outputs of a stage are shuffled. Property of {@link StageDefinition}.
- *
- * When the output of a stage is shuffled, it is globally sorted and partitioned according to the {@link ClusterBy}.
- * Hash-based (non-sorting) shuffle is not currently implemented.
+ * Additional methods for {@link ShuffleSpec} of kind {@link ShuffleKind#GLOBAL_SORT}.
  */
-@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, property = "type")
-@JsonSubTypes(value = {
-    @JsonSubTypes.Type(name = "maxCount", value = MaxCountShuffleSpec.class),
-    @JsonSubTypes.Type(name = "targetSize", value = TargetSizeShuffleSpec.class)
-})
-public interface ShuffleSpec
+public interface GlobalSortShuffleSpec extends ShuffleSpec
 {
   /**
-   * Clustering key that will determine how data are partitioned during the shuffle.
+   * Whether {@link #generatePartitionsForGlobalSort} needs a nonnull collector in order to do its work.
    */
-  ClusterBy getClusterBy();
-
-  /**
-   * Whether this stage aggregates by the clustering key or not.
-   */
-  boolean doesAggregateByClusterKey();
-
-  /**
-   * Whether {@link #generatePartitions} needs a nonnull collector.
-   */
-  boolean needsStatistics();
+  boolean mustGatherResultKeyStatistics();
 
   /**
    * Generates a set of partitions based on the provided statistics.
    *
-   * @param collector        must be nonnull if {@link #needsStatistics()} is true; may be null otherwise
+   * Only valid if {@link #kind()} is {@link ShuffleKind#GLOBAL_SORT}. Otherwise, throws {@link IllegalStateException}.
+   *
+   * @param collector        must be nonnull if {@link #mustGatherResultKeyStatistics()} is true; ignored otherwise
    * @param maxNumPartitions maximum number of partitions to generate
    *
    * @return either the partition assignment, or (as an error) a number of partitions, greater than maxNumPartitions,
    * that would be expected to be created
+   *
+   * @throws IllegalStateException if {@link #kind()} is not {@link ShuffleKind#GLOBAL_SORT}.
    */
-  Either<Long, ClusterByPartitions> generatePartitions(
+  Either<Long, ClusterByPartitions> generatePartitionsForGlobalSort(
       @Nullable ClusterByStatisticsCollector collector,
       int maxNumPartitions
   );
diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/TargetSizeShuffleSpec.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/GlobalSortTargetSizeShuffleSpec.java
similarity index 80%
rename from extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/TargetSizeShuffleSpec.java
rename to extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/GlobalSortTargetSizeShuffleSpec.java
index 49f4d71868..61ae457d62 100644
--- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/TargetSizeShuffleSpec.java
+++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/GlobalSortTargetSizeShuffleSpec.java
@@ -26,6 +26,8 @@ import com.google.common.base.Preconditions;
 import org.apache.druid.frame.key.ClusterBy;
 import org.apache.druid.frame.key.ClusterByPartitions;
 import org.apache.druid.java.util.common.Either;
+import org.apache.druid.java.util.common.IAE;
+import org.apache.druid.java.util.common.ISE;
 import org.apache.druid.msq.statistics.ClusterByStatisticsCollector;
 
 import javax.annotation.Nullable;
@@ -36,14 +38,16 @@ import java.util.Objects;
  * to a particular {@link #targetSize}. Commonly used when generating segments, which we want to have a certain number
  * of rows per segment.
  */
-public class TargetSizeShuffleSpec implements ShuffleSpec
+public class GlobalSortTargetSizeShuffleSpec implements GlobalSortShuffleSpec
 {
+  public static final String TYPE = "targetSize";
+
   private final ClusterBy clusterBy;
   private final long targetSize;
   private final boolean aggregate;
 
   @JsonCreator
-  public TargetSizeShuffleSpec(
+  public GlobalSortTargetSizeShuffleSpec(
       @JsonProperty("clusterBy") final ClusterBy clusterBy,
       @JsonProperty("targetSize") final long targetSize,
       @JsonProperty("aggregate") final boolean aggregate
@@ -52,24 +56,40 @@ public class TargetSizeShuffleSpec implements ShuffleSpec
     this.clusterBy = Preconditions.checkNotNull(clusterBy, "clusterBy");
     this.targetSize = targetSize;
     this.aggregate = aggregate;
+
+    if (!clusterBy.sortable()) {
+      throw new IAE("ClusterBy key must be sortable");
+    }
+  }
+
+  @Override
+  public ShuffleKind kind()
+  {
+    return ShuffleKind.GLOBAL_SORT;
   }
 
   @Override
   @JsonProperty("aggregate")
   @JsonInclude(JsonInclude.Include.NON_DEFAULT)
-  public boolean doesAggregateByClusterKey()
+  public boolean doesAggregate()
   {
     return aggregate;
   }
 
   @Override
-  public boolean needsStatistics()
+  public boolean mustGatherResultKeyStatistics()
   {
     return true;
   }
 
   @Override
-  public Either<Long, ClusterByPartitions> generatePartitions(
+  public int partitionCount()
+  {
+    throw new ISE("Number of partitions not known for [%s].", kind());
+  }
+
+  @Override
+  public Either<Long, ClusterByPartitions> generatePartitionsForGlobalSort(
       @Nullable final ClusterByStatisticsCollector collector,
       final int maxNumPartitions
   )
@@ -90,13 +110,13 @@ public class TargetSizeShuffleSpec implements ShuffleSpec
 
   @Override
   @JsonProperty
-  public ClusterBy getClusterBy()
+  public ClusterBy clusterBy()
   {
     return clusterBy;
   }
 
   @JsonProperty
-  long getTargetSize()
+  long targetSize()
   {
     return targetSize;
   }
@@ -110,7 +130,7 @@ public class TargetSizeShuffleSpec implements ShuffleSpec
     if (o == null || getClass() != o.getClass()) {
       return false;
     }
-    TargetSizeShuffleSpec that = (TargetSizeShuffleSpec) o;
+    GlobalSortTargetSizeShuffleSpec that = (GlobalSortTargetSizeShuffleSpec) o;
     return targetSize == that.targetSize && aggregate == that.aggregate && Objects.equals(clusterBy, that.clusterBy);
   }
 
diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/HashShuffleSpec.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/HashShuffleSpec.java
new file mode 100644
index 0000000000..fc453d7663
--- /dev/null
+++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/HashShuffleSpec.java
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.msq.kernel;
+
+import com.fasterxml.jackson.annotation.JsonCreator;
+import com.fasterxml.jackson.annotation.JsonProperty;
+import org.apache.druid.frame.key.ClusterBy;
+import org.apache.druid.java.util.common.IAE;
+
+public class HashShuffleSpec implements ShuffleSpec
+{
+  public static final String TYPE = "hash";
+
+  private final ClusterBy clusterBy;
+  private final int numPartitions;
+
+  @JsonCreator
+  public HashShuffleSpec(
+      @JsonProperty("clusterBy") final ClusterBy clusterBy,
+      @JsonProperty("partitions") final int numPartitions
+  )
+  {
+    this.clusterBy = clusterBy;
+    this.numPartitions = numPartitions;
+
+    if (clusterBy.getBucketByCount() > 0) {
+      // Only GlobalSortTargetSizeShuffleSpec supports bucket-by.
+      throw new IAE("Cannot bucket with %s partitioning (clusterBy = %s)", TYPE, clusterBy);
+    }
+  }
+
+  @Override
+  public ShuffleKind kind()
+  {
+    return clusterBy.sortable() && !clusterBy.isEmpty() ? ShuffleKind.HASH_LOCAL_SORT : ShuffleKind.HASH;
+  }
+
+  @Override
+  @JsonProperty
+  public ClusterBy clusterBy()
+  {
+    return clusterBy;
+  }
+
+  @Override
+  public boolean doesAggregate()
+  {
+    return false;
+  }
+
+  @Override
+  @JsonProperty("partitions")
+  public int partitionCount()
+  {
+    return numPartitions;
+  }
+}
diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/MixShuffleSpec.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/MixShuffleSpec.java
new file mode 100644
index 0000000000..3fe4e2e060
--- /dev/null
+++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/MixShuffleSpec.java
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.msq.kernel;
+
+import com.fasterxml.jackson.annotation.JsonCreator;
+import org.apache.druid.frame.key.ClusterBy;
+
+/**
+ * Shuffle spec that generates a single, unsorted partition.
+ */
+public class MixShuffleSpec implements ShuffleSpec
+{
+  public static final String TYPE = "mix";
+
+  private static final MixShuffleSpec INSTANCE = new MixShuffleSpec();
+
+  private MixShuffleSpec()
+  {
+  }
+
+  @JsonCreator
+  public static MixShuffleSpec instance()
+  {
+    return INSTANCE;
+  }
+
+  @Override
+  public ShuffleKind kind()
+  {
+    return ShuffleKind.MIX;
+  }
+
+  @Override
+  public ClusterBy clusterBy()
+  {
+    return ClusterBy.none();
+  }
+
+  @Override
+  public boolean doesAggregate()
+  {
+    return false;
+  }
+
+  @Override
+  public int partitionCount()
+  {
+    return 1;
+  }
+
+  @Override
+  public boolean equals(Object obj)
+  {
+    return obj != null && this.getClass().equals(obj.getClass());
+  }
+
+  @Override
+  public int hashCode()
+  {
+    return 0;
+  }
+
+  @Override
+  public String toString()
+  {
+    return "MuxShuffleSpec{}";
+  }
+}
diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/QueryDefinitionBuilder.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/QueryDefinitionBuilder.java
index 166369cc3b..29fc52fe89 100644
--- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/QueryDefinitionBuilder.java
+++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/QueryDefinitionBuilder.java
@@ -20,6 +20,7 @@
 package org.apache.druid.msq.kernel;
 
 import com.google.common.base.Preconditions;
+import org.apache.druid.java.util.common.ISE;
 
 import java.util.ArrayList;
 import java.util.List;
@@ -77,6 +78,17 @@ public class QueryDefinitionBuilder
     return stageBuilders.stream().mapToInt(StageDefinitionBuilder::getStageNumber).max().orElse(-1) + 1;
   }
 
+  public StageDefinitionBuilder getStageBuilder(final int stageNumber)
+  {
+    for (final StageDefinitionBuilder stageBuilder : stageBuilders) {
+      if (stageBuilder.getStageNumber() == stageNumber) {
+        return stageBuilder;
+      }
+    }
+
+    throw new ISE("No such stage [%s]", stageNumber);
+  }
+
   public QueryDefinition build()
   {
     final List<StageDefinition> stageDefinitions =
diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/ShuffleKind.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/ShuffleKind.java
new file mode 100644
index 0000000000..ac3bb99273
--- /dev/null
+++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/ShuffleKind.java
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.msq.kernel;
+
+public enum ShuffleKind
+{
+  /**
+   * Put all data in a single partition, with no sorting and no statistics gathering.
+   */
+  MIX(false, false),
+
+  /**
+   * Partition using hash codes, with no sorting.
+   *
+   * This kind of shuffle supports pipelining: producer and consumer stages can run at the same time.
+   */
+  HASH(true, false),
+
+  /**
+   * Partition using hash codes, with each partition internally sorted.
+   *
+   * Each worker partitions its outputs according to hash code of the cluster key, and does a local sort of its
+   * own outputs.
+   *
+   * Due to the need to sort outputs, this shuffle mechanism cannot be pipelined. Producer stages must finish before
+   * consumer stages can run.
+   */
+  HASH_LOCAL_SORT(true, true),
+
+  /**
+   * Partition using a distributed global sort.
+   *
+   * First, each worker reads its input fully and feeds statistics into a
+   * {@link org.apache.druid.msq.statistics.ClusterByStatisticsCollector}. The controller merges those statistics,
+   * generating final {@link org.apache.druid.frame.key.ClusterByPartitions}. Then, workers fully sort and partition
+   * their outputs along those lines.
+   *
+   * Consumers (workers in the next stage downstream) do an N-way merge of the already-sorted and already-partitioned
+   * output files from each worker.
+   *
+   * Due to the need to sort outputs, this shuffle mechanism cannot be pipelined. Producer stages must finish before
+   * consumer stages can run.
+   */
+  GLOBAL_SORT(false, true);
+
+  private final boolean hash;
+  private final boolean sort;
+
+  ShuffleKind(boolean hash, boolean sort)
+  {
+    this.hash = hash;
+    this.sort = sort;
+  }
+
+  /**
+   * Whether this shuffle does hash-partitioning.
+   */
+  public boolean isHash()
+  {
+    return hash;
+  }
+
+  /**
+   * Whether this shuffle sorts within partitions. (If true, it may, or may not, also sort globally.)
+   */
+  public boolean isSort()
+  {
+    return sort;
+  }
+}
diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/ShuffleSpec.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/ShuffleSpec.java
index 319e85850c..fe5a129304 100644
--- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/ShuffleSpec.java
+++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/ShuffleSpec.java
@@ -22,11 +22,6 @@ package org.apache.druid.msq.kernel;
 import com.fasterxml.jackson.annotation.JsonSubTypes;
 import com.fasterxml.jackson.annotation.JsonTypeInfo;
 import org.apache.druid.frame.key.ClusterBy;
-import org.apache.druid.frame.key.ClusterByPartitions;
-import org.apache.druid.java.util.common.Either;
-import org.apache.druid.msq.statistics.ClusterByStatisticsCollector;
-
-import javax.annotation.Nullable;
 
 /**
  * Describes how outputs of a stage are shuffled. Property of {@link StageDefinition}.
@@ -36,37 +31,46 @@ import javax.annotation.Nullable;
  */
 @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, property = "type")
 @JsonSubTypes(value = {
-    @JsonSubTypes.Type(name = "maxCount", value = MaxCountShuffleSpec.class),
-    @JsonSubTypes.Type(name = "targetSize", value = TargetSizeShuffleSpec.class)
+    @JsonSubTypes.Type(name = MixShuffleSpec.TYPE, value = MixShuffleSpec.class),
+    @JsonSubTypes.Type(name = HashShuffleSpec.TYPE, value = HashShuffleSpec.class),
+    @JsonSubTypes.Type(name = GlobalSortMaxCountShuffleSpec.TYPE, value = GlobalSortMaxCountShuffleSpec.class),
+    @JsonSubTypes.Type(name = GlobalSortTargetSizeShuffleSpec.TYPE, value = GlobalSortTargetSizeShuffleSpec.class)
 })
 public interface ShuffleSpec
 {
   /**
-   * Clustering key that will determine how data are partitioned during the shuffle.
+   * The nature of this shuffle: hash vs. range based partitioning; whether the data are sorted or not.
+   *
+   * If this method returns {@link ShuffleKind#GLOBAL_SORT}, then this spec is also an instance of
+   * {@link GlobalSortShuffleSpec}, and additional methods are available.
    */
-  ClusterBy getClusterBy();
+  ShuffleKind kind();
 
   /**
-   * Whether this stage aggregates by the clustering key or not.
+   * Partitioning key for the shuffle.
+   *
+   * If {@link #kind()} is {@link ShuffleKind#HASH}, data are partitioned using a hash of this key, but not sorted.
+   *
+   * If {@link #kind()} is {@link ShuffleKind#HASH_LOCAL_SORT}, data are partitioned using a hash of this key, and
+   * sorted within each partition.
+   *
+   * If {@link #kind()} is {@link ShuffleKind#GLOBAL_SORT}, data are partitioned using ranges of this key, and are
+   * sorted within each partition; therefore, the data are also globally sorted.
    */
-  boolean doesAggregateByClusterKey();
+  ClusterBy clusterBy();
 
   /**
-   * Whether {@link #generatePartitions} needs a nonnull collector.
+   * Whether this stage aggregates by the {@link #clusterBy()} key.
    */
-  boolean needsStatistics();
+  boolean doesAggregate();
 
   /**
-   * Generates a set of partitions based on the provided statistics.
+   * Number of partitions, if known.
    *
-   * @param collector        must be nonnull if {@link #needsStatistics()} is true; may be null otherwise
-   * @param maxNumPartitions maximum number of partitions to generate
+   * Partition count is always known if {@link #kind()} is {@link ShuffleKind#MIX}, {@link ShuffleKind#HASH}, or
+   * {@link ShuffleKind#HASH_LOCAL_SORT}. It is not known if {@link #kind()} is {@link ShuffleKind#GLOBAL_SORT}.
    *
-   * @return either the partition assignment, or (as an error) a number of partitions, greater than maxNumPartitions,
-   * that would be expected to be created
+   * @throws IllegalStateException if kind is {@link ShuffleKind#GLOBAL_SORT}
    */
-  Either<Long, ClusterByPartitions> generatePartitions(
-      @Nullable ClusterByStatisticsCollector collector,
-      int maxNumPartitions
-  );
+  int partitionCount();
 }
diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/StageDefinition.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/StageDefinition.java
index 083bc167df..9892eae824 100644
--- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/StageDefinition.java
+++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/StageDefinition.java
@@ -27,9 +27,16 @@ import com.google.common.base.Suppliers;
 import it.unimi.dsi.fastutil.ints.IntAVLTreeSet;
 import it.unimi.dsi.fastutil.ints.IntSet;
 import it.unimi.dsi.fastutil.ints.IntSets;
+import org.apache.druid.frame.FrameType;
+import org.apache.druid.frame.allocation.MemoryAllocator;
+import org.apache.druid.frame.allocation.MemoryAllocatorFactory;
+import org.apache.druid.frame.allocation.SingleMemoryAllocatorFactory;
 import org.apache.druid.frame.key.ClusterBy;
 import org.apache.druid.frame.key.ClusterByPartitions;
+import org.apache.druid.frame.key.KeyColumn;
 import org.apache.druid.frame.read.FrameReader;
+import org.apache.druid.frame.write.FrameWriterFactory;
+import org.apache.druid.frame.write.FrameWriters;
 import org.apache.druid.java.util.common.Either;
 import org.apache.druid.java.util.common.IAE;
 import org.apache.druid.java.util.common.ISE;
@@ -41,9 +48,9 @@ import org.apache.druid.msq.statistics.ClusterByStatisticsCollectorImpl;
 import org.apache.druid.segment.column.RowSignature;
 
 import javax.annotation.Nullable;
+import java.util.Collections;
 import java.util.List;
 import java.util.Objects;
-import java.util.Optional;
 import java.util.Set;
 import java.util.function.Supplier;
 
@@ -64,7 +71,7 @@ import java.util.function.Supplier;
  * Each stage has a {@link ShuffleSpec} describing the shuffle that occurs as part of the stage. The shuffle spec is
  * optional: if none is provided, then the {@link FrameProcessorFactory} directly writes to output partitions. If a
  * shuffle spec is provided, then the {@link FrameProcessorFactory} is expected to sort each output frame individually
- * according to {@link ShuffleSpec#getClusterBy()}. The execution system handles the rest, including sorting data across
+ * according to {@link ShuffleSpec#clusterBy()}. The execution system handles the rest, including sorting data across
  * frames and producing the appropriate output partitions.
  * <p>
  * The rarely-used parameter {@link #getShuffleCheckHasMultipleValues()} controls whether the execution system
@@ -128,7 +135,7 @@ public class StageDefinition
     this.maxInputBytesPerWorker = maxInputBytesPerWorker == null ?
                                   Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER : maxInputBytesPerWorker;
 
-    if (shuffleSpec != null && shuffleSpec.needsStatistics() && shuffleSpec.getClusterBy().getColumns().isEmpty()) {
+    if (mustGatherResultKeyStatistics() && shuffleSpec.clusterBy().getColumns().isEmpty()) {
       throw new IAE("Cannot shuffle with spec [%s] and nil clusterBy", shuffleSpec);
     }
 
@@ -157,7 +164,7 @@ public class StageDefinition
         .broadcastInputs(stageDef.getBroadcastInputNumbers())
         .processorFactory(stageDef.getProcessorFactory())
         .signature(stageDef.getSignature())
-        .shuffleSpec(stageDef.getShuffleSpec().orElse(null))
+        .shuffleSpec(stageDef.doesShuffle() ? stageDef.getShuffleSpec() : null)
         .maxWorkerCount(stageDef.getMaxWorkerCount())
         .shuffleCheckHasMultipleValues(stageDef.getShuffleCheckHasMultipleValues());
   }
@@ -212,16 +219,25 @@ public class StageDefinition
 
   public boolean doesSortDuringShuffle()
   {
-    if (shuffleSpec == null) {
+    if (shuffleSpec == null || shuffleSpec.clusterBy().isEmpty()) {
       return false;
     } else {
-      return !shuffleSpec.getClusterBy().getColumns().isEmpty() || shuffleSpec.needsStatistics();
+      return shuffleSpec.clusterBy().sortable();
     }
   }
 
-  public Optional<ShuffleSpec> getShuffleSpec()
+  /**
+   * Returns the {@link ShuffleSpec} for this stage, if {@link #doesShuffle()}.
+   *
+   * @throws IllegalStateException if this stage does not shuffle
+   */
+  public ShuffleSpec getShuffleSpec()
   {
-    return Optional.ofNullable(shuffleSpec);
+    if (shuffleSpec == null) {
+      throw new IllegalStateException("Stage does not shuffle");
+    }
+
+    return shuffleSpec;
   }
 
   /**
@@ -229,7 +245,25 @@ public class StageDefinition
    */
   public ClusterBy getClusterBy()
   {
-    return shuffleSpec != null ? shuffleSpec.getClusterBy() : ClusterBy.none();
+    if (shuffleSpec != null) {
+      return shuffleSpec.clusterBy();
+    } else {
+      return ClusterBy.none();
+    }
+  }
+
+  /**
+   * Returns the key used for sorting each individual partition, or an empty list if partitions are unsorted.
+   */
+  public List<KeyColumn> getSortKey()
+  {
+    final ClusterBy clusterBy = getClusterBy();
+
+    if (clusterBy.sortable()) {
+      return clusterBy.getColumns();
+    } else {
+      return Collections.emptyList();
+    }
   }
 
   @Nullable
@@ -285,40 +319,77 @@ public class StageDefinition
    */
   public boolean mustGatherResultKeyStatistics()
   {
-    return shuffleSpec != null && shuffleSpec.needsStatistics();
+    return shuffleSpec != null
+           && shuffleSpec.kind() == ShuffleKind.GLOBAL_SORT
+           && ((GlobalSortShuffleSpec) shuffleSpec).mustGatherResultKeyStatistics();
   }
 
-  public Either<Long, ClusterByPartitions> generatePartitionsForShuffle(
+  public Either<Long, ClusterByPartitions> generatePartitionBoundariesForShuffle(
       @Nullable ClusterByStatisticsCollector collector
   )
   {
     if (shuffleSpec == null) {
       throw new ISE("No shuffle for stage[%d]", getStageNumber());
+    } else if (shuffleSpec.kind() != ShuffleKind.GLOBAL_SORT) {
+      throw new ISE(
+          "Shuffle of kind [%s] cannot generate partition boundaries for stage[%d]",
+          shuffleSpec.kind(),
+          getStageNumber()
+      );
     } else if (mustGatherResultKeyStatistics() && collector == null) {
       throw new ISE("Statistics required, but not gathered for stage[%d]", getStageNumber());
     } else if (!mustGatherResultKeyStatistics() && collector != null) {
       throw new ISE("Statistics gathered, but not required for stage[%d]", getStageNumber());
     } else {
-      return shuffleSpec.generatePartitions(collector, MAX_PARTITIONS);
+      return ((GlobalSortShuffleSpec) shuffleSpec).generatePartitionsForGlobalSort(collector, MAX_PARTITIONS);
     }
   }
 
   public ClusterByStatisticsCollector createResultKeyStatisticsCollector(final int maxRetainedBytes)
   {
     if (!mustGatherResultKeyStatistics()) {
-      throw new ISE("No statistics needed");
+      throw new ISE("No statistics needed for stage[%d]", getStageNumber());
     }
 
     return ClusterByStatisticsCollectorImpl.create(
-        shuffleSpec.getClusterBy(),
+        shuffleSpec.clusterBy(),
         signature,
         maxRetainedBytes,
         PARTITION_STATS_MAX_BUCKETS,
-        shuffleSpec.doesAggregateByClusterKey(),
+        shuffleSpec.doesAggregate(),
         shuffleCheckHasMultipleValues
     );
   }
 
+  /**
+   * Create the {@link FrameWriterFactory} that must be used by {@link #getProcessorFactory()}.
+   *
+   * Calls {@link MemoryAllocatorFactory#newAllocator()} for each frame.
+   */
+  public FrameWriterFactory createFrameWriterFactory(final MemoryAllocatorFactory memoryAllocatorFactory)
+  {
+    return FrameWriters.makeFrameWriterFactory(
+        FrameType.ROW_BASED,
+        memoryAllocatorFactory,
+        signature,
+
+        // Main processor does not sort when there is a hash going on, even if isSort = true. This is because
+        // FrameChannelHashPartitioner is expected to be attached to the processor and do the sorting. We don't
+        // want to double-sort.
+        doesShuffle() && !shuffleSpec.kind().isHash() ? getClusterBy().getColumns() : Collections.emptyList()
+    );
+  }
+
+  /**
+   * Create the {@link FrameWriterFactory} that must be used by {@link #getProcessorFactory()}.
+   *
+   * Re-uses the same {@link MemoryAllocator} for each frame.
+   */
+  public FrameWriterFactory createFrameWriterFactory(final MemoryAllocator allocator)
+  {
+    return createFrameWriterFactory(new SingleMemoryAllocatorFactory(allocator));
+  }
+
   public FrameReader getFrameReader()
   {
     return frameReader.get();
diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/StageDefinitionBuilder.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/StageDefinitionBuilder.java
index 384b78fd33..3ff9594123 100644
--- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/StageDefinitionBuilder.java
+++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/StageDefinitionBuilder.java
@@ -118,6 +118,11 @@ public class StageDefinitionBuilder
     return stageNumber;
   }
 
+  public RowSignature getSignature()
+  {
+    return signature;
+  }
+
   public StageDefinition build(final String queryId)
   {
     return new StageDefinition(
diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/ControllerStageTracker.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/ControllerStageTracker.java
index 39b81b9081..1de56c529b 100644
--- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/ControllerStageTracker.java
+++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/ControllerStageTracker.java
@@ -43,6 +43,9 @@ import org.apache.druid.msq.input.InputSpecSlicer;
 import org.apache.druid.msq.input.stage.ReadablePartition;
 import org.apache.druid.msq.input.stage.ReadablePartitions;
 import org.apache.druid.msq.input.stage.StageInputSlice;
+import org.apache.druid.msq.kernel.GlobalSortShuffleSpec;
+import org.apache.druid.msq.kernel.ShuffleKind;
+import org.apache.druid.msq.kernel.ShuffleSpec;
 import org.apache.druid.msq.kernel.StageDefinition;
 import org.apache.druid.msq.kernel.WorkerAssignmentStrategy;
 import org.apache.druid.msq.statistics.ClusterByStatisticsCollector;
@@ -556,7 +559,8 @@ class ControllerStageTracker
           if (workers.isEmpty()) {
             // generate partition boundaries since all work is finished for the time chunk
             ClusterByStatisticsCollector collector = timeChunkToCollector.get(tc);
-            Either<Long, ClusterByPartitions> countOrPartitions = stageDef.generatePartitionsForShuffle(collector);
+            Either<Long, ClusterByPartitions> countOrPartitions =
+                stageDef.generatePartitionBoundariesForShuffle(collector);
             totalPartitionCount += getPartitionCountFromEither(countOrPartitions);
             if (totalPartitionCount > stageDef.getMaxPartitionCount()) {
               failForReason(new TooManyPartitionsFault(stageDef.getMaxPartitionCount()));
@@ -689,8 +693,8 @@ class ControllerStageTracker
         );
       }
       if (resultPartitions == null) {
-        Either<Long, ClusterByPartitions> countOrPartitions = stageDef.generatePartitionsForShuffle(timeChunkToCollector.get(
-            STATIC_TIME_CHUNK_FOR_PARALLEL_MERGE));
+        final ClusterByStatisticsCollector collector = timeChunkToCollector.get(STATIC_TIME_CHUNK_FOR_PARALLEL_MERGE);
+        Either<Long, ClusterByPartitions> countOrPartitions = stageDef.generatePartitionBoundariesForShuffle(collector);
         totalPartitionCount += getPartitionCountFromEither(countOrPartitions);
         if (totalPartitionCount > stageDef.getMaxPartitionCount()) {
           failForReason(new TooManyPartitionsFault(stageDef.getMaxPartitionCount()));
@@ -840,9 +844,10 @@ class ControllerStageTracker
   }
 
   /**
-   * Sets {@link #resultPartitions} (always) and {@link #resultPartitionBoundaries} without using key statistics.
-   * <p>
-   * If {@link StageDefinition#mustGatherResultKeyStatistics()} is true, this method should not be called.
+   * Sets {@link #resultPartitions} (always) and {@link #resultPartitionBoundaries} (if doing a global sort) without
+   * using key statistics. Called by the constructor.
+   *
+   * If {@link StageDefinition#mustGatherResultKeyStatistics()} is true, this method must not be called.
    */
   private void generateResultPartitionsAndBoundariesWithoutKeyStatistics()
   {
@@ -856,24 +861,31 @@ class ControllerStageTracker
     final int stageNumber = stageDef.getStageNumber();
 
     if (stageDef.doesShuffle()) {
-      if (stageDef.mustGatherResultKeyStatistics() && !allPartialKeyInformationFetched()) {
-        throw new ISE("Cannot generate result partitions without all worker key statistics");
-      }
+      final ShuffleSpec shuffleSpec = stageDef.getShuffleSpec();
 
-      final Either<Long, ClusterByPartitions> maybeResultPartitionBoundaries =
-          stageDef.generatePartitionsForShuffle(null);
+      if (shuffleSpec.kind() == ShuffleKind.GLOBAL_SORT) {
+        if (((GlobalSortShuffleSpec) shuffleSpec).mustGatherResultKeyStatistics()
+            && !allPartialKeyInformationFetched()) {
+          throw new ISE("Cannot generate result partitions without all worker key statistics");
+        }
 
-      if (maybeResultPartitionBoundaries.isError()) {
-        failForReason(new TooManyPartitionsFault(stageDef.getMaxPartitionCount()));
-        return;
-      }
+        final Either<Long, ClusterByPartitions> maybeResultPartitionBoundaries =
+            stageDef.generatePartitionBoundariesForShuffle(null);
 
-      resultPartitionBoundaries = maybeResultPartitionBoundaries.valueOrThrow();
-      resultPartitions = ReadablePartitions.striped(
-          stageNumber,
-          workerCount,
-          resultPartitionBoundaries.size()
-      );
+        if (maybeResultPartitionBoundaries.isError()) {
+          failForReason(new TooManyPartitionsFault(stageDef.getMaxPartitionCount()));
+          return;
+        }
+
+        resultPartitionBoundaries = maybeResultPartitionBoundaries.valueOrThrow();
+        resultPartitions = ReadablePartitions.striped(
+            stageNumber,
+            workerCount,
+            resultPartitionBoundaries.size()
+        );
+      } else {
+        resultPartitions = ReadablePartitions.striped(stageNumber, workerCount, shuffleSpec.partitionCount());
+      }
     } else {
       // No reshuffling: retain partitioning from nonbroadcast inputs.
       final Int2IntSortedMap partitionToWorkerMap = new Int2IntAVLTreeMap();
diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/worker/WorkerStageKernel.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/worker/WorkerStageKernel.java
index 5ff3398054..632b8a8106 100644
--- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/worker/WorkerStageKernel.java
+++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/worker/WorkerStageKernel.java
@@ -25,6 +25,7 @@ import org.apache.druid.java.util.common.IAE;
 import org.apache.druid.java.util.common.ISE;
 import org.apache.druid.java.util.common.Pair;
 import org.apache.druid.java.util.common.logger.Logger;
+import org.apache.druid.msq.kernel.ShuffleKind;
 import org.apache.druid.msq.kernel.StageDefinition;
 import org.apache.druid.msq.kernel.StageId;
 import org.apache.druid.msq.kernel.WorkOrder;
@@ -70,11 +71,12 @@ public class WorkerStageKernel
     this.workOrder = workOrder;
 
     if (workOrder.getStageDefinition().doesShuffle()
+        && workOrder.getStageDefinition().getShuffleSpec().kind() == ShuffleKind.GLOBAL_SORT
         && !workOrder.getStageDefinition().mustGatherResultKeyStatistics()) {
       // Use valueOrThrow instead of a nicer error collection mechanism, because we really don't expect the
       // MAX_PARTITIONS to be exceeded here. It would involve having a shuffleSpec that was statically configured
       // to use a huge number of partitions.
-      resultPartitionBoundaries = workOrder.getStageDefinition().generatePartitionsForShuffle(null).valueOrThrow();
+      resultPartitionBoundaries = workOrder.getStageDefinition().generatePartitionBoundariesForShuffle(null).valueOrThrow();
     }
   }
 
diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/BaseLeafFrameProcessorFactory.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/BaseLeafFrameProcessorFactory.java
index ef1275355c..ea913d184d 100644
--- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/BaseLeafFrameProcessorFactory.java
+++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/BaseLeafFrameProcessorFactory.java
@@ -23,15 +23,14 @@ import com.google.common.collect.Iterators;
 import it.unimi.dsi.fastutil.ints.Int2ObjectAVLTreeMap;
 import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
 import org.apache.druid.collections.ResourceHolder;
-import org.apache.druid.frame.allocation.MemoryAllocator;
 import org.apache.druid.frame.channel.ReadableConcatFrameChannel;
 import org.apache.druid.frame.channel.ReadableFrameChannel;
 import org.apache.druid.frame.channel.WritableFrameChannel;
-import org.apache.druid.frame.key.ClusterBy;
 import org.apache.druid.frame.processor.FrameProcessor;
 import org.apache.druid.frame.processor.OutputChannel;
 import org.apache.druid.frame.processor.OutputChannelFactory;
 import org.apache.druid.frame.processor.OutputChannels;
+import org.apache.druid.frame.write.FrameWriterFactory;
 import org.apache.druid.java.util.common.ISE;
 import org.apache.druid.java.util.common.Pair;
 import org.apache.druid.java.util.common.guava.Sequence;
@@ -48,7 +47,6 @@ import org.apache.druid.msq.input.stage.StageInputSlice;
 import org.apache.druid.msq.kernel.FrameContext;
 import org.apache.druid.msq.kernel.ProcessorsAndChannels;
 import org.apache.druid.msq.kernel.StageDefinition;
-import org.apache.druid.segment.column.RowSignature;
 
 import javax.annotation.Nullable;
 import java.io.IOException;
@@ -104,7 +102,7 @@ public abstract class BaseLeafFrameProcessorFactory extends BaseFrameProcessorFa
       outstandingProcessors = Math.min(totalProcessors, maxOutstandingProcessors);
     }
 
-    final AtomicReference<Queue<MemoryAllocator>> allocatorQueueRef =
+    final AtomicReference<Queue<FrameWriterFactory>> frameWriterFactoryQueueRef =
         new AtomicReference<>(new ArrayDeque<>(outstandingProcessors));
     final AtomicReference<Queue<WritableFrameChannel>> channelQueueRef =
         new AtomicReference<>(new ArrayDeque<>(outstandingProcessors));
@@ -114,7 +112,9 @@ public abstract class BaseLeafFrameProcessorFactory extends BaseFrameProcessorFa
       final OutputChannel outputChannel = outputChannelFactory.openChannel(0 /* Partition number doesn't matter */);
       outputChannels.add(outputChannel);
       channelQueueRef.get().add(outputChannel.getWritableChannel());
-      allocatorQueueRef.get().add(outputChannel.getFrameMemoryAllocator());
+      frameWriterFactoryQueueRef.get().add(
+          stageDefinition.createFrameWriterFactory(outputChannel.getFrameMemoryAllocator())
+      );
     }
 
     // Read all base inputs in separate processors, one per processor.
@@ -147,9 +147,7 @@ public abstract class BaseLeafFrameProcessorFactory extends BaseFrameProcessorFa
                     }
                   }
               ),
-              makeLazyResourceHolder(allocatorQueueRef, ignored -> {}),
-              stageDefinition.getSignature(),
-              stageDefinition.getClusterBy(),
+              makeLazyResourceHolder(frameWriterFactoryQueueRef, ignored -> {}),
               frameContext
           );
         }
@@ -257,9 +255,7 @@ public abstract class BaseLeafFrameProcessorFactory extends BaseFrameProcessorFa
       ReadableInput baseInput,
       Int2ObjectMap<ReadableInput> sideChannels,
       ResourceHolder<WritableFrameChannel> outputChannelSupplier,
-      ResourceHolder<MemoryAllocator> allocatorSupplier,
-      RowSignature signature,
-      ClusterBy clusterBy,
+      ResourceHolder<FrameWriterFactory> frameWriterFactory,
       FrameContext providerThingy
   );
 
diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/DataSourcePlan.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/DataSourcePlan.java
index 6868fef56f..c174c4ab82 100644
--- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/DataSourcePlan.java
+++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/DataSourcePlan.java
@@ -23,11 +23,15 @@ import com.fasterxml.jackson.core.JsonFactory;
 import com.fasterxml.jackson.core.JsonProcessingException;
 import com.fasterxml.jackson.databind.ObjectMapper;
 import com.google.common.base.Preconditions;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Iterables;
 import it.unimi.dsi.fastutil.ints.IntOpenHashSet;
 import it.unimi.dsi.fastutil.ints.IntSet;
 import it.unimi.dsi.fastutil.ints.IntSets;
 import org.apache.druid.data.input.impl.InlineInputSource;
 import org.apache.druid.data.input.impl.JsonInputFormat;
+import org.apache.druid.frame.key.ClusterBy;
+import org.apache.druid.frame.key.KeyColumn;
 import org.apache.druid.java.util.common.IAE;
 import org.apache.druid.java.util.common.Intervals;
 import org.apache.druid.java.util.common.UOE;
@@ -36,11 +40,16 @@ import org.apache.druid.msq.input.NilInputSource;
 import org.apache.druid.msq.input.external.ExternalInputSpec;
 import org.apache.druid.msq.input.stage.StageInputSpec;
 import org.apache.druid.msq.input.table.TableInputSpec;
+import org.apache.druid.msq.kernel.HashShuffleSpec;
 import org.apache.druid.msq.kernel.QueryDefinition;
 import org.apache.druid.msq.kernel.QueryDefinitionBuilder;
+import org.apache.druid.msq.kernel.StageDefinition;
+import org.apache.druid.msq.kernel.StageDefinitionBuilder;
+import org.apache.druid.msq.querykit.common.SortMergeJoinFrameProcessorFactory;
 import org.apache.druid.query.DataSource;
 import org.apache.druid.query.InlineDataSource;
 import org.apache.druid.query.JoinDataSource;
+import org.apache.druid.query.QueryContext;
 import org.apache.druid.query.QueryDataSource;
 import org.apache.druid.query.TableDataSource;
 import org.apache.druid.query.filter.DimFilter;
@@ -52,6 +61,8 @@ import org.apache.druid.segment.column.ColumnHolder;
 import org.apache.druid.segment.column.RowSignature;
 import org.apache.druid.sql.calcite.external.ExternalDataSource;
 import org.apache.druid.sql.calcite.parser.DruidSqlInsert;
+import org.apache.druid.sql.calcite.planner.JoinAlgorithm;
+import org.apache.druid.sql.calcite.planner.PlannerContext;
 import org.joda.time.Interval;
 
 import javax.annotation.Nullable;
@@ -63,9 +74,6 @@ import java.util.Map;
 import java.util.Optional;
 import java.util.stream.Collectors;
 
-/**
- * Used by {@link QueryKit} implementations to produce {@link InputSpec} from native {@link DataSource}.
- */
 public class DataSourcePlan
 {
   /**
@@ -108,6 +116,7 @@ public class DataSourcePlan
   public static DataSourcePlan forDataSource(
       final QueryKit queryKit,
       final String queryId,
+      final QueryContext queryContext,
       final DataSource dataSource,
       final QuerySegmentSpec querySegmentSpec,
       @Nullable DimFilter filter,
@@ -135,15 +144,35 @@ public class DataSourcePlan
           broadcast
       );
     } else if (dataSource instanceof JoinDataSource) {
-      return forJoin(
-          queryKit,
-          queryId,
-          (JoinDataSource) dataSource,
-          querySegmentSpec,
-          maxWorkerCount,
-          minStageNumber,
-          broadcast
-      );
+      final JoinAlgorithm joinAlgorithm = PlannerContext.getJoinAlgorithm(queryContext);
+
+      switch (joinAlgorithm) {
+        case BROADCAST:
+          return forBroadcastHashJoin(
+              queryKit,
+              queryId,
+              queryContext,
+              (JoinDataSource) dataSource,
+              querySegmentSpec,
+              maxWorkerCount,
+              minStageNumber,
+              broadcast
+          );
+
+        case SORT_MERGE:
+          return forSortMergeJoin(
+              queryKit,
+              queryId,
+              (JoinDataSource) dataSource,
+              querySegmentSpec,
+              maxWorkerCount,
+              minStageNumber,
+              broadcast
+          );
+
+        default:
+          throw new UOE("Cannot handle join algorithm [%s]", joinAlgorithm);
+      }
     } else {
       throw new UOE("Cannot handle dataSource [%s]", dataSource);
     }
@@ -263,7 +292,7 @@ public class DataSourcePlan
         // outermost query, and setting it for the subquery makes us erroneously add bucketing where it doesn't belong.
         dataSource.getQuery().withOverriddenContext(CONTEXT_MAP_NO_SEGMENT_GRANULARITY),
         queryKit,
-        ShuffleSpecFactories.subQueryWithMaxWorkerCount(maxWorkerCount),
+        ShuffleSpecFactories.globalSortWithMaxPartitionCount(maxWorkerCount),
         maxWorkerCount,
         minStageNumber
     );
@@ -278,9 +307,13 @@ public class DataSourcePlan
     );
   }
 
-  private static DataSourcePlan forJoin(
+  /**
+   * Build a plan for broadcast hash-join.
+   */
+  private static DataSourcePlan forBroadcastHashJoin(
       final QueryKit queryKit,
       final String queryId,
+      final QueryContext queryContext,
       final JoinDataSource dataSource,
       final QuerySegmentSpec querySegmentSpec,
       final int maxWorkerCount,
@@ -294,11 +327,13 @@ public class DataSourcePlan
     final DataSourcePlan basePlan = forDataSource(
         queryKit,
         queryId,
+        queryContext,
         analysis.getBaseDataSource(),
         querySegmentSpec,
         null, // Don't push query filters down through a join: this needs some work to ensure pruning works properly.
         maxWorkerCount,
         Math.max(minStageNumber, subQueryDefBuilder.getNextStageNumber()),
+
         broadcast
     );
 
@@ -312,6 +347,7 @@ public class DataSourcePlan
       final DataSourcePlan clausePlan = forDataSource(
           queryKit,
           queryId,
+          queryContext,
           clause.getDataSource(),
           new MultipleIntervalSegmentSpec(Intervals.ONLY_ETERNITY),
           null, // Don't push query filters down through a join: this needs some work to ensure pruning works properly.
@@ -341,6 +377,117 @@ public class DataSourcePlan
     return new DataSourcePlan(newDataSource, inputSpecs, broadcastInputs, subQueryDefBuilder);
   }
 
+  /**
+   * Build a plan for sort-merge join.
+   */
+  private static DataSourcePlan forSortMergeJoin(
+      final QueryKit queryKit,
+      final String queryId,
+      final JoinDataSource dataSource,
+      final QuerySegmentSpec querySegmentSpec,
+      final int maxWorkerCount,
+      final int minStageNumber,
+      final boolean broadcast
+  )
+  {
+    checkQuerySegmentSpecIsEternity(dataSource, querySegmentSpec);
+    SortMergeJoinFrameProcessorFactory.validateCondition(dataSource.getConditionAnalysis());
+
+    // Partition by keys given by the join condition.
+    final List<List<KeyColumn>> partitionKeys = SortMergeJoinFrameProcessorFactory.toKeyColumns(
+        SortMergeJoinFrameProcessorFactory.validateCondition(dataSource.getConditionAnalysis())
+    );
+
+    final QueryDefinitionBuilder subQueryDefBuilder = QueryDefinition.builder();
+
+    // Plan the left input.
+    // We're confident that we can cast dataSource.getLeft() to QueryDataSource, because DruidJoinQueryRel creates
+    // subqueries when the join algorithm is sortMerge.
+    final DataSourcePlan leftPlan = forQuery(
+        queryKit,
+        queryId,
+        (QueryDataSource) dataSource.getLeft(),
+        maxWorkerCount,
+        Math.max(minStageNumber, subQueryDefBuilder.getNextStageNumber()),
+        false
+    );
+    leftPlan.getSubQueryDefBuilder().ifPresent(subQueryDefBuilder::addAll);
+
+    // Plan the right input.
+    // We're confident that we can cast dataSource.getRight() to QueryDataSource, because DruidJoinQueryRel creates
+    // subqueries when the join algorithm is sortMerge.
+    final DataSourcePlan rightPlan = forQuery(
+        queryKit,
+        queryId,
+        (QueryDataSource) dataSource.getRight(),
+        maxWorkerCount,
+        Math.max(minStageNumber, subQueryDefBuilder.getNextStageNumber()),
+        false
+    );
+    rightPlan.getSubQueryDefBuilder().ifPresent(subQueryDefBuilder::addAll);
+
+    // Build up the left stage.
+    final StageDefinitionBuilder leftBuilder = subQueryDefBuilder.getStageBuilder(
+        ((StageInputSpec) Iterables.getOnlyElement(leftPlan.getInputSpecs())).getStageNumber()
+    );
+
+    final List<KeyColumn> leftPartitionKey = partitionKeys.get(0);
+    leftBuilder.shuffleSpec(new HashShuffleSpec(new ClusterBy(leftPartitionKey, 0), maxWorkerCount));
+    leftBuilder.signature(QueryKitUtils.sortableSignature(leftBuilder.getSignature(), leftPartitionKey));
+
+    // Build up the right stage.
+    final StageDefinitionBuilder rightBuilder = subQueryDefBuilder.getStageBuilder(
+        ((StageInputSpec) Iterables.getOnlyElement(rightPlan.getInputSpecs())).getStageNumber()
+    );
+
+    final List<KeyColumn> rightPartitionKey = partitionKeys.get(1);
+    rightBuilder.shuffleSpec(new HashShuffleSpec(new ClusterBy(rightPartitionKey, 0), maxWorkerCount));
+    rightBuilder.signature(QueryKitUtils.sortableSignature(rightBuilder.getSignature(), rightPartitionKey));
+
+    // Compute join signature.
+    final RowSignature.Builder joinSignatureBuilder = RowSignature.builder();
+
+    for (String leftColumn : leftBuilder.getSignature().getColumnNames()) {
+      joinSignatureBuilder.add(leftColumn, leftBuilder.getSignature().getColumnType(leftColumn).orElse(null));
+    }
+
+    for (String rightColumn : rightBuilder.getSignature().getColumnNames()) {
+      joinSignatureBuilder.add(
+          dataSource.getRightPrefix() + rightColumn,
+          rightBuilder.getSignature().getColumnType(rightColumn).orElse(null)
+      );
+    }
+
+    // Build up the join stage.
+    final int stageNumber = Math.max(minStageNumber, subQueryDefBuilder.getNextStageNumber());
+
+    subQueryDefBuilder.add(
+        StageDefinition.builder(stageNumber)
+                       .inputs(
+                           ImmutableList.of(
+                               Iterables.getOnlyElement(leftPlan.getInputSpecs()),
+                               Iterables.getOnlyElement(rightPlan.getInputSpecs())
+                           )
+                       )
+                       .maxWorkerCount(maxWorkerCount)
+                       .signature(joinSignatureBuilder.build())
+                       .processorFactory(
+                           new SortMergeJoinFrameProcessorFactory(
+                               dataSource.getRightPrefix(),
+                               dataSource.getConditionAnalysis(),
+                               dataSource.getJoinType()
+                           )
+                       )
+    );
+
+    return new DataSourcePlan(
+        new InputNumberDataSource(0),
+        Collections.singletonList(new StageInputSpec(stageNumber)),
+        broadcast ? IntOpenHashSet.of(0) : IntSets.emptySet(),
+        subQueryDefBuilder
+    );
+  }
+
   private static DataSource shiftInputNumbers(final DataSource dataSource, final int shift)
   {
     if (shift < 0) {
diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/QueryKitUtils.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/QueryKitUtils.java
index 1f863a8c73..bcd3a30df6 100644
--- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/QueryKitUtils.java
+++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/QueryKitUtils.java
@@ -23,7 +23,8 @@ import com.fasterxml.jackson.core.JsonProcessingException;
 import com.fasterxml.jackson.databind.ObjectMapper;
 import org.apache.calcite.sql.dialect.CalciteSqlDialect;
 import org.apache.druid.frame.key.ClusterBy;
-import org.apache.druid.frame.key.SortColumn;
+import org.apache.druid.frame.key.KeyColumn;
+import org.apache.druid.frame.key.KeyOrder;
 import org.apache.druid.java.util.common.IAE;
 import org.apache.druid.java.util.common.ISE;
 import org.apache.druid.java.util.common.StringUtils;
@@ -107,8 +108,8 @@ public class QueryKitUtils
     if (Granularities.ALL.equals(segmentGranularity)) {
       return clusterBy;
     } else {
-      final List<SortColumn> newColumns = new ArrayList<>(clusterBy.getColumns().size() + 1);
-      newColumns.add(new SortColumn(QueryKitUtils.SEGMENT_GRANULARITY_COLUMN, false));
+      final List<KeyColumn> newColumns = new ArrayList<>(clusterBy.getColumns().size() + 1);
+      newColumns.add(new KeyColumn(QueryKitUtils.SEGMENT_GRANULARITY_COLUMN, KeyOrder.ASCENDING));
       newColumns.addAll(clusterBy.getColumns());
       return new ClusterBy(newColumns, 1);
     }
@@ -153,12 +154,12 @@ public class QueryKitUtils
    */
   public static RowSignature sortableSignature(
       final RowSignature signature,
-      final List<SortColumn> clusterByColumns
+      final List<KeyColumn> clusterByColumns
   )
   {
     final RowSignature.Builder builder = RowSignature.builder();
 
-    for (final SortColumn columnName : clusterByColumns) {
+    for (final KeyColumn columnName : clusterByColumns) {
       final Optional<ColumnType> columnType = signature.getColumnType(columnName.columnName());
       if (!columnType.isPresent()) {
         throw new IAE("Column [%s] not present in signature", columnName);
@@ -168,7 +169,7 @@ public class QueryKitUtils
     }
 
     final Set<String> clusterByColumnNames =
-        clusterByColumns.stream().map(SortColumn::columnName).collect(Collectors.toSet());
+        clusterByColumns.stream().map(KeyColumn::columnName).collect(Collectors.toSet());
 
     for (int i = 0; i < signature.size(); i++) {
       final String columnName = signature.getColumnName(i);
diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/ShuffleSpecFactories.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/ShuffleSpecFactories.java
index 848567b2fc..971aa9b7e0 100644
--- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/ShuffleSpecFactories.java
+++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/ShuffleSpecFactories.java
@@ -19,7 +19,8 @@
 
 package org.apache.druid.msq.querykit;
 
-import org.apache.druid.msq.kernel.MaxCountShuffleSpec;
+import org.apache.druid.msq.kernel.GlobalSortMaxCountShuffleSpec;
+import org.apache.druid.msq.kernel.MixShuffleSpec;
 
 /**
  * Static factory methods for common implementations of {@link ShuffleSpecFactory}.
@@ -32,20 +33,24 @@ public class ShuffleSpecFactories
   }
 
   /**
-   * Factory that produces a single output partition.
+   * Factory that produces a single output partition, which may or may not be sorted.
    */
   public static ShuffleSpecFactory singlePartition()
   {
-    return (clusterBy, aggregate) ->
-        new MaxCountShuffleSpec(clusterBy, 1, aggregate);
+    return (clusterBy, aggregate) -> {
+      if (clusterBy.sortable() && !clusterBy.isEmpty()) {
+        return new GlobalSortMaxCountShuffleSpec(clusterBy, 1, aggregate);
+      } else {
+        return MixShuffleSpec.instance();
+      }
+    };
   }
 
   /**
    * Factory that produces a particular number of output partitions.
    */
-  public static ShuffleSpecFactory subQueryWithMaxWorkerCount(final int maxWorkerCount)
+  public static ShuffleSpecFactory globalSortWithMaxPartitionCount(final int partitions)
   {
-    return (clusterBy, aggregate) ->
-        new MaxCountShuffleSpec(clusterBy, maxWorkerCount, aggregate);
+    return (clusterBy, aggregate) -> new GlobalSortMaxCountShuffleSpec(clusterBy, partitions, aggregate);
   }
 }
diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/ShuffleSpecFactory.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/ShuffleSpecFactory.java
index 6b06276145..b9c0f1a0d2 100644
--- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/ShuffleSpecFactory.java
+++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/ShuffleSpecFactory.java
@@ -29,7 +29,7 @@ public interface ShuffleSpecFactory
 {
   /**
    * Build a {@link ShuffleSpec} for given {@link ClusterBy}. The {@code aggregate} flag is used to populate
-   * {@link ShuffleSpec#doesAggregateByClusterKey()}.
+   * {@link ShuffleSpec#doesAggregate()}.
    */
   ShuffleSpec build(ClusterBy clusterBy, boolean aggregate);
 }
diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/common/OffsetLimitFrameProcessor.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/common/OffsetLimitFrameProcessor.java
index a885ee1faa..aa2e054ec9 100644
--- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/common/OffsetLimitFrameProcessor.java
+++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/common/OffsetLimitFrameProcessor.java
@@ -21,8 +21,6 @@ package org.apache.druid.msq.querykit.common;
 
 import it.unimi.dsi.fastutil.ints.IntSet;
 import org.apache.druid.frame.Frame;
-import org.apache.druid.frame.FrameType;
-import org.apache.druid.frame.allocation.HeapMemoryAllocator;
 import org.apache.druid.frame.channel.FrameWithPartition;
 import org.apache.druid.frame.channel.ReadableFrameChannel;
 import org.apache.druid.frame.channel.WritableFrameChannel;
@@ -33,7 +31,6 @@ import org.apache.druid.frame.processor.ReturnOrAwait;
 import org.apache.druid.frame.read.FrameReader;
 import org.apache.druid.frame.write.FrameWriter;
 import org.apache.druid.frame.write.FrameWriterFactory;
-import org.apache.druid.frame.write.FrameWriters;
 import org.apache.druid.java.util.common.ISE;
 import org.apache.druid.segment.Cursor;
 
@@ -47,8 +44,10 @@ public class OffsetLimitFrameProcessor implements FrameProcessor<Long>
   private final ReadableFrameChannel inputChannel;
   private final WritableFrameChannel outputChannel;
   private final FrameReader frameReader;
+  private final FrameWriterFactory frameWriterFactory;
   private final long offset;
   private final long limit;
+  private final boolean inputSignatureMatchesOutputSignature;
 
   long rowsProcessedSoFar = 0L;
 
@@ -56,6 +55,7 @@ public class OffsetLimitFrameProcessor implements FrameProcessor<Long>
       ReadableFrameChannel inputChannel,
       WritableFrameChannel outputChannel,
       FrameReader frameReader,
+      FrameWriterFactory frameWriterFactory,
       long offset,
       long limit
   )
@@ -63,8 +63,10 @@ public class OffsetLimitFrameProcessor implements FrameProcessor<Long>
     this.inputChannel = inputChannel;
     this.outputChannel = outputChannel;
     this.frameReader = frameReader;
+    this.frameWriterFactory = frameWriterFactory;
     this.offset = offset;
     this.limit = limit;
+    this.inputSignatureMatchesOutputSignature = frameReader.signature().equals(frameWriterFactory.signature());
 
     if (offset < 0 || limit < 0) {
       throw new ISE("Offset and limit must be nonnegative");
@@ -130,31 +132,25 @@ public class OffsetLimitFrameProcessor implements FrameProcessor<Long>
       // Offset is past the end of the frame; skip it.
       rowsProcessedSoFar += frame.numRows();
       return null;
-    } else if (startRow == 0 && endRow == frame.numRows()) {
+    } else if (startRow == 0
+               && endRow == frame.numRows()
+               && inputSignatureMatchesOutputSignature
+               && frameWriterFactory.frameType().equals(frame.type())) {
+      // Want the whole frame; emit it as-is.
       rowsProcessedSoFar += frame.numRows();
       return frame;
     }
 
     final Cursor cursor = FrameProcessors.makeCursor(frame, frameReader);
 
-    // Using an unlimited memory allocator to make sure that atleast a single frame can always be generated
-    final HeapMemoryAllocator unlimitedAllocator = HeapMemoryAllocator.unlimited();
-
     long rowsProcessedSoFarInFrame = 0;
 
-    final FrameWriterFactory frameWriterFactory = FrameWriters.makeFrameWriterFactory(
-        FrameType.ROW_BASED,
-        unlimitedAllocator,
-        frameReader.signature(),
-        Collections.emptyList()
-    );
-
     try (final FrameWriter frameWriter = frameWriterFactory.newFrameWriter(cursor.getColumnSelectorFactory())) {
       while (!cursor.isDone() && rowsProcessedSoFarInFrame < endRow) {
         if (rowsProcessedSoFarInFrame >= startRow && !frameWriter.addSelection()) {
           // Don't retry; it can't work because the allocator is unlimited anyway.
           // Also, I don't think this line can be reached, because the allocator is unlimited.
-          throw new FrameRowTooLargeException(unlimitedAllocator.capacity());
+          throw new FrameRowTooLargeException(frameWriterFactory.allocatorCapacity());
         }
 
         cursor.advance();
diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/common/OffsetLimitFrameProcessorFactory.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/common/OffsetLimitFrameProcessorFactory.java
index ecf03a08c3..134c9ddd4e 100644
--- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/common/OffsetLimitFrameProcessorFactory.java
+++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/common/OffsetLimitFrameProcessorFactory.java
@@ -25,6 +25,7 @@ import com.fasterxml.jackson.annotation.JsonProperty;
 import com.fasterxml.jackson.annotation.JsonTypeName;
 import com.google.common.collect.Iterables;
 import com.google.common.collect.Iterators;
+import org.apache.druid.frame.allocation.HeapMemoryAllocator;
 import org.apache.druid.frame.channel.ReadableConcatFrameChannel;
 import org.apache.druid.frame.processor.FrameProcessor;
 import org.apache.druid.frame.processor.OutputChannel;
@@ -122,10 +123,12 @@ public class OffsetLimitFrameProcessorFactory extends BaseFrameProcessorFactory
       }
 
       // Note: OffsetLimitFrameProcessor does not use allocator from the outputChannel; it uses unlimited instead.
+      // This ensures that a single, limited output frame can always be generated from an input frame.
       return new OffsetLimitFrameProcessor(
           ReadableConcatFrameChannel.open(Iterators.transform(readableInputs.iterator(), ReadableInput::getChannel)),
           outputChannel.getWritableChannel(),
           readableInputs.frameReader(),
+          stageDefinition.createFrameWriterFactory(HeapMemoryAllocator.unlimited()),
           offset,
           // Limit processor will add limit + offset at various points; must avoid overflow
           limit == null ? Long.MAX_VALUE - offset : limit
diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/common/SortMergeJoinFrameProcessor.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/common/SortMergeJoinFrameProcessor.java
new file mode 100644
index 0000000000..56e54842c9
--- /dev/null
+++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/common/SortMergeJoinFrameProcessor.java
@@ -0,0 +1,1075 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.msq.querykit.common;
+
+import com.google.common.base.Preconditions;
+import com.google.common.base.Predicate;
+import com.google.common.collect.ImmutableList;
+import it.unimi.dsi.fastutil.ints.IntOpenHashSet;
+import it.unimi.dsi.fastutil.ints.IntSet;
+import org.apache.druid.frame.Frame;
+import org.apache.druid.frame.channel.FrameWithPartition;
+import org.apache.druid.frame.channel.ReadableFrameChannel;
+import org.apache.druid.frame.channel.WritableFrameChannel;
+import org.apache.druid.frame.key.FrameComparisonWidget;
+import org.apache.druid.frame.key.KeyColumn;
+import org.apache.druid.frame.key.RowKey;
+import org.apache.druid.frame.key.RowKeyReader;
+import org.apache.druid.frame.processor.FrameProcessor;
+import org.apache.druid.frame.processor.FrameProcessors;
+import org.apache.druid.frame.processor.FrameRowTooLargeException;
+import org.apache.druid.frame.processor.ReturnOrAwait;
+import org.apache.druid.frame.read.FrameReader;
+import org.apache.druid.frame.segment.FrameCursor;
+import org.apache.druid.frame.write.FrameWriter;
+import org.apache.druid.frame.write.FrameWriterFactory;
+import org.apache.druid.java.util.common.ISE;
+import org.apache.druid.msq.exec.Limits;
+import org.apache.druid.msq.indexing.error.MSQException;
+import org.apache.druid.msq.indexing.error.TooManyRowsWithSameKeyFault;
+import org.apache.druid.msq.input.ReadableInput;
+import org.apache.druid.query.dimension.DefaultDimensionSpec;
+import org.apache.druid.query.dimension.DimensionSpec;
+import org.apache.druid.query.filter.ValueMatcher;
+import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector;
+import org.apache.druid.segment.ColumnSelectorFactory;
+import org.apache.druid.segment.ColumnValueSelector;
+import org.apache.druid.segment.Cursor;
+import org.apache.druid.segment.DimensionSelector;
+import org.apache.druid.segment.DimensionSelectorUtils;
+import org.apache.druid.segment.IdLookup;
+import org.apache.druid.segment.NilColumnValueSelector;
+import org.apache.druid.segment.column.ColumnCapabilities;
+import org.apache.druid.segment.column.RowSignature;
+import org.apache.druid.segment.data.IndexedInts;
+import org.apache.druid.segment.data.ZeroIndexedInts;
+import org.apache.druid.segment.join.JoinPrefixUtils;
+import org.apache.druid.segment.join.JoinType;
+
+import javax.annotation.Nullable;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+
+/**
+ * Processor for a sort-merge join of two inputs.
+ *
+ * Prerequisites:
+ *
+ * 1) Two inputs, both of which are stages; i.e. {@link ReadableInput#hasChannel()}.
+ *
+ * 2) Conditions are all simple equalities. Validated by {@link SortMergeJoinFrameProcessorFactory#validateCondition}
+ * and then transformed to lists of key columns by {@link SortMergeJoinFrameProcessorFactory#toKeyColumns}.
+ *
+ * 3) Both inputs are comprised of {@link org.apache.druid.frame.FrameType#ROW_BASED} frames, are sorted by the same
+ * key, and that key can be used to check the provided condition. Validated by
+ * {@link SortMergeJoinFrameProcessorFactory#validateInputFrameSignatures}.
+ *
+ * Algorithm:
+ *
+ * 1) Read current key from each side of the join.
+ *
+ * 2) If there is no match, emit or skip the row for the earlier key, as appropriate, based on the join type.
+ *
+ * 3) If there is a match, identify a complete set on one side or the other. (It doesn't matter which side has the
+ * complete set, but we need it on one of them.) We mark the first row for the key using {@link Tracker#markCurrent()}
+ * and find complete sets using {@link Tracker#hasCompleteSetForMark()}. Once we find one, we store it in
+ * {@link #trackerWithCompleteSetForCurrentKey}. If both sides have a complete set, we break ties by choosing the
+ * left side.
+ *
+ * 4) Once a complete set for the current key is identified: for each row on the *other* side, loop through the entire
+ * set of rows on {@link #trackerWithCompleteSetForCurrentKey}, and emit that many joined rows.
+ *
+ * 5) Once we process the final row on the *other* side, reset both marks with {@link Tracker#markCurrent()} and
+ * continue the algorithm.
+ */
+public class SortMergeJoinFrameProcessor implements FrameProcessor<Long>
+{
+  private static final int LEFT = 0;
+  private static final int RIGHT = 1;
+
+  /**
+   * Input channels for each side of the join. Two-element array: {@link #LEFT} and {@link #RIGHT}.
+   */
+  private final List<ReadableFrameChannel> inputChannels;
+
+  /**
+   * Trackers for each side of the join. Two-element array: {@link #LEFT} and {@link #RIGHT}.
+   */
+  private final List<Tracker> trackers;
+
+  private final WritableFrameChannel outputChannel;
+  private final FrameWriterFactory frameWriterFactory;
+  private final String rightPrefix;
+  private final JoinType joinType;
+  private final JoinColumnSelectorFactory joinColumnSelectorFactory = new JoinColumnSelectorFactory();
+  private FrameWriter frameWriter = null;
+
+  // Used by runIncrementally to defer certain logic to the next run.
+  private Runnable nextIterationRunnable = null;
+
+  // Used by runIncrementally to remember which tracker has the complete set for the current key.
+  private int trackerWithCompleteSetForCurrentKey = -1;
+
+  SortMergeJoinFrameProcessor(
+      ReadableInput left,
+      ReadableInput right,
+      WritableFrameChannel outputChannel,
+      FrameWriterFactory frameWriterFactory,
+      String rightPrefix,
+      List<List<KeyColumn>> keyColumns,
+      JoinType joinType
+  )
+  {
+    this.inputChannels = ImmutableList.of(left.getChannel(), right.getChannel());
+    this.outputChannel = outputChannel;
+    this.frameWriterFactory = frameWriterFactory;
+    this.rightPrefix = rightPrefix;
+    this.joinType = joinType;
+    this.trackers = ImmutableList.of(
+        new Tracker(left, keyColumns.get(LEFT)),
+        new Tracker(right, keyColumns.get(RIGHT))
+    );
+  }
+
+  @Override
+  public List<ReadableFrameChannel> inputChannels()
+  {
+    return inputChannels;
+  }
+
+  @Override
+  public List<WritableFrameChannel> outputChannels()
+  {
+    return Collections.singletonList(outputChannel);
+  }
+
+  @Override
+  public ReturnOrAwait<Long> runIncrementally(IntSet readableInputs) throws IOException
+  {
+    // Fetch enough frames such that each tracker has one readable row.
+    for (int i = 0; i < inputChannels.size(); i++) {
+      final Tracker tracker = trackers.get(i);
+      if (tracker.isAtEndOfPushedData() && !pushNextFrame(i)) {
+        return nextAwait();
+      }
+    }
+
+    // Initialize new output frame, if needed.
+    startNewFrameIfNeeded();
+
+    while (!allTrackersAreAtEnd()
+           && !trackers.get(LEFT).needsMoreData()
+           && !trackers.get(RIGHT).needsMoreData()) {
+      // Algorithm can proceed: not all trackers are at the end of their streams, and no tracker needs more data to
+      // read the current cursor or move it forward.
+      if (nextIterationRunnable != null) {
+        final Runnable tmp = nextIterationRunnable;
+        nextIterationRunnable = null;
+        tmp.run();
+      }
+
+      final int markCmp = compareMarks();
+
+      // Two rows match if the keys compare equal _and_ neither key has a null component. (x JOIN y ON x.a = y.a does
+      // not match rows where "x.a" is null.)
+      final boolean match = markCmp == 0 && trackers.get(LEFT).hasCompletelyNonNullMark();
+
+      // If marked keys are equal on both sides ("match"), at least one side must have a complete set of rows
+      // for the marked key.
+      if (match && trackerWithCompleteSetForCurrentKey < 0) {
+        for (int i = 0; i < inputChannels.size(); i++) {
+          final Tracker tracker = trackers.get(i);
+
+          // Fetch up to one frame from each tracker, to check if that tracker has a complete set.
+          // Can't fetch more than one frame, because channels are only guaranteed to have one frame per run.
+          if (tracker.hasCompleteSetForMark() || (pushNextFrame(i) && tracker.hasCompleteSetForMark())) {
+            trackerWithCompleteSetForCurrentKey = i;
+            break;
+          }
+        }
+
+        if (trackerWithCompleteSetForCurrentKey < 0) {
+          // Algorithm cannot proceed; fetch more frames on the next run.
+          return nextAwait();
+        }
+      }
+
+      if (match || (markCmp <= 0 && joinType.isLefty()) || (markCmp >= 0 && joinType.isRighty())) {
+        // Emit row, if there's room in the current frameWriter.
+        joinColumnSelectorFactory.cmp = markCmp;
+        joinColumnSelectorFactory.match = match;
+
+        if (!frameWriter.addSelection()) {
+          if (frameWriter.getNumRows() > 0) {
+            // Out of space in the current frame. Run again without moving cursors.
+            flushCurrentFrame();
+            return ReturnOrAwait.runAgain();
+          } else {
+            throw new FrameRowTooLargeException(frameWriterFactory.allocatorCapacity());
+          }
+        }
+      }
+
+      // Advance one or both trackers.
+      if (match) {
+        // Matching keys. First advance the tracker with the complete set.
+        final Tracker tracker = trackers.get(trackerWithCompleteSetForCurrentKey);
+        final Tracker otherTracker = trackers.get(trackerWithCompleteSetForCurrentKey == LEFT ? RIGHT : LEFT);
+
+        tracker.advance();
+        if (!tracker.isCurrentSameKeyAsMark()) {
+          // Reached end of complete set. Advance the other tracker.
+          otherTracker.advance();
+
+          // On next iteration (when we're sure to have data) either rewind the complete-set tracker, or update marks
+          // of both, as appropriate.
+          onNextIteration(() -> {
+            if (otherTracker.isCurrentSameKeyAsMark()) {
+              otherTracker.markCurrent(); // Set mark to enable cleanup of old frames.
+              tracker.rewindToMark();
+            } else {
+              // Reached end of the other side too. Advance marks on both trackers.
+              tracker.markCurrent();
+              otherTracker.markCurrent();
+              trackerWithCompleteSetForCurrentKey = -1;
+            }
+          });
+        }
+      } else {
+        final int trackerToAdvance;
+
+        if (markCmp < 0) {
+          trackerToAdvance = LEFT;
+        } else if (markCmp > 0) {
+          trackerToAdvance = RIGHT;
+        } else {
+          // Key is null on both sides. Note that there is a preference for running through the left side first
+          // on a FULL join. It doesn't really matter which side we run through first, but we do need to be consistent
+          // for the benefit of the logic in "shouldEmitColumnValue".
+          trackerToAdvance = joinType.isLefty() ? LEFT : RIGHT;
+        }
+
+        final Tracker tracker = trackers.get(trackerToAdvance);
+
+        tracker.advance();
+
+        // On next iteration (when we're sure to have data), update mark if the key changed.
+        onNextIteration(() -> {
+          if (!tracker.isCurrentSameKeyAsMark()) {
+            tracker.markCurrent();
+            trackerWithCompleteSetForCurrentKey = -1;
+          }
+        });
+      }
+    }
+
+    if (allTrackersAreAtEnd()) {
+      flushCurrentFrame();
+      return ReturnOrAwait.returnObject(0L);
+    } else {
+      // Keep reading.
+      return nextAwait();
+    }
+  }
+
+  @Override
+  public void cleanup() throws IOException
+  {
+    FrameProcessors.closeAll(inputChannels(), outputChannels(), frameWriter, () -> trackers.forEach(Tracker::clear));
+  }
+
+  /**
+   * Returns a {@link ReturnOrAwait#awaitAll} for the channel numbers that need more data and have not yet hit their
+   * buffered-bytes limit, {@link Limits#MAX_BUFFERED_BYTES_FOR_SORT_MERGE_JOIN}.
+   *
+   * If all channels have hit their limit, throws {@link MSQException} with {@link TooManyRowsWithSameKeyFault}.
+   */
+  private ReturnOrAwait<Long> nextAwait()
+  {
+    final IntSet awaitSet = new IntOpenHashSet();
+    int trackerAtLimit = -1;
+
+    for (int i = 0; i < inputChannels.size(); i++) {
+      final Tracker tracker = trackers.get(i);
+      if (tracker.needsMoreData()) {
+        if (tracker.totalBytesBuffered() < Limits.MAX_BUFFERED_BYTES_FOR_SORT_MERGE_JOIN) {
+          awaitSet.add(i);
+        } else if (trackerAtLimit < 0) {
+          trackerAtLimit = i;
+        }
+      }
+    }
+
+    if (awaitSet.isEmpty() && trackerAtLimit > 0) {
+      // All trackers that need more data are at their max buffered bytes limit. Generate a nice exception.
+      final Tracker tracker = trackers.get(trackerAtLimit);
+      if (tracker.totalBytesBuffered() > Limits.MAX_BUFFERED_BYTES_FOR_SORT_MERGE_JOIN) {
+        // Generate a nice exception.
+        throw new MSQException(
+            new TooManyRowsWithSameKeyFault(
+                tracker.readMarkKey(),
+                tracker.totalBytesBuffered(),
+                Limits.MAX_BUFFERED_BYTES_FOR_SORT_MERGE_JOIN
+            )
+        );
+      }
+    }
+
+    return ReturnOrAwait.awaitAll(awaitSet);
+  }
+
+  /**
+   * Whether all trackers return true from {@link Tracker#isAtEnd()}.
+   */
+  private boolean allTrackersAreAtEnd()
+  {
+    for (Tracker tracker : trackers) {
+      if (!tracker.isAtEnd()) {
+        return false;
+      }
+    }
+
+    return true;
+  }
+
+  /**
+   * Compares the marked rows of the two {@link #trackers}.
+   *
+   * @throws IllegalStateException if either tracker does not have a marked row and is not completely done
+   */
+  private int compareMarks()
+  {
+    final Tracker leftTracker = trackers.get(LEFT);
+    final Tracker rightTracker = trackers.get(RIGHT);
+
+    Preconditions.checkState(leftTracker.hasMark() || leftTracker.isAtEnd(), "left.hasMark || left.isAtEnd");
+    Preconditions.checkState(rightTracker.hasMark() || rightTracker.isAtEnd(), "right.hasMark || right.isAtEnd");
+
+    if (!leftTracker.hasMark()) {
+      return rightTracker.markFrame < 0 ? 0 : 1;
+    } else if (!rightTracker.hasMark()) {
+      return -1;
+    } else {
+      final FrameHolder leftHolder = leftTracker.holders.get(leftTracker.markFrame);
+      final FrameHolder rightHolder = rightTracker.holders.get(rightTracker.markFrame);
+      return leftHolder.comparisonWidget.compare(
+          leftTracker.markRow,
+          rightHolder.comparisonWidget,
+          rightTracker.markRow
+      );
+    }
+  }
+
+  /**
+   * Pushes a frame from the indicated channel into the appropriate tracker. Returns true if a frame was pushed
+   * or if the channel is finished.
+   */
+  private boolean pushNextFrame(final int channelNumber)
+  {
+    final ReadableFrameChannel channel = inputChannels.get(channelNumber);
+    final Tracker tracker = trackers.get(channelNumber);
+
+    if (!channel.isFinished() && !channel.canRead()) {
+      return false;
+    } else if (channel.isFinished()) {
+      tracker.push(null);
+      return true;
+    } else {
+      final Frame frame = channel.read();
+
+      if (frame.numRows() == 0) {
+        // Skip, read next.
+        return false;
+      } else {
+        tracker.push(frame);
+        return true;
+      }
+    }
+  }
+
+  private void onNextIteration(final Runnable runnable)
+  {
+    if (nextIterationRunnable != null) {
+      throw new ISE("postAdvanceRunnable already set");
+    } else {
+      nextIterationRunnable = runnable;
+    }
+  }
+
+  private void startNewFrameIfNeeded()
+  {
+    if (frameWriter == null) {
+      frameWriter = frameWriterFactory.newFrameWriter(joinColumnSelectorFactory);
+    }
+  }
+
+  private void flushCurrentFrame() throws IOException
+  {
+    if (frameWriter != null) {
+      if (frameWriter.getNumRows() > 0) {
+        final Frame frame = Frame.wrap(frameWriter.toByteArray());
+        frameWriter.close();
+        frameWriter = null;
+        outputChannel.write(new FrameWithPartition(frame, FrameWithPartition.NO_PARTITION));
+      }
+    }
+  }
+
+  /**
+   * Tracks the current set of rows that have the same key from a sequence of frames.
+   *
+   * markFrame and markRow are set when we encounter a new key, which enables rewinding and re-reading data with the
+   * same key.
+   */
+  private static class Tracker
+  {
+    /**
+     * Frame holders for the current frame, as well as immediately prior frames that share the same marked key.
+     * Prior frames are cleared on each call to {@link #markCurrent()}.
+     */
+    private final List<FrameHolder> holders = new ArrayList<>();
+    private final ReadableInput input;
+    private final List<KeyColumn> keyColumns;
+
+    // markFrame and markRow are the first frame and row with the current key.
+    private int markFrame = -1;
+    private int markRow = -1;
+
+    // currentFrame is the frame containing the current cursor row.
+    private int currentFrame = -1;
+
+    // done indicates that no more data is available in the channel.
+    private boolean done;
+
+    public Tracker(ReadableInput input, List<KeyColumn> keyColumns)
+    {
+      this.input = input;
+      this.keyColumns = keyColumns;
+    }
+
+    /**
+     * Adds a holder for a frame. If this is the first frame, sets the current cursor position and mark to the first
+     * row of the frame. Otherwise, the cursor position and mark are not changed.
+     *
+     * Pushing a null frame indicates no more frames are coming.
+     *
+     * @param frame frame, or null indicating no more frames are coming
+     */
+    public void push(final Frame frame)
+    {
+      if (frame == null) {
+        done = true;
+        return;
+      }
+
+      if (done) {
+        throw new ISE("Cannot push frames when already done");
+      }
+
+      final boolean atEndOfPushedData = isAtEndOfPushedData();
+      final FrameReader frameReader = input.getChannelFrameReader();
+      final FrameCursor cursor = FrameProcessors.makeCursor(frame, frameReader);
+      final FrameComparisonWidget comparisonWidget =
+          frameReader.makeComparisonWidget(frame, keyColumns);
+
+      final RowSignature.Builder keySignatureBuilder = RowSignature.builder();
+      for (final KeyColumn keyColumn : keyColumns) {
+        keySignatureBuilder.add(
+            keyColumn.columnName(),
+            frameReader.signature().getColumnType(keyColumn.columnName()).orElse(null)
+        );
+      }
+
+      holders.add(
+          new FrameHolder(
+              frame,
+              RowKeyReader.create(keySignatureBuilder.build()),
+              cursor,
+              comparisonWidget
+          )
+      );
+
+      if (atEndOfPushedData) {
+        // Move currentFrame so it points at the next row, which we now have, instead of an "isDone" cursor.
+        currentFrame = currentFrame < 0 ? 0 : currentFrame + 1;
+      }
+
+      if (markFrame < 0) {
+        // Cleared mark means we want the current row to be marked.
+        markFrame = currentFrame;
+        markRow = 0;
+      }
+    }
+
+    /**
+     * Number of bytes currently buffered in {@link #holders}.
+     */
+    public long totalBytesBuffered()
+    {
+      long bytes = 0;
+      for (final FrameHolder holder : holders) {
+        bytes += holder.frame.numBytes();
+      }
+      return bytes;
+    }
+
+    /**
+     * Cursor containing the current row.
+     */
+    @Nullable
+    public FrameCursor currentCursor()
+    {
+      if (currentFrame < 0) {
+        return null;
+      } else {
+        return holders.get(currentFrame).cursor;
+      }
+    }
+
+    /**
+     * Advances the current row (the current row of {@link #currentFrame}). After calling this method,
+     * {@link #isAtEndOfPushedData()} may start returning true.
+     */
+    public void advance()
+    {
+      assert !isAtEndOfPushedData();
+
+      final FrameHolder currentHolder = holders.get(currentFrame);
+
+      currentHolder.cursor.advance();
+
+      if (currentHolder.cursor.isDone() && currentFrame + 1 < holders.size()) {
+        currentFrame++;
+        holders.get(currentFrame).cursor.reset();
+      }
+    }
+
+    /**
+     * Whether this tracker has a marked row.
+     */
+    public boolean hasMark()
+    {
+      return markFrame >= 0;
+    }
+
+    /**
+     * Whether this tracker has a marked row that is completely nonnull.
+     */
+    public boolean hasCompletelyNonNullMark()
+    {
+      return hasMark() && !holders.get(markFrame).comparisonWidget.isPartiallyNullKey(markRow);
+    }
+
+    /**
+     * Reads the current marked key.
+     */
+    @Nullable
+    public List<Object> readMarkKey()
+    {
+      if (!hasMark()) {
+        return null;
+      }
+
+      final FrameHolder markHolder = holders.get(markFrame);
+      final RowKey markKey = markHolder.comparisonWidget.readKey(markRow);
+      return markHolder.keyReader.read(markKey);
+    }
+
+    /**
+     * Rewind to the mark row: the first one with the current key.
+     *
+     * @throws IllegalStateException if there is no marked row
+     */
+    public void rewindToMark()
+    {
+      if (markFrame < 0) {
+        throw new ISE("No mark");
+      }
+
+      currentFrame = markFrame;
+      holders.get(currentFrame).cursor.setCurrentRow(markRow);
+    }
+
+    /**
+     * Set the mark row to the current row. Used when data from the old mark to the current row is no longer needed.
+     */
+    public void markCurrent()
+    {
+      if (isAtEndOfPushedData()) {
+        clear();
+      } else {
+        // Remove unnecessary holders, now that the mark has moved on.
+        while (currentFrame > 0) {
+          if (currentFrame == holders.size() - 1) {
+            final FrameHolder lastHolder = holders.get(currentFrame);
+            holders.clear();
+            holders.add(lastHolder);
+            currentFrame = 0;
+          } else {
+            holders.remove(0);
+            currentFrame--;
+          }
+        }
+
+        markFrame = 0;
+        markRow = holders.get(currentFrame).cursor.getCurrentRow();
+      }
+    }
+
+    /**
+     * Whether the current cursor is past the end of the last frame for which we have data.
+     */
+    public boolean isAtEndOfPushedData()
+    {
+      return currentFrame < 0 || (currentFrame == holders.size() - 1 && holders.get(currentFrame).cursor.isDone());
+    }
+
+    /**
+     * Whether the current cursor is past the end of all data that will ever be pushed.
+     */
+    public boolean isAtEnd()
+    {
+      return done && isAtEndOfPushedData();
+    }
+
+    /**
+     * Whether this tracker needs more data in order to read the current cursor location or move it forward.
+     */
+    public boolean needsMoreData()
+    {
+      return !done && isAtEndOfPushedData();
+    }
+
+    /**
+     * Whether this tracker contains all rows for the marked key.
+     *
+     * @throws IllegalStateException if there is no marked key
+     */
+    public boolean hasCompleteSetForMark()
+    {
+      if (markFrame < 0) {
+        throw new ISE("No mark");
+      }
+
+      if (done) {
+        return true;
+      }
+
+      final FrameHolder lastHolder = holders.get(holders.size() - 1);
+      return !isSameKeyAsMark(lastHolder, lastHolder.frame.numRows() - 1);
+    }
+
+    /**
+     * Whether the current position (the current row of the {@link #currentFrame}) compares equally to the mark row.
+     * If {@link #isAtEnd()}, returns true iff there is no mark row.
+     */
+    public boolean isCurrentSameKeyAsMark()
+    {
+      if (isAtEnd()) {
+        return markFrame < 0;
+      } else {
+        assert !isAtEndOfPushedData();
+        final FrameHolder headHolder = holders.get(currentFrame);
+        return isSameKeyAsMark(headHolder, headHolder.cursor.getCurrentRow());
+      }
+    }
+
+    /**
+     * Clears the current mark and all buffered frames. Does not change {@link #done}.
+     */
+    public void clear()
+    {
+      holders.clear();
+      markFrame = -1;
+      markRow = -1;
+      currentFrame = -1;
+    }
+
+    /**
+     * Whether the provided frame and row compares equally to the mark row. The provided row must be at, or after,
+     * the mark row.
+     */
+    private boolean isSameKeyAsMark(final FrameHolder holder, final int row)
+    {
+      if (markFrame < 0) {
+        throw new ISE("No marked frame");
+      }
+      if (row < 0 || row >= holder.frame.numRows()) {
+        throw new ISE("Row [%d] out of bounds", row);
+      }
+
+      final FrameHolder markHolder = holders.get(markFrame);
+      final int cmp = markHolder.comparisonWidget.compare(markRow, holder.comparisonWidget, row);
+
+      if (cmp > 0) {
+        // The provided row is at, or after, the marked row.
+        // Therefore, cmp > 0 may indicate that input was provided out of order.
+        throw new ISE("Row compares higher than mark; out-of-order input?");
+      }
+
+      return cmp == 0;
+    }
+  }
+
+  /**
+   * Selector for joined rows. This is used as an input to {@link #frameWriter}.
+   */
+  private class JoinColumnSelectorFactory implements ColumnSelectorFactory
+  {
+    /**
+     * Current key comparison between left- and right-hand side.
+     */
+    private int cmp;
+
+    /**
+     * Whether there is a match between the left- and right-hand side. Not equivalent to {@code cmp == 0} in
+     * the case where the key on both sides is null.
+     */
+    private boolean match;
+
+    @Override
+    public DimensionSelector makeDimensionSelector(DimensionSpec dimensionSpec)
+    {
+      if (dimensionSpec.getExtractionFn() != null || dimensionSpec.mustDecorate()) {
+        // Not supported; but that's okay, because these features aren't needed when reading from this
+        // ColumnSelectorFactory. It is handed to a FrameWriter, which always uses DefaultDimensionSpec.
+        throw new UnsupportedOperationException();
+      }
+
+      final int channel = getChannelNumber(dimensionSpec.getDimension());
+      final ColumnCapabilities columnCapabilities = getColumnCapabilities(dimensionSpec.getDimension());
+
+      if (columnCapabilities == null) {
+        // Not an output column.
+        return DimensionSelector.constant(null);
+      } else {
+        return new JoinDimensionSelector(channel, getInputColumnName(dimensionSpec.getDimension()));
+      }
+    }
+
+    @Override
+    public ColumnValueSelector<?> makeColumnValueSelector(String columnName)
+    {
+      final int channel = getChannelNumber(columnName);
+      final ColumnCapabilities columnCapabilities = getColumnCapabilities(columnName);
+
+      if (columnCapabilities == null) {
+        // Not an output column.
+        return NilColumnValueSelector.instance();
+      } else {
+        return new JoinColumnValueSelector(channel, getInputColumnName(columnName));
+      }
+    }
+
+    @Nullable
+    @Override
+    public ColumnCapabilities getColumnCapabilities(String column)
+    {
+      return frameWriterFactory.signature().getColumnCapabilities(column);
+    }
+
+    /**
+     * Channel number for a possibly-prefixed column name.
+     */
+    private int getChannelNumber(final String column)
+    {
+      if (JoinPrefixUtils.isPrefixedBy(column, rightPrefix)) {
+        return RIGHT;
+      } else {
+        return LEFT;
+      }
+    }
+
+    /**
+     * Unprefixed column name for a possibly-prefixed column name.
+     */
+    private String getInputColumnName(final String column)
+    {
+      if (JoinPrefixUtils.isPrefixedBy(column, rightPrefix)) {
+        return JoinPrefixUtils.unprefix(column, rightPrefix);
+      } else {
+        return column;
+      }
+    }
+
+    /**
+     * Whether columns for the given channel are to be emitted with the current row.
+     */
+    private boolean shouldEmitColumnValue(final int channel)
+    {
+      // Asymmetry between left and right is necessary to properly handle FULL OUTER case where there are null keys.
+      // In this case, we run through the left-hand side first, then the right-hand side.
+      return !trackers.get(channel).isAtEndOfPushedData()
+             && (match
+                 || (channel == LEFT && joinType.isLefty() && cmp <= 0)
+                 || (channel == RIGHT && joinType.isRighty()
+                     && ((joinType.isLefty() && cmp > 0) || (!joinType.isLefty() && cmp >= 0))));
+    }
+
+    private class JoinDimensionSelector implements DimensionSelector
+    {
+      private final int channel;
+      private final String columnName;
+
+      public JoinDimensionSelector(int channel, String columnName)
+      {
+        this.channel = channel;
+        this.columnName = columnName;
+      }
+
+      private Cursor currentCursor;
+      private DimensionSelector currentSelector;
+
+      @Nullable
+      @Override
+      public Object getObject()
+      {
+        refreshCursor();
+        if (shouldEmitColumnValue(channel)) {
+          return currentSelector.getObject();
+        } else {
+          return null;
+        }
+      }
+
+      @Override
+      public IndexedInts getRow()
+      {
+        refreshCursor();
+        if (shouldEmitColumnValue(channel)) {
+          return currentSelector.getRow();
+        } else {
+          return ZeroIndexedInts.instance();
+        }
+      }
+
+      @Nullable
+      @Override
+      public String lookupName(int id)
+      {
+        refreshCursor();
+        if (shouldEmitColumnValue(channel)) {
+          return currentSelector.lookupName(id);
+        } else {
+          return null;
+        }
+      }
+
+      @Nullable
+      @Override
+      public ByteBuffer lookupNameUtf8(int id)
+      {
+        refreshCursor();
+        if (shouldEmitColumnValue(channel)) {
+          return currentSelector.lookupNameUtf8(id);
+        } else {
+          return null;
+        }
+      }
+
+      @Override
+      public boolean supportsLookupNameUtf8()
+      {
+        refreshCursor();
+        if (shouldEmitColumnValue(channel)) {
+          return currentSelector.supportsLookupNameUtf8();
+        } else {
+          return true;
+        }
+      }
+
+      @Override
+      public int getValueCardinality()
+      {
+        return CARDINALITY_UNKNOWN;
+      }
+
+      @Override
+      public boolean nameLookupPossibleInAdvance()
+      {
+        return false;
+      }
+
+      @Nullable
+      @Override
+      public IdLookup idLookup()
+      {
+        return null;
+      }
+
+      @Override
+      public Class<?> classOfObject()
+      {
+        return Object.class;
+      }
+
+      @Override
+      public ValueMatcher makeValueMatcher(@Nullable String value)
+      {
+        return DimensionSelectorUtils.makeValueMatcherGeneric(this, value);
+      }
+
+      @Override
+      public ValueMatcher makeValueMatcher(Predicate<String> predicate)
+      {
+        return DimensionSelectorUtils.makeValueMatcherGeneric(this, predicate);
+      }
+
+      @Override
+      public void inspectRuntimeShape(RuntimeShapeInspector inspector)
+      {
+        // Not needed: TopN engine won't run on this.
+        throw new UnsupportedOperationException();
+      }
+
+      private void refreshCursor()
+      {
+        final FrameCursor headCursor = trackers.get(channel).currentCursor();
+
+        //noinspection ObjectEquality
+        if (currentCursor != headCursor) {
+          if (headCursor == null) {
+            currentCursor = null;
+            currentSelector = null;
+          } else {
+            currentCursor = headCursor;
+            currentSelector = headCursor.getColumnSelectorFactory()
+                                        .makeDimensionSelector(DefaultDimensionSpec.of(columnName));
+          }
+        }
+      }
+    }
+
+    private class JoinColumnValueSelector implements ColumnValueSelector<Object>
+    {
+      private final int channel;
+      private final String columnName;
+
+      private Cursor currentCursor;
+      private ColumnValueSelector<?> currentSelector;
+
+      public JoinColumnValueSelector(int channel, String columnName)
+      {
+        this.channel = channel;
+        this.columnName = columnName;
+      }
+
+      @Override
+      public long getLong()
+      {
+        refreshCursor();
+        if (shouldEmitColumnValue(channel)) {
+          return currentSelector.getLong();
+        } else {
+          return 0;
+        }
+      }
+
+      @Override
+      public double getDouble()
+      {
+        refreshCursor();
+        if (shouldEmitColumnValue(channel)) {
+          return currentSelector.getDouble();
+        } else {
+          return 0;
+        }
+      }
+
+      @Override
+      public float getFloat()
+      {
+        refreshCursor();
+        if (shouldEmitColumnValue(channel)) {
+          return currentSelector.getFloat();
+        } else {
+          return 0;
+        }
+      }
+
+      @Nullable
+      @Override
+      public Object getObject()
+      {
+        refreshCursor();
+        if (shouldEmitColumnValue(channel)) {
+          return currentSelector.getObject();
+        } else {
+          return null;
+        }
+      }
+
+      @Override
+      public boolean isNull()
+      {
+        refreshCursor();
+        return !shouldEmitColumnValue(channel) || currentSelector.isNull();
+      }
+
+      @Override
+      public Class<?> classOfObject()
+      {
+        return Object.class;
+      }
+
+      @Override
+      public void inspectRuntimeShape(RuntimeShapeInspector inspector)
+      {
+        // Not needed: TopN engine won't run on this.
+        throw new UnsupportedOperationException();
+      }
+
+      private void refreshCursor()
+      {
+        final FrameCursor headCursor = trackers.get(channel).currentCursor();
+
+        //noinspection ObjectEquality
+        if (currentCursor != headCursor) {
+          if (headCursor == null) {
+            currentCursor = null;
+            currentSelector = null;
+          } else {
+            currentCursor = headCursor;
+            currentSelector = headCursor.getColumnSelectorFactory().makeColumnValueSelector(columnName);
+          }
+        }
+      }
+    }
+  }
+
+  private static class FrameHolder
+  {
+    private final Frame frame;
+    private final RowKeyReader keyReader;
+    private final FrameCursor cursor;
+    private final FrameComparisonWidget comparisonWidget;
+
+    public FrameHolder(Frame frame, RowKeyReader keyReader, FrameCursor cursor, FrameComparisonWidget comparisonWidget)
+    {
+      this.frame = frame;
+      this.keyReader = keyReader;
+      this.cursor = cursor;
+      this.comparisonWidget = comparisonWidget;
+    }
+  }
+}
diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/common/SortMergeJoinFrameProcessorFactory.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/common/SortMergeJoinFrameProcessorFactory.java
new file mode 100644
index 0000000000..9aa5063092
--- /dev/null
+++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/common/SortMergeJoinFrameProcessorFactory.java
@@ -0,0 +1,277 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.msq.querykit.common;
+
+import com.fasterxml.jackson.annotation.JacksonInject;
+import com.fasterxml.jackson.annotation.JsonCreator;
+import com.fasterxml.jackson.annotation.JsonProperty;
+import com.fasterxml.jackson.annotation.JsonTypeName;
+import com.google.common.base.Preconditions;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Iterables;
+import it.unimi.dsi.fastutil.ints.Int2ObjectAVLTreeMap;
+import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
+import org.apache.druid.frame.key.KeyColumn;
+import org.apache.druid.frame.key.KeyOrder;
+import org.apache.druid.frame.processor.FrameProcessor;
+import org.apache.druid.frame.processor.OutputChannel;
+import org.apache.druid.frame.processor.OutputChannelFactory;
+import org.apache.druid.frame.processor.OutputChannels;
+import org.apache.druid.java.util.common.IAE;
+import org.apache.druid.java.util.common.ISE;
+import org.apache.druid.java.util.common.StringUtils;
+import org.apache.druid.java.util.common.guava.Sequences;
+import org.apache.druid.math.expr.ExprMacroTable;
+import org.apache.druid.msq.counters.CounterTracker;
+import org.apache.druid.msq.input.InputSlice;
+import org.apache.druid.msq.input.InputSliceReader;
+import org.apache.druid.msq.input.InputSlices;
+import org.apache.druid.msq.input.ReadableInput;
+import org.apache.druid.msq.input.stage.StageInputSlice;
+import org.apache.druid.msq.kernel.FrameContext;
+import org.apache.druid.msq.kernel.ProcessorsAndChannels;
+import org.apache.druid.msq.kernel.StageDefinition;
+import org.apache.druid.msq.querykit.BaseFrameProcessorFactory;
+import org.apache.druid.segment.column.RowSignature;
+import org.apache.druid.segment.join.Equality;
+import org.apache.druid.segment.join.JoinConditionAnalysis;
+import org.apache.druid.segment.join.JoinType;
+
+import javax.annotation.Nullable;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.function.Consumer;
+
+/**
+ * Factory for {@link SortMergeJoinFrameProcessor}, which does a sort-merge join of two inputs.
+ */
+@JsonTypeName("sortMergeJoin")
+public class SortMergeJoinFrameProcessorFactory extends BaseFrameProcessorFactory
+{
+  private static final int LEFT = 0;
+  private static final int RIGHT = 1;
+
+  private final String rightPrefix;
+  private final JoinConditionAnalysis condition;
+  private final JoinType joinType;
+
+  public SortMergeJoinFrameProcessorFactory(
+      final String rightPrefix,
+      final JoinConditionAnalysis condition,
+      final JoinType joinType
+  )
+  {
+    this.rightPrefix = Preconditions.checkNotNull(rightPrefix, "rightPrefix");
+    this.condition = validateCondition(Preconditions.checkNotNull(condition, "condition"));
+    this.joinType = Preconditions.checkNotNull(joinType, "joinType");
+  }
+
+  @JsonCreator
+  public static SortMergeJoinFrameProcessorFactory create(
+      @JsonProperty("rightPrefix") String rightPrefix,
+      @JsonProperty("condition") String condition,
+      @JsonProperty("joinType") JoinType joinType,
+      @JacksonInject ExprMacroTable macroTable
+  )
+  {
+    return new SortMergeJoinFrameProcessorFactory(
+        StringUtils.nullToEmptyNonDruidDataString(rightPrefix),
+        JoinConditionAnalysis.forExpression(
+            Preconditions.checkNotNull(condition, "condition"),
+            StringUtils.nullToEmptyNonDruidDataString(rightPrefix),
+            macroTable
+        ),
+        joinType
+    );
+  }
+
+  @JsonProperty
+  public String getRightPrefix()
+  {
+    return rightPrefix;
+  }
+
+  @JsonProperty
+  public String getCondition()
+  {
+    return condition.getOriginalExpression();
+  }
+
+  @JsonProperty
+  public JoinType getJoinType()
+  {
+    return joinType;
+  }
+
+  @Override
+  public ProcessorsAndChannels<FrameProcessor<Long>, Long> makeProcessors(
+      StageDefinition stageDefinition,
+      int workerNumber,
+      List<InputSlice> inputSlices,
+      InputSliceReader inputSliceReader,
+      @Nullable Object extra,
+      OutputChannelFactory outputChannelFactory,
+      FrameContext frameContext,
+      int maxOutstandingProcessors,
+      CounterTracker counters,
+      Consumer<Throwable> warningPublisher
+  ) throws IOException
+  {
+    if (inputSlices.size() != 2 || !inputSlices.stream().allMatch(slice -> slice instanceof StageInputSlice)) {
+      // Can't hit this unless there was some bug in QueryKit.
+      throw new ISE("Expected two stage inputs");
+    }
+
+    // Compute key columns.
+    final List<List<KeyColumn>> keyColumns = toKeyColumns(condition);
+
+    // Stitch up the inputs and validate each input channel signature.
+    // If validateInputFrameSignatures fails, it's a precondition violation: this class somehow got bad inputs.
+    final Int2ObjectMap<List<ReadableInput>> inputsByPartition = validateInputFrameSignatures(
+        InputSlices.attachAndCollectPartitions(
+            inputSlices,
+            inputSliceReader,
+            counters,
+            warningPublisher
+        ),
+        keyColumns
+    );
+
+    if (inputsByPartition.isEmpty()) {
+      return new ProcessorsAndChannels<>(Sequences.empty(), OutputChannels.none());
+    }
+
+    // Create output channels.
+    final Int2ObjectMap<OutputChannel> outputChannels = new Int2ObjectAVLTreeMap<>();
+    for (int partitionNumber : inputsByPartition.keySet()) {
+      outputChannels.put(partitionNumber, outputChannelFactory.openChannel(partitionNumber));
+    }
+
+    // Create processors.
+    final Iterable<FrameProcessor<Long>> processors = Iterables.transform(
+        inputsByPartition.int2ObjectEntrySet(),
+        entry -> {
+          final int partitionNumber = entry.getIntKey();
+          final List<ReadableInput> readableInputs = entry.getValue();
+          final OutputChannel outputChannel = outputChannels.get(partitionNumber);
+
+          return new SortMergeJoinFrameProcessor(
+              readableInputs.get(LEFT),
+              readableInputs.get(RIGHT),
+              outputChannel.getWritableChannel(),
+              stageDefinition.createFrameWriterFactory(outputChannel.getFrameMemoryAllocator()),
+              rightPrefix,
+              keyColumns,
+              joinType
+          );
+        }
+    );
+
+    return new ProcessorsAndChannels<>(
+        Sequences.simple(processors),
+        OutputChannels.wrap(ImmutableList.copyOf(outputChannels.values()))
+    );
+  }
+
+  /**
+   * Extracts key columns from a {@link JoinConditionAnalysis}. The returned list has two elements: 0 is the
+   * left-hand side, 1 is the right-hand side. Each sub-list has one element for each equi-condition.
+   *
+   * The condition must have been validated by {@link #validateCondition(JoinConditionAnalysis)}.
+   */
+  public static List<List<KeyColumn>> toKeyColumns(final JoinConditionAnalysis condition)
+  {
+    final List<List<KeyColumn>> retVal = new ArrayList<>();
+    retVal.add(new ArrayList<>()); // Left-side key columns
+    retVal.add(new ArrayList<>()); // Right-side key columns
+
+    for (final Equality equiCondition : condition.getEquiConditions()) {
+      final String leftColumn = Preconditions.checkNotNull(
+          equiCondition.getLeftExpr().getBindingIfIdentifier(),
+          "leftExpr#getBindingIfIdentifier"
+      );
+
+      retVal.get(0).add(new KeyColumn(leftColumn, KeyOrder.ASCENDING));
+      retVal.get(1).add(new KeyColumn(equiCondition.getRightColumn(), KeyOrder.ASCENDING));
+    }
+
+    return retVal;
+  }
+
+  /**
+   * Validates that a join condition can be handled by this processor. Returns the condition if it can be handled.
+   * Throws {@link IllegalArgumentException} if the condition cannot be handled.
+   */
+  public static JoinConditionAnalysis validateCondition(final JoinConditionAnalysis condition)
+  {
+    if (condition.isAlwaysTrue()) {
+      return condition;
+    }
+
+    if (condition.isAlwaysFalse()) {
+      throw new IAE("Cannot handle constant condition: %s", condition.getOriginalExpression());
+    }
+
+    if (condition.getNonEquiConditions().size() > 0) {
+      throw new IAE("Cannot handle non-equijoin condition: %s", condition.getOriginalExpression());
+    }
+
+    if (condition.getEquiConditions().stream().anyMatch(c -> !c.getLeftExpr().isIdentifier())) {
+      throw new IAE(
+          "Cannot handle equality condition involving left-hand expression: %s",
+          condition.getOriginalExpression()
+      );
+    }
+
+    return condition;
+  }
+
+  /**
+   * Validates that all signatures from {@link InputSlices#attachAndCollectPartitions} are prefixed by the
+   * provided {@code keyColumns}.
+   */
+  private static Int2ObjectMap<List<ReadableInput>> validateInputFrameSignatures(
+      final Int2ObjectMap<List<ReadableInput>> inputsByPartition,
+      final List<List<KeyColumn>> keyColumns
+  )
+  {
+    for (List<ReadableInput> readableInputs : inputsByPartition.values()) {
+      for (int i = 0; i < readableInputs.size(); i++) {
+        final ReadableInput readableInput = readableInputs.get(i);
+        Preconditions.checkState(readableInput.hasChannel(), "readableInput[%s].hasChannel", i);
+
+        final RowSignature signature = readableInput.getChannelFrameReader().signature();
+        for (int j = 0; j < keyColumns.get(i).size(); j++) {
+          final String columnName = keyColumns.get(i).get(j).columnName();
+          Preconditions.checkState(
+              columnName.equals(signature.getColumnName(j)),
+              "readableInput[%s] column[%s] has expected name[%s]",
+              i,
+              j,
+              columnName
+          );
+        }
+      }
+    }
+
+    return inputsByPartition;
+  }
+}
diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/groupby/GroupByPostShuffleFrameProcessor.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/groupby/GroupByPostShuffleFrameProcessor.java
index 207fe53de0..7a7a00a7dc 100644
--- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/groupby/GroupByPostShuffleFrameProcessor.java
+++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/groupby/GroupByPostShuffleFrameProcessor.java
@@ -22,12 +22,9 @@ package org.apache.druid.msq.querykit.groupby;
 import com.fasterxml.jackson.databind.ObjectMapper;
 import it.unimi.dsi.fastutil.ints.IntSet;
 import org.apache.druid.frame.Frame;
-import org.apache.druid.frame.FrameType;
-import org.apache.druid.frame.allocation.MemoryAllocator;
 import org.apache.druid.frame.channel.FrameWithPartition;
 import org.apache.druid.frame.channel.ReadableFrameChannel;
 import org.apache.druid.frame.channel.WritableFrameChannel;
-import org.apache.druid.frame.key.ClusterBy;
 import org.apache.druid.frame.processor.FrameProcessor;
 import org.apache.druid.frame.processor.FrameProcessors;
 import org.apache.druid.frame.processor.FrameRowTooLargeException;
@@ -35,7 +32,6 @@ import org.apache.druid.frame.processor.ReturnOrAwait;
 import org.apache.druid.frame.read.FrameReader;
 import org.apache.druid.frame.write.FrameWriter;
 import org.apache.druid.frame.write.FrameWriterFactory;
-import org.apache.druid.frame.write.FrameWriters;
 import org.apache.druid.msq.querykit.QueryKitUtils;
 import org.apache.druid.query.aggregation.AggregatorFactory;
 import org.apache.druid.query.aggregation.PostAggregator;
@@ -68,10 +64,8 @@ public class GroupByPostShuffleFrameProcessor implements FrameProcessor<Long>
   private final GroupByQuery query;
   private final ReadableFrameChannel inputChannel;
   private final WritableFrameChannel outputChannel;
-  private final MemoryAllocator allocator;
+  private final FrameWriterFactory frameWriterFactory;
   private final FrameReader frameReader;
-  private final RowSignature resultSignature;
-  private final ClusterBy clusterBy;
   private final ColumnSelectorFactory columnSelectorFactoryForFrameWriter;
   private final Comparator<ResultRow> compareFn;
   private final BinaryOperator<ResultRow> mergeFn;
@@ -90,10 +84,8 @@ public class GroupByPostShuffleFrameProcessor implements FrameProcessor<Long>
       final GroupByStrategySelector strategySelector,
       final ReadableFrameChannel inputChannel,
       final WritableFrameChannel outputChannel,
+      final FrameWriterFactory frameWriterFactory,
       final FrameReader frameReader,
-      final RowSignature resultSignature,
-      final ClusterBy clusterBy,
-      final MemoryAllocator allocator,
       final ObjectMapper jsonMapper
   )
   {
@@ -101,9 +93,7 @@ public class GroupByPostShuffleFrameProcessor implements FrameProcessor<Long>
     this.inputChannel = inputChannel;
     this.outputChannel = outputChannel;
     this.frameReader = frameReader;
-    this.resultSignature = resultSignature;
-    this.clusterBy = clusterBy;
-    this.allocator = allocator;
+    this.frameWriterFactory = frameWriterFactory;
     this.compareFn = strategySelector.strategize(query).createResultComparator(query);
     this.mergeFn = strategySelector.strategize(query).createMergeFn(query);
     this.finalizeFn = makeFinalizeFn(query);
@@ -249,10 +239,10 @@ public class GroupByPostShuffleFrameProcessor implements FrameProcessor<Long>
         outputRow = null;
         return true;
       } else {
-        throw new FrameRowTooLargeException(allocator.capacity());
+        throw new FrameRowTooLargeException(frameWriterFactory.allocatorCapacity());
       }
     } else {
-      throw new FrameRowTooLargeException(allocator.capacity());
+      throw new FrameRowTooLargeException(frameWriterFactory.allocatorCapacity());
     }
   }
 
@@ -269,8 +259,6 @@ public class GroupByPostShuffleFrameProcessor implements FrameProcessor<Long>
   private void setUpFrameWriterIfNeeded()
   {
     if (frameWriter == null) {
-      final FrameWriterFactory frameWriterFactory =
-          FrameWriters.makeFrameWriterFactory(FrameType.ROW_BASED, allocator, resultSignature, clusterBy.getColumns());
       frameWriter = frameWriterFactory.newFrameWriter(columnSelectorFactoryForFrameWriter);
     }
   }
diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/groupby/GroupByPostShuffleFrameProcessorFactory.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/groupby/GroupByPostShuffleFrameProcessorFactory.java
index ffb8bacf5e..73206da166 100644
--- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/groupby/GroupByPostShuffleFrameProcessorFactory.java
+++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/groupby/GroupByPostShuffleFrameProcessorFactory.java
@@ -24,8 +24,8 @@ import com.fasterxml.jackson.annotation.JsonProperty;
 import com.fasterxml.jackson.annotation.JsonTypeName;
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.Iterables;
-import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
-import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap;
+import it.unimi.dsi.fastutil.ints.Int2ObjectAVLTreeMap;
+import it.unimi.dsi.fastutil.ints.Int2ObjectSortedMap;
 import org.apache.druid.frame.processor.FrameProcessor;
 import org.apache.druid.frame.processor.OutputChannel;
 import org.apache.druid.frame.processor.OutputChannelFactory;
@@ -86,8 +86,8 @@ public class GroupByPostShuffleFrameProcessorFactory extends BaseFrameProcessorF
     // Expecting a single input slice from some prior stage.
     final StageInputSlice slice = (StageInputSlice) Iterables.getOnlyElement(inputSlices);
     final GroupByStrategySelector strategySelector = frameContext.groupByStrategySelector();
+    final Int2ObjectSortedMap<OutputChannel> outputChannels = new Int2ObjectAVLTreeMap<>();
 
-    final Int2ObjectMap<OutputChannel> outputChannels = new Int2ObjectOpenHashMap<>();
     for (final ReadablePartition partition : slice.getPartitions()) {
       outputChannels.computeIfAbsent(
           partition.getPartitionNumber(),
@@ -115,10 +115,8 @@ public class GroupByPostShuffleFrameProcessorFactory extends BaseFrameProcessorF
               strategySelector,
               readableInput.getChannel(),
               outputChannel.getWritableChannel(),
+              stageDefinition.createFrameWriterFactory(outputChannel.getFrameMemoryAllocator()),
               readableInput.getChannelFrameReader(),
-              stageDefinition.getSignature(),
-              stageDefinition.getClusterBy(),
-              outputChannel.getFrameMemoryAllocator(),
               frameContext.jsonMapper()
           );
         }
diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/groupby/GroupByPreShuffleFrameProcessorFactory.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/groupby/GroupByPreShuffleFrameProcessorFactory.java
index 63ad3cf890..285e75eaa5 100644
--- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/groupby/GroupByPreShuffleFrameProcessorFactory.java
+++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/groupby/GroupByPreShuffleFrameProcessorFactory.java
@@ -25,19 +25,15 @@ import com.fasterxml.jackson.annotation.JsonTypeName;
 import com.google.common.base.Preconditions;
 import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
 import org.apache.druid.collections.ResourceHolder;
-import org.apache.druid.frame.FrameType;
-import org.apache.druid.frame.allocation.MemoryAllocator;
 import org.apache.druid.frame.channel.WritableFrameChannel;
-import org.apache.druid.frame.key.ClusterBy;
 import org.apache.druid.frame.processor.FrameProcessor;
-import org.apache.druid.frame.write.FrameWriters;
+import org.apache.druid.frame.write.FrameWriterFactory;
 import org.apache.druid.java.util.common.Pair;
 import org.apache.druid.msq.input.ReadableInput;
 import org.apache.druid.msq.kernel.FrameContext;
 import org.apache.druid.msq.querykit.BaseLeafFrameProcessorFactory;
 import org.apache.druid.msq.querykit.LazyResourceHolder;
 import org.apache.druid.query.groupby.GroupByQuery;
-import org.apache.druid.segment.column.RowSignature;
 import org.apache.druid.segment.join.JoinableFactoryWrapper;
 
 @JsonTypeName("groupByPreShuffle")
@@ -62,9 +58,7 @@ public class GroupByPreShuffleFrameProcessorFactory extends BaseLeafFrameProcess
       final ReadableInput baseInput,
       final Int2ObjectMap<ReadableInput> sideChannels,
       final ResourceHolder<WritableFrameChannel> outputChannelHolder,
-      final ResourceHolder<MemoryAllocator> allocatorHolder,
-      final RowSignature signature,
-      final ClusterBy clusterBy,
+      final ResourceHolder<FrameWriterFactory> frameWriterFactoryHolder,
       final FrameContext frameContext
   )
   {
@@ -75,15 +69,7 @@ public class GroupByPreShuffleFrameProcessorFactory extends BaseLeafFrameProcess
         frameContext.groupByStrategySelector(),
         new JoinableFactoryWrapper(frameContext.joinableFactory()),
         outputChannelHolder,
-        new LazyResourceHolder<>(() -> Pair.of(
-            FrameWriters.makeFrameWriterFactory(
-                FrameType.ROW_BASED,
-                allocatorHolder.get(),
-                signature,
-                clusterBy.getColumns()
-            ),
-            allocatorHolder
-        )),
+        new LazyResourceHolder<>(() -> Pair.of(frameWriterFactoryHolder.get(), frameWriterFactoryHolder)),
         frameContext.memoryParameters().getBroadcastJoinMemory()
     );
   }
diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/groupby/GroupByQueryKit.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/groupby/GroupByQueryKit.java
index 402d2dfa3d..dea498a53b 100644
--- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/groupby/GroupByQueryKit.java
+++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/groupby/GroupByQueryKit.java
@@ -22,12 +22,13 @@ package org.apache.druid.msq.querykit.groupby;
 import com.fasterxml.jackson.databind.ObjectMapper;
 import com.google.common.base.Preconditions;
 import org.apache.druid.frame.key.ClusterBy;
-import org.apache.druid.frame.key.SortColumn;
+import org.apache.druid.frame.key.KeyColumn;
+import org.apache.druid.frame.key.KeyOrder;
 import org.apache.druid.java.util.common.ISE;
 import org.apache.druid.java.util.common.granularity.Granularities;
 import org.apache.druid.java.util.common.granularity.Granularity;
 import org.apache.druid.msq.input.stage.StageInputSpec;
-import org.apache.druid.msq.kernel.MaxCountShuffleSpec;
+import org.apache.druid.msq.kernel.MixShuffleSpec;
 import org.apache.druid.msq.kernel.QueryDefinition;
 import org.apache.druid.msq.kernel.QueryDefinitionBuilder;
 import org.apache.druid.msq.kernel.StageDefinition;
@@ -80,6 +81,7 @@ public class GroupByQueryKit implements QueryKit<GroupByQuery>
     final DataSourcePlan dataSourcePlan = DataSourcePlan.forDataSource(
         queryKit,
         queryId,
+        originalQuery.context(),
         originalQuery.getDataSource(),
         originalQuery.getQuerySegmentSpec(),
         originalQuery.getFilter(),
@@ -118,7 +120,7 @@ public class GroupByQueryKit implements QueryKit<GroupByQuery>
       shuffleSpecFactoryPreAggregation = ShuffleSpecFactories.singlePartition();
       shuffleSpecFactoryPostAggregation = ShuffleSpecFactories.singlePartition();
     } else if (doOrderBy) {
-      shuffleSpecFactoryPreAggregation = ShuffleSpecFactories.subQueryWithMaxWorkerCount(maxWorkerCount);
+      shuffleSpecFactoryPreAggregation = ShuffleSpecFactories.globalSortWithMaxPartitionCount(maxWorkerCount);
       shuffleSpecFactoryPostAggregation = doLimitOrOffset
                                           ? ShuffleSpecFactories.singlePartition()
                                           : resultShuffleSpecFactory;
@@ -162,7 +164,7 @@ public class GroupByQueryKit implements QueryKit<GroupByQuery>
                          .inputs(new StageInputSpec(firstStageNumber + 1))
                          .signature(resultSignature)
                          .maxWorkerCount(1)
-                         .shuffleSpec(new MaxCountShuffleSpec(ClusterBy.none(), 1, false))
+                         .shuffleSpec(MixShuffleSpec.instance())
                          .processorFactory(
                              new OffsetLimitFrameProcessorFactory(
                                  limitSpec.getOffset(),
@@ -221,10 +223,10 @@ public class GroupByQueryKit implements QueryKit<GroupByQuery>
    */
   static ClusterBy computeIntermediateClusterBy(final GroupByQuery query)
   {
-    final List<SortColumn> columns = new ArrayList<>();
+    final List<KeyColumn> columns = new ArrayList<>();
 
     for (final DimensionSpec dimension : query.getDimensions()) {
-      columns.add(new SortColumn(dimension.getOutputName(), false));
+      columns.add(new KeyColumn(dimension.getOutputName(), KeyOrder.ASCENDING));
     }
 
     // Note: ignoring time because we assume granularity = all.
@@ -240,13 +242,15 @@ public class GroupByQueryKit implements QueryKit<GroupByQuery>
       final DefaultLimitSpec defaultLimitSpec = (DefaultLimitSpec) query.getLimitSpec();
 
       if (!defaultLimitSpec.getColumns().isEmpty()) {
-        final List<SortColumn> clusterByColumns = new ArrayList<>();
+        final List<KeyColumn> clusterByColumns = new ArrayList<>();
 
         for (final OrderByColumnSpec orderBy : defaultLimitSpec.getColumns()) {
           clusterByColumns.add(
-              new SortColumn(
+              new KeyColumn(
                   orderBy.getDimension(),
                   orderBy.getDirection() == OrderByColumnSpec.Direction.DESCENDING
+                  ? KeyOrder.DESCENDING
+                  : KeyOrder.ASCENDING
               )
           );
         }
diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/scan/ScanQueryFrameProcessorFactory.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/scan/ScanQueryFrameProcessorFactory.java
index bda53af696..bc8ec9608e 100644
--- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/scan/ScanQueryFrameProcessorFactory.java
+++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/scan/ScanQueryFrameProcessorFactory.java
@@ -25,19 +25,15 @@ import com.fasterxml.jackson.annotation.JsonTypeName;
 import com.google.common.base.Preconditions;
 import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
 import org.apache.druid.collections.ResourceHolder;
-import org.apache.druid.frame.FrameType;
-import org.apache.druid.frame.allocation.MemoryAllocator;
 import org.apache.druid.frame.channel.WritableFrameChannel;
-import org.apache.druid.frame.key.ClusterBy;
 import org.apache.druid.frame.processor.FrameProcessor;
-import org.apache.druid.frame.write.FrameWriters;
+import org.apache.druid.frame.write.FrameWriterFactory;
 import org.apache.druid.java.util.common.Pair;
 import org.apache.druid.msq.input.ReadableInput;
 import org.apache.druid.msq.kernel.FrameContext;
 import org.apache.druid.msq.querykit.BaseLeafFrameProcessorFactory;
 import org.apache.druid.msq.querykit.LazyResourceHolder;
 import org.apache.druid.query.scan.ScanQuery;
-import org.apache.druid.segment.column.RowSignature;
 import org.apache.druid.segment.join.JoinableFactoryWrapper;
 
 import javax.annotation.Nullable;
@@ -78,9 +74,7 @@ public class ScanQueryFrameProcessorFactory extends BaseLeafFrameProcessorFactor
       ReadableInput baseInput,
       Int2ObjectMap<ReadableInput> sideChannels,
       ResourceHolder<WritableFrameChannel> outputChannelHolder,
-      ResourceHolder<MemoryAllocator> allocatorHolder,
-      RowSignature signature,
-      ClusterBy clusterBy,
+      ResourceHolder<FrameWriterFactory> frameWriterFactoryHolder,
       FrameContext frameContext
   )
   {
@@ -90,15 +84,7 @@ public class ScanQueryFrameProcessorFactory extends BaseLeafFrameProcessorFactor
         sideChannels,
         new JoinableFactoryWrapper(frameContext.joinableFactory()),
         outputChannelHolder,
-        new LazyResourceHolder<>(() -> Pair.of(
-            FrameWriters.makeFrameWriterFactory(
-                FrameType.ROW_BASED,
-                allocatorHolder.get(),
-                signature,
-                clusterBy.getColumns()
-            ),
-            allocatorHolder
-        )),
+        new LazyResourceHolder<>(() -> Pair.of(frameWriterFactoryHolder.get(), frameWriterFactoryHolder)),
         runningCountForLimit,
         frameContext.memoryParameters().getBroadcastJoinMemory(),
         frameContext.jsonMapper()
diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/scan/ScanQueryKit.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/scan/ScanQueryKit.java
index 9e44f152eb..4b2668febc 100644
--- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/scan/ScanQueryKit.java
+++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/scan/ScanQueryKit.java
@@ -22,10 +22,11 @@ package org.apache.druid.msq.querykit.scan;
 import com.fasterxml.jackson.core.JsonProcessingException;
 import com.fasterxml.jackson.databind.ObjectMapper;
 import org.apache.druid.frame.key.ClusterBy;
-import org.apache.druid.frame.key.SortColumn;
+import org.apache.druid.frame.key.KeyColumn;
+import org.apache.druid.frame.key.KeyOrder;
 import org.apache.druid.java.util.common.granularity.Granularity;
 import org.apache.druid.msq.input.stage.StageInputSpec;
-import org.apache.druid.msq.kernel.MaxCountShuffleSpec;
+import org.apache.druid.msq.kernel.MixShuffleSpec;
 import org.apache.druid.msq.kernel.QueryDefinition;
 import org.apache.druid.msq.kernel.QueryDefinitionBuilder;
 import org.apache.druid.msq.kernel.ShuffleSpec;
@@ -70,8 +71,8 @@ public class ScanQueryKit implements QueryKit<ScanQuery>
 
   /**
    * We ignore the resultShuffleSpecFactory in case:
-   *  1. There is no cluster by
-   *  2. This is an offset which means everything gets funneled into a single partition hence we use MaxCountShuffleSpec
+   * 1. There is no cluster by
+   * 2. This is an offset which means everything gets funneled into a single partition hence we use MaxCountShuffleSpec
    */
   // No ordering, but there is a limit or an offset. These work by funneling everything through a single partition.
   // So there is no point in forcing any particular partitioning. Since everything is funneled into a single
@@ -90,6 +91,7 @@ public class ScanQueryKit implements QueryKit<ScanQuery>
     final DataSourcePlan dataSourcePlan = DataSourcePlan.forDataSource(
         queryKit,
         queryId,
+        originalQuery.context(),
         originalQuery.getDataSource(),
         originalQuery.getQuerySegmentSpec(),
         originalQuery.getFilter(),
@@ -112,26 +114,26 @@ public class ScanQueryKit implements QueryKit<ScanQuery>
     //  1. There is no cluster by
     //  2. There is an offset which means everything gets funneled into a single partition hence we use MaxCountShuffleSpec
     if (queryToRun.getOrderBys().isEmpty() && hasLimitOrOffset) {
-      shuffleSpec = new MaxCountShuffleSpec(ClusterBy.none(), 1, false);
+      shuffleSpec = MixShuffleSpec.instance();
       signatureToUse = scanSignature;
     } else {
       final RowSignature.Builder signatureBuilder = RowSignature.builder().addAll(scanSignature);
       final Granularity segmentGranularity =
           QueryKitUtils.getSegmentGranularityFromContext(jsonMapper, queryToRun.getContext());
-      final List<SortColumn> clusterByColumns = new ArrayList<>();
+      final List<KeyColumn> clusterByColumns = new ArrayList<>();
 
       // Add regular orderBys.
       for (final ScanQuery.OrderBy orderBy : queryToRun.getOrderBys()) {
         clusterByColumns.add(
-            new SortColumn(
+            new KeyColumn(
                 orderBy.getColumnName(),
-                orderBy.getOrder() == ScanQuery.Order.DESCENDING
+                orderBy.getOrder() == ScanQuery.Order.DESCENDING ? KeyOrder.DESCENDING : KeyOrder.ASCENDING
             )
         );
       }
 
       // Add partition boosting column.
-      clusterByColumns.add(new SortColumn(QueryKitUtils.PARTITION_BOOST_COLUMN, false));
+      clusterByColumns.add(new KeyColumn(QueryKitUtils.PARTITION_BOOST_COLUMN, KeyOrder.ASCENDING));
       signatureBuilder.add(QueryKitUtils.PARTITION_BOOST_COLUMN, ColumnType.LONG);
 
       final ClusterBy clusterBy =
@@ -159,7 +161,7 @@ public class ScanQueryKit implements QueryKit<ScanQuery>
                          .inputs(new StageInputSpec(firstStageNumber))
                          .signature(signatureToUse)
                          .maxWorkerCount(1)
-                         .shuffleSpec(new MaxCountShuffleSpec(ClusterBy.none(), 1, false))
+                         .shuffleSpec(MixShuffleSpec.instance())
                          .processorFactory(
                              new OffsetLimitFrameProcessorFactory(
                                  queryToRun.getScanRowsOffset(),
diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/DurableStorageInputChannelFactory.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/DurableStorageInputChannelFactory.java
index bc69cdd4b7..1f6c1e44fa 100644
--- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/DurableStorageInputChannelFactory.java
+++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/DurableStorageInputChannelFactory.java
@@ -23,7 +23,6 @@ import com.google.common.base.Preconditions;
 import org.apache.commons.io.IOUtils;
 import org.apache.druid.frame.channel.ReadableFrameChannel;
 import org.apache.druid.frame.channel.ReadableInputStreamFrameChannel;
-import org.apache.druid.frame.processor.DurableStorageOutputChannelFactory;
 import org.apache.druid.frame.util.DurableStorageUtils;
 import org.apache.druid.java.util.common.IOE;
 import org.apache.druid.java.util.common.ISE;
diff --git a/processing/src/main/java/org/apache/druid/frame/processor/DurableStorageOutputChannelFactory.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/DurableStorageOutputChannelFactory.java
similarity index 97%
rename from processing/src/main/java/org/apache/druid/frame/processor/DurableStorageOutputChannelFactory.java
rename to extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/DurableStorageOutputChannelFactory.java
index b691710ded..4505be5846 100644
--- a/processing/src/main/java/org/apache/druid/frame/processor/DurableStorageOutputChannelFactory.java
+++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/DurableStorageOutputChannelFactory.java
@@ -17,7 +17,7 @@
  * under the License.
  */
 
-package org.apache.druid.frame.processor;
+package org.apache.druid.msq.shuffle;
 
 import com.google.common.base.Preconditions;
 import com.google.common.base.Suppliers;
@@ -32,6 +32,9 @@ import org.apache.druid.frame.channel.ReadableInputStreamFrameChannel;
 import org.apache.druid.frame.channel.WritableFrameFileChannel;
 import org.apache.druid.frame.file.FrameFileFooter;
 import org.apache.druid.frame.file.FrameFileWriter;
+import org.apache.druid.frame.processor.OutputChannel;
+import org.apache.druid.frame.processor.OutputChannelFactory;
+import org.apache.druid.frame.processor.PartitionedOutputChannel;
 import org.apache.druid.frame.util.DurableStorageUtils;
 import org.apache.druid.java.util.common.FileUtils;
 import org.apache.druid.java.util.common.ISE;
diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/MSQMode.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/MSQMode.java
index 6485f3ab70..d2f017b04d 100644
--- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/MSQMode.java
+++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/MSQMode.java
@@ -26,7 +26,6 @@ import org.apache.druid.msq.indexing.error.MSQWarnings;
 import org.apache.druid.query.QueryContexts;
 
 import javax.annotation.Nullable;
-
 import java.util.Arrays;
 import java.util.HashMap;
 import java.util.Map;
diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/MSQTaskQueryMaker.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/MSQTaskQueryMaker.java
index eafc03edf1..e2e13e3a59 100644
--- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/MSQTaskQueryMaker.java
+++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/MSQTaskQueryMaker.java
@@ -62,7 +62,6 @@ import org.apache.druid.sql.calcite.table.RowSignatures;
 import org.joda.time.Interval;
 
 import javax.annotation.Nullable;
-
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.HashMap;
diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/MSQTaskSqlEngine.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/MSQTaskSqlEngine.java
index e62e5d5d10..15c0c0aa79 100644
--- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/MSQTaskSqlEngine.java
+++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/MSQTaskSqlEngine.java
@@ -108,6 +108,7 @@ public class MSQTaskSqlEngine implements SqlEngine
   {
     switch (feature) {
       case ALLOW_BINDABLE_PLAN:
+      case ALLOW_BROADCAST_RIGHTY_JOIN:
       case TIMESERIES_QUERY:
       case TOPN_QUERY:
       case TIME_BOUNDARY_QUERY:
diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQSelectTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQSelectTest.java
index 548235b46b..0fca9a7d98 100644
--- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQSelectTest.java
+++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQSelectTest.java
@@ -38,7 +38,6 @@ import org.apache.druid.msq.test.MSQTestBase;
 import org.apache.druid.msq.test.MSQTestFileUtils;
 import org.apache.druid.query.InlineDataSource;
 import org.apache.druid.query.QueryDataSource;
-import org.apache.druid.query.TableDataSource;
 import org.apache.druid.query.aggregation.CountAggregatorFactory;
 import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory;
 import org.apache.druid.query.aggregation.FilteredAggregatorFactory;
@@ -62,6 +61,8 @@ import org.apache.druid.sql.SqlPlanningException;
 import org.apache.druid.sql.calcite.expression.DruidExpression;
 import org.apache.druid.sql.calcite.external.ExternalDataSource;
 import org.apache.druid.sql.calcite.filtration.Filtration;
+import org.apache.druid.sql.calcite.planner.JoinAlgorithm;
+import org.apache.druid.sql.calcite.planner.PlannerContext;
 import org.apache.druid.sql.calcite.planner.UnsupportedSQLQueryException;
 import org.apache.druid.sql.calcite.util.CalciteTests;
 import org.hamcrest.CoreMatchers;
@@ -425,8 +426,25 @@ public class MSQSelectTest extends MSQTestBase
   }
 
   @Test
-  public void testJoin()
+  public void testBroadcastJoin()
+  {
+    testJoin(JoinAlgorithm.BROADCAST);
+  }
+
+  @Test
+  public void testSortMergeJoin()
   {
+    testJoin(JoinAlgorithm.SORT_MERGE);
+  }
+
+  private void testJoin(final JoinAlgorithm joinAlgorithm)
+  {
+    final Map<String, Object> queryContext =
+        ImmutableMap.<String, Object>builder()
+                    .putAll(context)
+                    .put(PlannerContext.CTX_SQL_JOIN_ALGORITHM, joinAlgorithm.toString())
+                    .build();
+
     final RowSignature resultSignature = RowSignature.builder()
                                                      .add("dim2", ColumnType.STRING)
                                                      .add("EXPR$1", ColumnType.DOUBLE)
@@ -460,7 +478,7 @@ public class MSQSelectTest extends MSQTestBase
                                     .columns("dim2", "m1", "m2")
                                     .context(
                                         defaultScanQueryContext(
-                                            context,
+                                            queryContext,
                                             RowSignature.builder()
                                                         .add("dim2", ColumnType.STRING)
                                                         .add("m1", ColumnType.FLOAT)
@@ -470,6 +488,7 @@ public class MSQSelectTest extends MSQTestBase
                                     )
                                     .limit(10)
                                     .build()
+                                    .withOverriddenContext(queryContext)
                             ),
                             new QueryDataSource(
                                 newScanQueryBuilder()
@@ -479,11 +498,12 @@ public class MSQSelectTest extends MSQTestBase
                                     .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST)
                                     .context(
                                         defaultScanQueryContext(
-                                            context,
+                                            queryContext,
                                             RowSignature.builder().add("m1", ColumnType.FLOAT).build()
                                         )
                                     )
                                     .build()
+                                    .withOverriddenContext(queryContext)
                             ),
                             "j0.",
                             equalsCondition(
@@ -525,10 +545,9 @@ public class MSQSelectTest extends MSQTestBase
                                     new FieldAccessPostAggregator(null, "a0:count")
                                 )
                             )
-
                         )
                     )
-                    .setContext(context)
+                    .setContext(queryContext)
                     .build();
 
     testSelectQuery()
@@ -542,164 +561,23 @@ public class MSQSelectTest extends MSQTestBase
         .setExpectedMSQSpec(
             MSQSpec.builder()
                    .query(query)
-                   .columnMappings(new ColumnMappings(ImmutableList.of(
-                       new ColumnMapping("d0", "dim2"),
-                       new ColumnMapping("a0", "EXPR$1")
-                   )))
-                   .tuningConfig(MSQTuningConfig.defaultConfig())
-                   .build()
-        )
-        .setExpectedRowSignature(resultSignature)
-        .setExpectedResultRows(expectedResults)
-        .setQueryContext(context)
-        .setExpectedCountersForStageWorkerChannel(
-            CounterSnapshotMatcher
-                .with().totalFiles(1),
-            0, 0, "input0"
-        )
-        .setExpectedCountersForStageWorkerChannel(
-            CounterSnapshotMatcher
-                .with().rows(6).frames(1),
-            0, 0, "output"
-        )
-        .setExpectedCountersForStageWorkerChannel(
-            CounterSnapshotMatcher
-                .with().rows(6).frames(1),
-            0, 0, "shuffle"
-        )
-        .verifyResults();
-  }
-
-  @Test
-  public void testBroadcastJoin()
-  {
-    final RowSignature resultSignature = RowSignature.builder()
-                                                     .add("dim2", ColumnType.STRING)
-                                                     .add("EXPR$1", ColumnType.DOUBLE)
-                                                     .build();
-
-    final ImmutableList<Object[]> expectedResults;
-
-    if (NullHandling.sqlCompatible()) {
-      expectedResults = ImmutableList.of(
-          new Object[]{null, 4.0},
-          new Object[]{"", 3.0},
-          new Object[]{"a", 2.5},
-          new Object[]{"abc", 5.0}
-      );
-    } else {
-      expectedResults = ImmutableList.of(
-          new Object[]{null, 3.6666666666666665},
-          new Object[]{"a", 2.5},
-          new Object[]{"abc", 5.0}
-      );
-    }
-
-    final GroupByQuery query =
-        GroupByQuery.builder()
-                    .setDataSource(
-                        join(
-                            new TableDataSource(CalciteTests.DATASOURCE1),
-                            new QueryDataSource(
-                                newScanQueryBuilder()
-                                    .dataSource(CalciteTests.DATASOURCE1)
-                                    .intervals(querySegmentSpec(Filtration.eternity()))
-                                    .columns("dim2", "m1", "m2")
-                                    .context(
-                                        defaultScanQueryContext(
-                                            context,
-                                            RowSignature.builder()
-                                                        .add("dim2", ColumnType.STRING)
-                                                        .add("m1", ColumnType.FLOAT)
-                                                        .add("m2", ColumnType.DOUBLE)
-                                                        .build()
-                                        )
-                                    )
-                                    .limit(10)
-                                    .build()
-                            ),
-                            "j0.",
-                            equalsCondition(
-                                DruidExpression.ofColumn(ColumnType.FLOAT, "m1"),
-                                DruidExpression.ofColumn(ColumnType.FLOAT, "j0.m1")
-                            ),
-                            JoinType.INNER
-                        )
-                    )
-                    .setInterval(querySegmentSpec(Filtration.eternity()))
-                    .setDimensions(new DefaultDimensionSpec("j0.dim2", "d0", ColumnType.STRING))
-                    .setGranularity(Granularities.ALL)
-                    .setAggregatorSpecs(
-                        useDefault
-                        ? aggregators(
-                            new DoubleSumAggregatorFactory("a0:sum", "j0.m2"),
-                            new CountAggregatorFactory("a0:count")
-                        )
-                        : aggregators(
-                            new DoubleSumAggregatorFactory("a0:sum", "j0.m2"),
-                            new FilteredAggregatorFactory(
-                                new CountAggregatorFactory("a0:count"),
-                                not(selector("j0.m2", null, null)),
-
-                                // Not sure why the name is only set in SQL-compatible null mode. Seems strange.
-                                // May be due to JSON serialization: name is set on the serialized aggregator even
-                                // if it was originally created with no name.
-                                NullHandling.sqlCompatible() ? "a0:count" : null
-                            )
-                        )
-                    )
-                    .setPostAggregatorSpecs(
-                        ImmutableList.of(
-                            new ArithmeticPostAggregator(
-                                "a0",
-                                "quotient",
-                                ImmutableList.of(
-                                    new FieldAccessPostAggregator(null, "a0:sum"),
-                                    new FieldAccessPostAggregator(null, "a0:count")
-                                )
-                            )
-
-                        )
-                    )
-                    .setContext(context)
-                    .build();
-
-    testSelectQuery()
-        .setSql(
-            "SELECT t1.dim2, AVG(t1.m2) FROM "
-            + "foo "
-            + "INNER JOIN (SELECT * FROM foo LIMIT 10) AS t1 "
-            + "ON t1.m1 = foo.m1 "
-            + "GROUP BY t1.dim2"
-        )
-        .setExpectedMSQSpec(
-            MSQSpec.builder()
-                   .query(query)
-                   .columnMappings(new ColumnMappings(ImmutableList.of(
-                       new ColumnMapping("d0", "dim2"),
-                       new ColumnMapping("a0", "EXPR$1")
-                   )))
+                   .columnMappings(
+                       new ColumnMappings(
+                           ImmutableList.of(
+                               new ColumnMapping("d0", "dim2"),
+                               new ColumnMapping("a0", "EXPR$1")
+                           )
+                       )
+                   )
                    .tuningConfig(MSQTuningConfig.defaultConfig())
                    .build()
         )
         .setExpectedRowSignature(resultSignature)
         .setExpectedResultRows(expectedResults)
-        .setQueryContext(context)
-        .setExpectedCountersForStageWorkerChannel(
-            CounterSnapshotMatcher
-                .with().totalFiles(1),
-            0, 0, "input0"
-        )
-        .setExpectedCountersForStageWorkerChannel(
-            CounterSnapshotMatcher
-                .with().rows(6).frames(1),
-            0, 0, "output"
-        )
-        .setExpectedCountersForStageWorkerChannel(
-            CounterSnapshotMatcher
-                .with().rows(6).frames(1),
-            0, 0, "shuffle"
-        )
+        .setQueryContext(queryContext)
+        .setExpectedCountersForStageWorkerChannel(CounterSnapshotMatcher.with().totalFiles(1), 0, 0, "input0")
+        .setExpectedCountersForStageWorkerChannel(CounterSnapshotMatcher.with().rows(6).frames(1), 0, 0, "output")
+        .setExpectedCountersForStageWorkerChannel(CounterSnapshotMatcher.with().rows(6).frames(1), 0, 0, "shuffle")
         .verifyResults();
   }
 
diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerMemoryParametersTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerMemoryParametersTest.java
index a5ec9e48ca..f8f025ccf2 100644
--- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerMemoryParametersTest.java
+++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerMemoryParametersTest.java
@@ -32,19 +32,19 @@ public class WorkerMemoryParametersTest
   @Test
   public void test_oneWorkerInJvm_alone()
   {
-    Assert.assertEquals(parameters(1, 41, 224_785_000, 100_650_000, 75_000_000), compute(1_000_000_000, 1, 1, 1, 0));
-    Assert.assertEquals(parameters(2, 13, 149_410_000, 66_900_000, 75_000_000), compute(1_000_000_000, 1, 2, 1, 0));
-    Assert.assertEquals(parameters(4, 3, 89_110_000, 39_900_000, 75_000_000), compute(1_000_000_000, 1, 4, 1, 0));
-    Assert.assertEquals(parameters(3, 2, 48_910_000, 21_900_000, 75_000_000), compute(1_000_000_000, 1, 8, 1, 0));
-    Assert.assertEquals(parameters(2, 2, 33_448_460, 14_976_922, 75_000_000), compute(1_000_000_000, 1, 12, 1, 0));
+    Assert.assertEquals(params(1, 41, 224_785_000, 100_650_000, 75_000_000), create(1_000_000_000, 1, 1, 1, 0, 0));
+    Assert.assertEquals(params(2, 13, 149_410_000, 66_900_000, 75_000_000), create(1_000_000_000, 1, 2, 1, 0, 0));
+    Assert.assertEquals(params(4, 3, 89_110_000, 39_900_000, 75_000_000), create(1_000_000_000, 1, 4, 1, 0, 0));
+    Assert.assertEquals(params(3, 2, 48_910_000, 21_900_000, 75_000_000), create(1_000_000_000, 1, 8, 1, 0, 0));
+    Assert.assertEquals(params(2, 2, 33_448_460, 14_976_922, 75_000_000), create(1_000_000_000, 1, 12, 1, 0, 0));
 
     final MSQException e = Assert.assertThrows(
         MSQException.class,
-        () -> compute(1_000_000_000, 1, 32, 1, 0)
+        () -> create(1_000_000_000, 1, 32, 1, 0, 0)
     );
     Assert.assertEquals(new NotEnoughMemoryFault(1_588_044_000, 1_000_000_000, 750_000_000, 1, 32), e.getFault());
 
-    final MSQFault fault = Assert.assertThrows(MSQException.class, () -> compute(1_000_000_000, 2, 32, 1, 0))
+    final MSQFault fault = Assert.assertThrows(MSQException.class, () -> create(1_000_000_000, 2, 32, 1, 0, 0))
                                  .getFault();
 
     Assert.assertEquals(new NotEnoughMemoryFault(2024045333, 1_000_000_000, 750_000_000, 2, 32), fault);
@@ -54,12 +54,12 @@ public class WorkerMemoryParametersTest
   @Test
   public void test_oneWorkerInJvm_twoHundredWorkersInCluster()
   {
-    Assert.assertEquals(parameters(1, 83, 317_580_000, 142_200_000, 150_000_000), compute(2_000_000_000, 1, 1, 200, 0));
-    Assert.assertEquals(parameters(2, 27, 166_830_000, 74_700_000, 150_000_000), compute(2_000_000_000, 1, 2, 200, 0));
+    Assert.assertEquals(params(1, 83, 317_580_000, 142_200_000, 150_000_000), create(2_000_000_000, 1, 1, 200, 0, 0));
+    Assert.assertEquals(params(2, 27, 166_830_000, 74_700_000, 150_000_000), create(2_000_000_000, 1, 2, 200, 0, 0));
 
     final MSQException e = Assert.assertThrows(
         MSQException.class,
-        () -> compute(1_000_000_000, 1, 4, 200, 0)
+        () -> create(1_000_000_000, 1, 4, 200, 0, 0)
     );
 
     Assert.assertEquals(new TooManyWorkersFault(200, 109), e.getFault());
@@ -68,39 +68,69 @@ public class WorkerMemoryParametersTest
   @Test
   public void test_fourWorkersInJvm_twoHundredWorkersInCluster()
   {
-    Assert.assertEquals(
-        parameters(1, 150, 679_380_000, 304_200_000, 168_750_000),
-        compute(9_000_000_000L, 4, 1, 200, 0)
+    Assert.assertEquals(params(1, 150, 679_380_000, 304_200_000, 168_750_000), create(9_000_000_000L, 4, 1, 200, 0, 0));
+    Assert.assertEquals(params(2, 62, 543_705_000, 243_450_000, 168_750_000), create(9_000_000_000L, 4, 2, 200, 0, 0));
+    Assert.assertEquals(params(4, 22, 374_111_250, 167_512_500, 168_750_000), create(9_000_000_000L, 4, 4, 200, 0, 0));
+    Assert.assertEquals(params(4, 14, 204_517_500, 91_575_000, 168_750_000), create(9_000_000_000L, 4, 8, 200, 0, 0));
+    Assert.assertEquals(params(4, 8, 68_842_500, 30_825_000, 168_750_000), create(9_000_000_000L, 4, 16, 200, 0, 0));
+
+    final MSQException e = Assert.assertThrows(
+        MSQException.class,
+        () -> create(8_000_000_000L, 4, 32, 200, 0, 0)
     );
-    Assert.assertEquals(
-        parameters(2, 62, 543_705_000, 243_450_000, 168_750_000),
-        compute(9_000_000_000L, 4, 2, 200, 0)
+
+    Assert.assertEquals(new TooManyWorkersFault(200, 124), e.getFault());
+
+    // Make sure 124 actually works, and 125 doesn't. (Verify the error message above.)
+    Assert.assertEquals(params(4, 3, 16_750_000, 7_500_000, 150_000_000), create(8_000_000_000L, 4, 32, 124, 0, 0));
+
+    final MSQException e2 = Assert.assertThrows(
+        MSQException.class,
+        () -> create(8_000_000_000L, 4, 32, 125, 0, 0)
     );
+
+    Assert.assertEquals(new TooManyWorkersFault(125, 124), e2.getFault());
+  }
+
+  @Test
+  public void test_fourWorkersInJvm_twoHundredWorkersInCluster_hashPartitions()
+  {
     Assert.assertEquals(
-        parameters(4, 22, 374_111_250, 167_512_500, 168_750_000),
-        compute(9_000_000_000L, 4, 4, 200, 0)
-    );
-    Assert.assertEquals(parameters(4, 14, 204_517_500, 91_575_000, 168_750_000), compute(9_000_000_000L, 4, 8, 200, 0));
-    Assert.assertEquals(parameters(4, 8, 68_842_500, 30_825_000, 168_750_000), compute(9_000_000_000L, 4, 16, 200, 0));
+        params(1, 150, 545_380_000, 244_200_000, 168_750_000), create(9_000_000_000L, 4, 1, 200, 200, 0));
+    Assert.assertEquals(
+        params(2, 62, 409_705_000, 183_450_000, 168_750_000), create(9_000_000_000L, 4, 2, 200, 200, 0));
+    Assert.assertEquals(
+        params(4, 22, 240_111_250, 107_512_500, 168_750_000), create(9_000_000_000L, 4, 4, 200, 200, 0));
+    Assert.assertEquals(
+        params(4, 14, 70_517_500, 31_575_000, 168_750_000), create(9_000_000_000L, 4, 8, 200, 200, 0));
 
     final MSQException e = Assert.assertThrows(
         MSQException.class,
-        () -> compute(8_000_000_000L, 4, 32, 200, 0)
+        () -> create(9_000_000_000L, 4, 16, 200, 200, 0)
     );
 
-    Assert.assertEquals(new TooManyWorkersFault(200, 124), e.getFault());
+    Assert.assertEquals(new TooManyWorkersFault(200, 138), e.getFault());
+
+    // Make sure 138 actually works, and 139 doesn't. (Verify the error message above.)
+    Assert.assertEquals(params(4, 8, 17_922_500, 8_025_000, 168_750_000), create(9_000_000_000L, 4, 16, 138, 138, 0));
 
-    // Make sure 107 actually works. (Verify the error message above.)
-    Assert.assertEquals(parameters(4, 3, 28_140_000, 12_600_000, 150_000_000), compute(8_000_000_000L, 4, 32, 107, 0));
+    final MSQException e2 = Assert.assertThrows(
+        MSQException.class,
+        () -> create(9_000_000_000L, 4, 16, 139, 139, 0)
+    );
+
+    Assert.assertEquals(new TooManyWorkersFault(139, 138), e2.getFault());
   }
 
   @Test
-  public void test_oneWorkerInJvm_negativeUsableMemory()
+  public void test_oneWorkerInJvm_oneByteUsableMemory()
   {
-    Exception e = Assert.assertThrows(
-        IllegalArgumentException.class,
-        () -> WorkerMemoryParameters.createInstance(100, -50, 1, 32, 1)
+    final MSQException e = Assert.assertThrows(
+        MSQException.class,
+        () -> WorkerMemoryParameters.createInstance(1, 1, 1, 32, 1, 1)
     );
+
+    Assert.assertEquals(new NotEnoughMemoryFault(554669334, 1, 1, 1, 1), e.getFault());
   }
 
   @Test
@@ -109,7 +139,7 @@ public class WorkerMemoryParametersTest
     EqualsVerifier.forClass(WorkerMemoryParameters.class).usingGetClass().verify();
   }
 
-  private static WorkerMemoryParameters parameters(
+  private static WorkerMemoryParameters params(
       final int superSorterMaxActiveProcessors,
       final int superSorterMaxChannelsPerProcessor,
       final long appenderatorMemory,
@@ -126,11 +156,12 @@ public class WorkerMemoryParametersTest
     );
   }
 
-  private static WorkerMemoryParameters compute(
+  private static WorkerMemoryParameters create(
       final long maxMemoryInJvm,
       final int numWorkersInJvm,
       final int numProcessingThreadsInJvm,
       final int numInputWorkers,
+      final int numHashOutputPartitions,
       final int totalLookUpFootprint
   )
   {
@@ -139,6 +170,7 @@ public class WorkerMemoryParametersTest
         numWorkersInJvm,
         numProcessingThreadsInJvm,
         numInputWorkers,
+        numHashOutputPartitions,
         totalLookUpFootprint
     );
   }
diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/error/MSQFaultSerdeTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/error/MSQFaultSerdeTest.java
index aae47170a2..435c5de065 100644
--- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/error/MSQFaultSerdeTest.java
+++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/error/MSQFaultSerdeTest.java
@@ -31,6 +31,7 @@ import org.junit.Before;
 import org.junit.Test;
 
 import java.io.IOException;
+import java.util.Arrays;
 
 public class MSQFaultSerdeTest
 {
@@ -69,6 +70,7 @@ public class MSQFaultSerdeTest
     assertFaultSerde(new TooManyClusteredByColumnsFault(10, 8, 1));
     assertFaultSerde(new TooManyInputFilesFault(15, 10, 5));
     assertFaultSerde(new TooManyPartitionsFault(10));
+    assertFaultSerde(new TooManyRowsWithSameKeyFault(Arrays.asList("foo", 123), 1, 2));
     assertFaultSerde(new TooManyWarningsFault(10, "the error"));
     assertFaultSerde(new TooManyWorkersFault(10, 5));
     assertFaultSerde(new TooManyAttemptsForWorker(2, "taskId", 1, "rootError"));
diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/report/MSQTaskReportTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/report/MSQTaskReportTest.java
index a0eeec58fb..366e9c5099 100644
--- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/report/MSQTaskReportTest.java
+++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/report/MSQTaskReportTest.java
@@ -24,7 +24,8 @@ import com.fasterxml.jackson.databind.ObjectMapper;
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableMap;
 import org.apache.druid.frame.key.ClusterBy;
-import org.apache.druid.frame.key.SortColumn;
+import org.apache.druid.frame.key.KeyColumn;
+import org.apache.druid.frame.key.KeyOrder;
 import org.apache.druid.indexer.TaskState;
 import org.apache.druid.indexing.common.SingleFileTaskReportFileWriter;
 import org.apache.druid.indexing.common.TaskReport;
@@ -35,7 +36,7 @@ import org.apache.druid.msq.counters.CounterSnapshotsTree;
 import org.apache.druid.msq.guice.MSQIndexingModule;
 import org.apache.druid.msq.indexing.error.MSQErrorReport;
 import org.apache.druid.msq.indexing.error.TooManyColumnsFault;
-import org.apache.druid.msq.kernel.MaxCountShuffleSpec;
+import org.apache.druid.msq.kernel.GlobalSortMaxCountShuffleSpec;
 import org.apache.druid.msq.kernel.QueryDefinition;
 import org.apache.druid.msq.kernel.StageDefinition;
 import org.apache.druid.msq.querykit.common.OffsetLimitFrameProcessorFactory;
@@ -65,8 +66,8 @@ public class MSQTaskReportTest
                   .builder(0)
                   .processorFactory(new OffsetLimitFrameProcessorFactory(0, 1L))
                   .shuffleSpec(
-                      new MaxCountShuffleSpec(
-                          new ClusterBy(ImmutableList.of(new SortColumn("s", false)), 0),
+                      new GlobalSortMaxCountShuffleSpec(
+                          new ClusterBy(ImmutableList.of(new KeyColumn("s", KeyOrder.ASCENDING)), 0),
                           2,
                           false
                       )
diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/QueryDefinitionTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/QueryDefinitionTest.java
index fc97fcd708..8a5533d22c 100644
--- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/QueryDefinitionTest.java
+++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/QueryDefinitionTest.java
@@ -23,7 +23,8 @@ import com.fasterxml.jackson.databind.ObjectMapper;
 import com.google.common.collect.ImmutableList;
 import nl.jqno.equalsverifier.EqualsVerifier;
 import org.apache.druid.frame.key.ClusterBy;
-import org.apache.druid.frame.key.SortColumn;
+import org.apache.druid.frame.key.KeyColumn;
+import org.apache.druid.frame.key.KeyOrder;
 import org.apache.druid.msq.guice.MSQIndexingModule;
 import org.apache.druid.msq.querykit.common.OffsetLimitFrameProcessorFactory;
 import org.apache.druid.segment.TestHelper;
@@ -45,8 +46,8 @@ public class QueryDefinitionTest
                     .builder(0)
                     .processorFactory(new OffsetLimitFrameProcessorFactory(0, 1L))
                     .shuffleSpec(
-                        new MaxCountShuffleSpec(
-                            new ClusterBy(ImmutableList.of(new SortColumn("s", false)), 0),
+                        new GlobalSortMaxCountShuffleSpec(
+                            new ClusterBy(ImmutableList.of(new KeyColumn("s", KeyOrder.ASCENDING)), 0),
                             2,
                             false
                         )
diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/StageDefinitionTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/StageDefinitionTest.java
index 2257f04364..793c4293f1 100644
--- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/StageDefinitionTest.java
+++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/StageDefinitionTest.java
@@ -23,7 +23,8 @@ import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableSet;
 import nl.jqno.equalsverifier.EqualsVerifier;
 import org.apache.druid.frame.key.ClusterBy;
-import org.apache.druid.frame.key.SortColumn;
+import org.apache.druid.frame.key.KeyColumn;
+import org.apache.druid.frame.key.KeyOrder;
 import org.apache.druid.java.util.common.ISE;
 import org.apache.druid.msq.exec.Limits;
 import org.apache.druid.msq.input.stage.StageInputSpec;
@@ -60,7 +61,7 @@ public class StageDefinitionTest
         Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER
     );
 
-    Assert.assertThrows(ISE.class, () -> stageDefinition.generatePartitionsForShuffle(null));
+    Assert.assertThrows(ISE.class, () -> stageDefinition.generatePartitionBoundariesForShuffle(null));
   }
 
   @Test
@@ -72,16 +73,19 @@ public class StageDefinitionTest
         ImmutableSet.of(),
         new OffsetLimitFrameProcessorFactory(0, 1L),
         RowSignature.empty(),
-        new MaxCountShuffleSpec(new ClusterBy(ImmutableList.of(new SortColumn("test", false)), 1), 2, false),
+        new GlobalSortMaxCountShuffleSpec(
+            new ClusterBy(ImmutableList.of(new KeyColumn("test", KeyOrder.ASCENDING)), 0),
+            2,
+            false
+        ),
         1,
         false,
         Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER
     );
 
-    Assert.assertThrows(ISE.class, () -> stageDefinition.generatePartitionsForShuffle(null));
+    Assert.assertThrows(ISE.class, () -> stageDefinition.generatePartitionBoundariesForShuffle(null));
   }
 
-
   @Test
   public void testGeneratePartitionsForNonNullShuffleWithNonNullCollector()
   {
@@ -91,7 +95,11 @@ public class StageDefinitionTest
         ImmutableSet.of(),
         new OffsetLimitFrameProcessorFactory(0, 1L),
         RowSignature.empty(),
-        new MaxCountShuffleSpec(new ClusterBy(ImmutableList.of(new SortColumn("test", false)), 0), 1, false),
+        new GlobalSortMaxCountShuffleSpec(
+            new ClusterBy(ImmutableList.of(new KeyColumn("test", KeyOrder.ASCENDING)), 0),
+            1,
+            false
+        ),
         1,
         false,
         Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER
@@ -99,8 +107,8 @@ public class StageDefinitionTest
 
     Assert.assertThrows(
         ISE.class,
-        () -> stageDefinition.generatePartitionsForShuffle(ClusterByStatisticsCollectorImpl.create(new ClusterBy(
-            ImmutableList.of(new SortColumn("test", false)),
+        () -> stageDefinition.generatePartitionBoundariesForShuffle(ClusterByStatisticsCollectorImpl.create(new ClusterBy(
+            ImmutableList.of(new KeyColumn("test", KeyOrder.ASCENDING)),
             1
         ), RowSignature.builder().add("test", ColumnType.STRING).build(), 1000, 100, false, false))
     );
diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/controller/MockQueryDefinitionBuilder.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/controller/MockQueryDefinitionBuilder.java
index bd8a473e17..f16e35e6e2 100644
--- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/controller/MockQueryDefinitionBuilder.java
+++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/controller/MockQueryDefinitionBuilder.java
@@ -22,13 +22,14 @@ package org.apache.druid.msq.kernel.controller;
 import com.google.common.base.Preconditions;
 import com.google.common.collect.ImmutableList;
 import org.apache.druid.frame.key.ClusterBy;
-import org.apache.druid.frame.key.SortColumn;
+import org.apache.druid.frame.key.KeyColumn;
+import org.apache.druid.frame.key.KeyOrder;
 import org.apache.druid.java.util.common.ISE;
 import org.apache.druid.java.util.common.StringUtils;
 import org.apache.druid.msq.input.InputSpec;
 import org.apache.druid.msq.input.stage.StageInputSpec;
 import org.apache.druid.msq.kernel.FrameProcessorFactory;
-import org.apache.druid.msq.kernel.MaxCountShuffleSpec;
+import org.apache.druid.msq.kernel.GlobalSortMaxCountShuffleSpec;
 import org.apache.druid.msq.kernel.QueryDefinition;
 import org.apache.druid.msq.kernel.QueryDefinitionBuilder;
 import org.apache.druid.msq.kernel.ShuffleSpec;
@@ -115,10 +116,10 @@ public class MockQueryDefinitionBuilder
     ShuffleSpec shuffleSpec;
 
     if (shuffling) {
-      shuffleSpec = new MaxCountShuffleSpec(
+      shuffleSpec = new GlobalSortMaxCountShuffleSpec(
           new ClusterBy(
               ImmutableList.of(
-                  new SortColumn(SHUFFLE_KEY_COLUMN, false)
+                  new KeyColumn(SHUFFLE_KEY_COLUMN, KeyOrder.ASCENDING)
               ),
               0
           ),
diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/querykit/common/SortMergeJoinFrameProcessorTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/querykit/common/SortMergeJoinFrameProcessorTest.java
new file mode 100644
index 0000000000..cfc74d792f
--- /dev/null
+++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/querykit/common/SortMergeJoinFrameProcessorTest.java
@@ -0,0 +1,1080 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.msq.querykit.common;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Iterables;
+import com.google.common.collect.Lists;
+import com.google.common.util.concurrent.ListenableFuture;
+import com.google.common.util.concurrent.MoreExecutors;
+import org.apache.druid.common.config.NullHandling;
+import org.apache.druid.common.guava.FutureUtils;
+import org.apache.druid.frame.Frame;
+import org.apache.druid.frame.FrameType;
+import org.apache.druid.frame.allocation.ArenaMemoryAllocator;
+import org.apache.druid.frame.allocation.SingleMemoryAllocatorFactory;
+import org.apache.druid.frame.channel.BlockingQueueFrameChannel;
+import org.apache.druid.frame.channel.ReadableFrameChannel;
+import org.apache.druid.frame.channel.ReadableNilFrameChannel;
+import org.apache.druid.frame.key.KeyColumn;
+import org.apache.druid.frame.key.KeyOrder;
+import org.apache.druid.frame.processor.FrameProcessorExecutor;
+import org.apache.druid.frame.read.FrameReader;
+import org.apache.druid.frame.segment.FrameStorageAdapter;
+import org.apache.druid.frame.testutil.FrameSequenceBuilder;
+import org.apache.druid.frame.testutil.FrameTestUtil;
+import org.apache.druid.frame.write.FrameWriterFactory;
+import org.apache.druid.frame.write.FrameWriters;
+import org.apache.druid.java.util.common.Intervals;
+import org.apache.druid.java.util.common.concurrent.Execs;
+import org.apache.druid.java.util.common.guava.Sequence;
+import org.apache.druid.java.util.common.guava.Sequences;
+import org.apache.druid.msq.input.ReadableInput;
+import org.apache.druid.msq.kernel.StageId;
+import org.apache.druid.msq.kernel.StagePartition;
+import org.apache.druid.msq.test.LimitedFrameWriterFactory;
+import org.apache.druid.segment.RowBasedSegment;
+import org.apache.druid.segment.StorageAdapter;
+import org.apache.druid.segment.column.ColumnType;
+import org.apache.druid.segment.column.RowSignature;
+import org.apache.druid.segment.join.JoinTestHelper;
+import org.apache.druid.segment.join.JoinType;
+import org.apache.druid.testing.InitializedNullHandlingTest;
+import org.apache.druid.timeline.SegmentId;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.TimeUnit;
+
+@RunWith(Parameterized.class)
+public class SortMergeJoinFrameProcessorTest extends InitializedNullHandlingTest
+{
+  private static final StagePartition STAGE_PARTITION = new StagePartition(new StageId("q", 0), 0);
+
+  private final int rowsPerInputFrame;
+  private final int rowsPerOutputFrame;
+
+  private FrameProcessorExecutor exec;
+
+  @Rule
+  public TemporaryFolder temporaryFolder = new TemporaryFolder();
+
+  public SortMergeJoinFrameProcessorTest(int rowsPerInputFrame, int rowsPerOutputFrame)
+  {
+    this.rowsPerInputFrame = rowsPerInputFrame;
+    this.rowsPerOutputFrame = rowsPerOutputFrame;
+  }
+
+  @Parameterized.Parameters(name = "rowsPerInputFrame = {0}, rowsPerOutputFrame = {1}")
+  public static Iterable<Object[]> constructorFeeder()
+  {
+    final List<Object[]> constructors = new ArrayList<>();
+
+    for (final int rowsPerInputFrame : new int[]{1, 2, 7, Integer.MAX_VALUE}) {
+      for (final int rowsPerOutputFrame : new int[]{1, 2, 7, Integer.MAX_VALUE}) {
+        constructors.add(new Object[]{rowsPerInputFrame, rowsPerOutputFrame});
+      }
+    }
+
+    return constructors;
+  }
+
+  @Before
+  public void setUp()
+  {
+    exec = new FrameProcessorExecutor(MoreExecutors.listeningDecorator(Execs.singleThreaded("test-exec")));
+  }
+
+  @After
+  public void tearDown() throws Exception
+  {
+    exec.getExecutorService().shutdownNow();
+    exec.getExecutorService().awaitTermination(10, TimeUnit.MINUTES);
+  }
+
+  @Test
+  public void testLeftJoinEmptyLeftSide() throws Exception
+  {
+    final ReadableInput factChannel = ReadableInput.channel(
+        ReadableNilFrameChannel.INSTANCE,
+        FrameReader.create(JoinTestHelper.FACT_SIGNATURE),
+        STAGE_PARTITION
+    );
+
+    final ReadableInput countriesChannel =
+        buildCountriesInput(ImmutableList.of(new KeyColumn("countryIsoCode", KeyOrder.ASCENDING)));
+
+    final BlockingQueueFrameChannel outputChannel = BlockingQueueFrameChannel.minimal();
+
+    final RowSignature joinSignature =
+        RowSignature.builder()
+                    .add("page", ColumnType.STRING)
+                    .add("countryIsoCode", ColumnType.STRING)
+                    .add("j0.countryIsoCode", ColumnType.STRING)
+                    .add("j0.countryName", ColumnType.STRING)
+                    .add("j0.countryNumber", ColumnType.LONG)
+                    .build();
+
+    final SortMergeJoinFrameProcessor processor = new SortMergeJoinFrameProcessor(
+        factChannel,
+        countriesChannel,
+        outputChannel.writable(),
+        makeFrameWriterFactory(joinSignature),
+        "j0.",
+        ImmutableList.of(
+            ImmutableList.of(new KeyColumn("countryIsoCode", KeyOrder.ASCENDING)),
+            ImmutableList.of(new KeyColumn("countryIsoCode", KeyOrder.ASCENDING))
+        ),
+        JoinType.LEFT
+    );
+
+    assertResult(processor, outputChannel.readable(), joinSignature, Collections.emptyList());
+  }
+
+  @Test
+  public void testLeftJoinEmptyRightSide() throws Exception
+  {
+    final ReadableInput factChannel = buildFactInput(
+        ImmutableList.of(
+            new KeyColumn("countryIsoCode", KeyOrder.ASCENDING),
+            new KeyColumn("page", KeyOrder.ASCENDING)
+        )
+    );
+
+
+    final ReadableInput countriesChannel = ReadableInput.channel(
+        ReadableNilFrameChannel.INSTANCE,
+        FrameReader.create(JoinTestHelper.COUNTRIES_SIGNATURE),
+        STAGE_PARTITION
+    );
+
+    final BlockingQueueFrameChannel outputChannel = BlockingQueueFrameChannel.minimal();
+
+    final RowSignature joinSignature =
+        RowSignature.builder()
+                    .add("page", ColumnType.STRING)
+                    .add("countryIsoCode", ColumnType.STRING)
+                    .add("j0.countryIsoCode", ColumnType.STRING)
+                    .add("j0.countryName", ColumnType.STRING)
+                    .add("j0.countryNumber", ColumnType.LONG)
+                    .build();
+
+    final SortMergeJoinFrameProcessor processor = new SortMergeJoinFrameProcessor(
+        factChannel,
+        countriesChannel,
+        outputChannel.writable(),
+        makeFrameWriterFactory(joinSignature),
+        "j0.",
+        ImmutableList.of(
+            ImmutableList.of(new KeyColumn("countryIsoCode", KeyOrder.ASCENDING)),
+            ImmutableList.of(new KeyColumn("countryIsoCode", KeyOrder.ASCENDING))
+        ),
+        JoinType.LEFT
+    );
+
+    final List<List<Object>> expectedRows = Arrays.asList(
+        Arrays.asList("Agama mossambica", null, null, null, null),
+        Arrays.asList("Apamea abruzzorum", null, null, null, null),
+        Arrays.asList("Atractus flammigerus", null, null, null, null),
+        Arrays.asList("Rallicula", null, null, null, null),
+        Arrays.asList("Talk:Oswald Tilghman", null, null, null, null),
+        Arrays.asList("Peremptory norm", "AU", null, null, null),
+        Arrays.asList("Didier Leclair", "CA", null, null, null),
+        Arrays.asList("Les Argonautes", "CA", null, null, null),
+        Arrays.asList("Sarah Michelle Gellar", "CA", null, null, null),
+        Arrays.asList("Golpe de Estado en Chile de 1973", "CL", null, null, null),
+        Arrays.asList("Diskussion:Sebastian Schulz", "DE", null, null, null),
+        Arrays.asList("Gabinete Ministerial de Rafael Correa", "EC", null, null, null),
+        Arrays.asList("Saison 9 de Secret Story", "FR", null, null, null),
+        Arrays.asList("Glasgow", "GB", null, null, null),
+        Arrays.asList("Giusy Ferreri discography", "IT", null, null, null),
+        Arrays.asList("Roma-Bangkok", "IT", null, null, null),
+        Arrays.asList("青野武", "JP", null, null, null),
+        Arrays.asList("유희왕 GX", "KR", null, null, null),
+        Arrays.asList("History of Fourems", "MMMM", null, null, null),
+        Arrays.asList("Mathis Bolly", "MX", null, null, null),
+        Arrays.asList("Orange Soda", "MatchNothing", null, null, null),
+        Arrays.asList("Алиса в Зазеркалье", "NO", null, null, null),
+        Arrays.asList("Cream Soda", "SU", null, null, null),
+        Arrays.asList("Wendigo", "SV", null, null, null),
+        Arrays.asList("Carlo Curti", "US", null, null, null),
+        Arrays.asList("DirecTV", "US", null, null, null),
+        Arrays.asList("Old Anatolian Turkish", "US", null, null, null),
+        Arrays.asList("Otjiwarongo Airport", "US", null, null, null),
+        Arrays.asList("President of India", "US", null, null, null)
+    );
+
+    assertResult(processor, outputChannel.readable(), joinSignature, expectedRows);
+  }
+
+  @Test
+  public void testInnerJoinEmptyRightSide() throws Exception
+  {
+    final ReadableInput factChannel = buildFactInput(
+        ImmutableList.of(
+            new KeyColumn("countryIsoCode", KeyOrder.ASCENDING),
+            new KeyColumn("page", KeyOrder.ASCENDING)
+        )
+    );
+
+    final ReadableInput countriesChannel = ReadableInput.channel(
+        ReadableNilFrameChannel.INSTANCE,
+        FrameReader.create(JoinTestHelper.COUNTRIES_SIGNATURE),
+        STAGE_PARTITION
+    );
+
+    final BlockingQueueFrameChannel outputChannel = BlockingQueueFrameChannel.minimal();
+
+    final RowSignature joinSignature =
+        RowSignature.builder()
+                    .add("page", ColumnType.STRING)
+                    .add("countryIsoCode", ColumnType.STRING)
+                    .add("j0.countryIsoCode", ColumnType.STRING)
+                    .add("j0.countryName", ColumnType.STRING)
+                    .add("j0.countryNumber", ColumnType.LONG)
+                    .build();
+
+    final SortMergeJoinFrameProcessor processor = new SortMergeJoinFrameProcessor(
+        factChannel,
+        countriesChannel,
+        outputChannel.writable(),
+        makeFrameWriterFactory(joinSignature),
+        "j0.",
+        ImmutableList.of(
+            ImmutableList.of(new KeyColumn("countryIsoCode", KeyOrder.ASCENDING)),
+            ImmutableList.of(new KeyColumn("countryIsoCode", KeyOrder.ASCENDING))
+        ),
+        JoinType.INNER
+    );
+
+    assertResult(processor, outputChannel.readable(), joinSignature, Collections.emptyList());
+  }
+
+  @Test
+  public void testLeftJoinCountryIsoCode() throws Exception
+  {
+    final ReadableInput factChannel = buildFactInput(
+        ImmutableList.of(
+            new KeyColumn("countryIsoCode", KeyOrder.ASCENDING),
+            new KeyColumn("page", KeyOrder.ASCENDING)
+        )
+    );
+
+    final ReadableInput countriesChannel =
+        buildCountriesInput(ImmutableList.of(new KeyColumn("countryIsoCode", KeyOrder.ASCENDING)));
+
+    final BlockingQueueFrameChannel outputChannel = BlockingQueueFrameChannel.minimal();
+
+    final RowSignature joinSignature =
+        RowSignature.builder()
+                    .add("page", ColumnType.STRING)
+                    .add("countryIsoCode", ColumnType.STRING)
+                    .add("j0.countryIsoCode", ColumnType.STRING)
+                    .add("j0.countryName", ColumnType.STRING)
+                    .add("j0.countryNumber", ColumnType.LONG)
+                    .build();
+
+    final SortMergeJoinFrameProcessor processor = new SortMergeJoinFrameProcessor(
+        factChannel,
+        countriesChannel,
+        outputChannel.writable(),
+        makeFrameWriterFactory(joinSignature),
+        "j0.",
+        ImmutableList.of(
+            ImmutableList.of(new KeyColumn("countryIsoCode", KeyOrder.ASCENDING)),
+            ImmutableList.of(new KeyColumn("countryIsoCode", KeyOrder.ASCENDING))
+        ),
+        JoinType.LEFT
+    );
+
+    final List<List<Object>> expectedRows = Arrays.asList(
+        Arrays.asList("Agama mossambica", null, null, null, null),
+        Arrays.asList("Apamea abruzzorum", null, null, null, null),
+        Arrays.asList("Atractus flammigerus", null, null, null, null),
+        Arrays.asList("Rallicula", null, null, null, null),
+        Arrays.asList("Talk:Oswald Tilghman", null, null, null, null),
+        Arrays.asList("Peremptory norm", "AU", "AU", "Australia", 0L),
+        Arrays.asList("Didier Leclair", "CA", "CA", "Canada", 1L),
+        Arrays.asList("Les Argonautes", "CA", "CA", "Canada", 1L),
+        Arrays.asList("Sarah Michelle Gellar", "CA", "CA", "Canada", 1L),
+        Arrays.asList("Golpe de Estado en Chile de 1973", "CL", "CL", "Chile", 2L),
+        Arrays.asList("Diskussion:Sebastian Schulz", "DE", "DE", "Germany", 3L),
+        Arrays.asList("Gabinete Ministerial de Rafael Correa", "EC", "EC", "Ecuador", 4L),
+        Arrays.asList("Saison 9 de Secret Story", "FR", "FR", "France", 5L),
+        Arrays.asList("Glasgow", "GB", "GB", "United Kingdom", 6L),
+        Arrays.asList("Giusy Ferreri discography", "IT", "IT", "Italy", 7L),
+        Arrays.asList("Roma-Bangkok", "IT", "IT", "Italy", 7L),
+        Arrays.asList("青野武", "JP", "JP", "Japan", 8L),
+        Arrays.asList("유희왕 GX", "KR", "KR", "Republic of Korea", 9L),
+        Arrays.asList("History of Fourems", "MMMM", "MMMM", "Fourems", 205L),
+        Arrays.asList("Mathis Bolly", "MX", "MX", "Mexico", 10L),
+        Arrays.asList("Orange Soda", "MatchNothing", null, null, null),
+        Arrays.asList("Алиса в Зазеркалье", "NO", "NO", "Norway", 11L),
+        Arrays.asList("Cream Soda", "SU", "SU", "States United", 15L),
+        Arrays.asList("Wendigo", "SV", "SV", "El Salvador", 12L),
+        Arrays.asList("Carlo Curti", "US", "US", "United States", 13L),
+        Arrays.asList("DirecTV", "US", "US", "United States", 13L),
+        Arrays.asList("Old Anatolian Turkish", "US", "US", "United States", 13L),
+        Arrays.asList("Otjiwarongo Airport", "US", "US", "United States", 13L),
+        Arrays.asList("President of India", "US", "US", "United States", 13L)
+    );
+
+    assertResult(processor, outputChannel.readable(), joinSignature, expectedRows);
+  }
+
+  @Test
+  public void testCrossJoin() throws Exception
+  {
+    final ReadableInput factChannel = buildFactInput(
+        ImmutableList.of(
+            new KeyColumn("countryIsoCode", KeyOrder.ASCENDING),
+            new KeyColumn("page", KeyOrder.ASCENDING)
+        )
+    );
+
+    final ReadableInput countriesChannel = makeChannelFromResourceWithLimit(
+        JoinTestHelper.COUNTRIES_RESOURCE,
+        JoinTestHelper.COUNTRIES_SIGNATURE,
+        ImmutableList.of(new KeyColumn("countryIsoCode", KeyOrder.ASCENDING)),
+        2
+    );
+
+    final BlockingQueueFrameChannel outputChannel = BlockingQueueFrameChannel.minimal();
+
+    final RowSignature joinSignature =
+        RowSignature.builder()
+                    .add("j0.page", ColumnType.STRING)
+                    .add("countryIsoCode", ColumnType.STRING)
+                    .build();
+
+    final SortMergeJoinFrameProcessor processor = new SortMergeJoinFrameProcessor(
+        countriesChannel,
+        factChannel,
+        outputChannel.writable(),
+        makeFrameWriterFactory(joinSignature),
+        "j0.",
+        ImmutableList.of(Collections.emptyList(), Collections.emptyList()),
+        JoinType.INNER
+    );
+
+    final List<List<Object>> expectedRows = Arrays.asList(
+        Arrays.asList("Agama mossambica", "AU"),
+        Arrays.asList("Agama mossambica", "CA"),
+        Arrays.asList("Apamea abruzzorum", "AU"),
+        Arrays.asList("Apamea abruzzorum", "CA"),
+        Arrays.asList("Atractus flammigerus", "AU"),
+        Arrays.asList("Atractus flammigerus", "CA"),
+        Arrays.asList("Rallicula", "AU"),
+        Arrays.asList("Rallicula", "CA"),
+        Arrays.asList("Talk:Oswald Tilghman", "AU"),
+        Arrays.asList("Talk:Oswald Tilghman", "CA"),
+        Arrays.asList("Peremptory norm", "AU"),
+        Arrays.asList("Peremptory norm", "CA"),
+        Arrays.asList("Didier Leclair", "AU"),
+        Arrays.asList("Didier Leclair", "CA"),
+        Arrays.asList("Les Argonautes", "AU"),
+        Arrays.asList("Les Argonautes", "CA"),
+        Arrays.asList("Sarah Michelle Gellar", "AU"),
+        Arrays.asList("Sarah Michelle Gellar", "CA"),
+        Arrays.asList("Golpe de Estado en Chile de 1973", "AU"),
+        Arrays.asList("Golpe de Estado en Chile de 1973", "CA"),
+        Arrays.asList("Diskussion:Sebastian Schulz", "AU"),
+        Arrays.asList("Diskussion:Sebastian Schulz", "CA"),
+        Arrays.asList("Gabinete Ministerial de Rafael Correa", "AU"),
+        Arrays.asList("Gabinete Ministerial de Rafael Correa", "CA"),
+        Arrays.asList("Saison 9 de Secret Story", "AU"),
+        Arrays.asList("Saison 9 de Secret Story", "CA"),
+        Arrays.asList("Glasgow", "AU"),
+        Arrays.asList("Glasgow", "CA"),
+        Arrays.asList("Giusy Ferreri discography", "AU"),
+        Arrays.asList("Giusy Ferreri discography", "CA"),
+        Arrays.asList("Roma-Bangkok", "AU"),
+        Arrays.asList("Roma-Bangkok", "CA"),
+        Arrays.asList("青野武", "AU"),
+        Arrays.asList("青野武", "CA"),
+        Arrays.asList("유희왕 GX", "AU"),
+        Arrays.asList("유희왕 GX", "CA"),
+        Arrays.asList("History of Fourems", "AU"),
+        Arrays.asList("History of Fourems", "CA"),
+        Arrays.asList("Mathis Bolly", "AU"),
+        Arrays.asList("Mathis Bolly", "CA"),
+        Arrays.asList("Orange Soda", "AU"),
+        Arrays.asList("Orange Soda", "CA"),
+        Arrays.asList("Алиса в Зазеркалье", "AU"),
+        Arrays.asList("Алиса в Зазеркалье", "CA"),
+        Arrays.asList("Cream Soda", "AU"),
+        Arrays.asList("Cream Soda", "CA"),
+        Arrays.asList("Wendigo", "AU"),
+        Arrays.asList("Wendigo", "CA"),
+        Arrays.asList("Carlo Curti", "AU"),
+        Arrays.asList("Carlo Curti", "CA"),
+        Arrays.asList("DirecTV", "AU"),
+        Arrays.asList("DirecTV", "CA"),
+        Arrays.asList("Old Anatolian Turkish", "AU"),
+        Arrays.asList("Old Anatolian Turkish", "CA"),
+        Arrays.asList("Otjiwarongo Airport", "AU"),
+        Arrays.asList("Otjiwarongo Airport", "CA"),
+        Arrays.asList("President of India", "AU"),
+        Arrays.asList("President of India", "CA")
+    );
+
+    assertResult(processor, outputChannel.readable(), joinSignature, expectedRows);
+  }
+
+  @Test
+  public void testLeftJoinRegions() throws Exception
+  {
+    final ReadableInput factChannel =
+        buildFactInput(
+            ImmutableList.of(
+                new KeyColumn("countryIsoCode", KeyOrder.ASCENDING),
+                new KeyColumn("regionIsoCode", KeyOrder.ASCENDING),
+                new KeyColumn("page", KeyOrder.ASCENDING)
+            )
+        );
+
+    final ReadableInput regionsChannel =
+        buildRegionsInput(
+            ImmutableList.of(
+                new KeyColumn("countryIsoCode", KeyOrder.ASCENDING),
+                new KeyColumn("regionIsoCode", KeyOrder.ASCENDING)
+            )
+        );
+
+    final BlockingQueueFrameChannel outputChannel = BlockingQueueFrameChannel.minimal();
+
+    final RowSignature joinSignature =
+        RowSignature.builder()
+                    .add("page", ColumnType.STRING)
+                    .add("j0.regionName", ColumnType.STRING)
+                    .add("countryIsoCode", ColumnType.STRING)
+                    .build();
+
+    final SortMergeJoinFrameProcessor processor = new SortMergeJoinFrameProcessor(
+        factChannel,
+        regionsChannel,
+        outputChannel.writable(),
+        makeFrameWriterFactory(joinSignature),
+        "j0.",
+        ImmutableList.of(
+            ImmutableList.of(
+                new KeyColumn("countryIsoCode", KeyOrder.ASCENDING),
+                new KeyColumn("regionIsoCode", KeyOrder.ASCENDING)
+            ),
+            ImmutableList.of(
+                new KeyColumn("countryIsoCode", KeyOrder.ASCENDING),
+                new KeyColumn("regionIsoCode", KeyOrder.ASCENDING)
+            )
+        ),
+        JoinType.LEFT
+    );
+
+    final List<List<Object>> expectedRows = Arrays.asList(
+        Arrays.asList("Agama mossambica", null, null),
+        Arrays.asList("Apamea abruzzorum", null, null),
+        Arrays.asList("Atractus flammigerus", null, null),
+        Arrays.asList("Rallicula", null, null),
+        Arrays.asList("Talk:Oswald Tilghman", null, null),
+        Arrays.asList("Peremptory norm", "New South Wales", "AU"),
+        Arrays.asList("Didier Leclair", "Ontario", "CA"),
+        Arrays.asList("Sarah Michelle Gellar", "Ontario", "CA"),
+        Arrays.asList("Les Argonautes", "Quebec", "CA"),
+        Arrays.asList("Golpe de Estado en Chile de 1973", "Santiago Metropolitan", "CL"),
+        Arrays.asList("Diskussion:Sebastian Schulz", "Hesse", "DE"),
+        Arrays.asList("Gabinete Ministerial de Rafael Correa", "Provincia del Guayas", "EC"),
+        Arrays.asList("Saison 9 de Secret Story", "Val d'Oise", "FR"),
+        Arrays.asList("Glasgow", "Kingston upon Hull", "GB"),
+        Arrays.asList("Giusy Ferreri discography", "Provincia di Varese", "IT"),
+        Arrays.asList("Roma-Bangkok", "Provincia di Varese", "IT"),
+        Arrays.asList("青野武", "Tōkyō", "JP"),
+        Arrays.asList("유희왕 GX", "Seoul", "KR"),
+        Arrays.asList("History of Fourems", "Fourems Province", "MMMM"),
+        Arrays.asList("Mathis Bolly", "Mexico City", "MX"),
+        Arrays.asList("Orange Soda", null, "MatchNothing"),
+        Arrays.asList("Алиса в Зазеркалье", "Finnmark Fylke", "NO"),
+        Arrays.asList("Cream Soda", "Ainigriv", "SU"),
+        Arrays.asList("Wendigo", "Departamento de San Salvador", "SV"),
+        Arrays.asList("Carlo Curti", "California", "US"),
+        Arrays.asList("Otjiwarongo Airport", "California", "US"),
+        Arrays.asList("President of India", "California", "US"),
+        Arrays.asList("DirecTV", "North Carolina", "US"),
+        Arrays.asList("Old Anatolian Turkish", "Virginia", "US")
+    );
+
+    assertResult(processor, outputChannel.readable(), joinSignature, expectedRows);
+  }
+
+  @Test
+  public void testRightJoinRegionCodeOnly() throws Exception
+  {
+    // This join generates duplicates.
+
+    final ReadableInput factChannel =
+        buildFactInput(
+            ImmutableList.of(
+                new KeyColumn("regionIsoCode", KeyOrder.ASCENDING),
+                new KeyColumn("page", KeyOrder.ASCENDING)
+            )
+        );
+
+    final ReadableInput regionsChannel =
+        buildRegionsInput(
+            ImmutableList.of(
+                new KeyColumn("regionIsoCode", KeyOrder.ASCENDING),
+                new KeyColumn("countryIsoCode", KeyOrder.ASCENDING)
+            )
+        );
+
+    final BlockingQueueFrameChannel outputChannel = BlockingQueueFrameChannel.minimal();
+
+    final RowSignature joinSignature =
+        RowSignature.builder()
+                    .add("j0.page", ColumnType.STRING)
+                    .add("regionName", ColumnType.STRING)
+                    .add("j0.countryIsoCode", ColumnType.STRING)
+                    .build();
+
+    final SortMergeJoinFrameProcessor processor = new SortMergeJoinFrameProcessor(
+        regionsChannel,
+        factChannel,
+        outputChannel.writable(),
+        makeFrameWriterFactory(joinSignature),
+        "j0.",
+        ImmutableList.of(
+            ImmutableList.of(new KeyColumn("regionIsoCode", KeyOrder.ASCENDING)),
+            ImmutableList.of(new KeyColumn("regionIsoCode", KeyOrder.ASCENDING))
+        ),
+        JoinType.RIGHT
+    );
+
+    final List<List<Object>> expectedRows = Arrays.asList(
+        Arrays.asList("Agama mossambica", null, null),
+        Arrays.asList("Apamea abruzzorum", null, null),
+        Arrays.asList("Atractus flammigerus", null, null),
+        Arrays.asList("Rallicula", null, null),
+        Arrays.asList("Talk:Oswald Tilghman", null, null),
+        Arrays.asList("유희왕 GX", "Seoul", "KR"),
+        Arrays.asList("青野武", "Tōkyō", "JP"),
+        Arrays.asList("Алиса в Зазеркалье", "Finnmark Fylke", "NO"),
+        Arrays.asList("Saison 9 de Secret Story", "Val d'Oise", "FR"),
+        Arrays.asList("Cream Soda", "Ainigriv", "SU"),
+        Arrays.asList("Carlo Curti", "California", "US"),
+        Arrays.asList("Otjiwarongo Airport", "California", "US"),
+        Arrays.asList("President of India", "California", "US"),
+        Arrays.asList("Mathis Bolly", "Mexico City", "MX"),
+        Arrays.asList("Gabinete Ministerial de Rafael Correa", "Provincia del Guayas", "EC"),
+        Arrays.asList("Diskussion:Sebastian Schulz", "Hesse", "DE"),
+        Arrays.asList("Glasgow", "Kingston upon Hull", "GB"),
+        Arrays.asList("History of Fourems", "Fourems Province", "MMMM"),
+        Arrays.asList("Orange Soda", null, "MatchNothing"),
+        Arrays.asList("DirecTV", "North Carolina", "US"),
+        Arrays.asList("Peremptory norm", "New South Wales", "AU"),
+        Arrays.asList("Didier Leclair", "Ontario", "CA"),
+        Arrays.asList("Sarah Michelle Gellar", "Ontario", "CA"),
+        Arrays.asList("Les Argonautes", "Quebec", "CA"),
+        Arrays.asList("Golpe de Estado en Chile de 1973", "Santiago Metropolitan", "CL"),
+        Arrays.asList("Wendigo", "Departamento de San Salvador", "SV"),
+        Arrays.asList("Giusy Ferreri discography", "Provincia di Varese", "IT"),
+        Arrays.asList("Giusy Ferreri discography", "Virginia", "IT"),
+        Arrays.asList("Old Anatolian Turkish", "Provincia di Varese", "US"),
+        Arrays.asList("Old Anatolian Turkish", "Virginia", "US"),
+        Arrays.asList("Roma-Bangkok", "Provincia di Varese", "IT"),
+        Arrays.asList("Roma-Bangkok", "Virginia", "IT")
+    );
+
+    assertResult(processor, outputChannel.readable(), joinSignature, expectedRows);
+  }
+
+  @Test
+  public void testFullOuterJoinRegionCodeOnly() throws Exception
+  {
+    // This join generates duplicates.
+
+    final ReadableInput factChannel =
+        buildFactInput(
+            ImmutableList.of(
+                new KeyColumn("regionIsoCode", KeyOrder.ASCENDING),
+                new KeyColumn("page", KeyOrder.ASCENDING)
+            )
+        );
+
+    final ReadableInput regionsChannel =
+        buildRegionsInput(
+            ImmutableList.of(
+                new KeyColumn("regionIsoCode", KeyOrder.ASCENDING),
+                new KeyColumn("countryIsoCode", KeyOrder.ASCENDING)
+            )
+        );
+
+    final BlockingQueueFrameChannel outputChannel = BlockingQueueFrameChannel.minimal();
+
+    final RowSignature joinSignature =
+        RowSignature.builder()
+                    .add("j0.page", ColumnType.STRING)
+                    .add("regionName", ColumnType.STRING)
+                    .add("j0.countryIsoCode", ColumnType.STRING)
+                    .build();
+
+    final SortMergeJoinFrameProcessor processor = new SortMergeJoinFrameProcessor(
+        regionsChannel,
+        factChannel,
+        outputChannel.writable(),
+        makeFrameWriterFactory(joinSignature),
+        "j0.",
+        ImmutableList.of(
+            ImmutableList.of(new KeyColumn("regionIsoCode", KeyOrder.ASCENDING)),
+            ImmutableList.of(new KeyColumn("regionIsoCode", KeyOrder.ASCENDING))
+        ),
+        JoinType.FULL
+    );
+
+    final List<List<Object>> expectedRows = Arrays.asList(
+        Arrays.asList(null, "Nulland", null),
+        Arrays.asList("Agama mossambica", null, null),
+        Arrays.asList("Apamea abruzzorum", null, null),
+        Arrays.asList("Atractus flammigerus", null, null),
+        Arrays.asList("Rallicula", null, null),
+        Arrays.asList("Talk:Oswald Tilghman", null, null),
+        Arrays.asList("유희왕 GX", "Seoul", "KR"),
+        Arrays.asList("青野武", "Tōkyō", "JP"),
+        Arrays.asList("Алиса в Зазеркалье", "Finnmark Fylke", "NO"),
+        Arrays.asList("Saison 9 de Secret Story", "Val d'Oise", "FR"),
+        Arrays.asList(null, "Foureis Province", null),
+        Arrays.asList("Cream Soda", "Ainigriv", "SU"),
+        Arrays.asList("Carlo Curti", "California", "US"),
+        Arrays.asList("Otjiwarongo Airport", "California", "US"),
+        Arrays.asList("President of India", "California", "US"),
+        Arrays.asList("Mathis Bolly", "Mexico City", "MX"),
+        Arrays.asList("Gabinete Ministerial de Rafael Correa", "Provincia del Guayas", "EC"),
+        Arrays.asList("Diskussion:Sebastian Schulz", "Hesse", "DE"),
+        Arrays.asList("Glasgow", "Kingston upon Hull", "GB"),
+        Arrays.asList("History of Fourems", "Fourems Province", "MMMM"),
+        Arrays.asList("Orange Soda", null, "MatchNothing"),
+        Arrays.asList("DirecTV", "North Carolina", "US"),
+        Arrays.asList("Peremptory norm", "New South Wales", "AU"),
+        Arrays.asList("Didier Leclair", "Ontario", "CA"),
+        Arrays.asList("Sarah Michelle Gellar", "Ontario", "CA"),
+        Arrays.asList("Les Argonautes", "Quebec", "CA"),
+        Arrays.asList("Golpe de Estado en Chile de 1973", "Santiago Metropolitan", "CL"),
+        Arrays.asList("Wendigo", "Departamento de San Salvador", "SV"),
+        Arrays.asList("Giusy Ferreri discography", "Provincia di Varese", "IT"),
+        Arrays.asList("Giusy Ferreri discography", "Virginia", "IT"),
+        Arrays.asList("Old Anatolian Turkish", "Provincia di Varese", "US"),
+        Arrays.asList("Old Anatolian Turkish", "Virginia", "US"),
+        Arrays.asList("Roma-Bangkok", "Provincia di Varese", "IT"),
+        Arrays.asList("Roma-Bangkok", "Virginia", "IT"),
+        Arrays.asList(null, "Usca City", null)
+    );
+
+    assertResult(processor, outputChannel.readable(), joinSignature, expectedRows);
+  }
+
+  @Test
+  public void testLeftJoinCountryNumber() throws Exception
+  {
+    final ReadableInput factChannel = buildFactInput(
+        ImmutableList.of(
+            new KeyColumn("countryNumber", KeyOrder.ASCENDING),
+            new KeyColumn("page", KeyOrder.ASCENDING)
+        )
+    );
+
+    final ReadableInput countriesChannel =
+        buildCountriesInput(ImmutableList.of(new KeyColumn("countryNumber", KeyOrder.ASCENDING)));
+
+    final BlockingQueueFrameChannel outputChannel = BlockingQueueFrameChannel.minimal();
+
+    final RowSignature joinSignature =
+        RowSignature.builder()
+                    .add("page", ColumnType.STRING)
+                    .add("countryIsoCode", ColumnType.STRING)
+                    .add("j0.countryIsoCode", ColumnType.STRING)
+                    .add("j0.countryName", ColumnType.STRING)
+                    .add("j0.countryNumber", ColumnType.LONG)
+                    .build();
+
+    final SortMergeJoinFrameProcessor processor = new SortMergeJoinFrameProcessor(
+        factChannel,
+        countriesChannel,
+        outputChannel.writable(),
+        makeFrameWriterFactory(joinSignature),
+        "j0.",
+        ImmutableList.of(
+            ImmutableList.of(new KeyColumn("countryNumber", KeyOrder.ASCENDING)),
+            ImmutableList.of(new KeyColumn("countryNumber", KeyOrder.ASCENDING))
+        ),
+        JoinType.LEFT
+    );
+
+    final String countryCodeForNull;
+    final String countryNameForNull;
+    final Long countryNumberForNull;
+
+    if (NullHandling.sqlCompatible()) {
+      countryCodeForNull = null;
+      countryNameForNull = null;
+      countryNumberForNull = null;
+    } else {
+      // In default-value mode, null country number from the left-hand table converts to zero, which matches Australia.
+      countryCodeForNull = "AU";
+      countryNameForNull = "Australia";
+      countryNumberForNull = 0L;
+    }
+
+    final List<List<Object>> expectedRows = Lists.newArrayList(
+        Arrays.asList("Agama mossambica", null, countryCodeForNull, countryNameForNull, countryNumberForNull),
+        Arrays.asList("Apamea abruzzorum", null, countryCodeForNull, countryNameForNull, countryNumberForNull),
+        Arrays.asList("Atractus flammigerus", null, countryCodeForNull, countryNameForNull, countryNumberForNull),
+        Arrays.asList("Rallicula", null, countryCodeForNull, countryNameForNull, countryNumberForNull),
+        Arrays.asList("Talk:Oswald Tilghman", null, countryCodeForNull, countryNameForNull, countryNumberForNull),
+        Arrays.asList("Peremptory norm", "AU", "AU", "Australia", 0L),
+        Arrays.asList("Didier Leclair", "CA", "CA", "Canada", 1L),
+        Arrays.asList("Les Argonautes", "CA", "CA", "Canada", 1L),
+        Arrays.asList("Sarah Michelle Gellar", "CA", "CA", "Canada", 1L),
+        Arrays.asList("Golpe de Estado en Chile de 1973", "CL", "CL", "Chile", 2L),
+        Arrays.asList("Diskussion:Sebastian Schulz", "DE", "DE", "Germany", 3L),
+        Arrays.asList("Gabinete Ministerial de Rafael Correa", "EC", "EC", "Ecuador", 4L),
+        Arrays.asList("Saison 9 de Secret Story", "FR", "FR", "France", 5L),
+        Arrays.asList("Glasgow", "GB", "GB", "United Kingdom", 6L),
+        Arrays.asList("Giusy Ferreri discography", "IT", "IT", "Italy", 7L),
+        Arrays.asList("Roma-Bangkok", "IT", "IT", "Italy", 7L),
+        Arrays.asList("青野武", "JP", "JP", "Japan", 8L),
+        Arrays.asList("유희왕 GX", "KR", "KR", "Republic of Korea", 9L),
+        Arrays.asList("Mathis Bolly", "MX", "MX", "Mexico", 10L),
+        Arrays.asList("Алиса в Зазеркалье", "NO", "NO", "Norway", 11L),
+        Arrays.asList("Wendigo", "SV", "SV", "El Salvador", 12L),
+        Arrays.asList("Carlo Curti", "US", "US", "United States", 13L),
+        Arrays.asList("DirecTV", "US", "US", "United States", 13L),
+        Arrays.asList("Old Anatolian Turkish", "US", "US", "United States", 13L),
+        Arrays.asList("Otjiwarongo Airport", "US", "US", "United States", 13L),
+        Arrays.asList("President of India", "US", "US", "United States", 13L),
+        Arrays.asList("Cream Soda", "SU", "SU", "States United", 15L),
+        Arrays.asList("Orange Soda", "MatchNothing", null, null, null),
+        Arrays.asList("History of Fourems", "MMMM", "MMMM", "Fourems", 205L)
+    );
+
+    if (!NullHandling.sqlCompatible()) {
+      // Sorting order is different in default-value mode, since 0 and null collapse.
+      // "Peremptory norm" moves before "Rallicula".
+      expectedRows.add(3, expectedRows.remove(5));
+    }
+
+    assertResult(processor, outputChannel.readable(), joinSignature, expectedRows);
+  }
+
+  @Test
+  public void testRightJoinCountryNumber() throws Exception
+  {
+    final ReadableInput factChannel = buildFactInput(
+        ImmutableList.of(
+            new KeyColumn("countryNumber", KeyOrder.ASCENDING),
+            new KeyColumn("page", KeyOrder.ASCENDING)
+        )
+    );
+
+    final ReadableInput countriesChannel =
+        buildCountriesInput(ImmutableList.of(new KeyColumn("countryNumber", KeyOrder.ASCENDING)));
+
+    final BlockingQueueFrameChannel outputChannel = BlockingQueueFrameChannel.minimal();
+
+    final RowSignature joinSignature =
+        RowSignature.builder()
+                    .add("j0.page", ColumnType.STRING)
+                    .add("j0.countryIsoCode", ColumnType.STRING)
+                    .add("countryIsoCode", ColumnType.STRING)
+                    .add("countryName", ColumnType.STRING)
+                    .add("countryNumber", ColumnType.LONG)
+                    .build();
+
+    final SortMergeJoinFrameProcessor processor = new SortMergeJoinFrameProcessor(
+        countriesChannel,
+        factChannel,
+        outputChannel.writable(),
+        makeFrameWriterFactory(joinSignature),
+        "j0.",
+        ImmutableList.of(
+            ImmutableList.of(new KeyColumn("countryNumber", KeyOrder.ASCENDING)),
+            ImmutableList.of(new KeyColumn("countryNumber", KeyOrder.ASCENDING))
+        ),
+        JoinType.RIGHT
+    );
+
+    final String countryCodeForNull;
+    final String countryNameForNull;
+    final Long countryNumberForNull;
+
+    if (NullHandling.sqlCompatible()) {
+      countryCodeForNull = null;
+      countryNameForNull = null;
+      countryNumberForNull = null;
+    } else {
+      // In default-value mode, null country number from the left-hand table converts to zero, which matches Australia.
+      countryCodeForNull = "AU";
+      countryNameForNull = "Australia";
+      countryNumberForNull = 0L;
+    }
+
+    final List<List<Object>> expectedRows = Lists.newArrayList(
+        Arrays.asList("Agama mossambica", null, countryCodeForNull, countryNameForNull, countryNumberForNull),
+        Arrays.asList("Apamea abruzzorum", null, countryCodeForNull, countryNameForNull, countryNumberForNull),
+        Arrays.asList("Atractus flammigerus", null, countryCodeForNull, countryNameForNull, countryNumberForNull),
+        Arrays.asList("Rallicula", null, countryCodeForNull, countryNameForNull, countryNumberForNull),
+        Arrays.asList("Talk:Oswald Tilghman", null, countryCodeForNull, countryNameForNull, countryNumberForNull),
+        Arrays.asList("Peremptory norm", "AU", "AU", "Australia", 0L),
+        Arrays.asList("Didier Leclair", "CA", "CA", "Canada", 1L),
+        Arrays.asList("Les Argonautes", "CA", "CA", "Canada", 1L),
+        Arrays.asList("Sarah Michelle Gellar", "CA", "CA", "Canada", 1L),
+        Arrays.asList("Golpe de Estado en Chile de 1973", "CL", "CL", "Chile", 2L),
+        Arrays.asList("Diskussion:Sebastian Schulz", "DE", "DE", "Germany", 3L),
+        Arrays.asList("Gabinete Ministerial de Rafael Correa", "EC", "EC", "Ecuador", 4L),
+        Arrays.asList("Saison 9 de Secret Story", "FR", "FR", "France", 5L),
+        Arrays.asList("Glasgow", "GB", "GB", "United Kingdom", 6L),
+        Arrays.asList("Giusy Ferreri discography", "IT", "IT", "Italy", 7L),
+        Arrays.asList("Roma-Bangkok", "IT", "IT", "Italy", 7L),
+        Arrays.asList("青野武", "JP", "JP", "Japan", 8L),
+        Arrays.asList("유희왕 GX", "KR", "KR", "Republic of Korea", 9L),
+        Arrays.asList("Mathis Bolly", "MX", "MX", "Mexico", 10L),
+        Arrays.asList("Алиса в Зазеркалье", "NO", "NO", "Norway", 11L),
+        Arrays.asList("Wendigo", "SV", "SV", "El Salvador", 12L),
+        Arrays.asList("Carlo Curti", "US", "US", "United States", 13L),
+        Arrays.asList("DirecTV", "US", "US", "United States", 13L),
+        Arrays.asList("Old Anatolian Turkish", "US", "US", "United States", 13L),
+        Arrays.asList("Otjiwarongo Airport", "US", "US", "United States", 13L),
+        Arrays.asList("President of India", "US", "US", "United States", 13L),
+        Arrays.asList("Cream Soda", "SU", "SU", "States United", 15L),
+        Arrays.asList("Orange Soda", "MatchNothing", null, null, null),
+        Arrays.asList("History of Fourems", "MMMM", "MMMM", "Fourems", 205L)
+    );
+
+    if (!NullHandling.sqlCompatible()) {
+      // Sorting order is different in default-value mode, since 0 and null collapse.
+      // "Peremptory norm" moves before "Rallicula".
+      expectedRows.add(3, expectedRows.remove(5));
+    }
+
+    assertResult(processor, outputChannel.readable(), joinSignature, expectedRows);
+  }
+
+  @Test
+  public void testInnerJoinCountryIsoCode() throws Exception
+  {
+    final ReadableInput factChannel = buildFactInput(
+        ImmutableList.of(
+            new KeyColumn("countryIsoCode", KeyOrder.ASCENDING),
+            new KeyColumn("page", KeyOrder.ASCENDING)
+        )
+    );
+
+    final ReadableInput countriesChannel =
+        buildCountriesInput(ImmutableList.of(new KeyColumn("countryIsoCode", KeyOrder.ASCENDING)));
+
+    final BlockingQueueFrameChannel outputChannel = BlockingQueueFrameChannel.minimal();
+
+    final RowSignature joinSignature =
+        RowSignature.builder()
+                    .add("page", ColumnType.STRING)
+                    .add("countryIsoCode", ColumnType.STRING)
+                    .add("j0.countryIsoCode", ColumnType.STRING)
+                    .add("j0.countryName", ColumnType.STRING)
+                    .add("j0.countryNumber", ColumnType.LONG)
+                    .build();
+
+    final SortMergeJoinFrameProcessor processor = new SortMergeJoinFrameProcessor(
+        factChannel,
+        countriesChannel,
+        outputChannel.writable(),
+        makeFrameWriterFactory(joinSignature),
+        "j0.",
+        ImmutableList.of(
+            ImmutableList.of(new KeyColumn("countryIsoCode", KeyOrder.ASCENDING)),
+            ImmutableList.of(new KeyColumn("countryIsoCode", KeyOrder.ASCENDING))
+        ),
+        JoinType.INNER
+    );
+
+    final List<List<Object>> expectedRows = Arrays.asList(
+        Arrays.asList("Peremptory norm", "AU", "AU", "Australia", 0L),
+        Arrays.asList("Didier Leclair", "CA", "CA", "Canada", 1L),
+        Arrays.asList("Les Argonautes", "CA", "CA", "Canada", 1L),
+        Arrays.asList("Sarah Michelle Gellar", "CA", "CA", "Canada", 1L),
+        Arrays.asList("Golpe de Estado en Chile de 1973", "CL", "CL", "Chile", 2L),
+        Arrays.asList("Diskussion:Sebastian Schulz", "DE", "DE", "Germany", 3L),
+        Arrays.asList("Gabinete Ministerial de Rafael Correa", "EC", "EC", "Ecuador", 4L),
+        Arrays.asList("Saison 9 de Secret Story", "FR", "FR", "France", 5L),
+        Arrays.asList("Glasgow", "GB", "GB", "United Kingdom", 6L),
+        Arrays.asList("Giusy Ferreri discography", "IT", "IT", "Italy", 7L),
+        Arrays.asList("Roma-Bangkok", "IT", "IT", "Italy", 7L),
+        Arrays.asList("青野武", "JP", "JP", "Japan", 8L),
+        Arrays.asList("유희왕 GX", "KR", "KR", "Republic of Korea", 9L),
+        Arrays.asList("History of Fourems", "MMMM", "MMMM", "Fourems", 205L),
+        Arrays.asList("Mathis Bolly", "MX", "MX", "Mexico", 10L),
+        Arrays.asList("Алиса в Зазеркалье", "NO", "NO", "Norway", 11L),
+        Arrays.asList("Cream Soda", "SU", "SU", "States United", 15L),
+        Arrays.asList("Wendigo", "SV", "SV", "El Salvador", 12L),
+        Arrays.asList("Carlo Curti", "US", "US", "United States", 13L),
+        Arrays.asList("DirecTV", "US", "US", "United States", 13L),
+        Arrays.asList("Old Anatolian Turkish", "US", "US", "United States", 13L),
+        Arrays.asList("Otjiwarongo Airport", "US", "US", "United States", 13L),
+        Arrays.asList("President of India", "US", "US", "United States", 13L)
+    );
+
+    assertResult(processor, outputChannel.readable(), joinSignature, expectedRows);
+  }
+
+  private void assertResult(
+      final SortMergeJoinFrameProcessor processor,
+      final ReadableFrameChannel readableOutputChannel,
+      final RowSignature joinSignature,
+      final List<List<Object>> expectedRows
+  )
+  {
+    final ListenableFuture<Long> retVal = exec.runFully(processor, null);
+    final Sequence<List<Object>> rowsFromProcessor = FrameTestUtil.readRowsFromFrameChannel(
+        readableOutputChannel,
+        FrameReader.create(joinSignature)
+    );
+
+    FrameTestUtil.assertRowsEqual(Sequences.simple(expectedRows), rowsFromProcessor);
+    Assert.assertEquals(0L, (long) FutureUtils.getUnchecked(retVal, true));
+  }
+
+  private ReadableInput buildFactInput(final List<KeyColumn> keyColumns) throws IOException
+  {
+    return makeChannelFromResource(
+        JoinTestHelper.FACT_RESOURCE,
+        JoinTestHelper.FACT_SIGNATURE,
+        keyColumns
+    );
+  }
+
+  private ReadableInput buildCountriesInput(final List<KeyColumn> keyColumns) throws IOException
+  {
+    return makeChannelFromResource(
+        JoinTestHelper.COUNTRIES_RESOURCE,
+        JoinTestHelper.COUNTRIES_SIGNATURE,
+        keyColumns
+    );
+  }
+
+  private ReadableInput buildRegionsInput(final List<KeyColumn> keyColumns) throws IOException
+  {
+    return makeChannelFromResource(
+        JoinTestHelper.REGIONS_RESOURCE,
+        JoinTestHelper.REGIONS_SIGNATURE,
+        keyColumns
+    );
+  }
+
+  private ReadableInput makeChannelFromResource(
+      final String resource,
+      final RowSignature signature,
+      final List<KeyColumn> keyColumns
+  ) throws IOException
+  {
+    return makeChannelFromResourceWithLimit(resource, signature, keyColumns, -1);
+  }
+
+  private ReadableInput makeChannelFromResourceWithLimit(
+      final String resource,
+      final RowSignature signature,
+      final List<KeyColumn> keyColumns,
+      final long limit
+  ) throws IOException
+  {
+    try (final RowBasedSegment<Map<String, Object>> segment = JoinTestHelper.withRowsFromResource(
+        resource,
+        rows -> new RowBasedSegment<>(
+            SegmentId.dummy(resource),
+            limit < 0 ? Sequences.simple(rows) : Sequences.simple(rows).limit(limit),
+            columnName -> m -> m.get(columnName),
+            signature
+        )
+    )) {
+      final StorageAdapter adapter = segment.asStorageAdapter();
+      return makeChannelFromAdapter(adapter, keyColumns);
+    }
+  }
+
+  private ReadableInput makeChannelFromAdapter(
+      final StorageAdapter adapter,
+      final List<KeyColumn> keyColumns
+  ) throws IOException
+  {
+    // Create a single, sorted frame.
+    final FrameSequenceBuilder singleFrameBuilder =
+        FrameSequenceBuilder.fromAdapter(adapter)
+                            .frameType(FrameType.ROW_BASED)
+                            .maxRowsPerFrame(Integer.MAX_VALUE)
+                            .sortBy(keyColumns);
+
+    final RowSignature signature = singleFrameBuilder.signature();
+    final Frame frame = Iterables.getOnlyElement(singleFrameBuilder.frames().toList());
+
+    // Split it up into frames that match rowsPerFrame. Set max size enough to hold all rows we might ever want to use.
+    final BlockingQueueFrameChannel channel = new BlockingQueueFrameChannel(10_000);
+
+    final FrameReader frameReader = FrameReader.create(signature);
+
+    final FrameSequenceBuilder frameSequenceBuilder =
+        FrameSequenceBuilder.fromAdapter(new FrameStorageAdapter(frame, frameReader, Intervals.ETERNITY))
+                            .frameType(FrameType.ROW_BASED)
+                            .maxRowsPerFrame(rowsPerInputFrame);
+
+    final Sequence<Frame> frames = frameSequenceBuilder.frames();
+    frames.forEach(
+        f -> {
+          try {
+            channel.writable().write(f);
+          }
+          catch (IOException e) {
+            throw new RuntimeException(e);
+          }
+        }
+    );
+
+    channel.writable().close();
+    return ReadableInput.channel(channel.readable(), FrameReader.create(signature), STAGE_PARTITION);
+  }
+
+  private FrameWriterFactory makeFrameWriterFactory(final RowSignature signature)
+  {
+    return new LimitedFrameWriterFactory(
+        FrameWriters.makeFrameWriterFactory(
+            FrameType.ROW_BASED,
+            new SingleMemoryAllocatorFactory(ArenaMemoryAllocator.createOnHeap(1_000_000)),
+            signature,
+            Collections.emptyList()
+        ),
+        rowsPerOutputFrame
+    );
+  }
+}
diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/querykit/scan/ScanQueryFrameProcessorTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/querykit/scan/ScanQueryFrameProcessorTest.java
index d93e8df42d..c6c76ae2e9 100644
--- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/querykit/scan/ScanQueryFrameProcessorTest.java
+++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/querykit/scan/ScanQueryFrameProcessorTest.java
@@ -22,19 +22,18 @@ package org.apache.druid.msq.querykit.scan;
 import com.google.common.util.concurrent.ListenableFuture;
 import com.google.common.util.concurrent.MoreExecutors;
 import it.unimi.dsi.fastutil.ints.Int2ObjectMaps;
-import org.apache.datasketches.memory.WritableMemory;
 import org.apache.druid.collections.ResourceHolder;
 import org.apache.druid.frame.Frame;
 import org.apache.druid.frame.FrameType;
 import org.apache.druid.frame.allocation.ArenaMemoryAllocator;
 import org.apache.druid.frame.allocation.HeapMemoryAllocator;
+import org.apache.druid.frame.allocation.SingleMemoryAllocatorFactory;
 import org.apache.druid.frame.channel.BlockingQueueFrameChannel;
 import org.apache.druid.frame.channel.WritableFrameChannel;
 import org.apache.druid.frame.processor.FrameProcessorExecutor;
 import org.apache.druid.frame.read.FrameReader;
 import org.apache.druid.frame.testutil.FrameSequenceBuilder;
 import org.apache.druid.frame.testutil.FrameTestUtil;
-import org.apache.druid.frame.write.FrameWriter;
 import org.apache.druid.frame.write.FrameWriterFactory;
 import org.apache.druid.frame.write.FrameWriters;
 import org.apache.druid.jackson.DefaultObjectMapper;
@@ -46,10 +45,10 @@ import org.apache.druid.msq.input.ReadableInput;
 import org.apache.druid.msq.kernel.StageId;
 import org.apache.druid.msq.kernel.StagePartition;
 import org.apache.druid.msq.querykit.LazyResourceHolder;
+import org.apache.druid.msq.test.LimitedFrameWriterFactory;
 import org.apache.druid.query.Druids;
 import org.apache.druid.query.scan.ScanQuery;
 import org.apache.druid.query.spec.MultipleIntervalSegmentSpec;
-import org.apache.druid.segment.ColumnSelectorFactory;
 import org.apache.druid.segment.TestIndex;
 import org.apache.druid.segment.column.RowSignature;
 import org.apache.druid.segment.incremental.IncrementalIndexStorageAdapter;
@@ -117,10 +116,10 @@ public class ScanQueryFrameProcessorTest extends InitializedNullHandlingTest
     final StagePartition stagePartition = new StagePartition(new StageId("query", 0), 0);
 
     // Limit output frames to 1 row to ensure we test edge cases
-    final FrameWriterFactory frameWriterFactory = limitedFrameWriterFactory(
+    final FrameWriterFactory frameWriterFactory = new LimitedFrameWriterFactory(
         FrameWriters.makeFrameWriterFactory(
             FrameType.ROW_BASED,
-            HeapMemoryAllocator.unlimited(),
+            new SingleMemoryAllocatorFactory(HeapMemoryAllocator.unlimited()),
             signature,
             Collections.emptyList()
         ),
@@ -171,72 +170,4 @@ public class ScanQueryFrameProcessorTest extends InitializedNullHandlingTest
 
     Assert.assertEquals(adapter.getNumRows(), (long) retVal.get());
   }
-
-  /**
-   * Wraps a {@link FrameWriterFactory}, creating a new factory that returns {@link FrameWriter} which write
-   * a limited number of rows.
-   */
-  private static FrameWriterFactory limitedFrameWriterFactory(final FrameWriterFactory baseFactory, final int rowLimit)
-  {
-    return new FrameWriterFactory()
-    {
-      @Override
-      public FrameWriter newFrameWriter(ColumnSelectorFactory columnSelectorFactory)
-      {
-        return new LimitedFrameWriter(baseFactory.newFrameWriter(columnSelectorFactory), rowLimit);
-      }
-
-      @Override
-      public long allocatorCapacity()
-      {
-        return baseFactory.allocatorCapacity();
-      }
-    };
-  }
-
-  private static class LimitedFrameWriter implements FrameWriter
-  {
-    private final FrameWriter baseWriter;
-    private final int rowLimit;
-
-    public LimitedFrameWriter(FrameWriter baseWriter, int rowLimit)
-    {
-      this.baseWriter = baseWriter;
-      this.rowLimit = rowLimit;
-    }
-
-    @Override
-    public boolean addSelection()
-    {
-      if (baseWriter.getNumRows() >= rowLimit) {
-        return false;
-      } else {
-        return baseWriter.addSelection();
-      }
-    }
-
-    @Override
-    public int getNumRows()
-    {
-      return baseWriter.getNumRows();
-    }
-
-    @Override
-    public long getTotalSize()
-    {
-      return baseWriter.getTotalSize();
-    }
-
-    @Override
-    public long writeTo(WritableMemory memory, long position)
-    {
-      return baseWriter.writeTo(memory, position);
-    }
-
-    @Override
-    public void close()
-    {
-      baseWriter.close();
-    }
-  }
 }
diff --git a/processing/src/test/java/org/apache/druid/frame/processor/DurableStorageOutputChannelFactoryTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/shuffle/DurableStorageOutputChannelFactoryTest.java
similarity index 93%
rename from processing/src/test/java/org/apache/druid/frame/processor/DurableStorageOutputChannelFactoryTest.java
rename to extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/shuffle/DurableStorageOutputChannelFactoryTest.java
index e5cb1d2224..c17c916a94 100644
--- a/processing/src/test/java/org/apache/druid/frame/processor/DurableStorageOutputChannelFactoryTest.java
+++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/shuffle/DurableStorageOutputChannelFactoryTest.java
@@ -17,8 +17,9 @@
  * under the License.
  */
 
-package org.apache.druid.frame.processor;
+package org.apache.druid.msq.shuffle;
 
+import org.apache.druid.frame.processor.OutputChannelFactoryTest;
 import org.apache.druid.storage.local.LocalFileStorageConnector;
 import org.junit.ClassRule;
 import org.junit.rules.TemporaryFolder;
diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/statistics/ClusterByStatisticsCollectorImplTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/statistics/ClusterByStatisticsCollectorImplTest.java
index 005a7a5bfa..70543e3dbd 100644
--- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/statistics/ClusterByStatisticsCollectorImplTest.java
+++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/statistics/ClusterByStatisticsCollectorImplTest.java
@@ -28,10 +28,11 @@ import com.google.common.math.LongMath;
 import org.apache.druid.frame.key.ClusterBy;
 import org.apache.druid.frame.key.ClusterByPartition;
 import org.apache.druid.frame.key.ClusterByPartitions;
+import org.apache.druid.frame.key.KeyColumn;
+import org.apache.druid.frame.key.KeyOrder;
 import org.apache.druid.frame.key.KeyTestUtils;
 import org.apache.druid.frame.key.RowKey;
 import org.apache.druid.frame.key.RowKeyReader;
-import org.apache.druid.frame.key.SortColumn;
 import org.apache.druid.java.util.common.ISE;
 import org.apache.druid.java.util.common.Pair;
 import org.apache.druid.java.util.common.StringUtils;
@@ -73,16 +74,20 @@ public class ClusterByStatisticsCollectorImplTest extends InitializedNullHandlin
                                                             .build();
 
   private static final ClusterBy CLUSTER_BY_X = new ClusterBy(
-      ImmutableList.of(new SortColumn("x", false)),
+      ImmutableList.of(new KeyColumn("x", KeyOrder.ASCENDING)),
       0
   );
 
   private static final ClusterBy CLUSTER_BY_XY_BUCKET_BY_X = new ClusterBy(
-      ImmutableList.of(new SortColumn("x", false), new SortColumn("y", false)),
+      ImmutableList.of(new KeyColumn("x", KeyOrder.ASCENDING), new KeyColumn("y", KeyOrder.ASCENDING)),
       1
   );
   private static final ClusterBy CLUSTER_BY_XYZ_BUCKET_BY_X = new ClusterBy(
-      ImmutableList.of(new SortColumn("x", false), new SortColumn("y", false), new SortColumn("z", false)),
+      ImmutableList.of(
+          new KeyColumn("x", KeyOrder.ASCENDING),
+          new KeyColumn("y", KeyOrder.ASCENDING),
+          new KeyColumn("z", KeyOrder.ASCENDING)
+      ),
       1
   );
 
diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/statistics/DelegateOrMinKeyCollectorTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/statistics/DelegateOrMinKeyCollectorTest.java
index bf27234ecc..c5efeaa039 100644
--- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/statistics/DelegateOrMinKeyCollectorTest.java
+++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/statistics/DelegateOrMinKeyCollectorTest.java
@@ -22,9 +22,10 @@ package org.apache.druid.msq.statistics;
 import com.google.common.collect.ImmutableList;
 import org.apache.druid.common.config.NullHandling;
 import org.apache.druid.frame.key.ClusterBy;
+import org.apache.druid.frame.key.KeyColumn;
+import org.apache.druid.frame.key.KeyOrder;
 import org.apache.druid.frame.key.KeyTestUtils;
 import org.apache.druid.frame.key.RowKey;
-import org.apache.druid.frame.key.SortColumn;
 import org.apache.druid.java.util.common.ISE;
 import org.apache.druid.segment.column.ColumnType;
 import org.apache.druid.segment.column.RowSignature;
@@ -38,7 +39,7 @@ import java.util.NoSuchElementException;
 
 public class DelegateOrMinKeyCollectorTest
 {
-  private final ClusterBy clusterBy = new ClusterBy(ImmutableList.of(new SortColumn("x", false)), 0);
+  private final ClusterBy clusterBy = new ClusterBy(ImmutableList.of(new KeyColumn("x", KeyOrder.ASCENDING)), 0);
   private final RowSignature signature = RowSignature.builder().add("x", ColumnType.LONG).build();
   private final Comparator<RowKey> comparator = clusterBy.keyComparator();
 
diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/statistics/DistinctKeyCollectorTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/statistics/DistinctKeyCollectorTest.java
index 6d3622612d..9ac750958c 100644
--- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/statistics/DistinctKeyCollectorTest.java
+++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/statistics/DistinctKeyCollectorTest.java
@@ -24,8 +24,9 @@ import org.apache.druid.common.config.NullHandling;
 import org.apache.druid.frame.key.ClusterBy;
 import org.apache.druid.frame.key.ClusterByPartition;
 import org.apache.druid.frame.key.ClusterByPartitions;
+import org.apache.druid.frame.key.KeyColumn;
+import org.apache.druid.frame.key.KeyOrder;
 import org.apache.druid.frame.key.RowKey;
-import org.apache.druid.frame.key.SortColumn;
 import org.apache.druid.java.util.common.Pair;
 import org.hamcrest.MatcherAssert;
 import org.hamcrest.Matchers;
@@ -40,7 +41,7 @@ import java.util.NoSuchElementException;
 
 public class DistinctKeyCollectorTest
 {
-  private final ClusterBy clusterBy = new ClusterBy(ImmutableList.of(new SortColumn("x", false)), 0);
+  private final ClusterBy clusterBy = new ClusterBy(ImmutableList.of(new KeyColumn("x", KeyOrder.ASCENDING)), 0);
   private final Comparator<RowKey> comparator = clusterBy.keyComparator();
   private final int numKeys = 500_000;
 
diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/statistics/KeyCollectorTestUtils.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/statistics/KeyCollectorTestUtils.java
index f62eaeb58b..3ec0d3920a 100644
--- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/statistics/KeyCollectorTestUtils.java
+++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/statistics/KeyCollectorTestUtils.java
@@ -23,9 +23,10 @@ import com.google.common.collect.ImmutableList;
 import com.google.common.collect.Lists;
 import org.apache.druid.frame.key.ClusterBy;
 import org.apache.druid.frame.key.ClusterByPartitions;
+import org.apache.druid.frame.key.KeyColumn;
+import org.apache.druid.frame.key.KeyOrder;
 import org.apache.druid.frame.key.KeyTestUtils;
 import org.apache.druid.frame.key.RowKey;
-import org.apache.druid.frame.key.SortColumn;
 import org.apache.druid.java.util.common.Pair;
 import org.apache.druid.java.util.common.StringUtils;
 import org.apache.druid.segment.column.ColumnType;
@@ -289,8 +290,8 @@ public class KeyCollectorTestUtils
   private static RowKey createSingleLongKey(final long n)
   {
     final RowSignature signature = RowSignature.builder().add("x", ColumnType.LONG).build();
-    final List<SortColumn> sortColumns = ImmutableList.of(new SortColumn("x", false));
-    final RowSignature keySignature = KeyTestUtils.createKeySignature(sortColumns, signature);
+    final List<KeyColumn> keyColumns = ImmutableList.of(new KeyColumn("x", KeyOrder.ASCENDING));
+    final RowSignature keySignature = KeyTestUtils.createKeySignature(keyColumns, signature);
     return KeyTestUtils.createKey(keySignature, n);
   }
 }
diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/statistics/QuantilesSketchKeyCollectorTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/statistics/QuantilesSketchKeyCollectorTest.java
index 79b8b06b70..64288e1b2f 100644
--- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/statistics/QuantilesSketchKeyCollectorTest.java
+++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/statistics/QuantilesSketchKeyCollectorTest.java
@@ -24,9 +24,10 @@ import org.apache.druid.common.config.NullHandling;
 import org.apache.druid.frame.key.ClusterBy;
 import org.apache.druid.frame.key.ClusterByPartition;
 import org.apache.druid.frame.key.ClusterByPartitions;
+import org.apache.druid.frame.key.KeyColumn;
+import org.apache.druid.frame.key.KeyOrder;
 import org.apache.druid.frame.key.KeyTestUtils;
 import org.apache.druid.frame.key.RowKey;
-import org.apache.druid.frame.key.SortColumn;
 import org.apache.druid.java.util.common.Pair;
 import org.apache.druid.segment.column.ColumnType;
 import org.apache.druid.segment.column.RowSignature;
@@ -41,7 +42,7 @@ import java.util.NoSuchElementException;
 
 public class QuantilesSketchKeyCollectorTest
 {
-  private final ClusterBy clusterBy = new ClusterBy(ImmutableList.of(new SortColumn("x", false)), 0);
+  private final ClusterBy clusterBy = new ClusterBy(ImmutableList.of(new KeyColumn("x", KeyOrder.ASCENDING)), 0);
   private final Comparator<RowKey> comparator = clusterBy.keyComparator();
   private final int numKeys = 500_000;
 
@@ -167,12 +168,12 @@ public class QuantilesSketchKeyCollectorTest
   @Test
   public void testAverageKeyLength()
   {
-    final QuantilesSketchKeyCollector collector = QuantilesSketchKeyCollectorFactory.create(clusterBy).newKeyCollector();
-
+    final QuantilesSketchKeyCollector collector =
+        QuantilesSketchKeyCollectorFactory.create(clusterBy).newKeyCollector();
     final QuantilesSketchKeyCollector other = QuantilesSketchKeyCollectorFactory.create(clusterBy).newKeyCollector();
 
     RowSignature smallKeySignature = KeyTestUtils.createKeySignature(
-        new ClusterBy(ImmutableList.of(new SortColumn("x", false)), 0).getColumns(),
+        new ClusterBy(ImmutableList.of(new KeyColumn("x", KeyOrder.ASCENDING)), 0).getColumns(),
         RowSignature.builder().add("x", ColumnType.LONG).build()
     );
     RowKey smallKey = KeyTestUtils.createKey(smallKeySignature, 1L);
@@ -180,11 +181,12 @@ public class QuantilesSketchKeyCollectorTest
     RowSignature largeKeySignature = KeyTestUtils.createKeySignature(
         new ClusterBy(
             ImmutableList.of(
-                new SortColumn("x", false),
-                new SortColumn("y", false),
-                new SortColumn("z", false)
+                new KeyColumn("x", KeyOrder.ASCENDING),
+                new KeyColumn("y", KeyOrder.ASCENDING),
+                new KeyColumn("z", KeyOrder.ASCENDING)
             ),
-            0).getColumns(),
+            0
+        ).getColumns(),
         RowSignature.builder()
                     .add("x", ColumnType.LONG)
                     .add("y", ColumnType.LONG)
@@ -201,7 +203,11 @@ public class QuantilesSketchKeyCollectorTest
     Assert.assertEquals(largeKey.estimatedObjectSizeBytes(), other.getAverageKeyLength(), 0);
 
     collector.addAll(other);
-    Assert.assertEquals((smallKey.estimatedObjectSizeBytes() * 3 + largeKey.estimatedObjectSizeBytes() * 5) / 8.0, collector.getAverageKeyLength(), 0);
+    Assert.assertEquals(
+        (smallKey.estimatedObjectSizeBytes() * 3 + largeKey.estimatedObjectSizeBytes() * 5) / 8.0,
+        collector.getAverageKeyLength(),
+        0
+    );
   }
 
   @Test
diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteSelectQueryTestMSQ.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteSelectQueryTestMSQ.java
index 1042387020..9d31d4f206 100644
--- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteSelectQueryTestMSQ.java
+++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteSelectQueryTestMSQ.java
@@ -39,8 +39,6 @@ import org.junit.Ignore;
  */
 public class CalciteSelectQueryTestMSQ extends CalciteQueryTest
 {
-
-  private MSQTestOverlordServiceClient indexingServiceClient;
   private TestGroupByBuffers groupByBuffers;
 
   @Before
@@ -76,9 +74,10 @@ public class CalciteSelectQueryTestMSQ extends CalciteQueryTest
             2,
             10,
             2,
+            0,
             0
         );
-    indexingServiceClient = new MSQTestOverlordServiceClient(
+    final MSQTestOverlordServiceClient indexingServiceClient = new MSQTestOverlordServiceClient(
         queryJsonMapper,
         injector,
         new MSQTestTaskActionClient(queryJsonMapper),
diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/LimitedFrameWriterFactory.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/LimitedFrameWriterFactory.java
new file mode 100644
index 0000000000..6bde53957f
--- /dev/null
+++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/LimitedFrameWriterFactory.java
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.msq.test;
+
+import org.apache.datasketches.memory.WritableMemory;
+import org.apache.druid.frame.FrameType;
+import org.apache.druid.frame.write.FrameWriter;
+import org.apache.druid.frame.write.FrameWriterFactory;
+import org.apache.druid.segment.ColumnSelectorFactory;
+import org.apache.druid.segment.column.RowSignature;
+
+public class LimitedFrameWriterFactory implements FrameWriterFactory
+{
+  private final FrameWriterFactory baseFactory;
+  private final int rowLimit;
+
+  /**
+   * Wraps a {@link FrameWriterFactory}, creating a new factory that returns {@link FrameWriter} which write
+   * a limited number of rows.
+   */
+  public LimitedFrameWriterFactory(FrameWriterFactory baseFactory, int rowLimit)
+  {
+    this.baseFactory = baseFactory;
+    this.rowLimit = rowLimit;
+  }
+
+  @Override
+  public FrameWriter newFrameWriter(ColumnSelectorFactory columnSelectorFactory)
+  {
+    return new LimitedFrameWriter(baseFactory.newFrameWriter(columnSelectorFactory), rowLimit);
+  }
+
+  @Override
+  public long allocatorCapacity()
+  {
+    return baseFactory.allocatorCapacity();
+  }
+
+  @Override
+  public RowSignature signature()
+  {
+    return baseFactory.signature();
+  }
+
+  @Override
+  public FrameType frameType()
+  {
+    return baseFactory.frameType();
+  }
+
+  private static class LimitedFrameWriter implements FrameWriter
+  {
+    private final FrameWriter baseWriter;
+    private final int rowLimit;
+
+    public LimitedFrameWriter(FrameWriter baseWriter, int rowLimit)
+    {
+      this.baseWriter = baseWriter;
+      this.rowLimit = rowLimit;
+    }
+
+    @Override
+    public boolean addSelection()
+    {
+      if (baseWriter.getNumRows() >= rowLimit) {
+        return false;
+      } else {
+        return baseWriter.addSelection();
+      }
+    }
+
+    @Override
+    public int getNumRows()
+    {
+      return baseWriter.getNumRows();
+    }
+
+    @Override
+    public long getTotalSize()
+    {
+      return baseWriter.getTotalSize();
+    }
+
+    @Override
+    public long writeTo(WritableMemory memory, long position)
+    {
+      return baseWriter.writeTo(memory, position);
+    }
+
+    @Override
+    public void close()
+    {
+      baseWriter.close();
+    }
+  }
+
+}
diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java
index 50da3c1f75..a165eb3822 100644
--- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java
+++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java
@@ -286,6 +286,7 @@ public class MSQTestBase extends BaseCalciteQueryTest
           2,
           10,
           2,
+          1,
           0
       )
   );
diff --git a/processing/src/main/java/org/apache/druid/frame/write/FrameWriterFactory.java b/processing/src/main/java/org/apache/druid/frame/allocation/ArenaMemoryAllocatorFactory.java
similarity index 61%
copy from processing/src/main/java/org/apache/druid/frame/write/FrameWriterFactory.java
copy to processing/src/main/java/org/apache/druid/frame/allocation/ArenaMemoryAllocatorFactory.java
index 2df1c8f3d1..80d4dcdadd 100644
--- a/processing/src/main/java/org/apache/druid/frame/write/FrameWriterFactory.java
+++ b/processing/src/main/java/org/apache/druid/frame/allocation/ArenaMemoryAllocatorFactory.java
@@ -17,19 +17,29 @@
  * under the License.
  */
 
-package org.apache.druid.frame.write;
-
-import org.apache.druid.segment.ColumnSelectorFactory;
+package org.apache.druid.frame.allocation;
 
 /**
- * Interface for creating {@link FrameWriter}.
+ * Creates {@link ArenaMemoryAllocator} on each call to {@link #newAllocator()}.
  */
-public interface FrameWriterFactory
+public class ArenaMemoryAllocatorFactory implements MemoryAllocatorFactory
 {
-  /**
-   * Create a writer where {@link FrameWriter#addSelection()} adds the current row from a {@link ColumnSelectorFactory}.
-   */
-  FrameWriter newFrameWriter(ColumnSelectorFactory columnSelectorFactory);
+  private final int capacity;
+
+  public ArenaMemoryAllocatorFactory(final int capacity)
+  {
+    this.capacity = capacity;
+  }
+
+  @Override
+  public MemoryAllocator newAllocator()
+  {
+    return ArenaMemoryAllocator.createOnHeap(capacity);
+  }
 
-  long allocatorCapacity();
+  @Override
+  public long allocatorCapacity()
+  {
+    return capacity;
+  }
 }
diff --git a/processing/src/main/java/org/apache/druid/frame/write/FrameWriterFactory.java b/processing/src/main/java/org/apache/druid/frame/allocation/MemoryAllocatorFactory.java
similarity index 65%
copy from processing/src/main/java/org/apache/druid/frame/write/FrameWriterFactory.java
copy to processing/src/main/java/org/apache/druid/frame/allocation/MemoryAllocatorFactory.java
index 2df1c8f3d1..edb74ef19f 100644
--- a/processing/src/main/java/org/apache/druid/frame/write/FrameWriterFactory.java
+++ b/processing/src/main/java/org/apache/druid/frame/allocation/MemoryAllocatorFactory.java
@@ -17,19 +17,23 @@
  * under the License.
  */
 
-package org.apache.druid.frame.write;
-
-import org.apache.druid.segment.ColumnSelectorFactory;
+package org.apache.druid.frame.allocation;
 
 /**
- * Interface for creating {@link FrameWriter}.
+ * Factory for {@link MemoryAllocator}.
+ *
+ * Used by {@link org.apache.druid.frame.write.FrameWriters#makeFrameWriterFactory} to create
+ * {@link org.apache.druid.frame.write.FrameWriterFactory}.
  */
-public interface FrameWriterFactory
+public interface MemoryAllocatorFactory
 {
   /**
-   * Create a writer where {@link FrameWriter#addSelection()} adds the current row from a {@link ColumnSelectorFactory}.
+   * Returns a new allocator with capacity {@link #allocatorCapacity()}.
    */
-  FrameWriter newFrameWriter(ColumnSelectorFactory columnSelectorFactory);
+  MemoryAllocator newAllocator();
 
+  /**
+   * Capacity of allocators returned by {@link #newAllocator()}.
+   */
   long allocatorCapacity();
 }
diff --git a/processing/src/main/java/org/apache/druid/frame/allocation/MemoryRange.java b/processing/src/main/java/org/apache/druid/frame/allocation/MemoryRange.java
index f464e4155d..656f6446ff 100644
--- a/processing/src/main/java/org/apache/druid/frame/allocation/MemoryRange.java
+++ b/processing/src/main/java/org/apache/druid/frame/allocation/MemoryRange.java
@@ -22,8 +22,8 @@ package org.apache.druid.frame.allocation;
 import org.apache.datasketches.memory.Memory;
 
 /**
- * Reference to a particular region of some {@link Memory}. This is used because it is cheaper to create than
- * calling {@link Memory#region}.
+ * Reference to a particular region of some {@link Memory}. This is used because it is cheaper to reuse this object
+ * rather than calling {@link Memory#region} for each row.
  *
  * Not immutable. The pointed-to range may change as this object gets reused.
  */
@@ -39,8 +39,8 @@ public class MemoryRange<T extends Memory>
   }
 
   /**
-   * Returns the underlying memory *without* clipping it to this particular range. Callers must remember to continue
-   * applying the offset given by {@link #start} and capacity given by {@link #length}.
+   * Returns the underlying memory *without* clipping it to this particular range. Callers must apply the offset
+   * given by {@link #start} and capacity given by {@link #length}.
    */
   public T memory()
   {
diff --git a/processing/src/main/java/org/apache/druid/frame/allocation/SingleMemoryAllocatorFactory.java b/processing/src/main/java/org/apache/druid/frame/allocation/SingleMemoryAllocatorFactory.java
new file mode 100644
index 0000000000..94ba574987
--- /dev/null
+++ b/processing/src/main/java/org/apache/druid/frame/allocation/SingleMemoryAllocatorFactory.java
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.frame.allocation;
+
+import org.apache.druid.java.util.common.ISE;
+
+/**
+ * Wraps a single {@link MemoryAllocator}.
+ *
+ * The same instance is returned on each call to {@link #newAllocator()}, after validating that it is 100% free.
+ * Calling {@link #newAllocator()} before freeing all previously-allocated memory leads to an IllegalStateException.
+ */
+public class SingleMemoryAllocatorFactory implements MemoryAllocatorFactory
+{
+  private final MemoryAllocator allocator;
+  private final long capacity;
+
+  public SingleMemoryAllocatorFactory(final MemoryAllocator allocator)
+  {
+    this.allocator = allocator;
+    this.capacity = allocator.capacity();
+  }
+
+  @Override
+  public MemoryAllocator newAllocator()
+  {
+    // Allocators are reused, which allows each call to "newAllocator" to use the same arena (if it's arena-based).
+    // Just need to validate that it has actually been closed out prior to handing it to someone else.
+
+    if (allocator.available() != allocator.capacity()) {
+      throw new ISE("Allocator in use");
+    }
+
+    return allocator;
+  }
+
+  @Override
+  public long allocatorCapacity()
+  {
+    return capacity;
+  }
+}
diff --git a/processing/src/main/java/org/apache/druid/frame/channel/BlockingQueueFrameChannel.java b/processing/src/main/java/org/apache/druid/frame/channel/BlockingQueueFrameChannel.java
index 1b0aa009a3..9010f620e5 100644
--- a/processing/src/main/java/org/apache/druid/frame/channel/BlockingQueueFrameChannel.java
+++ b/processing/src/main/java/org/apache/druid/frame/channel/BlockingQueueFrameChannel.java
@@ -172,8 +172,6 @@ public class BlockingQueueFrameChannel
           // If this happens, it's a bug, potentially due to incorrectly using this class with multiple writers.
           throw new ISE("Could not write error to channel");
         }
-
-        close();
       }
     }
 
@@ -181,8 +179,8 @@ public class BlockingQueueFrameChannel
     public void close()
     {
       synchronized (lock) {
-        if (isFinished()) {
-          throw new ISE("Already done");
+        if (isClosed()) {
+          throw new ISE("Already closed");
         }
 
         if (!queue.offer(END_MARKER)) {
@@ -193,6 +191,15 @@ public class BlockingQueueFrameChannel
         notifyReader();
       }
     }
+
+    @Override
+    public boolean isClosed()
+    {
+      synchronized (lock) {
+        final Optional<Either<Throwable, FrameWithPartition>> lastElement = queue.peekLast();
+        return lastElement != null && END_MARKER.equals(lastElement);
+      }
+    }
   }
 
   private class Readable implements ReadableFrameChannel
diff --git a/processing/src/main/java/org/apache/druid/frame/channel/ComposingWritableFrameChannel.java b/processing/src/main/java/org/apache/druid/frame/channel/ComposingWritableFrameChannel.java
index b9ca905ed2..7f2a61e437 100644
--- a/processing/src/main/java/org/apache/druid/frame/channel/ComposingWritableFrameChannel.java
+++ b/processing/src/main/java/org/apache/druid/frame/channel/ComposingWritableFrameChannel.java
@@ -92,9 +92,16 @@ public class ComposingWritableFrameChannel implements WritableFrameChannel
   {
     if (currentIndex < channels.size()) {
       channels.get(currentIndex).get().close();
+      currentIndex = channels.size();
     }
   }
 
+  @Override
+  public boolean isClosed()
+  {
+    return currentIndex == channels.size();
+  }
+
   @Override
   public ListenableFuture<?> writabilityFuture()
   {
diff --git a/processing/src/main/java/org/apache/druid/frame/channel/WritableFrameChannel.java b/processing/src/main/java/org/apache/druid/frame/channel/WritableFrameChannel.java
index ccb0166acc..bb1527305f 100644
--- a/processing/src/main/java/org/apache/druid/frame/channel/WritableFrameChannel.java
+++ b/processing/src/main/java/org/apache/druid/frame/channel/WritableFrameChannel.java
@@ -73,6 +73,11 @@ public interface WritableFrameChannel extends Closeable
   @Override
   void close() throws IOException;
 
+  /**
+   * Whether {@link #close()} has been called on this channel.
+   */
+  boolean isClosed();
+
   /**
    * Returns a future that resolves when {@link #write} is able to receive a new frame without blocking or throwing
    * an exception. The future never resolves to an exception.
diff --git a/processing/src/main/java/org/apache/druid/frame/channel/WritableFrameFileChannel.java b/processing/src/main/java/org/apache/druid/frame/channel/WritableFrameFileChannel.java
index 8ed0a5b721..6463abe974 100644
--- a/processing/src/main/java/org/apache/druid/frame/channel/WritableFrameFileChannel.java
+++ b/processing/src/main/java/org/apache/druid/frame/channel/WritableFrameFileChannel.java
@@ -32,6 +32,7 @@ import java.io.IOException;
 public class WritableFrameFileChannel implements WritableFrameChannel
 {
   private final FrameFileWriter writer;
+  private boolean closed;
 
   public WritableFrameFileChannel(final FrameFileWriter writer)
   {
@@ -55,6 +56,13 @@ public class WritableFrameFileChannel implements WritableFrameChannel
   public void close() throws IOException
   {
     writer.close();
+    closed = true;
+  }
+
+  @Override
+  public boolean isClosed()
+  {
+    return closed;
   }
 
   @Override
diff --git a/processing/src/main/java/org/apache/druid/frame/field/ComplexFieldReader.java b/processing/src/main/java/org/apache/druid/frame/field/ComplexFieldReader.java
index 64ce17e06a..045835cbb8 100644
--- a/processing/src/main/java/org/apache/druid/frame/field/ComplexFieldReader.java
+++ b/processing/src/main/java/org/apache/druid/frame/field/ComplexFieldReader.java
@@ -83,6 +83,12 @@ public class ComplexFieldReader implements FieldReader
     return DimensionSelector.constant(null, extractionFn);
   }
 
+  @Override
+  public boolean isNull(Memory memory, long position)
+  {
+    return memory.getByte(position) == ComplexFieldWriter.NULL_BYTE;
+  }
+
   @Override
   public boolean isComparable()
   {
diff --git a/processing/src/main/java/org/apache/druid/frame/field/DoubleFieldReader.java b/processing/src/main/java/org/apache/druid/frame/field/DoubleFieldReader.java
index 3e5c04e6cb..a2a3d576ba 100644
--- a/processing/src/main/java/org/apache/druid/frame/field/DoubleFieldReader.java
+++ b/processing/src/main/java/org/apache/druid/frame/field/DoubleFieldReader.java
@@ -66,6 +66,12 @@ public class DoubleFieldReader implements FieldReader
     );
   }
 
+  @Override
+  public boolean isNull(Memory memory, long position)
+  {
+    return memory.getByte(position) == DoubleFieldWriter.NULL_BYTE;
+  }
+
   @Override
   public boolean isComparable()
   {
diff --git a/processing/src/main/java/org/apache/druid/frame/field/FieldReader.java b/processing/src/main/java/org/apache/druid/frame/field/FieldReader.java
index c6f9c61312..b0dfdfa499 100644
--- a/processing/src/main/java/org/apache/druid/frame/field/FieldReader.java
+++ b/processing/src/main/java/org/apache/druid/frame/field/FieldReader.java
@@ -52,6 +52,11 @@ public interface FieldReader
       @Nullable ExtractionFn extractionFn
   );
 
+  /**
+   * Whether the provided memory position points to a null value.
+   */
+  boolean isNull(Memory memory, long position);
+
   /**
    * Whether this field is comparable. Comparable fields can be compared as unsigned bytes.
    */
diff --git a/processing/src/main/java/org/apache/druid/frame/field/FloatFieldReader.java b/processing/src/main/java/org/apache/druid/frame/field/FloatFieldReader.java
index 0d29a66f15..7c4b59cdc2 100644
--- a/processing/src/main/java/org/apache/druid/frame/field/FloatFieldReader.java
+++ b/processing/src/main/java/org/apache/druid/frame/field/FloatFieldReader.java
@@ -66,6 +66,12 @@ public class FloatFieldReader implements FieldReader
     );
   }
 
+  @Override
+  public boolean isNull(Memory memory, long position)
+  {
+    return memory.getByte(position) == FloatFieldWriter.NULL_BYTE;
+  }
+
   @Override
   public boolean isComparable()
   {
diff --git a/processing/src/main/java/org/apache/druid/frame/field/LongFieldReader.java b/processing/src/main/java/org/apache/druid/frame/field/LongFieldReader.java
index 7cdcea61b2..0d7f88a3a0 100644
--- a/processing/src/main/java/org/apache/druid/frame/field/LongFieldReader.java
+++ b/processing/src/main/java/org/apache/druid/frame/field/LongFieldReader.java
@@ -66,6 +66,12 @@ public class LongFieldReader implements FieldReader
     );
   }
 
+  @Override
+  public boolean isNull(Memory memory, long position)
+  {
+    return memory.getByte(position) == LongFieldWriter.NULL_BYTE;
+  }
+
   @Override
   public boolean isComparable()
   {
diff --git a/processing/src/main/java/org/apache/druid/frame/field/StringFieldReader.java b/processing/src/main/java/org/apache/druid/frame/field/StringFieldReader.java
index 204ecceef1..4ddbc6698c 100644
--- a/processing/src/main/java/org/apache/druid/frame/field/StringFieldReader.java
+++ b/processing/src/main/java/org/apache/druid/frame/field/StringFieldReader.java
@@ -87,6 +87,16 @@ public class StringFieldReader implements FieldReader
     return new Selector(memory, fieldPointer, extractionFn, false);
   }
 
+  @Override
+  public boolean isNull(Memory memory, long position)
+  {
+    final byte nullByte = memory.getByte(position);
+    assert nullByte == StringFieldWriter.NULL_BYTE || nullByte == StringFieldWriter.NOT_NULL_BYTE;
+    return nullByte == StringFieldWriter.NULL_BYTE
+           && memory.getByte(position + 1) == StringFieldWriter.VALUE_TERMINATOR
+           && memory.getByte(position + 2) == StringFieldWriter.ROW_TERMINATOR;
+  }
+
   @Override
   public boolean isComparable()
   {
diff --git a/processing/src/main/java/org/apache/druid/frame/key/ByteRowKeyComparator.java b/processing/src/main/java/org/apache/druid/frame/key/ByteRowKeyComparator.java
index 5e00946736..e7b7e2871d 100644
--- a/processing/src/main/java/org/apache/druid/frame/key/ByteRowKeyComparator.java
+++ b/processing/src/main/java/org/apache/druid/frame/key/ByteRowKeyComparator.java
@@ -23,6 +23,7 @@ import com.google.common.primitives.Ints;
 import it.unimi.dsi.fastutil.ints.IntArrayList;
 import it.unimi.dsi.fastutil.ints.IntList;
 import org.apache.druid.frame.read.FrameReaderUtils;
+import org.apache.druid.java.util.common.IAE;
 
 import java.util.Arrays;
 import java.util.Comparator;
@@ -49,7 +50,7 @@ public class ByteRowKeyComparator implements Comparator<byte[]>
     this.ascDescRunLengths = ascDescRunLengths;
   }
 
-  public static ByteRowKeyComparator create(final List<SortColumn> keyColumns)
+  public static ByteRowKeyComparator create(final List<KeyColumn> keyColumns)
   {
     return new ByteRowKeyComparator(
         computeFirstFieldPosition(keyColumns.size()),
@@ -74,18 +75,24 @@ public class ByteRowKeyComparator implements Comparator<byte[]>
    *
    * Public so {@link FrameComparisonWidgetImpl} can use it.
    */
-  public static int[] computeAscDescRunLengths(final List<SortColumn> sortColumns)
+  public static int[] computeAscDescRunLengths(final List<KeyColumn> keyColumns)
   {
     final IntList ascDescRunLengths = new IntArrayList(4);
 
-    boolean descending = false;
+    KeyOrder order = KeyOrder.ASCENDING;
     int runLength = 0;
 
-    for (final SortColumn column : sortColumns) {
-      if (column.descending() != descending) {
+    for (final KeyColumn column : keyColumns) {
+      if (column.order() == KeyOrder.NONE) {
+        throw new IAE("Key must be sortable");
+      }
+
+      if (column.order() != order) {
         ascDescRunLengths.add(runLength);
         runLength = 0;
-        descending = !descending;
+
+        // Invert "order".
+        order = order == KeyOrder.ASCENDING ? KeyOrder.DESCENDING : KeyOrder.ASCENDING;
       }
 
       runLength++;
diff --git a/processing/src/main/java/org/apache/druid/frame/key/ClusterBy.java b/processing/src/main/java/org/apache/druid/frame/key/ClusterBy.java
index 367802e000..3be8cf22cb 100644
--- a/processing/src/main/java/org/apache/druid/frame/key/ClusterBy.java
+++ b/processing/src/main/java/org/apache/druid/frame/key/ClusterBy.java
@@ -43,12 +43,13 @@ import java.util.Objects;
  */
 public class ClusterBy
 {
-  private final List<SortColumn> columns;
+  private final List<KeyColumn> columns;
   private final int bucketByCount;
+  private final boolean sortable;
 
   @JsonCreator
   public ClusterBy(
-      @JsonProperty("columns") List<SortColumn> columns,
+      @JsonProperty("columns") List<KeyColumn> columns,
       @JsonProperty("bucketByCount") int bucketByCount
   )
   {
@@ -58,6 +59,21 @@ public class ClusterBy
     if (bucketByCount < 0 || bucketByCount > columns.size()) {
       throw new IAE("Invalid bucketByCount [%d]", bucketByCount);
     }
+
+    // Key must be 100% sortable or 100% nonsortable. If empty, call it sortable.
+    boolean sortable = true;
+
+    for (int i = 0; i < columns.size(); i++) {
+      final KeyColumn column = columns.get(i);
+
+      if (i == 0) {
+        sortable = column.order().sortable();
+      } else if (sortable != column.order().sortable()) {
+        throw new IAE("Cannot mix sortable and unsortable key columns");
+      }
+    }
+
+    this.sortable = sortable;
   }
 
   /**
@@ -72,7 +88,7 @@ public class ClusterBy
    * The columns that comprise this key, in order.
    */
   @JsonProperty
-  public List<SortColumn> getColumns()
+  public List<KeyColumn> getColumns()
   {
     return columns;
   }
@@ -86,7 +102,7 @@ public class ClusterBy
    *
    * Will always be less than, or equal to, the size of {@link #getColumns()}.
    *
-   * Not relevant when a ClusterBy instance is used as an ordering key rather than a partitioning key.
+   * Only relevant when a ClusterBy instance is used as a partitioning key.
    */
   @JsonProperty
   @JsonInclude(JsonInclude.Include.NON_DEFAULT)
@@ -95,6 +111,22 @@ public class ClusterBy
     return bucketByCount;
   }
 
+  /**
+   * Whether this key is empty.
+   */
+  public boolean isEmpty()
+  {
+    return columns.isEmpty();
+  }
+
+  /**
+   * Whether this key is sortable. Empty keys (with no columns) are considered sortable.
+   */
+  public boolean sortable()
+  {
+    return sortable;
+  }
+
   /**
    * Create a reader for keys for this instance.
    *
@@ -105,8 +137,8 @@ public class ClusterBy
   {
     final RowSignature.Builder newSignature = RowSignature.builder();
 
-    for (final SortColumn sortColumn : columns) {
-      final String columnName = sortColumn.columnName();
+    for (final KeyColumn keyColumn : columns) {
+      final String columnName = keyColumn.columnName();
       final ColumnCapabilities capabilities = inspector.getColumnCapabilities(columnName);
       final ColumnType columnType =
           Preconditions.checkNotNull(capabilities, "Type for column [%s]", columnName).toColumnType();
diff --git a/processing/src/main/java/org/apache/druid/frame/key/FrameComparisonWidget.java b/processing/src/main/java/org/apache/druid/frame/key/FrameComparisonWidget.java
index 031f988f18..e495b7f517 100644
--- a/processing/src/main/java/org/apache/druid/frame/key/FrameComparisonWidget.java
+++ b/processing/src/main/java/org/apache/druid/frame/key/FrameComparisonWidget.java
@@ -34,6 +34,11 @@ public interface FrameComparisonWidget
    */
   RowKey readKey(int row);
 
+  /**
+   * Whether a particular row has a null field in its comparison key.
+   */
+  boolean isPartiallyNullKey(int row);
+
   /**
    * Compare a specific row of this frame to the provided key. The key must have been created with sortColumns
    * that match the ones used to create this widget, or else results are undefined.
@@ -42,7 +47,7 @@ public interface FrameComparisonWidget
 
   /**
    * Compare a specific row of this frame to a specific row of another frame. The other frame must have the same
-   * signature, or else results are undefined. The other frame may be the same object as this frame; for example,
+   * sort key, or else results are undefined. The other frame may be the same object as this frame; for example,
    * this is used by {@link org.apache.druid.frame.write.FrameSort} to sort frames in-place.
    */
   int compare(int row, FrameComparisonWidget otherWidget, int otherRow);
diff --git a/processing/src/main/java/org/apache/druid/frame/key/FrameComparisonWidgetImpl.java b/processing/src/main/java/org/apache/druid/frame/key/FrameComparisonWidgetImpl.java
index 0aeed47760..fe11c02755 100644
--- a/processing/src/main/java/org/apache/druid/frame/key/FrameComparisonWidgetImpl.java
+++ b/processing/src/main/java/org/apache/druid/frame/key/FrameComparisonWidgetImpl.java
@@ -23,10 +23,13 @@ import com.google.common.primitives.Ints;
 import org.apache.datasketches.memory.Memory;
 import org.apache.druid.frame.Frame;
 import org.apache.druid.frame.FrameType;
+import org.apache.druid.frame.field.FieldReader;
 import org.apache.druid.frame.read.FrameReader;
 import org.apache.druid.frame.read.FrameReaderUtils;
 import org.apache.druid.frame.write.FrameWriterUtils;
 import org.apache.druid.frame.write.RowBasedFrameWriter;
+import org.apache.druid.java.util.common.ISE;
+import org.apache.druid.segment.column.RowSignature;
 
 import java.util.List;
 
@@ -39,28 +42,30 @@ import java.util.List;
 public class FrameComparisonWidgetImpl implements FrameComparisonWidget
 {
   private final Frame frame;
-  private final FrameReader frameReader;
+  private final RowSignature signature;
   private final Memory rowOffsetRegion;
   private final Memory dataRegion;
   private final int keyFieldCount;
+  private final List<FieldReader> keyFieldReaders;
   private final long firstFieldPosition;
   private final int[] ascDescRunLengths;
 
   private FrameComparisonWidgetImpl(
       final Frame frame,
-      final FrameReader frameReader,
+      final RowSignature signature,
       final Memory rowOffsetRegion,
       final Memory dataRegion,
-      final int keyFieldCount,
+      final List<FieldReader> keyFieldReaders,
       final long firstFieldPosition,
       final int[] ascDescRunLengths
   )
   {
     this.frame = frame;
-    this.frameReader = frameReader;
+    this.signature = signature;
     this.rowOffsetRegion = rowOffsetRegion;
     this.dataRegion = dataRegion;
-    this.keyFieldCount = keyFieldCount;
+    this.keyFieldCount = keyFieldReaders.size();
+    this.keyFieldReaders = keyFieldReaders;
     this.firstFieldPosition = firstFieldPosition;
     this.ascDescRunLengths = ascDescRunLengths;
   }
@@ -68,41 +73,46 @@ public class FrameComparisonWidgetImpl implements FrameComparisonWidget
   /**
    * Create a {@link FrameComparisonWidget} for the given frame.
    *
-   * Only possible for frames of type {@link FrameType#ROW_BASED}. The provided sortColumns must be a prefix
+   * Only possible for frames of type {@link FrameType#ROW_BASED}. The provided keyColumns must be a prefix
    * of {@link FrameReader#signature()}.
    *
-   * @param frame       frame, must be {@link FrameType#ROW_BASED}
-   * @param frameReader reader for the frame
-   * @param sortColumns columns to sort by
+   * @param frame            frame, must be {@link FrameType#ROW_BASED}
+   * @param signature        signature for the frame
+   * @param keyColumns       columns to sort by
+   * @param keyColumnReaders readers for key columns
    */
   public static FrameComparisonWidgetImpl create(
       final Frame frame,
-      final FrameReader frameReader,
-      final List<SortColumn> sortColumns
+      final RowSignature signature,
+      final List<KeyColumn> keyColumns,
+      final List<FieldReader> keyColumnReaders
   )
   {
-    FrameWriterUtils.verifySortColumns(sortColumns, frameReader.signature());
+    FrameWriterUtils.verifySortColumns(keyColumns, signature);
+
+    if (keyColumnReaders.size() != keyColumns.size()) {
+      throw new ISE("Mismatched lengths for keyColumnReaders and keyColumns");
+    }
 
     return new FrameComparisonWidgetImpl(
         FrameType.ROW_BASED.ensureType(frame),
-        frameReader,
+        signature,
         frame.region(RowBasedFrameWriter.ROW_OFFSET_REGION),
         frame.region(RowBasedFrameWriter.ROW_DATA_REGION),
-        sortColumns.size(),
-        ByteRowKeyComparator.computeFirstFieldPosition(frameReader.signature().size()),
-        ByteRowKeyComparator.computeAscDescRunLengths(sortColumns)
+        keyColumnReaders,
+        ByteRowKeyComparator.computeFirstFieldPosition(signature.size()),
+        ByteRowKeyComparator.computeAscDescRunLengths(keyColumns)
     );
   }
 
   @Override
   public RowKey readKey(int row)
   {
-    final int keyFieldPointersEndInRow = keyFieldCount * Integer.BYTES;
-
     if (keyFieldCount == 0) {
       return RowKey.empty();
     }
 
+    final int keyFieldPointersEndInRow = keyFieldCount * Integer.BYTES;
     final long rowPosition = getRowPositionInDataRegion(row);
     final int keyEndInRow =
         dataRegion.getInt(rowPosition + (long) (keyFieldCount - 1) * Integer.BYTES);
@@ -110,7 +120,7 @@ public class FrameComparisonWidgetImpl implements FrameComparisonWidget
     final long keyLength = keyEndInRow - firstFieldPosition;
     final byte[] keyBytes = new byte[Ints.checkedCast(keyFieldPointersEndInRow + keyEndInRow - firstFieldPosition)];
 
-    final int headerSizeAdjustment = (frameReader.signature().size() - keyFieldCount) * Integer.BYTES;
+    final int headerSizeAdjustment = (signature.size() - keyFieldCount) * Integer.BYTES;
     for (int i = 0; i < keyFieldCount; i++) {
       final int fieldEndPosition = dataRegion.getInt(rowPosition + ((long) Integer.BYTES * i));
       final int adjustedFieldEndPosition = fieldEndPosition - headerSizeAdjustment;
@@ -127,6 +137,28 @@ public class FrameComparisonWidgetImpl implements FrameComparisonWidget
     return RowKey.wrap(keyBytes);
   }
 
+  @Override
+  public boolean isPartiallyNullKey(int row)
+  {
+    if (keyFieldCount == 0) {
+      return false;
+    }
+
+    final long rowPosition = getRowPositionInDataRegion(row);
+    long keyFieldPosition = rowPosition + (long) signature.size() * Integer.BYTES;
+
+    for (int i = 0; i < keyFieldCount; i++) {
+      final boolean isNull = keyFieldReaders.get(i).isNull(dataRegion, keyFieldPosition);
+      if (isNull) {
+        return true;
+      } else {
+        keyFieldPosition = rowPosition + dataRegion.getInt(rowPosition + (long) i * Integer.BYTES);
+      }
+    }
+
+    return false;
+  }
+
   @Override
   public int compare(int row, RowKey key)
   {
@@ -187,7 +219,7 @@ public class FrameComparisonWidgetImpl implements FrameComparisonWidget
     final long otherRowPosition = otherWidgetImpl.getRowPositionInDataRegion(otherRow);
 
     long comparableBytesStartPositionInRow = firstFieldPosition;
-    long otherComparableBytesStartPositionInRow = firstFieldPosition;
+    long otherComparableBytesStartPositionInRow = otherWidgetImpl.firstFieldPosition;
 
     boolean ascending = true;
     int field = 0;
@@ -240,8 +272,7 @@ public class FrameComparisonWidgetImpl implements FrameComparisonWidget
 
   long getFieldEndPositionInRow(final long rowPosition, final int fieldNumber)
   {
-    assert fieldNumber >= 0 && fieldNumber < frameReader.signature().size();
-
+    assert fieldNumber >= 0 && fieldNumber < signature.size();
     return dataRegion.getInt(rowPosition + (long) fieldNumber * Integer.BYTES);
   }
 
diff --git a/processing/src/main/java/org/apache/druid/frame/key/SortColumn.java b/processing/src/main/java/org/apache/druid/frame/key/KeyColumn.java
similarity index 73%
rename from processing/src/main/java/org/apache/druid/frame/key/SortColumn.java
rename to processing/src/main/java/org/apache/druid/frame/key/KeyColumn.java
index 2d9d53e47b..7811616477 100644
--- a/processing/src/main/java/org/apache/druid/frame/key/SortColumn.java
+++ b/processing/src/main/java/org/apache/druid/frame/key/KeyColumn.java
@@ -20,7 +20,6 @@
 package org.apache.druid.frame.key;
 
 import com.fasterxml.jackson.annotation.JsonCreator;
-import com.fasterxml.jackson.annotation.JsonInclude;
 import com.fasterxml.jackson.annotation.JsonProperty;
 import org.apache.druid.java.util.common.IAE;
 import org.apache.druid.java.util.common.StringUtils;
@@ -28,17 +27,17 @@ import org.apache.druid.java.util.common.StringUtils;
 import java.util.Objects;
 
 /**
- * Represents a component of an order-by key.
+ * Represents a component of a hash or sorting key.
  */
-public class SortColumn
+public class KeyColumn
 {
   private final String columnName;
-  private final boolean descending;
+  private final KeyOrder order;
 
   @JsonCreator
-  public SortColumn(
+  public KeyColumn(
       @JsonProperty("columnName") String columnName,
-      @JsonProperty("descending") boolean descending
+      @JsonProperty("order") KeyOrder order
   )
   {
     if (columnName == null || columnName.isEmpty()) {
@@ -46,7 +45,7 @@ public class SortColumn
     }
 
     this.columnName = columnName;
-    this.descending = descending;
+    this.order = order;
   }
 
   @JsonProperty
@@ -56,10 +55,9 @@ public class SortColumn
   }
 
   @JsonProperty
-  @JsonInclude(JsonInclude.Include.NON_DEFAULT)
-  public boolean descending()
+  public KeyOrder order()
   {
-    return descending;
+    return order;
   }
 
   @Override
@@ -71,19 +69,19 @@ public class SortColumn
     if (o == null || getClass() != o.getClass()) {
       return false;
     }
-    SortColumn that = (SortColumn) o;
-    return descending == that.descending && Objects.equals(columnName, that.columnName);
+    KeyColumn that = (KeyColumn) o;
+    return order == that.order && Objects.equals(columnName, that.columnName);
   }
 
   @Override
   public int hashCode()
   {
-    return Objects.hash(columnName, descending);
+    return Objects.hash(columnName, order);
   }
 
   @Override
   public String toString()
   {
-    return StringUtils.format("%s%s", columnName, descending ? " DESC" : "");
+    return StringUtils.format("%s%s", columnName, order == KeyOrder.NONE ? "" : " " + order);
   }
 }
diff --git a/processing/src/main/java/org/apache/druid/frame/key/KeyOrder.java b/processing/src/main/java/org/apache/druid/frame/key/KeyOrder.java
new file mode 100644
index 0000000000..fe568483bf
--- /dev/null
+++ b/processing/src/main/java/org/apache/druid/frame/key/KeyOrder.java
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.frame.key;
+
+/**
+ * Ordering associated with a {@link KeyColumn}.
+ */
+public enum KeyOrder
+{
+  /**
+   * Not ordered.
+   *
+   * Possible if the key is used only for non-sorting purposes, such as hashing without sorting.
+   */
+  NONE(false),
+
+  /**
+   * Ordered ascending.
+   *
+   * Note that sortable key order does not necessarily mean that we are using range-based partitioning. We may be
+   * using hash-based partitioning along with each partition internally being sorted by a key.
+   */
+  ASCENDING(true),
+
+  /**
+   * Ordered descending.
+   *
+   * Note that sortable key order does not necessarily mean that we are using range-based partitioning. We may be
+   * using hash-based partitioning along with each partition internally being sorted by a key.
+   */
+  DESCENDING(true);
+
+  private final boolean sortable;
+
+  KeyOrder(boolean sortable)
+  {
+    this.sortable = sortable;
+  }
+
+  public boolean sortable()
+  {
+    return sortable;
+  }
+}
diff --git a/processing/src/main/java/org/apache/druid/frame/key/RowKeyComparator.java b/processing/src/main/java/org/apache/druid/frame/key/RowKeyComparator.java
index c4d73696b1..3e7e6faed5 100644
--- a/processing/src/main/java/org/apache/druid/frame/key/RowKeyComparator.java
+++ b/processing/src/main/java/org/apache/druid/frame/key/RowKeyComparator.java
@@ -36,7 +36,7 @@ public class RowKeyComparator implements Comparator<RowKey>
     this.byteRowKeyComparatorDelegate = byteRowKeyComparatorDelegate;
   }
 
-  public static RowKeyComparator create(final List<SortColumn> keyColumns)
+  public static RowKeyComparator create(final List<KeyColumn> keyColumns)
   {
     return new RowKeyComparator(ByteRowKeyComparator.create(keyColumns));
   }
diff --git a/processing/src/main/java/org/apache/druid/frame/processor/BlockingQueueOutputChannelFactory.java b/processing/src/main/java/org/apache/druid/frame/processor/BlockingQueueOutputChannelFactory.java
index 20426b346b..8af823da8d 100644
--- a/processing/src/main/java/org/apache/druid/frame/processor/BlockingQueueOutputChannelFactory.java
+++ b/processing/src/main/java/org/apache/druid/frame/processor/BlockingQueueOutputChannelFactory.java
@@ -38,10 +38,10 @@ public class BlockingQueueOutputChannelFactory implements OutputChannelFactory
   public OutputChannel openChannel(final int partitionNumber)
   {
     final BlockingQueueFrameChannel channel = BlockingQueueFrameChannel.minimal();
-    return OutputChannel.pair(
+    return OutputChannel.immediatelyReadablePair(
         channel.writable(),
         ArenaMemoryAllocator.createOnHeap(frameSize),
-        channel::readable,
+        channel.readable(),
         partitionNumber
     );
   }
diff --git a/processing/src/main/java/org/apache/druid/frame/processor/FrameChannelHashPartitioner.java b/processing/src/main/java/org/apache/druid/frame/processor/FrameChannelHashPartitioner.java
new file mode 100644
index 0000000000..4ce8f3e097
--- /dev/null
+++ b/processing/src/main/java/org/apache/druid/frame/processor/FrameChannelHashPartitioner.java
@@ -0,0 +1,348 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.frame.processor;
+
+import com.google.common.collect.ImmutableList;
+import it.unimi.dsi.fastutil.ints.IntAVLTreeSet;
+import it.unimi.dsi.fastutil.ints.IntSet;
+import org.apache.datasketches.memory.Memory;
+import org.apache.druid.frame.Frame;
+import org.apache.druid.frame.FrameType;
+import org.apache.druid.frame.allocation.MemoryRange;
+import org.apache.druid.frame.channel.ReadableFrameChannel;
+import org.apache.druid.frame.channel.WritableFrameChannel;
+import org.apache.druid.frame.read.FrameReader;
+import org.apache.druid.frame.read.FrameReaderUtils;
+import org.apache.druid.frame.segment.row.FrameColumnSelectorFactory;
+import org.apache.druid.frame.write.FrameWriter;
+import org.apache.druid.frame.write.FrameWriterFactory;
+import org.apache.druid.frame.write.FrameWriterUtils;
+import org.apache.druid.java.util.common.IAE;
+import org.apache.druid.java.util.common.ISE;
+import org.apache.druid.java.util.common.StringUtils;
+import org.apache.druid.query.dimension.DimensionSpec;
+import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector;
+import org.apache.druid.segment.ColumnSelectorFactory;
+import org.apache.druid.segment.ColumnValueSelector;
+import org.apache.druid.segment.Cursor;
+import org.apache.druid.segment.DimensionSelector;
+import org.apache.druid.segment.LongColumnSelector;
+import org.apache.druid.segment.VirtualColumn;
+import org.apache.druid.segment.VirtualColumns;
+import org.apache.druid.segment.column.ColumnCapabilities;
+import org.apache.druid.segment.column.ColumnCapabilitiesImpl;
+import org.apache.druid.segment.column.ColumnType;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.List;
+import java.util.function.LongSupplier;
+import java.util.function.Supplier;
+
+/**
+ * Processor that hash-partitions rows from any number of input channels, and writes partitioned frames to output
+ * channels.
+ *
+ * Input frames must be {@link FrameType#ROW_BASED}, and input signature must be the same as output signature.
+ * This processor hashes each row using {@link Memory#xxHash64} with a seed of {@link #HASH_SEED}.
+ */
+public class FrameChannelHashPartitioner implements FrameProcessor<Long>
+{
+  private static final String PARTITION_COLUMN_NAME =
+      StringUtils.format("%s_part", FrameWriterUtils.RESERVED_FIELD_PREFIX);
+  private static final long HASH_SEED = 0;
+
+  private final List<ReadableFrameChannel> inputChannels;
+  private final List<WritableFrameChannel> outputChannels;
+  private final FrameReader frameReader;
+  private final int keyFieldCount;
+  private final FrameWriterFactory frameWriterFactory;
+  private final IntSet awaitSet;
+
+  private Cursor cursor;
+  private LongSupplier cursorRowPartitionNumberSupplier;
+  private long rowsWritten;
+
+  // Indirection allows FrameWriters to follow "cursor" even when it is replaced with a new instance.
+  private final MultiColumnSelectorFactory cursorColumnSelectorFactory;
+  private final FrameWriter[] frameWriters;
+
+  public FrameChannelHashPartitioner(
+      final List<ReadableFrameChannel> inputChannels,
+      final List<WritableFrameChannel> outputChannels,
+      final FrameReader frameReader,
+      final int keyFieldCount,
+      final FrameWriterFactory frameWriterFactory
+  )
+  {
+    this.inputChannels = inputChannels;
+    this.outputChannels = outputChannels;
+    this.frameReader = frameReader;
+    this.keyFieldCount = keyFieldCount;
+    this.frameWriterFactory = frameWriterFactory;
+    this.awaitSet = FrameProcessors.rangeSet(inputChannels.size());
+    this.frameWriters = new FrameWriter[outputChannels.size()];
+    this.cursorColumnSelectorFactory = new MultiColumnSelectorFactory(
+        Collections.singletonList(() -> cursor.getColumnSelectorFactory()),
+        frameReader.signature()
+    ).withRowMemoryAndSignatureColumns();
+
+    if (!frameReader.signature().equals(frameWriterFactory.signature())) {
+      throw new IAE("Input signature does not match output signature");
+    }
+  }
+
+  @Override
+  public List<ReadableFrameChannel> inputChannels()
+  {
+    return inputChannels;
+  }
+
+  @Override
+  public List<WritableFrameChannel> outputChannels()
+  {
+    return outputChannels;
+  }
+
+  @Override
+  public ReturnOrAwait<Long> runIncrementally(final IntSet readableInputs) throws IOException
+  {
+    if (cursor == null) {
+      readNextFrame(readableInputs);
+    }
+
+    if (cursor != null) {
+      processCursor();
+    }
+
+    if (cursor != null) {
+      return ReturnOrAwait.runAgain();
+    } else if (awaitSet.isEmpty()) {
+      flushFrameWriters();
+      return ReturnOrAwait.returnObject(rowsWritten);
+    } else {
+      return ReturnOrAwait.awaitAny(awaitSet);
+    }
+  }
+
+  @Override
+  public void cleanup() throws IOException
+  {
+    FrameProcessors.closeAll(inputChannels(), outputChannels(), frameWriters);
+  }
+
+  private void processCursor() throws IOException
+  {
+    assert cursor != null;
+
+    while (!cursor.isDone()) {
+      final int partition = (int) cursorRowPartitionNumberSupplier.getAsLong();
+      final FrameWriter frameWriter = getOrCreateFrameWriter(partition);
+
+      if (frameWriter.addSelection()) {
+        cursor.advance();
+      } else if (frameWriter.getNumRows() > 0) {
+        writeFrame(partition);
+        return;
+      } else {
+        throw new FrameRowTooLargeException(frameWriterFactory.allocatorCapacity());
+      }
+    }
+
+    cursor = null;
+    cursorRowPartitionNumberSupplier = null;
+  }
+
+  private void readNextFrame(final IntSet readableInputs)
+  {
+    if (cursor != null) {
+      throw new ISE("Already reading a frame");
+    }
+
+    final IntSet readySet = new IntAVLTreeSet(readableInputs);
+
+    for (int channelNumber : readableInputs) {
+      final ReadableFrameChannel channel = inputChannels.get(channelNumber);
+
+      if (channel.isFinished()) {
+        awaitSet.remove(channelNumber);
+        readySet.remove(channelNumber);
+      }
+    }
+
+    if (!readySet.isEmpty()) {
+      // Read a random channel: avoid biasing towards lower-numbered channels.
+      final int channelNumber = FrameProcessors.selectRandom(readySet);
+      final ReadableFrameChannel channel = inputChannels.get(channelNumber);
+
+      if (!channel.isFinished()) {
+        // Need row-based frame so we can hash memory directly.
+        final Frame frame = FrameType.ROW_BASED.ensureType(channel.read());
+
+        final HashPartitionVirtualColumn hashPartitionVirtualColumn =
+            new HashPartitionVirtualColumn(PARTITION_COLUMN_NAME, frameReader, keyFieldCount, outputChannels.size());
+
+        cursor = FrameProcessors.makeCursor(
+            frame,
+            frameReader,
+            VirtualColumns.create(Collections.singletonList(hashPartitionVirtualColumn))
+        );
+
+        cursorRowPartitionNumberSupplier =
+            cursor.getColumnSelectorFactory().makeColumnValueSelector(PARTITION_COLUMN_NAME)::getLong;
+      }
+    }
+  }
+
+  private void flushFrameWriters() throws IOException
+  {
+    for (int i = 0; i < frameWriters.length; i++) {
+      if (frameWriters[i] != null) {
+        writeFrame(i);
+      }
+    }
+  }
+
+  private FrameWriter getOrCreateFrameWriter(final int partition)
+  {
+    if (frameWriters[partition] == null) {
+      frameWriters[partition] = frameWriterFactory.newFrameWriter(cursorColumnSelectorFactory);
+    }
+
+    return frameWriters[partition];
+  }
+
+  private void writeFrame(final int partition) throws IOException
+  {
+    if (frameWriters[partition] == null || frameWriters[partition].getNumRows() == 0) {
+      throw new ISE("Nothing to write for partition [%,d]", partition);
+    }
+
+    final Frame frame = Frame.wrap(frameWriters[partition].toByteArray());
+    outputChannels.get(partition).write(frame);
+    frameWriters[partition].close();
+    frameWriters[partition] = null;
+    rowsWritten += frame.numRows();
+  }
+
+  /**
+   * Virtual column that provides a hash code of the {@link FrameType#ROW_BASED} frame row that is wrapped in
+   * the provided {@link ColumnSelectorFactory}, using {@link FrameReaderUtils#makeRowMemorySupplier}.
+   */
+  private static class HashPartitionVirtualColumn implements VirtualColumn
+  {
+    private final String name;
+    private final FrameReader frameReader;
+    private final int keyFieldCount;
+    private final int partitionCount;
+
+    public HashPartitionVirtualColumn(
+        final String name,
+        final FrameReader frameReader,
+        final int keyFieldCount,
+        final int partitionCount
+    )
+    {
+      this.name = name;
+      this.frameReader = frameReader;
+      this.keyFieldCount = keyFieldCount;
+      this.partitionCount = partitionCount;
+    }
+
+    @Override
+    public String getOutputName()
+    {
+      return name;
+    }
+
+    @Override
+    public DimensionSelector makeDimensionSelector(DimensionSpec dimensionSpec, ColumnSelectorFactory factory)
+    {
+      throw new UnsupportedOperationException();
+    }
+
+    @Override
+    public ColumnValueSelector<?> makeColumnValueSelector(String columnName, ColumnSelectorFactory factory)
+    {
+      final Supplier<MemoryRange<Memory>> rowMemorySupplier =
+          FrameReaderUtils.makeRowMemorySupplier(factory, frameReader.signature());
+      final int frameFieldCount = frameReader.signature().size();
+
+      return new LongColumnSelector()
+      {
+        @Override
+        public long getLong()
+        {
+          if (keyFieldCount == 0) {
+            return 0;
+          }
+
+          final MemoryRange<Memory> rowMemoryRange = rowMemorySupplier.get();
+          final Memory memory = rowMemoryRange.memory();
+          final long keyStartPosition = (long) Integer.BYTES * frameFieldCount;
+          final long keyEndPosition =
+              memory.getInt(rowMemoryRange.start() + (long) Integer.BYTES * (keyFieldCount - 1));
+          final int keyLength = (int) (keyEndPosition - keyStartPosition);
+          final long hash = memory.xxHash64(rowMemoryRange.start() + keyStartPosition, keyLength, HASH_SEED);
+          return Math.abs(hash % partitionCount);
+        }
+
+        @Override
+        public boolean isNull()
+        {
+          return false;
+        }
+
+        @Override
+        public void inspectRuntimeShape(RuntimeShapeInspector inspector)
+        {
+          // Nothing to do.
+        }
+      };
+    }
+
+    @Override
+    public ColumnCapabilities capabilities(String columnName)
+    {
+      return ColumnCapabilitiesImpl.createSimpleNumericColumnCapabilities(ColumnType.LONG).setHasNulls(false);
+    }
+
+    @Override
+    public List<String> requiredColumns()
+    {
+      return ImmutableList.of(
+          FrameColumnSelectorFactory.ROW_MEMORY_COLUMN,
+          FrameColumnSelectorFactory.ROW_SIGNATURE_COLUMN
+      );
+    }
+
+    @Override
+    public boolean usesDotNotation()
+    {
+      return false;
+    }
+
+    @Override
+    public byte[] getCacheKey()
+    {
+      throw new UnsupportedOperationException();
+    }
+  }
+}
diff --git a/processing/src/main/java/org/apache/druid/frame/processor/FrameChannelMerger.java b/processing/src/main/java/org/apache/druid/frame/processor/FrameChannelMerger.java
index d1258e47e4..8117bd63d9 100644
--- a/processing/src/main/java/org/apache/druid/frame/processor/FrameChannelMerger.java
+++ b/processing/src/main/java/org/apache/druid/frame/processor/FrameChannelMerger.java
@@ -27,20 +27,17 @@ import org.apache.druid.frame.Frame;
 import org.apache.druid.frame.channel.FrameWithPartition;
 import org.apache.druid.frame.channel.ReadableFrameChannel;
 import org.apache.druid.frame.channel.WritableFrameChannel;
-import org.apache.druid.frame.key.ClusterBy;
 import org.apache.druid.frame.key.ClusterByPartitions;
 import org.apache.druid.frame.key.FrameComparisonWidget;
+import org.apache.druid.frame.key.KeyColumn;
 import org.apache.druid.frame.key.RowKey;
 import org.apache.druid.frame.read.FrameReader;
-import org.apache.druid.frame.segment.row.FrameColumnSelectorFactory;
 import org.apache.druid.frame.write.FrameWriter;
 import org.apache.druid.frame.write.FrameWriterFactory;
 import org.apache.druid.java.util.common.IAE;
 import org.apache.druid.java.util.common.ISE;
 import org.apache.druid.segment.ColumnSelectorFactory;
 import org.apache.druid.segment.Cursor;
-import org.apache.druid.segment.column.ColumnType;
-import org.apache.druid.segment.column.RowSignature;
 
 import javax.annotation.Nullable;
 import java.io.IOException;
@@ -57,7 +54,7 @@ import java.util.function.Supplier;
  * Frames from input channels must be {@link org.apache.druid.frame.FrameType#ROW_BASED}. Output frames will
  * be row-based as well.
  *
- * For unsorted output, use {@link FrameChannelMuxer} instead.
+ * For unsorted output, use {@link FrameChannelMixer} instead.
  */
 public class FrameChannelMerger implements FrameProcessor<Long>
 {
@@ -66,7 +63,7 @@ public class FrameChannelMerger implements FrameProcessor<Long>
   private final List<ReadableFrameChannel> inputChannels;
   private final WritableFrameChannel outputChannel;
   private final FrameReader frameReader;
-  private final ClusterBy clusterBy;
+  private final List<KeyColumn> sortKey;
   private final ClusterByPartitions partitions;
   private final IntPriorityQueue priorityQueue;
   private final FrameWriterFactory frameWriterFactory;
@@ -83,7 +80,7 @@ public class FrameChannelMerger implements FrameProcessor<Long>
       final FrameReader frameReader,
       final WritableFrameChannel outputChannel,
       final FrameWriterFactory frameWriterFactory,
-      final ClusterBy clusterBy,
+      final List<KeyColumn> sortKey,
       @Nullable final ClusterByPartitions partitions,
       final long rowLimit
   )
@@ -102,11 +99,15 @@ public class FrameChannelMerger implements FrameProcessor<Long>
       throw new IAE("Partitions must all abut each other");
     }
 
+    if (!sortKey.stream().allMatch(keyColumn -> keyColumn.order().sortable())) {
+      throw new IAE("Key is not sortable");
+    }
+
     this.inputChannels = inputChannels;
     this.outputChannel = outputChannel;
     this.frameReader = frameReader;
     this.frameWriterFactory = frameWriterFactory;
-    this.clusterBy = clusterBy;
+    this.sortKey = sortKey;
     this.partitions = partitionsToUse;
     this.rowLimit = rowLimit;
     this.currentFrames = new FramePlus[inputChannels.size()];
@@ -127,18 +128,10 @@ public class FrameChannelMerger implements FrameProcessor<Long>
       frameColumnSelectorFactorySuppliers.add(() -> currentFrames[frameNumber].cursor.getColumnSelectorFactory());
     }
 
-    this.mergedColumnSelectorFactory =
-        new MultiColumnSelectorFactory(
-            frameColumnSelectorFactorySuppliers,
-
-            // Include ROW_SIGNATURE_COLUMN, ROW_MEMORY_COLUMN to potentially enable direct row memory copying.
-            // If these columns don't actually exist in the underlying column selector factories, they'll be ignored.
-            RowSignature.builder()
-                        .addAll(frameReader.signature())
-                        .add(FrameColumnSelectorFactory.ROW_SIGNATURE_COLUMN, ColumnType.UNKNOWN_COMPLEX)
-                        .add(FrameColumnSelectorFactory.ROW_MEMORY_COLUMN, ColumnType.UNKNOWN_COMPLEX)
-                        .build()
-        );
+    this.mergedColumnSelectorFactory = new MultiColumnSelectorFactory(
+        frameColumnSelectorFactorySuppliers,
+        frameReader.signature()
+    ).withRowMemoryAndSignatureColumns();
   }
 
   @Override
@@ -244,7 +237,7 @@ public class FrameChannelMerger implements FrameProcessor<Long>
             if (channel.canRead()) {
               // Read next frame from this channel.
               final Frame frame = channel.read();
-              currentFrames[currentChannel] = new FramePlus(frame, frameReader, clusterBy);
+              currentFrames[currentChannel] = new FramePlus(frame, frameReader, sortKey);
               priorityQueue.enqueue(currentChannel);
             } else if (channel.isFinished()) {
               // Done reading this channel. Fall through and continue with other channels.
@@ -281,7 +274,7 @@ public class FrameChannelMerger implements FrameProcessor<Long>
 
         if (channel.canRead()) {
           final Frame frame = channel.read();
-          currentFrames[i] = new FramePlus(frame, frameReader, clusterBy);
+          currentFrames[i] = new FramePlus(frame, frameReader, sortKey);
           priorityQueue.enqueue(i);
         } else if (!channel.isFinished()) {
           await.add(i);
@@ -301,10 +294,10 @@ public class FrameChannelMerger implements FrameProcessor<Long>
     private final FrameComparisonWidget comparisonWidget;
     private int rowNumber;
 
-    private FramePlus(Frame frame, FrameReader frameReader, ClusterBy clusterBy)
+    private FramePlus(Frame frame, FrameReader frameReader, List<KeyColumn> sortKey)
     {
       this.cursor = FrameProcessors.makeCursor(frame, frameReader);
-      this.comparisonWidget = frameReader.makeComparisonWidget(frame, clusterBy.getColumns());
+      this.comparisonWidget = frameReader.makeComparisonWidget(frame, sortKey);
       this.rowNumber = 0;
     }
 
diff --git a/processing/src/main/java/org/apache/druid/frame/processor/FrameChannelMuxer.java b/processing/src/main/java/org/apache/druid/frame/processor/FrameChannelMixer.java
similarity index 63%
rename from processing/src/main/java/org/apache/druid/frame/processor/FrameChannelMuxer.java
rename to processing/src/main/java/org/apache/druid/frame/processor/FrameChannelMixer.java
index bc3e8a2f8f..7f1c48f7ff 100644
--- a/processing/src/main/java/org/apache/druid/frame/processor/FrameChannelMuxer.java
+++ b/processing/src/main/java/org/apache/druid/frame/processor/FrameChannelMixer.java
@@ -19,8 +19,7 @@
 
 package org.apache.druid.frame.processor;
 
-import it.unimi.dsi.fastutil.ints.IntIterator;
-import it.unimi.dsi.fastutil.ints.IntOpenHashSet;
+import it.unimi.dsi.fastutil.ints.IntAVLTreeSet;
 import it.unimi.dsi.fastutil.ints.IntSet;
 import org.apache.druid.frame.Frame;
 import org.apache.druid.frame.channel.ReadableFrameChannel;
@@ -29,7 +28,6 @@ import org.apache.druid.frame.channel.WritableFrameChannel;
 import java.io.IOException;
 import java.util.Collections;
 import java.util.List;
-import java.util.concurrent.ThreadLocalRandom;
 
 /**
  * Processor that merges frames from inputChannels into a single outputChannel. No sorting is done: input frames are
@@ -37,21 +35,22 @@ import java.util.concurrent.ThreadLocalRandom;
  *
  * For sorted output, use {@link FrameChannelMerger} instead.
  */
-public class FrameChannelMuxer implements FrameProcessor<Long>
+public class FrameChannelMixer implements FrameProcessor<Long>
 {
   private final List<ReadableFrameChannel> inputChannels;
   private final WritableFrameChannel outputChannel;
 
-  private final IntSet remainingChannels = new IntOpenHashSet();
+  private final IntSet awaitSet;
   private long rowsRead = 0L;
 
-  public FrameChannelMuxer(
+  public FrameChannelMixer(
       final List<ReadableFrameChannel> inputChannels,
       final WritableFrameChannel outputChannel
   )
   {
     this.inputChannels = inputChannels;
     this.outputChannel = outputChannel;
+    this.awaitSet = FrameProcessors.rangeSet(inputChannels.size());
   }
 
   @Override
@@ -69,39 +68,33 @@ public class FrameChannelMuxer implements FrameProcessor<Long>
   @Override
   public ReturnOrAwait<Long> runIncrementally(final IntSet readableInputs) throws IOException
   {
-    if (remainingChannels.isEmpty()) {
-      // First run.
-      for (int i = 0; i < inputChannels.size(); i++) {
-        final ReadableFrameChannel channel = inputChannels.get(i);
-        if (!channel.isFinished()) {
-          remainingChannels.add(i);
-        }
+    final IntSet readySet = new IntAVLTreeSet(readableInputs);
+
+    for (int channelNumber : readableInputs) {
+      final ReadableFrameChannel channel = inputChannels.get(channelNumber);
+
+      if (channel.isFinished()) {
+        awaitSet.remove(channelNumber);
+        readySet.remove(channelNumber);
       }
     }
 
-    if (!readableInputs.isEmpty()) {
-      // Avoid biasing towards lower-numbered channels.
-      final int channelIdx = ThreadLocalRandom.current().nextInt(readableInputs.size());
-
-      int i = 0;
-      for (IntIterator iterator = readableInputs.iterator(); iterator.hasNext(); i++) {
-        final int channelNumber = iterator.nextInt();
-        final ReadableFrameChannel channel = inputChannels.get(channelNumber);
+    if (!readySet.isEmpty()) {
+      // Read a random channel: avoid biasing towards lower-numbered channels.
+      final int channelNumber = FrameProcessors.selectRandom(readySet);
+      final ReadableFrameChannel channel = inputChannels.get(channelNumber);
 
-        if (channel.isFinished()) {
-          remainingChannels.remove(channelNumber);
-        } else if (i == channelIdx) {
-          final Frame frame = channel.read();
-          outputChannel.write(frame);
-          rowsRead += frame.numRows();
-        }
+      if (!channel.isFinished()) {
+        final Frame frame = channel.read();
+        outputChannel.write(frame);
+        rowsRead += frame.numRows();
       }
     }
 
-    if (remainingChannels.isEmpty()) {
+    if (awaitSet.isEmpty()) {
       return ReturnOrAwait.returnObject(rowsRead);
     } else {
-      return ReturnOrAwait.awaitAny(remainingChannels);
+      return ReturnOrAwait.awaitAny(awaitSet);
     }
   }
 
diff --git a/processing/src/main/java/org/apache/druid/frame/processor/FrameProcessor.java b/processing/src/main/java/org/apache/druid/frame/processor/FrameProcessor.java
index 291687806a..7fc1f1d133 100644
--- a/processing/src/main/java/org/apache/druid/frame/processor/FrameProcessor.java
+++ b/processing/src/main/java/org/apache/druid/frame/processor/FrameProcessor.java
@@ -71,7 +71,7 @@ public interface FrameProcessor<T>
    *
    * Implementations typically call {@link ReadableFrameChannel#close()} and
    * {@link WritableFrameChannel#close()} on all input and output channels, as well as releasing any additional
-   * resources that may be held.
+   * resources that may be held, such as {@link org.apache.druid.frame.write.FrameWriter}.
    *
    * In cases of cancellation, this method may be called even if {@link #runIncrementally} has not yet returned a
    * result via {@link ReturnOrAwait#returnObject}.
diff --git a/processing/src/main/java/org/apache/druid/frame/processor/FrameProcessors.java b/processing/src/main/java/org/apache/druid/frame/processor/FrameProcessors.java
index b1c9db3381..3380c76883 100644
--- a/processing/src/main/java/org/apache/druid/frame/processor/FrameProcessors.java
+++ b/processing/src/main/java/org/apache/druid/frame/processor/FrameProcessors.java
@@ -20,23 +20,28 @@
 package org.apache.druid.frame.processor;
 
 import com.google.common.collect.Lists;
+import it.unimi.dsi.fastutil.ints.IntAVLTreeSet;
+import it.unimi.dsi.fastutil.ints.IntIterator;
 import it.unimi.dsi.fastutil.ints.IntSet;
+import it.unimi.dsi.fastutil.ints.IntSortedSet;
 import org.apache.druid.frame.Frame;
 import org.apache.druid.frame.channel.ReadableFrameChannel;
 import org.apache.druid.frame.channel.WritableFrameChannel;
 import org.apache.druid.frame.read.FrameReader;
+import org.apache.druid.frame.segment.FrameCursor;
 import org.apache.druid.frame.segment.FrameStorageAdapter;
+import org.apache.druid.java.util.common.IAE;
 import org.apache.druid.java.util.common.Intervals;
 import org.apache.druid.java.util.common.granularity.Granularities;
 import org.apache.druid.java.util.common.guava.Yielders;
 import org.apache.druid.java.util.common.io.Closer;
-import org.apache.druid.segment.Cursor;
 import org.apache.druid.segment.VirtualColumns;
 
 import java.io.Closeable;
 import java.io.IOException;
 import java.util.Arrays;
 import java.util.List;
+import java.util.concurrent.ThreadLocalRandom;
 import java.util.concurrent.atomic.AtomicBoolean;
 
 public class FrameProcessors
@@ -92,17 +97,64 @@ public class FrameProcessors
     return new FrameProcessorWithBaggage();
   }
 
-  public static Cursor makeCursor(final Frame frame, final FrameReader frameReader)
+  /**
+   * Returns a {@link FrameCursor} for the provided {@link Frame}, allowing both sequential and random access.
+   */
+  public static FrameCursor makeCursor(final Frame frame, final FrameReader frameReader)
+  {
+    return makeCursor(frame, frameReader, VirtualColumns.EMPTY);
+  }
+
+  /**
+   * Returns a {@link FrameCursor} for the provided {@link Frame} and {@link VirtualColumns}, allowing both sequential
+   * and random access.
+   */
+  public static FrameCursor makeCursor(
+      final Frame frame,
+      final FrameReader frameReader,
+      final VirtualColumns virtualColumns
+  )
   {
-    // Safe to never close the Sequence that the Cursor comes from, because it does not do anything when it is closed.
+    // Safe to never close the Sequence that the FrameCursor comes from, because it does not need to be closed.
     // Refer to FrameStorageAdapter#makeCursors.
 
-    return Yielders.each(
+    return (FrameCursor) Yielders.each(
         new FrameStorageAdapter(frame, frameReader, Intervals.ETERNITY)
-            .makeCursors(null, Intervals.ETERNITY, VirtualColumns.EMPTY, Granularities.ALL, false, null)
+            .makeCursors(null, Intervals.ETERNITY, virtualColumns, Granularities.ALL, false, null)
     ).get();
   }
 
+  /**
+   * Creates a mutable sorted set from 0 to "size" (exclusive).
+   *
+   * @throws IllegalArgumentException if size is negative
+   */
+  public static IntSortedSet rangeSet(final int size)
+  {
+    if (size < 0) {
+      throw new IAE("Size must be nonnegative");
+    }
+
+    final IntSortedSet set = new IntAVLTreeSet();
+
+    for (int i = 0; i < size; i++) {
+      set.add(i);
+    }
+
+    return set;
+  }
+
+  /**
+   * Selects a random element from a set of ints.
+   */
+  public static int selectRandom(final IntSet ints)
+  {
+    final int idx = ThreadLocalRandom.current().nextInt(ints.size());
+    final IntIterator iterator = ints.iterator();
+    iterator.skip(idx);
+    return iterator.nextInt();
+  }
+
   /**
    * Helper method for implementing {@link FrameProcessor#cleanup()}.
    *
diff --git a/processing/src/main/java/org/apache/druid/frame/processor/MultiColumnSelectorFactory.java b/processing/src/main/java/org/apache/druid/frame/processor/MultiColumnSelectorFactory.java
index 33ebc3d393..7121e205fe 100644
--- a/processing/src/main/java/org/apache/druid/frame/processor/MultiColumnSelectorFactory.java
+++ b/processing/src/main/java/org/apache/druid/frame/processor/MultiColumnSelectorFactory.java
@@ -20,16 +20,18 @@
 package org.apache.druid.frame.processor;
 
 import com.google.common.base.Predicate;
+import org.apache.druid.frame.segment.row.FrameColumnSelectorFactory;
 import org.apache.druid.query.dimension.DimensionSpec;
 import org.apache.druid.query.filter.ValueMatcher;
 import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector;
-import org.apache.druid.segment.ColumnInspector;
 import org.apache.druid.segment.ColumnSelectorFactory;
 import org.apache.druid.segment.ColumnValueSelector;
 import org.apache.druid.segment.DimensionSelector;
 import org.apache.druid.segment.DimensionSelectorUtils;
 import org.apache.druid.segment.IdLookup;
 import org.apache.druid.segment.column.ColumnCapabilities;
+import org.apache.druid.segment.column.ColumnType;
+import org.apache.druid.segment.column.RowSignature;
 import org.apache.druid.segment.data.IndexedInts;
... 3528 lines suppressed ...


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@druid.apache.org
For additional commands, e-mail: commits-help@druid.apache.org