You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@doris.apache.org by GitBox <gi...@apache.org> on 2022/06/20 07:39:39 UTC

[GitHub] [doris] cambyzju commented on a diff in pull request #10233: [feature-wip](array-type) add function arrays_overlap

cambyzju commented on code in PR #10233:
URL: https://github.com/apache/doris/pull/10233#discussion_r901347390


##########
be/src/vec/functions/array/function_arrays_overlap.h:
##########
@@ -0,0 +1,245 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+#pragma once
+
+#include <string_view>
+
+#include "vec/columns/column_array.h"
+#include "vec/columns/column_string.h"
+#include "vec/common/hash_table/hash_set.h"
+#include "vec/common/string_ref.h"
+#include "vec/data_types/data_type_array.h"
+#include "vec/data_types/data_type_number.h"
+#include "vec/functions/array/function_array_utils.h"
+#include "vec/functions/function.h"
+
+namespace doris::vectorized {
+
+template <typename T>
+struct OverlapSetImpl {
+    using ElementNativeType = typename NativeType<typename T::value_type>::Type;
+    using Set = HashSetWithStackMemory<ElementNativeType, DefaultHash<ElementNativeType>, 4>;
+    Set set;
+    void insert_array(const IColumn* column, size_t start, size_t size) {
+        const auto& vec = assert_cast<const T&>(*column).get_data();
+        for (size_t i = start; i < start + size; ++i) {
+            set.insert(vec[i]);
+        }
+    }
+    bool find_any(const IColumn* column, size_t start, size_t size) {
+        const auto& vec = assert_cast<const T&>(*column).get_data();
+        for (size_t i = start; i < start + size; ++i) {
+            if (set.find(vec[i])) {
+                return true;
+            }
+        }
+        return false;
+    }
+};
+
+template <>
+struct OverlapSetImpl<ColumnString> {
+    using Set = HashSetWithStackMemory<StringRef, DefaultHash<StringRef>, 4>;
+    Set set;
+    void insert_array(const IColumn* column, size_t start, size_t size) {
+        for (size_t i = start; i < start + size; ++i) {
+            set.insert(column->get_data_at(i));
+        }
+    }
+    bool find_any(const IColumn* column, size_t start, size_t size) {
+        for (size_t i = start; i < start + size; ++i) {
+            if (set.find(column->get_data_at(i))) {
+                return true;
+            }
+        }
+        return false;
+    }
+};
+
+class FunctionArraysOverlap : public IFunction {
+public:
+    static constexpr auto name = "arrays_overlap";
+    static FunctionPtr create() { return std::make_shared<FunctionArraysOverlap>(); }
+
+    /// Get function name.
+    String get_name() const override { return name; }
+
+    bool use_default_implementation_for_nulls() const override { return false; }
+
+    bool is_variadic() const override { return false; }
+
+    size_t get_number_of_arguments() const override { return 2; }
+
+    DataTypePtr get_return_type_impl(const DataTypes& arguments) const override {
+        DCHECK(is_array(remove_nullable(arguments[0]))) << arguments[0]->get_name();
+        DCHECK(is_array(remove_nullable(arguments[1]))) << arguments[0]->get_name();
+        return make_nullable(std::make_shared<DataTypeUInt8>());
+    }
+
+    Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments,
+                        size_t result, size_t input_rows_count) override {
+        auto left_column =
+                block.get_by_position(arguments[0]).column->convert_to_full_column_if_const();
+        auto right_column =
+                block.get_by_position(arguments[1]).column->convert_to_full_column_if_const();
+        ColumnArrayExecutionData left_exec_data;
+        ColumnArrayExecutionData right_exec_data;
+
+        Status ret = Status::RuntimeError(
+                fmt::format("execute failed, unsupported types for function {}({}, {})", get_name(),
+                            block.get_by_position(arguments[0]).type->get_name(),
+                            block.get_by_position(arguments[1]).type->get_name()));
+
+        // extract array column
+        if (!extract_column_array_info(*left_column, left_exec_data) ||
+            !extract_column_array_info(*right_column, right_exec_data)) {
+            return ret;
+        }
+
+        // data type compare
+        auto left_data_type = remove_nullable(block.get_by_position(arguments[0]).type);
+        auto right_data_type = remove_nullable(block.get_by_position(arguments[1]).type);
+        if (!left_data_type->equals(*right_data_type)) {
+            return ret;
+        }
+
+        // prepare return column
+        auto dst_nested_col = ColumnVector<UInt8>::create(input_rows_count, 0);
+        auto dst_null_map = ColumnVector<UInt8>::create(input_rows_count, 0);
+        UInt8* dst_null_map_data = dst_null_map->get_data().data();
+
+        // any array is null or any elements in array is null, return null
+        RETURN_IF_ERROR(_execute_nullable(left_exec_data, dst_null_map_data));
+        RETURN_IF_ERROR(_execute_nullable(right_exec_data, dst_null_map_data));
+
+        // execute overlap check
+        if (left_exec_data.nested_col->is_column_string()) {
+            ret = _execute_internal<ColumnString>(left_exec_data, right_exec_data,
+                                                  dst_null_map_data,
+                                                  dst_nested_col->get_data().data());
+        } else if (left_exec_data.nested_col->is_date_type()) {
+            ret = _execute_internal<ColumnDate>(left_exec_data, right_exec_data, dst_null_map_data,
+                                                dst_nested_col->get_data().data());
+        } else if (left_exec_data.nested_col->is_datetime_type()) {
+            ret = _execute_internal<ColumnDateTime>(left_exec_data, right_exec_data,
+                                                    dst_null_map_data,
+                                                    dst_nested_col->get_data().data());
+        } else if (left_exec_data.nested_col->is_numeric()) {
+            if (check_column<ColumnUInt8>(*left_exec_data.nested_col)) {
+                ret = _execute_internal<ColumnUInt8>(left_exec_data, right_exec_data,
+                                                     dst_null_map_data,
+                                                     dst_nested_col->get_data().data());
+            } else if (check_column<ColumnInt8>(*left_exec_data.nested_col)) {
+                ret = _execute_internal<ColumnInt8>(left_exec_data, right_exec_data,
+                                                    dst_null_map_data,
+                                                    dst_nested_col->get_data().data());
+            } else if (check_column<ColumnInt16>(*left_exec_data.nested_col)) {
+                ret = _execute_internal<ColumnInt16>(left_exec_data, right_exec_data,
+                                                     dst_null_map_data,
+                                                     dst_nested_col->get_data().data());
+            } else if (check_column<ColumnInt32>(*left_exec_data.nested_col)) {
+                ret = _execute_internal<ColumnInt32>(left_exec_data, right_exec_data,
+                                                     dst_null_map_data,
+                                                     dst_nested_col->get_data().data());
+            } else if (check_column<ColumnInt64>(*left_exec_data.nested_col)) {
+                ret = _execute_internal<ColumnInt64>(left_exec_data, right_exec_data,
+                                                     dst_null_map_data,
+                                                     dst_nested_col->get_data().data());
+            } else if (check_column<ColumnInt128>(*left_exec_data.nested_col)) {
+                ret = _execute_internal<ColumnInt128>(left_exec_data, right_exec_data,
+                                                      dst_null_map_data,
+                                                      dst_nested_col->get_data().data());
+            } else if (check_column<ColumnFloat32>(*left_exec_data.nested_col)) {
+                ret = _execute_internal<ColumnFloat32>(left_exec_data, right_exec_data,
+                                                       dst_null_map_data,
+                                                       dst_nested_col->get_data().data());
+            } else if (check_column<ColumnFloat64>(*left_exec_data.nested_col)) {
+                ret = _execute_internal<ColumnFloat64>(left_exec_data, right_exec_data,
+                                                       dst_null_map_data,
+                                                       dst_nested_col->get_data().data());
+            }
+        } else if (left_exec_data.nested_col->is_column_decimal()) {
+            if (check_column<ColumnDecimal128>(*left_exec_data.nested_col)) {
+                ret = _execute_internal<ColumnDecimal128>(left_exec_data, right_exec_data,
+                                                          dst_null_map_data,
+                                                          dst_nested_col->get_data().data());
+            }
+        }
+
+        if (ret == Status::OK()) {
+            block.replace_by_position(result, ColumnNullable::create(std::move(dst_nested_col),
+                                                                     std::move(dst_null_map)));
+        }
+
+        return ret;
+    }
+
+private:
+    Status _execute_nullable(const ColumnArrayExecutionData& data, UInt8* dst_nullmap_data) {
+        for (ssize_t row = 0; row < data.offsets_ptr->size(); ++row) {
+            if (dst_nullmap_data[row]) {
+                continue;
+            }
+
+            if (data.array_nullmap_data && data.array_nullmap_data[row]) {
+                dst_nullmap_data[row] = 1;
+                continue;
+            }
+
+            // any element inside array is NULL, return NULL
+            ssize_t start = (*data.offsets_ptr)[row - 1];
+            ssize_t size = (*data.offsets_ptr)[row] - start;
+            for (ssize_t i = start; i < start + size; ++i) {
+                if (data.nested_nullmap_data && data.nested_nullmap_data[i]) {
+                    dst_nullmap_data[row] = 1;
+                    break;
+                }
+            }
+        }
+        return Status::OK();
+    }
+
+    template <typename T>
+    Status _execute_internal(const ColumnArrayExecutionData& left_data,
+                             const ColumnArrayExecutionData& right_data,
+                             const UInt8* dst_nullmap_data, UInt8* dst_data) {
+        using ExecutorImpl = OverlapSetImpl<T>;
+        for (ssize_t row = 0; row < left_data.offsets_ptr->size(); ++row) {
+            if (dst_nullmap_data[row]) {
+                continue;
+            }
+
+            ssize_t left_start = (*left_data.offsets_ptr)[row - 1];
+            ssize_t left_size = (*left_data.offsets_ptr)[row] - left_start;
+            ssize_t right_start = (*right_data.offsets_ptr)[row - 1];
+            ssize_t right_size = (*right_data.offsets_ptr)[row] - right_start;
+            if (left_size == 0 || right_size == 0) {
+                dst_data[row] = 0;

Review Comment:
   done



##########
be/src/vec/functions/array/function_arrays_overlap.h:
##########
@@ -0,0 +1,245 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+#pragma once
+
+#include <string_view>
+
+#include "vec/columns/column_array.h"
+#include "vec/columns/column_string.h"
+#include "vec/common/hash_table/hash_set.h"
+#include "vec/common/string_ref.h"
+#include "vec/data_types/data_type_array.h"
+#include "vec/data_types/data_type_number.h"
+#include "vec/functions/array/function_array_utils.h"
+#include "vec/functions/function.h"
+
+namespace doris::vectorized {
+
+template <typename T>
+struct OverlapSetImpl {
+    using ElementNativeType = typename NativeType<typename T::value_type>::Type;
+    using Set = HashSetWithStackMemory<ElementNativeType, DefaultHash<ElementNativeType>, 4>;
+    Set set;
+    void insert_array(const IColumn* column, size_t start, size_t size) {
+        const auto& vec = assert_cast<const T&>(*column).get_data();
+        for (size_t i = start; i < start + size; ++i) {
+            set.insert(vec[i]);
+        }
+    }
+    bool find_any(const IColumn* column, size_t start, size_t size) {
+        const auto& vec = assert_cast<const T&>(*column).get_data();
+        for (size_t i = start; i < start + size; ++i) {
+            if (set.find(vec[i])) {
+                return true;
+            }
+        }
+        return false;
+    }
+};
+
+template <>
+struct OverlapSetImpl<ColumnString> {
+    using Set = HashSetWithStackMemory<StringRef, DefaultHash<StringRef>, 4>;
+    Set set;
+    void insert_array(const IColumn* column, size_t start, size_t size) {
+        for (size_t i = start; i < start + size; ++i) {
+            set.insert(column->get_data_at(i));
+        }
+    }
+    bool find_any(const IColumn* column, size_t start, size_t size) {
+        for (size_t i = start; i < start + size; ++i) {
+            if (set.find(column->get_data_at(i))) {
+                return true;
+            }
+        }
+        return false;
+    }
+};
+
+class FunctionArraysOverlap : public IFunction {
+public:
+    static constexpr auto name = "arrays_overlap";
+    static FunctionPtr create() { return std::make_shared<FunctionArraysOverlap>(); }
+
+    /// Get function name.
+    String get_name() const override { return name; }
+
+    bool use_default_implementation_for_nulls() const override { return false; }
+
+    bool is_variadic() const override { return false; }
+
+    size_t get_number_of_arguments() const override { return 2; }
+
+    DataTypePtr get_return_type_impl(const DataTypes& arguments) const override {
+        DCHECK(is_array(remove_nullable(arguments[0]))) << arguments[0]->get_name();
+        DCHECK(is_array(remove_nullable(arguments[1]))) << arguments[0]->get_name();
+        return make_nullable(std::make_shared<DataTypeUInt8>());
+    }
+
+    Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments,
+                        size_t result, size_t input_rows_count) override {
+        auto left_column =
+                block.get_by_position(arguments[0]).column->convert_to_full_column_if_const();
+        auto right_column =
+                block.get_by_position(arguments[1]).column->convert_to_full_column_if_const();
+        ColumnArrayExecutionData left_exec_data;
+        ColumnArrayExecutionData right_exec_data;
+
+        Status ret = Status::RuntimeError(
+                fmt::format("execute failed, unsupported types for function {}({}, {})", get_name(),
+                            block.get_by_position(arguments[0]).type->get_name(),
+                            block.get_by_position(arguments[1]).type->get_name()));
+
+        // extract array column
+        if (!extract_column_array_info(*left_column, left_exec_data) ||
+            !extract_column_array_info(*right_column, right_exec_data)) {
+            return ret;
+        }
+
+        // data type compare
+        auto left_data_type = remove_nullable(block.get_by_position(arguments[0]).type);
+        auto right_data_type = remove_nullable(block.get_by_position(arguments[1]).type);
+        if (!left_data_type->equals(*right_data_type)) {
+            return ret;
+        }
+
+        // prepare return column
+        auto dst_nested_col = ColumnVector<UInt8>::create(input_rows_count, 0);
+        auto dst_null_map = ColumnVector<UInt8>::create(input_rows_count, 0);
+        UInt8* dst_null_map_data = dst_null_map->get_data().data();
+
+        // any array is null or any elements in array is null, return null
+        RETURN_IF_ERROR(_execute_nullable(left_exec_data, dst_null_map_data));
+        RETURN_IF_ERROR(_execute_nullable(right_exec_data, dst_null_map_data));
+
+        // execute overlap check
+        if (left_exec_data.nested_col->is_column_string()) {
+            ret = _execute_internal<ColumnString>(left_exec_data, right_exec_data,
+                                                  dst_null_map_data,
+                                                  dst_nested_col->get_data().data());
+        } else if (left_exec_data.nested_col->is_date_type()) {
+            ret = _execute_internal<ColumnDate>(left_exec_data, right_exec_data, dst_null_map_data,
+                                                dst_nested_col->get_data().data());
+        } else if (left_exec_data.nested_col->is_datetime_type()) {
+            ret = _execute_internal<ColumnDateTime>(left_exec_data, right_exec_data,
+                                                    dst_null_map_data,
+                                                    dst_nested_col->get_data().data());
+        } else if (left_exec_data.nested_col->is_numeric()) {
+            if (check_column<ColumnUInt8>(*left_exec_data.nested_col)) {
+                ret = _execute_internal<ColumnUInt8>(left_exec_data, right_exec_data,
+                                                     dst_null_map_data,
+                                                     dst_nested_col->get_data().data());
+            } else if (check_column<ColumnInt8>(*left_exec_data.nested_col)) {
+                ret = _execute_internal<ColumnInt8>(left_exec_data, right_exec_data,
+                                                    dst_null_map_data,
+                                                    dst_nested_col->get_data().data());
+            } else if (check_column<ColumnInt16>(*left_exec_data.nested_col)) {
+                ret = _execute_internal<ColumnInt16>(left_exec_data, right_exec_data,
+                                                     dst_null_map_data,
+                                                     dst_nested_col->get_data().data());
+            } else if (check_column<ColumnInt32>(*left_exec_data.nested_col)) {
+                ret = _execute_internal<ColumnInt32>(left_exec_data, right_exec_data,
+                                                     dst_null_map_data,
+                                                     dst_nested_col->get_data().data());
+            } else if (check_column<ColumnInt64>(*left_exec_data.nested_col)) {
+                ret = _execute_internal<ColumnInt64>(left_exec_data, right_exec_data,
+                                                     dst_null_map_data,
+                                                     dst_nested_col->get_data().data());
+            } else if (check_column<ColumnInt128>(*left_exec_data.nested_col)) {
+                ret = _execute_internal<ColumnInt128>(left_exec_data, right_exec_data,
+                                                      dst_null_map_data,
+                                                      dst_nested_col->get_data().data());
+            } else if (check_column<ColumnFloat32>(*left_exec_data.nested_col)) {
+                ret = _execute_internal<ColumnFloat32>(left_exec_data, right_exec_data,
+                                                       dst_null_map_data,
+                                                       dst_nested_col->get_data().data());
+            } else if (check_column<ColumnFloat64>(*left_exec_data.nested_col)) {
+                ret = _execute_internal<ColumnFloat64>(left_exec_data, right_exec_data,
+                                                       dst_null_map_data,
+                                                       dst_nested_col->get_data().data());
+            }
+        } else if (left_exec_data.nested_col->is_column_decimal()) {
+            if (check_column<ColumnDecimal128>(*left_exec_data.nested_col)) {
+                ret = _execute_internal<ColumnDecimal128>(left_exec_data, right_exec_data,
+                                                          dst_null_map_data,
+                                                          dst_nested_col->get_data().data());
+            }
+        }
+
+        if (ret == Status::OK()) {
+            block.replace_by_position(result, ColumnNullable::create(std::move(dst_nested_col),
+                                                                     std::move(dst_null_map)));
+        }
+
+        return ret;
+    }
+
+private:
+    Status _execute_nullable(const ColumnArrayExecutionData& data, UInt8* dst_nullmap_data) {
+        for (ssize_t row = 0; row < data.offsets_ptr->size(); ++row) {
+            if (dst_nullmap_data[row]) {
+                continue;
+            }
+
+            if (data.array_nullmap_data && data.array_nullmap_data[row]) {
+                dst_nullmap_data[row] = 1;
+                continue;
+            }
+
+            // any element inside array is NULL, return NULL
+            ssize_t start = (*data.offsets_ptr)[row - 1];
+            ssize_t size = (*data.offsets_ptr)[row] - start;
+            for (ssize_t i = start; i < start + size; ++i) {
+                if (data.nested_nullmap_data && data.nested_nullmap_data[i]) {
+                    dst_nullmap_data[row] = 1;
+                    break;
+                }
+            }
+        }
+        return Status::OK();
+    }
+
+    template <typename T>
+    Status _execute_internal(const ColumnArrayExecutionData& left_data,
+                             const ColumnArrayExecutionData& right_data,
+                             const UInt8* dst_nullmap_data, UInt8* dst_data) {
+        using ExecutorImpl = OverlapSetImpl<T>;
+        for (ssize_t row = 0; row < left_data.offsets_ptr->size(); ++row) {
+            if (dst_nullmap_data[row]) {
+                continue;
+            }
+
+            ssize_t left_start = (*left_data.offsets_ptr)[row - 1];
+            ssize_t left_size = (*left_data.offsets_ptr)[row] - left_start;
+            ssize_t right_start = (*right_data.offsets_ptr)[row - 1];
+            ssize_t right_size = (*right_data.offsets_ptr)[row] - right_start;
+            if (left_size == 0 || right_size == 0) {
+                dst_data[row] = 0;
+            }
+
+            ExecutorImpl impl;

Review Comment:
   done



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org