You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@doris.apache.org by yi...@apache.org on 2022/07/29 08:55:29 UTC

[doris] branch master updated: [improvement]Use phmap::flat_hash_set in AggregateFunctionUniq (#11257)

This is an automated email from the ASF dual-hosted git repository.

yiguolei pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new a7199fb98e [improvement]Use phmap::flat_hash_set in AggregateFunctionUniq (#11257)
a7199fb98e is described below

commit a7199fb98e18b925664b38460b667d04cbee8e01
Author: Jerry Hu <mr...@gmail.com>
AuthorDate: Fri Jul 29 16:55:22 2022 +0800

    [improvement]Use phmap::flat_hash_set in AggregateFunctionUniq (#11257)
---
 .../vec/aggregate_functions/aggregate_function.h   |  15 +++
 .../aggregate_function_nothing.h                   |   3 +
 .../aggregate_functions/aggregate_function_null.h  |  12 ++
 .../aggregate_functions/aggregate_function_uniq.h  | 137 ++++++++++++++++-----
 be/src/vec/exec/vaggregation_node.cpp              |  12 +-
 5 files changed, 136 insertions(+), 43 deletions(-)

diff --git a/be/src/vec/aggregate_functions/aggregate_function.h b/be/src/vec/aggregate_functions/aggregate_function.h
index 677c189002..c7c7fc38ca 100644
--- a/be/src/vec/aggregate_functions/aggregate_function.h
+++ b/be/src/vec/aggregate_functions/aggregate_function.h
@@ -107,6 +107,10 @@ public:
     virtual void deserialize_vec(AggregateDataPtr places, ColumnString* column, Arena* arena,
                                  size_t num_rows) const = 0;
 
+    /// Deserializes state and merge it with current aggregation function.
+    virtual void deserialize_and_merge(AggregateDataPtr __restrict place, BufferReadable& buf,
+                                       Arena* arena) const = 0;
+
     /// Returns true if a function requires Arena to handle own states (see add(), merge(), deserialize()).
     virtual bool allocates_memory_in_arena() const { return false; }
 
@@ -253,6 +257,17 @@ public:
     size_t align_of_data() const override { return alignof(Data); }
 
     void reset(AggregateDataPtr place) const override {}
+
+    void deserialize_and_merge(AggregateDataPtr __restrict place, BufferReadable& buf,
+                               Arena* arena) const override {
+        Data deserialized_data;
+        AggregateDataPtr deserialized_place = (AggregateDataPtr)&deserialized_data;
+
+        auto derived = static_cast<const Derived*>(this);
+        derived->create(deserialized_place);
+        derived->deserialize(deserialized_place, buf, arena);
+        derived->merge(place, deserialized_place, arena);
+    }
 };
 
 using AggregateFunctionPtr = std::shared_ptr<IAggregateFunction>;
diff --git a/be/src/vec/aggregate_functions/aggregate_function_nothing.h b/be/src/vec/aggregate_functions/aggregate_function_nothing.h
index c0ae740be4..64af14a6cf 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_nothing.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_nothing.h
@@ -64,6 +64,9 @@ public:
     void insert_result_into(ConstAggregateDataPtr, IColumn& to) const override {
         to.insert_default();
     }
+
+    void deserialize_and_merge(AggregateDataPtr __restrict place, BufferReadable& buf,
+                               Arena* arena) const override {}
 };
 
 } // namespace doris::vectorized
diff --git a/be/src/vec/aggregate_functions/aggregate_function_null.h b/be/src/vec/aggregate_functions/aggregate_function_null.h
index 5b804b82a7..89960bc9f0 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_null.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_null.h
@@ -151,6 +151,18 @@ public:
         }
     }
 
+    void deserialize_and_merge(AggregateDataPtr __restrict place, BufferReadable& buf,
+                               Arena* arena) const override {
+        bool flag = true;
+        if (result_is_nullable) {
+            read_binary(flag, buf);
+        }
+        if (flag) {
+            set_flag(place);
+            nested_function->deserialize_and_merge(nested_place(place), buf, arena);
+        }
+    }
+
     void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
         if constexpr (result_is_nullable) {
             ColumnNullable& to_concrete = assert_cast<ColumnNullable&>(to);
diff --git a/be/src/vec/aggregate_functions/aggregate_function_uniq.h b/be/src/vec/aggregate_functions/aggregate_function_uniq.h
index c717307c72..988e9bdb01 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_uniq.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_uniq.h
@@ -20,6 +20,8 @@
 
 #pragma once
 
+#include <parallel_hashmap/phmap.h>
+
 #include <type_traits>
 
 #include "gutil/hash/city.h"
@@ -34,29 +36,26 @@
 
 namespace doris::vectorized {
 
+// Here is an empirical value.
+static constexpr size_t HASH_MAP_PREFETCH_DIST = 16;
+
 /// uniqExact
 
 template <typename T>
 struct AggregateFunctionUniqExactData {
-    using Key = T;
-
-    /// When creating, the hash table must be small.
-    using Set = HashSet<Key, HashCRC32<Key>, HashTableGrower<4>,
-                        HashTableAllocatorWithStackMemory<sizeof(Key) * (1 << 4)>>;
-
-    Set set;
-
-    static String get_name() { return "uniqExact"; }
-};
-
-/// For rows, we put the SipHash values (128 bits) into the hash table.
-template <>
-struct AggregateFunctionUniqExactData<String> {
-    using Key = UInt128;
-
-    /// When creating, the hash table must be small.
-    using Set = HashSet<Key, UInt128TrivialHash, HashTableGrower<3>,
-                        HashTableAllocatorWithStackMemory<sizeof(Key) * (1 << 3)>>;
+    static constexpr bool is_string_key = std::is_same_v<T, String>;
+    using Key = std::conditional_t<is_string_key, UInt128, T>;
+    using Hash = std::conditional_t<is_string_key, UInt128TrivialHash, HashCRC32<Key>>;
+
+    using Set = phmap::flat_hash_set<Key, Hash>;
+
+    static UInt128 ALWAYS_INLINE get_key(const StringRef& value) {
+        UInt128 key;
+        SipHash hash;
+        hash.update(value.data, value.size);
+        hash.get128(key.low, key.high);
+        return key;
+    }
 
     Set set;
 
@@ -73,16 +72,9 @@ struct OneAdder {
     static void ALWAYS_INLINE add(Data& data, const IColumn& column, size_t row_num) {
         if constexpr (std::is_same_v<T, String>) {
             StringRef value = column.get_data_at(row_num);
-
-            UInt128 key;
-            SipHash hash;
-            hash.update(value.data, value.size);
-            hash.get128(key.low, key.high);
-
-            data.set.insert(key);
-        } else if constexpr (std::is_same_v<T, Decimal128>) {
-            data.set.insert(
-                    assert_cast<const ColumnDecimal<Decimal128>&>(column).get_data()[row_num]);
+            data.set.insert(Data::get_key(value));
+        } else if constexpr (IsDecimalNumber<T>) {
+            data.set.insert(assert_cast<const ColumnDecimal<T>&>(column).get_data()[row_num]);
         } else {
             data.set.insert(assert_cast<const ColumnVector<T>&>(column).get_data()[row_num]);
         }
@@ -96,6 +88,7 @@ template <typename T, typename Data>
 class AggregateFunctionUniq final
         : public IAggregateFunctionDataHelper<Data, AggregateFunctionUniq<T, Data>> {
 public:
+    using KeyType = std::conditional_t<std::is_same_v<T, String>, UInt128, T>;
     AggregateFunctionUniq(const DataTypes& argument_types_)
             : IAggregateFunctionDataHelper<Data, AggregateFunctionUniq<T, Data>>(argument_types_,
                                                                                  {}) {}
@@ -109,18 +102,96 @@ public:
         detail::OneAdder<T, Data>::add(this->data(place), *columns[0], row_num);
     }
 
+    static ALWAYS_INLINE const KeyType* get_keys(std::vector<KeyType>& keys_container,
+                                                 const IColumn& column, size_t batch_size) {
+        if constexpr (std::is_same_v<T, String>) {
+            keys_container.resize(batch_size);
+            for (size_t i = 0; i != batch_size; ++i) {
+                StringRef value = column.get_data_at(i);
+                keys_container[i] = Data::get_key(value);
+            }
+            return keys_container.data();
+        } else {
+            using ColumnType =
+                    std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<T>, ColumnVector<T>>;
+            return assert_cast<const ColumnType&>(column).get_data().data();
+        }
+    }
+
+    void add_batch(size_t batch_size, AggregateDataPtr* places, size_t place_offset,
+                   const IColumn** columns, Arena* arena) const override {
+        std::vector<KeyType> keys_container;
+        const KeyType* keys = get_keys(keys_container, *columns[0], batch_size);
+
+        std::vector<typename Data::Set*> array_of_data_set(batch_size);
+
+        for (size_t i = 0; i != batch_size; ++i) {
+            array_of_data_set[i] = &(this->data(places[i] + place_offset).set);
+        }
+
+        for (size_t i = 0; i != batch_size; ++i) {
+            if (i + HASH_MAP_PREFETCH_DIST < batch_size) {
+                array_of_data_set[i + HASH_MAP_PREFETCH_DIST]->prefetch(
+                        keys[i + HASH_MAP_PREFETCH_DIST]);
+            }
+
+            array_of_data_set[i]->insert(keys[i]);
+        }
+    }
+
     void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
                Arena*) const override {
-        this->data(place).set.merge(this->data(rhs).set);
+        auto& rhs_set = this->data(rhs).set;
+        if (rhs_set.size() == 0) return;
+
+        auto& set = this->data(place).set;
+        set.rehash(set.size() + rhs_set.size());
+
+        for (auto elem : rhs_set) {
+            set.insert(elem);
+        }
+    }
+
+    void add_batch_single_place(size_t batch_size, AggregateDataPtr place, const IColumn** columns,
+                                Arena* arena) const override {
+        std::vector<KeyType> keys_container;
+        const KeyType* keys = get_keys(keys_container, *columns[0], batch_size);
+        auto& set = this->data(place).set;
+
+        for (size_t i = 0; i != batch_size; ++i) {
+            if (i + HASH_MAP_PREFETCH_DIST < batch_size) {
+                set.prefetch(keys[i + HASH_MAP_PREFETCH_DIST]);
+            }
+            set.insert(keys[i]);
+        }
     }
 
     void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override {
-        this->data(place).set.write(buf);
+        auto& set = this->data(place).set;
+        write_var_uint(set.size(), buf);
+        for (const auto& elem : set) {
+            write_pod_binary(elem, buf);
+        }
+    }
+
+    void deserialize_and_merge(AggregateDataPtr __restrict place, BufferReadable& buf,
+                               Arena* arena) const override {
+        auto& set = this->data(place).set;
+        size_t size;
+        read_var_uint(size, buf);
+
+        set.rehash(size + set.size());
+
+        for (size_t i = 0; i < size; ++i) {
+            KeyType ref;
+            read_pod_binary(ref, buf);
+            set.insert(ref);
+        }
     }
 
     void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
-                     Arena*) const override {
-        this->data(place).set.read(buf);
+                     Arena* arena) const override {
+        deserialize_and_merge(place, buf, arena);
     }
 
     void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
diff --git a/be/src/vec/exec/vaggregation_node.cpp b/be/src/vec/exec/vaggregation_node.cpp
index d4325bc17e..54f6b8d15a 100644
--- a/be/src/vec/exec/vaggregation_node.cpp
+++ b/be/src/vec/exec/vaggregation_node.cpp
@@ -618,18 +618,10 @@ Status AggregationNode::_merge_without_key(Block* block) {
 
             for (int j = 0; j < rows; ++j) {
                 VectorBufferReader buffer_reader(((ColumnString*)(column.get()))->get_data_at(j));
-                _create_agg_status(deserialize_buffer.get());
 
-                _aggregate_evaluators[i]->function()->deserialize(
-                        deserialize_buffer.get() + _offsets_of_aggregate_states[i], buffer_reader,
+                _aggregate_evaluators[i]->function()->deserialize_and_merge(
+                        _agg_data.without_key + _offsets_of_aggregate_states[i], buffer_reader,
                         &_agg_arena_pool);
-
-                _aggregate_evaluators[i]->function()->merge(
-                        _agg_data.without_key + _offsets_of_aggregate_states[i],
-                        deserialize_buffer.get() + _offsets_of_aggregate_states[i],
-                        &_agg_arena_pool);
-
-                _destroy_agg_status(deserialize_buffer.get());
             }
         } else {
             _aggregate_evaluators[i]->execute_single_add(


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org