You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ag...@apache.org on 2022/10/07 16:42:56 UTC

[arrow-datafusion] branch master updated: Always track the final size of the in-mem sorted arrays (#3753)

This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 0ab0e0f64 Always track the final size of the in-mem sorted arrays (#3753)
0ab0e0f64 is described below

commit 0ab0e0f6489e6cf560cba5a0530a19c4516c15b1
Author: Batuhan Taskaya <is...@gmail.com>
AuthorDate: Fri Oct 7 19:42:49 2022 +0300

    Always track the final size of the in-mem sorted arrays (#3753)
---
 datafusion/core/src/execution/memory_manager.rs |  7 +++++
 datafusion/core/src/physical_plan/sorts/sort.rs | 37 +++++++++++++---------
 datafusion/core/tests/sql/decimal.rs            | 42 +++++++++++++++++++++++++
 3 files changed, 71 insertions(+), 15 deletions(-)

diff --git a/datafusion/core/src/execution/memory_manager.rs b/datafusion/core/src/execution/memory_manager.rs
index f148e331e..48d4ca3c3 100644
--- a/datafusion/core/src/execution/memory_manager.rs
+++ b/datafusion/core/src/execution/memory_manager.rs
@@ -195,6 +195,13 @@ pub trait MemoryConsumer: Send + Sync {
         Ok(())
     }
 
+    /// Grow without spilling to the disk. It grows the memory directly
+    /// so it should be only used when the consumer already allocated the
+    /// memory and it is safe to grow without spilling.
+    fn grow(&self, required: usize) {
+        self.memory_manager().record_free_then_acquire(0, required);
+    }
+
     /// Return `freed` memory to the memory manager,
     /// may wake up other requesters waiting for their minimum memory quota.
     fn shrink(&self, freed: usize) {
diff --git a/datafusion/core/src/physical_plan/sorts/sort.rs b/datafusion/core/src/physical_plan/sorts/sort.rs
index 4586106bb..763c7c553 100644
--- a/datafusion/core/src/physical_plan/sorts/sort.rs
+++ b/datafusion/core/src/physical_plan/sorts/sort.rs
@@ -50,7 +50,7 @@ use futures::lock::Mutex;
 use futures::{Stream, StreamExt, TryFutureExt, TryStreamExt};
 use log::{debug, error};
 use std::any::Any;
-use std::cmp::min;
+use std::cmp::{min, Ordering};
 use std::fmt;
 use std::fmt::{Debug, Formatter};
 use std::fs::File;
@@ -124,20 +124,27 @@ impl ExternalSorter {
             // calls to `timer.done()` below.
             let _timer = tracking_metrics.elapsed_compute().timer();
             let partial = sort_batch(input, self.schema.clone(), &self.expr, self.fetch)?;
-            // The resulting batch might be smaller than the input batch if there
-            // is an propagated limit.
-
-            if self.fetch.is_some() {
-                let new_size = batch_byte_size(&partial.sorted_batch);
-                let size_delta = size.checked_sub(new_size).ok_or_else(|| {
-                    DataFusionError::Internal(format!(
-                        "The size of the sorted batch is larger than the size of the input batch: {} > {}",
-                        new_size,
-                        size
-                    ))
-                })?;
-                self.shrink(size_delta);
-                self.metrics.mem_used().sub(size_delta);
+
+            // The resulting batch might be smaller (or larger, see #3747) than the input
+            // batch due to either a propagated limit or the re-construction of arrays. So
+            // for being reliable, we need to reflect the memory usage of the partial batch.
+            let new_size = batch_byte_size(&partial.sorted_batch);
+            match new_size.cmp(&size) {
+                Ordering::Greater => {
+                    // We don't have to call try_grow here, since we have already used the
+                    // memory (so spilling right here wouldn't help at all for the current
+                    // operation). But we still have to record it so that other requesters
+                    // would know about this unexpected increase in memory consuption.
+                    let new_size_delta = new_size - size;
+                    self.grow(new_size_delta);
+                    self.metrics.mem_used().add(new_size_delta);
+                }
+                Ordering::Less => {
+                    let size_delta = size - new_size;
+                    self.shrink(size_delta);
+                    self.metrics.mem_used().sub(size_delta);
+                }
+                Ordering::Equal => {}
             }
             in_mem_batches.push(partial);
         }
diff --git a/datafusion/core/tests/sql/decimal.rs b/datafusion/core/tests/sql/decimal.rs
index 9d32f1c31..2e3e3d2ab 100644
--- a/datafusion/core/tests/sql/decimal.rs
+++ b/datafusion/core/tests/sql/decimal.rs
@@ -690,6 +690,48 @@ async fn decimal_sort() -> Result<()> {
     ];
     assert_batches_eq!(expected, &actual);
 
+    let sql = "select * from decimal_simple where c1 >= 0.00004 order by c1 limit 10";
+    let actual = execute_to_batches(&ctx, sql).await;
+    assert_eq!(
+        &DataType::Decimal128(10, 6),
+        actual[0].schema().field(0).data_type()
+    );
+    let expected = vec![
+        "+----------+----------------+-----+-------+-----------+",
+        "| c1       | c2             | c3  | c4    | c5        |",
+        "+----------+----------------+-----+-------+-----------+",
+        "| 0.000040 | 0.000000000004 | 5   | true  | 0.0000440 |",
+        "| 0.000040 | 0.000000000004 | 12  | false | 0.0000400 |",
+        "| 0.000040 | 0.000000000004 | 14  | true  | 0.0000400 |",
+        "| 0.000040 | 0.000000000004 | 8   | false | 0.0000440 |",
+        "| 0.000050 | 0.000000000005 | 9   | true  | 0.0000520 |",
+        "| 0.000050 | 0.000000000005 | 4   | true  | 0.0000780 |",
+        "| 0.000050 | 0.000000000005 | 8   | false | 0.0000330 |",
+        "| 0.000050 | 0.000000000005 | 100 | true  | 0.0000680 |",
+        "| 0.000050 | 0.000000000005 | 1   | false | 0.0001000 |",
+        "+----------+----------------+-----+-------+-----------+",
+    ];
+    assert_batches_eq!(expected, &actual);
+
+    let sql = "select * from decimal_simple where c1 >= 0.00004 order by c1 limit 5";
+    let actual = execute_to_batches(&ctx, sql).await;
+    assert_eq!(
+        &DataType::Decimal128(10, 6),
+        actual[0].schema().field(0).data_type()
+    );
+    let expected = vec![
+        "+----------+----------------+----+-------+-----------+",
+        "| c1       | c2             | c3 | c4    | c5        |",
+        "+----------+----------------+----+-------+-----------+",
+        "| 0.000040 | 0.000000000004 | 5  | true  | 0.0000440 |",
+        "| 0.000040 | 0.000000000004 | 12 | false | 0.0000400 |",
+        "| 0.000040 | 0.000000000004 | 14 | true  | 0.0000400 |",
+        "| 0.000040 | 0.000000000004 | 8  | false | 0.0000440 |",
+        "| 0.000050 | 0.000000000005 | 9  | true  | 0.0000520 |",
+        "+----------+----------------+----+-------+-----------+",
+    ];
+    assert_batches_eq!(expected, &actual);
+
     let sql = "select * from decimal_simple where c1 >= 0.00004 order by c1 desc";
     let actual = execute_to_batches(&ctx, sql).await;
     assert_eq!(