You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2023/06/26 19:30:50 UTC
[arrow-datafusion] branch main updated: Order Preserving RepartitionExec Implementation (#6742)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new f24a724543 Order Preserving RepartitionExec Implementation (#6742)
f24a724543 is described below
commit f24a724543c43e99e65a7069ed170f976e432332
Author: Mustafa Akur <10...@users.noreply.github.com>
AuthorDate: Mon Jun 26 22:30:44 2023 +0300
Order Preserving RepartitionExec Implementation (#6742)
* Write tests for functionality
* Implement sort preserving repartition exec
* Minor changes
* Implement second design (per partition merge)
* Simplifications
* Address reviews
* Move the fuzz test to appropriate folder, improve comments
* Decrease code duplication
* simplifications
* Update comment
---------
Co-authored-by: Mehmet Ozan Kabak <oz...@gmail.com>
Co-authored-by: Mustafa Akur <ak...@gmail.com>
---
.../repartition/distributor_channels.rs | 13 ++
.../core/src/physical_plan/repartition/mod.rs | 179 +++++++++++++---
datafusion/core/tests/fuzz_cases/mod.rs | 1 +
.../fuzz_cases/sort_preserving_repartition_fuzz.rs | 237 +++++++++++++++++++++
4 files changed, 405 insertions(+), 25 deletions(-)
diff --git a/datafusion/core/src/physical_plan/repartition/distributor_channels.rs b/datafusion/core/src/physical_plan/repartition/distributor_channels.rs
index d9466d647c..e71b88467b 100644
--- a/datafusion/core/src/physical_plan/repartition/distributor_channels.rs
+++ b/datafusion/core/src/physical_plan/repartition/distributor_channels.rs
@@ -83,6 +83,19 @@ pub fn channels<T>(
(senders, receivers)
}
+type PartitionAwareSenders<T> = Vec<Vec<DistributionSender<T>>>;
+type PartitionAwareReceivers<T> = Vec<Vec<DistributionReceiver<T>>>;
+
+/// Create `n_out` empty channels for each of the `n_in` inputs.
+/// This way, each distinct partition will communicate via a dedicated channel.
+/// This SPSC structure enables us to track which partition input data comes from.
+pub fn partition_aware_channels<T>(
+ n_in: usize,
+ n_out: usize,
+) -> (PartitionAwareSenders<T>, PartitionAwareReceivers<T>) {
+ (0..n_in).map(|_| channels(n_out)).unzip()
+}
+
/// Erroring during [send](DistributionSender::send).
///
/// This occurs when the [receiver](DistributionReceiver) is gone.
diff --git a/datafusion/core/src/physical_plan/repartition/mod.rs b/datafusion/core/src/physical_plan/repartition/mod.rs
index 0dc16eaf1d..72ff0c3713 100644
--- a/datafusion/core/src/physical_plan/repartition/mod.rs
+++ b/datafusion/core/src/physical_plan/repartition/mod.rs
@@ -16,7 +16,8 @@
// under the License.
//! The repartition operator maps N input partitions to M output partitions based on a
-//! partitioning scheme.
+//! partitioning scheme (according to flag `preserve_order` ordering can be preserved during
+//! repartitioning if its input is ordered).
use std::pin::Pin;
use std::sync::Arc;
@@ -24,7 +25,9 @@ use std::task::{Context, Poll};
use std::{any::Any, vec};
use crate::physical_plan::hash_utils::create_hashes;
-use crate::physical_plan::repartition::distributor_channels::channels;
+use crate::physical_plan::repartition::distributor_channels::{
+ channels, partition_aware_channels,
+};
use crate::physical_plan::{
DisplayFormatType, EquivalenceProperties, ExecutionPlan, Partitioning, Statistics,
};
@@ -42,6 +45,9 @@ use super::expressions::PhysicalSortExpr;
use super::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet};
use super::{RecordBatchStream, SendableRecordBatchStream};
+use crate::physical_plan::common::transpose;
+use crate::physical_plan::metrics::BaselineMetrics;
+use crate::physical_plan::sorts::streaming_merge;
use datafusion_execution::TaskContext;
use datafusion_physical_expr::PhysicalExpr;
use futures::stream::Stream;
@@ -53,6 +59,8 @@ use tokio::task::JoinHandle;
mod distributor_channels;
type MaybeBatch = Option<Result<RecordBatch>>;
+type InputPartitionsToCurrentPartitionSender = Vec<DistributionSender<MaybeBatch>>;
+type InputPartitionsToCurrentPartitionReceiver = Vec<DistributionReceiver<MaybeBatch>>;
/// Inner state of [`RepartitionExec`].
#[derive(Debug)]
@@ -62,8 +70,8 @@ struct RepartitionExecState {
channels: HashMap<
usize,
(
- DistributionSender<MaybeBatch>,
- DistributionReceiver<MaybeBatch>,
+ InputPartitionsToCurrentPartitionSender,
+ InputPartitionsToCurrentPartitionReceiver,
SharedMemoryReservation,
),
>,
@@ -245,6 +253,9 @@ pub struct RepartitionExec {
/// Execution metrics
metrics: ExecutionPlanMetricsSet,
+
+ /// Boolean flag to decide whether to preserve ordering
+ preserve_order: bool,
}
#[derive(Debug, Clone)]
@@ -298,6 +309,15 @@ impl RepartitionExec {
pub fn partitioning(&self) -> &Partitioning {
&self.partitioning
}
+
+ /// Get name of the Executor
+ pub fn name(&self) -> &str {
+ if self.preserve_order {
+ "SortPreservingRepartitionExec"
+ } else {
+ "RepartitionExec"
+ }
+ }
}
impl ExecutionPlan for RepartitionExec {
@@ -345,8 +365,12 @@ impl ExecutionPlan for RepartitionExec {
}
fn maintains_input_order(&self) -> Vec<bool> {
- // We preserve ordering when input partitioning is 1
- vec![self.input().output_partitioning().partition_count() <= 1]
+ if self.preserve_order {
+ vec![true]
+ } else {
+ // We preserve ordering when input partitioning is 1
+ vec![self.input().output_partitioning().partition_count() <= 1]
+ }
}
fn equivalence_properties(&self) -> EquivalenceProperties {
@@ -359,7 +383,8 @@ impl ExecutionPlan for RepartitionExec {
context: Arc<TaskContext>,
) -> Result<SendableRecordBatchStream> {
trace!(
- "Start RepartitionExec::execute for partition: {}",
+ "Start {}::execute for partition: {}",
+ self.name(),
partition
);
// lock mutexes
@@ -370,13 +395,29 @@ impl ExecutionPlan for RepartitionExec {
// if this is the first partition to be invoked then we need to set up initial state
if state.channels.is_empty() {
- // create one channel per *output* partition
- // note we use a custom channel that ensures there is always data for each receiver
- // but limits the amount of buffering if required.
- let (txs, rxs) = channels(num_output_partitions);
+ let (txs, rxs) = if self.preserve_order {
+ let (txs, rxs) =
+ partition_aware_channels(num_input_partitions, num_output_partitions);
+ // Take transpose of senders and receivers. `state.channels` keeps track of entries per output partition
+ let txs = transpose(txs);
+ let rxs = transpose(rxs);
+ (txs, rxs)
+ } else {
+ // create one channel per *output* partition
+ // note we use a custom channel that ensures there is always data for each receiver
+ // but limits the amount of buffering if required.
+ let (txs, rxs) = channels(num_output_partitions);
+ // Clone sender for ech input partitions
+ let txs = txs
+ .into_iter()
+ .map(|item| vec![item; num_input_partitions])
+ .collect::<Vec<_>>();
+ let rxs = rxs.into_iter().map(|item| vec![item]).collect::<Vec<_>>();
+ (txs, rxs)
+ };
for (partition, (tx, rx)) in txs.into_iter().zip(rxs).enumerate() {
let reservation = Arc::new(Mutex::new(
- MemoryConsumer::new(format!("RepartitionExec[{partition}]"))
+ MemoryConsumer::new(format!("{}[{partition}]", self.name()))
.register(context.memory_pool()),
));
state.channels.insert(partition, (tx, rx, reservation));
@@ -389,7 +430,7 @@ impl ExecutionPlan for RepartitionExec {
.channels
.iter()
.map(|(partition, (tx, _rx, reservation))| {
- (*partition, (tx.clone(), Arc::clone(reservation)))
+ (*partition, (tx[i].clone(), Arc::clone(reservation)))
})
.collect();
@@ -420,24 +461,53 @@ impl ExecutionPlan for RepartitionExec {
}
trace!(
- "Before returning stream in RepartitionExec::execute for partition: {}",
+ "Before returning stream in {}::execute for partition: {}",
+ self.name(),
partition
);
// now return stream for the specified *output* partition which will
// read from the channel
- let (_tx, rx, reservation) = state
+ let (_tx, mut rx, reservation) = state
.channels
.remove(&partition)
.expect("partition not used yet");
- Ok(Box::pin(RepartitionStream {
- num_input_partitions,
- num_input_partitions_processed: 0,
- schema: self.input.schema(),
- input: rx,
- drop_helper: Arc::clone(&state.abort_helper),
- reservation,
- }))
+
+ if self.preserve_order {
+ // Store streams from all the input partitions:
+ let input_streams = rx
+ .into_iter()
+ .map(|receiver| {
+ Box::pin(PerPartitionStream {
+ schema: self.schema(),
+ receiver,
+ drop_helper: Arc::clone(&state.abort_helper),
+ reservation: reservation.clone(),
+ }) as SendableRecordBatchStream
+ })
+ .collect::<Vec<_>>();
+ // Note that receiver size (`rx.len()`) and `num_input_partitions` are same.
+
+ // Get existing ordering:
+ let sort_exprs = self.input.output_ordering().unwrap_or(&[]);
+ // Merge streams (while preserving ordering) coming from input partitions to this partition:
+ streaming_merge(
+ input_streams,
+ self.schema(),
+ sort_exprs,
+ BaselineMetrics::new(&self.metrics, partition),
+ context.session_config().batch_size(),
+ )
+ } else {
+ Ok(Box::pin(RepartitionStream {
+ num_input_partitions,
+ num_input_partitions_processed: 0,
+ schema: self.input.schema(),
+ input: rx.swap_remove(0),
+ drop_helper: Arc::clone(&state.abort_helper),
+ reservation,
+ }))
+ }
}
fn metrics(&self) -> Option<MetricsSet> {
@@ -453,7 +523,8 @@ impl ExecutionPlan for RepartitionExec {
DisplayFormatType::Default | DisplayFormatType::Verbose => {
write!(
f,
- "RepartitionExec: partitioning={}, input_partitions={}",
+ "{}: partitioning={}, input_partitions={}",
+ self.name(),
self.partitioning,
self.input.output_partitioning().partition_count()
)
@@ -480,9 +551,16 @@ impl RepartitionExec {
abort_helper: Arc::new(AbortOnDropMany::<()>(vec![])),
})),
metrics: ExecutionPlanMetricsSet::new(),
+ preserve_order: false,
})
}
+ /// Set Order preserving flag
+ pub fn with_preserve_order(mut self) -> Self {
+ self.preserve_order = true;
+ self
+ }
+
/// Pulls data from the specified input plan, feeding it to the
/// output partitions based on the desired partitioning
///
@@ -575,7 +653,7 @@ impl RepartitionExec {
/// channels.
async fn wait_for_task(
input_task: AbortOnDropSingle<Result<()>>,
- txs: HashMap<usize, DistributionSender<Option<Result<RecordBatch>>>>,
+ txs: HashMap<usize, DistributionSender<MaybeBatch>>,
) {
// wait for completion, and propagate error
// note we ignore errors on send (.ok) as that means the receiver has already shutdown.
@@ -681,6 +759,56 @@ impl RecordBatchStream for RepartitionStream {
}
}
+/// This struct converts a receiver to a stream.
+/// Receiver receives data on an SPSC channel.
+struct PerPartitionStream {
+ /// Schema wrapped by Arc
+ schema: SchemaRef,
+
+ /// channel containing the repartitioned batches
+ receiver: DistributionReceiver<MaybeBatch>,
+
+ /// Handle to ensure background tasks are killed when no longer needed.
+ #[allow(dead_code)]
+ drop_helper: Arc<AbortOnDropMany<()>>,
+
+ /// Memory reservation.
+ reservation: SharedMemoryReservation,
+}
+
+impl Stream for PerPartitionStream {
+ type Item = Result<RecordBatch>;
+
+ fn poll_next(
+ mut self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ ) -> Poll<Option<Self::Item>> {
+ match self.receiver.recv().poll_unpin(cx) {
+ Poll::Ready(Some(Some(v))) => {
+ if let Ok(batch) = &v {
+ self.reservation
+ .lock()
+ .shrink(batch.get_array_memory_size());
+ }
+ Poll::Ready(Some(v))
+ }
+ Poll::Ready(Some(None)) => {
+ // Input partition has finished sending batches
+ Poll::Ready(None)
+ }
+ Poll::Ready(None) => Poll::Ready(None),
+ Poll::Pending => Poll::Pending,
+ }
+ }
+}
+
+impl RecordBatchStream for PerPartitionStream {
+ /// Get the schema
+ fn schema(&self) -> SchemaRef {
+ self.schema.clone()
+ }
+}
+
#[cfg(test)]
mod tests {
use super::*;
@@ -705,6 +833,7 @@ mod tests {
use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv};
use futures::FutureExt;
use std::collections::HashSet;
+ use tokio::task::JoinHandle;
#[tokio::test]
async fn one_to_many_round_robin() -> Result<()> {
diff --git a/datafusion/core/tests/fuzz_cases/mod.rs b/datafusion/core/tests/fuzz_cases/mod.rs
index bf415c7f1c..c49eb65988 100644
--- a/datafusion/core/tests/fuzz_cases/mod.rs
+++ b/datafusion/core/tests/fuzz_cases/mod.rs
@@ -19,4 +19,5 @@ mod aggregate_fuzz;
mod join_fuzz;
mod merge_fuzz;
mod order_spill_fuzz;
+mod sort_preserving_repartition_fuzz;
mod window_fuzz;
diff --git a/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs b/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs
new file mode 100644
index 0000000000..d0ce58f6e6
--- /dev/null
+++ b/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs
@@ -0,0 +1,237 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#[cfg(test)]
+mod sp_repartition_fuzz_tests {
+ use arrow::compute::concat_batches;
+ use arrow_array::{ArrayRef, Int64Array, RecordBatch};
+ use arrow_schema::SortOptions;
+ use datafusion::physical_plan::memory::MemoryExec;
+ use datafusion::physical_plan::repartition::RepartitionExec;
+ use datafusion::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec;
+ use datafusion::physical_plan::{collect, ExecutionPlan, Partitioning};
+ use datafusion::prelude::SessionContext;
+ use datafusion_execution::config::SessionConfig;
+ use datafusion_physical_expr::expressions::col;
+ use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr};
+ use rand::rngs::StdRng;
+ use rand::{Rng, SeedableRng};
+ use std::sync::Arc;
+ use test_utils::add_empty_batches;
+
+ #[tokio::test(flavor = "multi_thread", worker_threads = 8)]
+ async fn sort_preserving_repartition_test() {
+ let seed_start = 0;
+ let seed_end = 100;
+ let n_row = 1000;
+ // Since ordering in the test (ORDER BY a,b,c)
+ // covers all the table (table consists of a,b,c columns).
+ // Result doesn't depend on the stable/unstable sort
+ // behaviour. We can choose, n_distinct as we like. However,
+ // we chose it a large number to decrease probability of having same rows in the table.
+ let n_distinct = 1_000_000;
+ for (is_first_roundrobin, is_first_sort_preserving) in
+ [(false, false), (false, true), (true, false), (true, true)]
+ {
+ for is_second_roundrobin in [false, true] {
+ let mut handles = Vec::new();
+
+ for seed in seed_start..seed_end {
+ let job = tokio::spawn(run_sort_preserving_repartition_test(
+ make_staggered_batches::<true>(n_row, n_distinct, seed as u64),
+ is_first_roundrobin,
+ is_first_sort_preserving,
+ is_second_roundrobin,
+ ));
+ handles.push(job);
+ }
+
+ for job in handles {
+ job.await.unwrap();
+ }
+ }
+ }
+ }
+
+ /// Check whether physical plan below
+ /// "SortPreservingMergeExec: [a@0 ASC,b@1 ASC,c@2 ASC]",
+ /// " SortPreservingRepartitionExec: partitioning=Hash([Column { name: \"c\", index: 2 }], 2), input_partitions=2", (Partitioning can be roundrobin also)
+ /// " SortPreservingRepartitionExec: partitioning=Hash([Column { name: \"c\", index: 2 }], 2), input_partitions=1", (Partitioning can be roundrobin also)
+ /// " MemoryExec: partitions=1, partition_sizes=[75]",
+ /// and / or
+ /// "SortPreservingMergeExec: [a@0 ASC,b@1 ASC,c@2 ASC]",
+ /// " SortPreservingRepartitionExec: partitioning=Hash([Column { name: \"c\", index: 2 }], 2), input_partitions=2", (Partitioning can be roundrobin also)
+ /// " RepartitionExec: partitioning=Hash([Column { name: \"c\", index: 2 }], 2), input_partitions=1", (Partitioning can be roundrobin also)
+ /// " MemoryExec: partitions=1, partition_sizes=[75]",
+ /// preserves ordering. Input fed to the plan above should be same with the output of the plan.
+ async fn run_sort_preserving_repartition_test(
+ input1: Vec<RecordBatch>,
+ // If `true`, first repartition executor after `MemoryExec` will be in `RoundRobin` mode
+ // else it will be in `Hash` mode
+ is_first_roundrobin: bool,
+ // If `true`, first repartition executor after `MemoryExec` will be `SortPreservingRepartitionExec`
+ // If `false`, first repartition executor after `MemoryExec` will be `RepartitionExec` (Since its input
+ // partition number is 1, `RepartitionExec` also preserves ordering.).
+ is_first_sort_preserving: bool,
+ // If `true`, second repartition executor after `MemoryExec` will be in `RoundRobin` mode
+ // else it will be in `Hash` mode
+ is_second_roundrobin: bool,
+ ) {
+ let schema = input1[0].schema();
+ let session_config = SessionConfig::new().with_batch_size(50);
+ let ctx = SessionContext::with_config(session_config);
+ let mut sort_keys = vec![];
+ for ordering_col in ["a", "b", "c"] {
+ sort_keys.push(PhysicalSortExpr {
+ expr: col(ordering_col, &schema).unwrap(),
+ options: SortOptions::default(),
+ })
+ }
+
+ let concat_input_record = concat_batches(&schema, &input1).unwrap();
+
+ let running_source = Arc::new(
+ MemoryExec::try_new(&[input1.clone()], schema.clone(), None)
+ .unwrap()
+ .with_sort_information(sort_keys.clone()),
+ );
+ let hash_exprs = vec![col("c", &schema).unwrap()];
+
+ let intermediate = match (is_first_roundrobin, is_first_sort_preserving) {
+ (true, true) => sort_preserving_repartition_exec_round_robin(running_source),
+ (true, false) => repartition_exec_round_robin(running_source),
+ (false, true) => {
+ sort_preserving_repartition_exec_hash(running_source, hash_exprs.clone())
+ }
+ (false, false) => repartition_exec_hash(running_source, hash_exprs.clone()),
+ };
+
+ let intermediate = if is_second_roundrobin {
+ sort_preserving_repartition_exec_round_robin(intermediate)
+ } else {
+ sort_preserving_repartition_exec_hash(intermediate, hash_exprs.clone())
+ };
+
+ let final_plan = sort_preserving_merge_exec(sort_keys.clone(), intermediate);
+ let task_ctx = ctx.task_ctx();
+
+ let collected_running = collect(final_plan, task_ctx.clone()).await.unwrap();
+ let concat_res = concat_batches(&schema, &collected_running).unwrap();
+ assert_eq!(concat_res, concat_input_record);
+ }
+
+ fn sort_preserving_repartition_exec_round_robin(
+ input: Arc<dyn ExecutionPlan>,
+ ) -> Arc<dyn ExecutionPlan> {
+ Arc::new(
+ RepartitionExec::try_new(input, Partitioning::RoundRobinBatch(2))
+ .unwrap()
+ .with_preserve_order(),
+ )
+ }
+
+ fn repartition_exec_round_robin(
+ input: Arc<dyn ExecutionPlan>,
+ ) -> Arc<dyn ExecutionPlan> {
+ Arc::new(
+ RepartitionExec::try_new(input, Partitioning::RoundRobinBatch(2)).unwrap(),
+ )
+ }
+
+ fn sort_preserving_repartition_exec_hash(
+ input: Arc<dyn ExecutionPlan>,
+ hash_expr: Vec<Arc<dyn PhysicalExpr>>,
+ ) -> Arc<dyn ExecutionPlan> {
+ Arc::new(
+ RepartitionExec::try_new(input, Partitioning::Hash(hash_expr, 2))
+ .unwrap()
+ .with_preserve_order(),
+ )
+ }
+
+ fn repartition_exec_hash(
+ input: Arc<dyn ExecutionPlan>,
+ hash_expr: Vec<Arc<dyn PhysicalExpr>>,
+ ) -> Arc<dyn ExecutionPlan> {
+ Arc::new(
+ RepartitionExec::try_new(input, Partitioning::Hash(hash_expr, 2)).unwrap(),
+ )
+ }
+
+ fn sort_preserving_merge_exec(
+ sort_exprs: impl IntoIterator<Item = PhysicalSortExpr>,
+ input: Arc<dyn ExecutionPlan>,
+ ) -> Arc<dyn ExecutionPlan> {
+ let sort_exprs = sort_exprs.into_iter().collect();
+ Arc::new(SortPreservingMergeExec::new(sort_exprs, input))
+ }
+
+ /// Return randomly sized record batches with:
+ /// three sorted int64 columns 'a', 'b', 'c' ranged from 0..'n_distinct' as columns
+ pub(crate) fn make_staggered_batches<const STREAM: bool>(
+ len: usize,
+ n_distinct: usize,
+ random_seed: u64,
+ ) -> Vec<RecordBatch> {
+ // use a random number generator to pick a random sized output
+ let mut rng = StdRng::seed_from_u64(random_seed);
+ let mut input123: Vec<(i64, i64, i64)> = vec![(0, 0, 0); len];
+ input123.iter_mut().for_each(|v| {
+ *v = (
+ rng.gen_range(0..n_distinct) as i64,
+ rng.gen_range(0..n_distinct) as i64,
+ rng.gen_range(0..n_distinct) as i64,
+ )
+ });
+ input123.sort();
+ let input1 =
+ Int64Array::from_iter_values(input123.clone().into_iter().map(|k| k.0));
+ let input2 =
+ Int64Array::from_iter_values(input123.clone().into_iter().map(|k| k.1));
+ let input3 =
+ Int64Array::from_iter_values(input123.clone().into_iter().map(|k| k.2));
+
+ // split into several record batches
+ let mut remainder = RecordBatch::try_from_iter(vec![
+ ("a", Arc::new(input1) as ArrayRef),
+ ("b", Arc::new(input2) as ArrayRef),
+ ("c", Arc::new(input3) as ArrayRef),
+ ])
+ .unwrap();
+
+ let mut batches = vec![];
+ if STREAM {
+ while remainder.num_rows() > 0 {
+ let batch_size = rng.gen_range(0..50);
+ if remainder.num_rows() < batch_size {
+ break;
+ }
+ batches.push(remainder.slice(0, batch_size));
+ remainder =
+ remainder.slice(batch_size, remainder.num_rows() - batch_size);
+ }
+ } else {
+ while remainder.num_rows() > 0 {
+ let batch_size = rng.gen_range(0..remainder.num_rows() + 1);
+ batches.push(remainder.slice(0, batch_size));
+ remainder =
+ remainder.slice(batch_size, remainder.num_rows() - batch_size);
+ }
+ }
+ add_empty_batches(batches, &mut rng)
+ }
+}