You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@doris.apache.org by pa...@apache.org on 2022/07/14 03:23:48 UTC

[doris] branch master updated: [vectorized][udf] improvement java-udaf with group by clause (#10296)

This is an automated email from the ASF dual-hosted git repository.

panxiaolei pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new e361eb385e [vectorized][udf] improvement java-udaf with group by clause (#10296)
e361eb385e is described below

commit e361eb385e443c1ed616965b6dcc98f3245c2f42
Author: zhangstar333 <87...@users.noreply.github.com>
AuthorDate: Thu Jul 14 11:23:42 2022 +0800

    [vectorized][udf] improvement java-udaf with group by clause (#10296)
    
    save for file about udaf
    add bool _destory_deserialize
    update some code according reviewer
    change destroy all data at once
---
 .../aggregate_function_java_udaf.h                 | 166 ++++++++++++++-------
 .../java/org/apache/doris/udf/UdafExecutor.java    |  70 +++++----
 gensrc/thrift/Types.thrift                         |   4 +
 3 files changed, 154 insertions(+), 86 deletions(-)

diff --git a/be/src/vec/aggregate_functions/aggregate_function_java_udaf.h b/be/src/vec/aggregate_functions/aggregate_function_java_udaf.h
index 8594cd30bf..09ff047998 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_java_udaf.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_java_udaf.h
@@ -45,10 +45,11 @@ namespace doris::vectorized {
 const char* UDAF_EXECUTOR_CLASS = "org/apache/doris/udf/UdafExecutor";
 const char* UDAF_EXECUTOR_CTOR_SIGNATURE = "([B)V";
 const char* UDAF_EXECUTOR_CLOSE_SIGNATURE = "()V";
-const char* UDAF_EXECUTOR_ADD_SIGNATURE = "(JJ)V";
-const char* UDAF_EXECUTOR_SERIALIZE_SIGNATURE = "()[B";
-const char* UDAF_EXECUTOR_MERGE_SIGNATURE = "([B)V";
-const char* UDAF_EXECUTOR_RESULT_SIGNATURE = "(J)Z";
+const char* UDAF_EXECUTOR_DESTROY_SIGNATURE = "()V";
+const char* UDAF_EXECUTOR_ADD_SIGNATURE = "(ZJJ)V";
+const char* UDAF_EXECUTOR_SERIALIZE_SIGNATURE = "(J)[B";
+const char* UDAF_EXECUTOR_MERGE_SIGNATURE = "(J[B)V";
+const char* UDAF_EXECUTOR_RESULT_SIGNATURE = "(JJ)Z";
 // Calling Java method about those signture means: "(argument-types)return-type"
 // https://www.iitk.ac.in/esc101/05Aug/tutorial/native1.1/implementing/method.html
 
@@ -60,6 +61,7 @@ public:
         input_values_buffer_ptr.reset(new int64_t[num_args]);
         input_nulls_buffer_ptr.reset(new int64_t[num_args]);
         input_offsets_ptrs.reset(new int64_t[num_args]);
+        input_place_ptrs.reset(new int64_t);
         output_value_buffer.reset(new int64_t);
         output_null_value.reset(new int64_t);
         output_offsets_ptr.reset(new int64_t);
@@ -96,6 +98,7 @@ public:
             ctor_params.__set_input_buffer_ptrs((int64_t)input_values_buffer_ptr.get());
             ctor_params.__set_input_nulls_ptrs((int64_t)input_nulls_buffer_ptr.get());
             ctor_params.__set_output_buffer_ptr((int64_t)output_value_buffer.get());
+            ctor_params.__set_input_places_ptr((int64_t)input_place_ptrs.get());
 
             ctor_params.__set_output_null_ptr((int64_t)output_null_value.get());
             ctor_params.__set_output_offsets_ptr((int64_t)output_offsets_ptr.get());
@@ -114,8 +117,8 @@ public:
         return Status::OK();
     }
 
-    Status add(const IColumn** columns, size_t row_num_start, size_t row_num_end,
-               const DataTypes& argument_types) {
+    Status add(const int64_t places_address[], bool is_single_place, const IColumn** columns,
+               size_t row_num_start, size_t row_num_end, const DataTypes& argument_types) {
         JNIEnv* env = nullptr;
         RETURN_NOT_OK_STATUS_WITH_WARN(JniUtil::GetJNIEnv(&env), "Java-Udaf add function");
         for (int arg_idx = 0; arg_idx < argument_size; ++arg_idx) {
@@ -144,29 +147,30 @@ public:
                                             argument_types[arg_idx]->get_name()));
             }
         }
-        env->CallNonvirtualVoidMethod(executor_obj, executor_cl, executor_add_id, row_num_start,
-                                      row_num_end);
+        *input_place_ptrs = reinterpret_cast<int64_t>(places_address);
+        env->CallNonvirtualVoidMethod(executor_obj, executor_cl, executor_add_id, is_single_place,
+                                      row_num_start, row_num_end);
         return JniUtil::GetJniExceptionMsg(env);
     }
 
-    Status merge(const AggregateJavaUdafData& rhs) {
+    Status merge(const AggregateJavaUdafData& rhs, int64_t place) {
         JNIEnv* env = nullptr;
         RETURN_NOT_OK_STATUS_WITH_WARN(JniUtil::GetJNIEnv(&env), "Java-Udaf merge function");
         serialize_data = rhs.serialize_data;
         long len = serialize_data.length();
         jbyteArray arr = env->NewByteArray(len);
         env->SetByteArrayRegion(arr, 0, len, reinterpret_cast<jbyte*>(serialize_data.data()));
-        env->CallNonvirtualVoidMethod(executor_obj, executor_cl, executor_merge_id, arr);
+        env->CallNonvirtualVoidMethod(executor_obj, executor_cl, executor_merge_id, place, arr);
         return JniUtil::GetJniExceptionMsg(env);
     }
 
-    Status write(BufferWritable& buf) {
+    Status write(BufferWritable& buf, int64_t place) {
         JNIEnv* env = nullptr;
         RETURN_NOT_OK_STATUS_WITH_WARN(JniUtil::GetJNIEnv(&env), "Java-Udaf write function");
         // TODO: Here get a byte[] from FE serialize, and then allocate the same length bytes to
         // save it in BE, Because i'm not sure there is a way to use the byte[] not allocate again.
-        jbyteArray arr = (jbyteArray)(env->CallNonvirtualObjectMethod(executor_obj, executor_cl,
-                                                                      executor_serialize_id));
+        jbyteArray arr = (jbyteArray)(env->CallNonvirtualObjectMethod(
+                executor_obj, executor_cl, executor_serialize_id, place));
         int len = env->GetArrayLength(arr);
         serialize_data.resize(len);
         env->GetByteArrayRegion(arr, 0, len, reinterpret_cast<jbyte*>(serialize_data.data()));
@@ -176,7 +180,14 @@ public:
 
     void read(BufferReadable& buf) { read_binary(serialize_data, buf); }
 
-    Status get(IColumn& to, const DataTypePtr& result_type) const {
+    Status destroy() {
+        JNIEnv* env = nullptr;
+        RETURN_NOT_OK_STATUS_WITH_WARN(JniUtil::GetJNIEnv(&env), "Java-Udaf destroy function");
+        env->CallNonvirtualVoidMethod(executor_obj, executor_cl, executor_destroy_id);
+        return JniUtil::GetJniExceptionMsg(env);
+    }
+
+    Status get(IColumn& to, const DataTypePtr& result_type, int64_t place) const {
         to.insert_default();
         JNIEnv* env = nullptr;
         RETURN_NOT_OK_STATUS_WITH_WARN(JniUtil::GetJNIEnv(&env), "Java-Udaf get value function");
@@ -187,34 +198,34 @@ public:
             auto& data_col = nullable.get_nested_column();
 
 #ifndef EVALUATE_JAVA_UDAF
-#define EVALUATE_JAVA_UDAF                                                                        \
-    if (data_col.is_column_string()) {                                                            \
-        const ColumnString* str_col = check_and_get_column<ColumnString>(data_col);               \
-        ColumnString::Chars& chars = const_cast<ColumnString::Chars&>(str_col->get_chars());      \
-        ColumnString::Offsets& offsets =                                                          \
-                const_cast<ColumnString::Offsets&>(str_col->get_offsets());                       \
-        int increase_buffer_size = 0;                                                             \
-        *output_value_buffer = reinterpret_cast<int64_t>(chars.data());                           \
-        *output_offsets_ptr = reinterpret_cast<int64_t>(offsets.data());                          \
-        *output_intermediate_state_ptr = chars.size();                                            \
-        jboolean res = env->CallNonvirtualBooleanMethod(executor_obj, executor_cl,                \
-                                                        executor_result_id, to.size() - 1);       \
-        while (res != JNI_TRUE) {                                                                 \
-            int32_t buffer_size = JniUtil::IncreaseReservedBufferSize(increase_buffer_size);      \
-            increase_buffer_size++;                                                               \
-            chars.reserve(chars.size() + buffer_size);                                            \
-            chars.resize(chars.size() + buffer_size);                                             \
-            *output_intermediate_state_ptr = chars.size();                                        \
-            res = env->CallNonvirtualBooleanMethod(executor_obj, executor_cl, executor_result_id, \
-                                                   to.size() - 1);                                \
-        }                                                                                         \
-    } else if (data_col.is_numeric() || data_col.is_column_decimal()) {                           \
-        *output_value_buffer = reinterpret_cast<int64_t>(data_col.get_raw_data().data);           \
-        env->CallNonvirtualBooleanMethod(executor_obj, executor_cl, executor_result_id,           \
-                                         to.size() - 1);                                          \
-    } else {                                                                                      \
-        return Status::InvalidArgument(strings::Substitute(                                       \
-                "Java UDAF doesn't support return type is $0 now !", result_type->get_name()));   \
+#define EVALUATE_JAVA_UDAF                                                                         \
+    if (data_col.is_column_string()) {                                                             \
+        const ColumnString* str_col = check_and_get_column<ColumnString>(data_col);                \
+        ColumnString::Chars& chars = const_cast<ColumnString::Chars&>(str_col->get_chars());       \
+        ColumnString::Offsets& offsets =                                                           \
+                const_cast<ColumnString::Offsets&>(str_col->get_offsets());                        \
+        int increase_buffer_size = 0;                                                              \
+        *output_value_buffer = reinterpret_cast<int64_t>(chars.data());                            \
+        *output_offsets_ptr = reinterpret_cast<int64_t>(offsets.data());                           \
+        *output_intermediate_state_ptr = chars.size();                                             \
+        jboolean res = env->CallNonvirtualBooleanMethod(executor_obj, executor_cl,                 \
+                                                        executor_result_id, to.size() - 1, place); \
+        while (res != JNI_TRUE) {                                                                  \
+            int32_t buffer_size = JniUtil::IncreaseReservedBufferSize(increase_buffer_size);       \
+            increase_buffer_size++;                                                                \
+            chars.reserve(chars.size() + buffer_size);                                             \
+            chars.resize(chars.size() + buffer_size);                                              \
+            *output_intermediate_state_ptr = chars.size();                                         \
+            res = env->CallNonvirtualBooleanMethod(executor_obj, executor_cl, executor_result_id,  \
+                                                   to.size() - 1, place);                          \
+        }                                                                                          \
+    } else if (data_col.is_numeric() || data_col.is_column_decimal()) {                            \
+        *output_value_buffer = reinterpret_cast<int64_t>(data_col.get_raw_data().data);            \
+        env->CallNonvirtualBooleanMethod(executor_obj, executor_cl, executor_result_id,            \
+                                         to.size() - 1, place);                                    \
+    } else {                                                                                       \
+        return Status::InvalidArgument(strings::Substitute(                                        \
+                "Java UDAF doesn't support return type is $0 now !", result_type->get_name()));    \
     }
 #endif
             EVALUATE_JAVA_UDAF;
@@ -224,7 +235,7 @@ public:
             auto& data_col = to;
             EVALUATE_JAVA_UDAF;
             env->CallNonvirtualBooleanMethod(executor_obj, executor_cl, executor_result_id,
-                                             to.size() - 1);
+                                             to.size() - 1, place);
         }
         return JniUtil::GetJniExceptionMsg(env);
     }
@@ -250,10 +261,14 @@ private:
                 register_id("serialize", UDAF_EXECUTOR_SERIALIZE_SIGNATURE, executor_serialize_id));
         RETURN_IF_ERROR(
                 register_id("getValue", UDAF_EXECUTOR_RESULT_SIGNATURE, executor_result_id));
+        RETURN_IF_ERROR(
+                register_id("destroy", UDAF_EXECUTOR_DESTROY_SIGNATURE, executor_destroy_id));
         return Status::OK();
     }
 
 private:
+    // TODO: too many variables are hold, it's causing a lot of memory waste
+    // it's time to refactor it.
     jclass executor_cl;
     jobject executor_obj;
     jmethodID executor_ctor_id;
@@ -263,10 +278,12 @@ private:
     jmethodID executor_serialize_id;
     jmethodID executor_result_id;
     jmethodID executor_close_id;
+    jmethodID executor_destroy_id;
 
     std::unique_ptr<int64_t[]> input_values_buffer_ptr;
     std::unique_ptr<int64_t[]> input_nulls_buffer_ptr;
     std::unique_ptr<int64_t[]> input_offsets_ptrs;
+    std::unique_ptr<int64_t> input_place_ptrs;
     std::unique_ptr<int64_t> output_value_buffer;
     std::unique_ptr<int64_t> output_null_value;
     std::unique_ptr<int64_t> output_offsets_ptr;
@@ -283,7 +300,9 @@ public:
                       const DataTypePtr& return_type)
             : IAggregateFunctionDataHelper(argument_types, parameters),
               _fn(fn),
-              _return_type(return_type) {}
+              _return_type(return_type),
+              _first_created(true),
+              _exec_place(nullptr) {}
     ~AggregateJavaUdaf() = default;
 
     static AggregateFunctionPtr create(const TFunction& fn, const DataTypes& argument_types,
@@ -292,54 +311,87 @@ public:
     }
 
     void create(AggregateDataPtr __restrict place) const override {
-        new (place) Data(argument_types.size());
-        Status status = Status::OK();
-        RETURN_IF_STATUS_ERROR(status, data(place).init_udaf(_fn));
+        if (_first_created) {
+            new (place) Data(argument_types.size());
+            Status status = Status::OK();
+            RETURN_IF_STATUS_ERROR(status, this->data(place).init_udaf(_fn));
+            _first_created = false;
+            _exec_place = place;
+        }
+    }
+
+    // To avoid multiple times JNI call, Here will destroy all data at once
+    void destroy(AggregateDataPtr __restrict place) const noexcept override {
+        if (place == _exec_place) {
+            this->data(_exec_place).destroy();
+            this->data(_exec_place).~Data();
+        }
     }
 
     String get_name() const override { return _fn.name.function_name; }
 
     DataTypePtr get_return_type() const override { return _return_type; }
 
-    // TODO: here calling add operator maybe only hava done one row, this performance may be poorly
-    // so it's possible to maintain a hashtable in FE, the key is place address, value is the object
-    // then we can calling add_bacth function and calculate the whole batch at once,
-    // and avoid calling jni multiple times.
     void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num,
              Arena*) const override {
-        this->data(place).add(columns, row_num, row_num + 1, argument_types);
+        LOG(WARNING) << " shouldn't going add function, there maybe some error about function "
+                     << _fn.name.function_name;
+    }
+
+    void add_batch(size_t batch_size, AggregateDataPtr* places, size_t place_offset,
+                   const IColumn** columns, Arena* arena) const override {
+        int64_t places_address[batch_size];
+        for (size_t i = 0; i < batch_size; ++i) {
+            places_address[i] = reinterpret_cast<int64_t>(places[i]);
+        }
+        this->data(_exec_place).add(places_address, false, columns, 0, batch_size, argument_types);
     }
 
     // TODO: Here we calling method by jni, And if we get a thrown from FE,
     // But can't let user known the error, only return directly and output error to log file.
     void add_batch_single_place(size_t batch_size, AggregateDataPtr place, const IColumn** columns,
                                 Arena* arena) const override {
-        this->data(place).add(columns, 0, batch_size, argument_types);
+        int64_t places_address[1];
+        places_address[0] = reinterpret_cast<int64_t>(place);
+        this->data(_exec_place).add(places_address, true, columns, 0, batch_size, argument_types);
     }
 
-    void reset(AggregateDataPtr place) const override {}
+    // TODO: reset function should be implement also in struct data
+    void reset(AggregateDataPtr place) const override {
+        LOG(WARNING) << " shouldn't going reset function, there maybe some error about function "
+                     << _fn.name.function_name;
+    }
 
     void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
                Arena*) const override {
-        this->data(place).merge(this->data(rhs));
+        this->data(_exec_place).merge(this->data(rhs), reinterpret_cast<int64_t>(place));
     }
 
     void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override {
-        this->data(const_cast<AggregateDataPtr&>(place)).write(buf);
+        this->data(const_cast<AggregateDataPtr&>(_exec_place))
+                .write(buf, reinterpret_cast<int64_t>(place));
     }
 
+    // during merge-finalized phase, for deserialize and merge firstly,
+    // will call create --- deserialize --- merge --- destory for each rows ,
+    // so need doing new (place), to create Data and read to buf, then call merge ,
+    // and during destory about deserialize, because haven't done init_udaf,
+    // so it's can't call ~Data, only to change _destory_deserialize flag.
     void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
                      Arena*) const override {
+        new (place) Data(argument_types.size());
         this->data(place).read(buf);
     }
 
     void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
-        this->data(place).get(to, _return_type);
+        this->data(_exec_place).get(to, _return_type, reinterpret_cast<int64_t>(place));
     }
 
 private:
     TFunction _fn;
     DataTypePtr _return_type;
+    mutable bool _first_created;
+    mutable AggregateDataPtr _exec_place;
 };
 
 } // namespace doris::vectorized
diff --git a/fe/java-udf/src/main/java/org/apache/doris/udf/UdafExecutor.java b/fe/java-udf/src/main/java/org/apache/doris/udf/UdafExecutor.java
index c3632a15ad..f944cfe2d1 100644
--- a/fe/java-udf/src/main/java/org/apache/doris/udf/UdafExecutor.java
+++ b/fe/java-udf/src/main/java/org/apache/doris/udf/UdafExecutor.java
@@ -57,21 +57,22 @@ public class UdafExecutor {
     public static final String UDAF_DESERIALIZE_FUNCTION = "deserialize";
     public static final String UDAF_MERGE_FUNCTION = "merge";
     public static final String UDAF_RESULT_FUNCTION = "getValue";
-    private static final Logger LOG = Logger.getLogger(UdfExecutor.class);
+    private static final Logger LOG = Logger.getLogger(UdafExecutor.class);
     private static final TBinaryProtocol.Factory PROTOCOL_FACTORY = new TBinaryProtocol.Factory();
     private final long inputBufferPtrs;
     private final long inputNullsPtrs;
     private final long inputOffsetsPtrs;
+    private final long inputPlacesPtr;
     private final long outputBufferPtr;
     private final long outputNullPtr;
     private final long outputOffsetsPtr;
     private final long outputIntermediateStatePtr;
     private Object udaf;
     private HashMap<String, Method> allMethods;
+    private HashMap<Long, Object> stateObjMap;
     private URLClassLoader classLoader;
     private JavaUdfDataType[] argTypes;
     private JavaUdfDataType retType;
-    private Object stateObj;
 
     /**
      * Constructor to create an object.
@@ -91,17 +92,18 @@ public class UdafExecutor {
         inputBufferPtrs = request.input_buffer_ptrs;
         inputNullsPtrs = request.input_nulls_ptrs;
         inputOffsetsPtrs = request.input_offsets_ptrs;
+        inputPlacesPtr = request.input_places_ptr;
 
         outputBufferPtr = request.output_buffer_ptr;
         outputNullPtr = request.output_null_ptr;
         outputOffsetsPtr = request.output_offsets_ptr;
         outputIntermediateStatePtr = request.output_intermediate_state_ptr;
         allMethods = new HashMap<>();
+        stateObjMap = new HashMap<>();
         String className = request.fn.aggregate_fn.symbol;
         String jarFile = request.location;
         Type funcRetType = UdfUtils.fromThrift(request.fn.ret_type, 0).first;
         init(jarFile, className, funcRetType, parameterTypes);
-        stateObj = create();
     }
 
     /**
@@ -110,7 +112,6 @@ public class UdafExecutor {
     public void close() {
         if (classLoader != null) {
             try {
-                destroy();
                 classLoader.close();
             } catch (Exception e) {
                 // Log and ignore.
@@ -132,24 +133,23 @@ public class UdafExecutor {
     /**
      * invoke add function, add row in loop [rowStart, rowEnd).
      */
-    public void add(long rowStart, long rowEnd) throws UdfRuntimeException {
+    public void add(boolean isSinglePlace, long rowStart, long rowEnd) throws UdfRuntimeException {
         try {
-            Object[] inputArgs = new Object[argTypes.length + 1];
-            inputArgs[0] = stateObj;
-            for (long row = rowStart; row < rowEnd; ++row) {
-                Object[] inputObjects = allocateInputObjects(row);
-                for (int i = 0; i < argTypes.length; ++i) {
-                    if (UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputNullsPtrs, i)) == -1
-                            || UdfUtils.UNSAFE.getByte(null,
-                                    UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputNullsPtrs, i)) + row)
-                                    == 0) {
+            long idx = rowStart;
+            do {
+                Long curPlace = UdfUtils.UNSAFE.getLong(null, UdfUtils.UNSAFE.getLong(null, inputPlacesPtr) + 8L * idx);
+                Object[] inputArgs = new Object[argTypes.length + 1];
+                stateObjMap.putIfAbsent(curPlace, createAggState());
+                inputArgs[0] = stateObjMap.get(curPlace);
+                do {
+                    Object[] inputObjects = allocateInputObjects(idx);
+                    for (int i = 0; i < argTypes.length; ++i) {
                         inputArgs[i + 1] = inputObjects[i];
-                    } else {
-                        inputArgs[i + 1] = null;
                     }
-                }
-                allMethods.get(UDAF_ADD_FUNCTION).invoke(udaf, inputArgs);
-            }
+                    allMethods.get(UDAF_ADD_FUNCTION).invoke(udaf, inputArgs);
+                    idx++;
+                } while (isSinglePlace && idx < rowEnd);
+            } while (idx < rowEnd);
         } catch (Exception e) {
             throw new UdfRuntimeException("UDAF failed to add: ", e);
         }
@@ -158,7 +158,7 @@ public class UdafExecutor {
     /**
      * invoke user create function to get obj.
      */
-    public Object create() throws UdfRuntimeException {
+    public Object createAggState() throws UdfRuntimeException {
         try {
             return allMethods.get(UDAF_CREATE_FUNCTION).invoke(udaf, null);
         } catch (Exception e) {
@@ -167,11 +167,14 @@ public class UdafExecutor {
     }
 
     /**
-     * invoke destroy before colse.
+     * invoke destroy before colse. Here we destroy all data at once
      */
     public void destroy() throws UdfRuntimeException {
         try {
-            allMethods.get(UDAF_DESTROY_FUNCTION).invoke(udaf, stateObj);
+            for (Object obj : stateObjMap.values()) {
+                allMethods.get(UDAF_DESTROY_FUNCTION).invoke(udaf, obj);
+            }
+            stateObjMap.clear();
         } catch (Exception e) {
             throw new UdfRuntimeException("UDAF failed to destroy: ", e);
         }
@@ -180,11 +183,11 @@ public class UdafExecutor {
     /**
      * invoke serialize function and return byte[] to backends.
      */
-    public byte[] serialize() throws UdfRuntimeException {
+    public byte[] serialize(long place) throws UdfRuntimeException {
         try {
             Object[] args = new Object[2];
             ByteArrayOutputStream baos = new ByteArrayOutputStream();
-            args[0] = stateObj;
+            args[0] = stateObjMap.get((Long) place);
             args[1] = new DataOutputStream(baos);
             allMethods.get(UDAF_SERIALIZE_FUNCTION).invoke(udaf, args);
             return baos.toByteArray();
@@ -197,15 +200,17 @@ public class UdafExecutor {
      * invoke merge function and it's have done deserialze.
      * here call deserialize first, and call merge.
      */
-    public void merge(byte[] data) throws UdfRuntimeException {
+    public void merge(long place, byte[] data) throws UdfRuntimeException {
         try {
             Object[] args = new Object[2];
             ByteArrayInputStream bins = new ByteArrayInputStream(data);
-            args[0] = create();
+            args[0] = createAggState();
             args[1] = new DataInputStream(bins);
             allMethods.get(UDAF_DESERIALIZE_FUNCTION).invoke(udaf, args);
             args[1] = args[0];
-            args[0] = stateObj;
+            Long curPlace = place;
+            stateObjMap.putIfAbsent(curPlace, createAggState());
+            args[0] = stateObjMap.get(curPlace);
             allMethods.get(UDAF_MERGE_FUNCTION).invoke(udaf, args);
         } catch (Exception e) {
             throw new UdfRuntimeException("UDAF failed to merge: ", e);
@@ -215,9 +220,10 @@ public class UdafExecutor {
     /**
      * invoke getValue to return finally result.
      */
-    public boolean getValue(long row) throws UdfRuntimeException {
+    public boolean getValue(long row, long place) throws UdfRuntimeException {
         try {
-            return storeUdfResult(allMethods.get(UDAF_RESULT_FUNCTION).invoke(udaf, stateObj), row);
+            return storeUdfResult(allMethods.get(UDAF_RESULT_FUNCTION).invoke(udaf, stateObjMap.get((Long) place)),
+                    row);
         } catch (Exception e) {
             throw new UdfRuntimeException("UDAF failed to result", e);
         }
@@ -353,6 +359,12 @@ public class UdafExecutor {
         Object[] inputObjects = new Object[argTypes.length];
 
         for (int i = 0; i < argTypes.length; ++i) {
+            // skip the input column of current row is null
+            if (UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputNullsPtrs, i)) != -1
+                    && UdfUtils.UNSAFE.getByte(null, UdfUtils.getAddressAtOffset(inputNullsPtrs, i) + row) == 1) {
+                inputObjects[i] = null;
+                continue;
+            }
             switch (argTypes[i]) {
                 case BOOLEAN:
                     inputObjects[i] = UdfUtils.UNSAFE.getBoolean(null,
diff --git a/gensrc/thrift/Types.thrift b/gensrc/thrift/Types.thrift
index 3913f30376..ecf3268de0 100644
--- a/gensrc/thrift/Types.thrift
+++ b/gensrc/thrift/Types.thrift
@@ -385,6 +385,10 @@ struct TJavaUdfExecutorCtorParams {
   9: optional i64 output_intermediate_state_ptr
 
   10: optional i64 batch_size_ptr
+  
+  // this is used to pass place or places to FE, which could help us call jni
+  // only once and can process a batch size data in JAVA-Udaf
+  11: optional i64 input_places_ptr
 }
 
 // Contains all interesting statistics from a single 'memory pool' in the JVM.


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org