You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@doris.apache.org by pa...@apache.org on 2022/07/14 03:23:48 UTC
[doris] branch master updated: [vectorized][udf] improvement java-udaf with group by clause (#10296)
This is an automated email from the ASF dual-hosted git repository.
panxiaolei pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new e361eb385e [vectorized][udf] improvement java-udaf with group by clause (#10296)
e361eb385e is described below
commit e361eb385e443c1ed616965b6dcc98f3245c2f42
Author: zhangstar333 <87...@users.noreply.github.com>
AuthorDate: Thu Jul 14 11:23:42 2022 +0800
[vectorized][udf] improvement java-udaf with group by clause (#10296)
save for file about udaf
add bool _destory_deserialize
update some code according reviewer
change destroy all data at once
---
.../aggregate_function_java_udaf.h | 166 ++++++++++++++-------
.../java/org/apache/doris/udf/UdafExecutor.java | 70 +++++----
gensrc/thrift/Types.thrift | 4 +
3 files changed, 154 insertions(+), 86 deletions(-)
diff --git a/be/src/vec/aggregate_functions/aggregate_function_java_udaf.h b/be/src/vec/aggregate_functions/aggregate_function_java_udaf.h
index 8594cd30bf..09ff047998 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_java_udaf.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_java_udaf.h
@@ -45,10 +45,11 @@ namespace doris::vectorized {
const char* UDAF_EXECUTOR_CLASS = "org/apache/doris/udf/UdafExecutor";
const char* UDAF_EXECUTOR_CTOR_SIGNATURE = "([B)V";
const char* UDAF_EXECUTOR_CLOSE_SIGNATURE = "()V";
-const char* UDAF_EXECUTOR_ADD_SIGNATURE = "(JJ)V";
-const char* UDAF_EXECUTOR_SERIALIZE_SIGNATURE = "()[B";
-const char* UDAF_EXECUTOR_MERGE_SIGNATURE = "([B)V";
-const char* UDAF_EXECUTOR_RESULT_SIGNATURE = "(J)Z";
+const char* UDAF_EXECUTOR_DESTROY_SIGNATURE = "()V";
+const char* UDAF_EXECUTOR_ADD_SIGNATURE = "(ZJJ)V";
+const char* UDAF_EXECUTOR_SERIALIZE_SIGNATURE = "(J)[B";
+const char* UDAF_EXECUTOR_MERGE_SIGNATURE = "(J[B)V";
+const char* UDAF_EXECUTOR_RESULT_SIGNATURE = "(JJ)Z";
// Calling Java method about those signture means: "(argument-types)return-type"
// https://www.iitk.ac.in/esc101/05Aug/tutorial/native1.1/implementing/method.html
@@ -60,6 +61,7 @@ public:
input_values_buffer_ptr.reset(new int64_t[num_args]);
input_nulls_buffer_ptr.reset(new int64_t[num_args]);
input_offsets_ptrs.reset(new int64_t[num_args]);
+ input_place_ptrs.reset(new int64_t);
output_value_buffer.reset(new int64_t);
output_null_value.reset(new int64_t);
output_offsets_ptr.reset(new int64_t);
@@ -96,6 +98,7 @@ public:
ctor_params.__set_input_buffer_ptrs((int64_t)input_values_buffer_ptr.get());
ctor_params.__set_input_nulls_ptrs((int64_t)input_nulls_buffer_ptr.get());
ctor_params.__set_output_buffer_ptr((int64_t)output_value_buffer.get());
+ ctor_params.__set_input_places_ptr((int64_t)input_place_ptrs.get());
ctor_params.__set_output_null_ptr((int64_t)output_null_value.get());
ctor_params.__set_output_offsets_ptr((int64_t)output_offsets_ptr.get());
@@ -114,8 +117,8 @@ public:
return Status::OK();
}
- Status add(const IColumn** columns, size_t row_num_start, size_t row_num_end,
- const DataTypes& argument_types) {
+ Status add(const int64_t places_address[], bool is_single_place, const IColumn** columns,
+ size_t row_num_start, size_t row_num_end, const DataTypes& argument_types) {
JNIEnv* env = nullptr;
RETURN_NOT_OK_STATUS_WITH_WARN(JniUtil::GetJNIEnv(&env), "Java-Udaf add function");
for (int arg_idx = 0; arg_idx < argument_size; ++arg_idx) {
@@ -144,29 +147,30 @@ public:
argument_types[arg_idx]->get_name()));
}
}
- env->CallNonvirtualVoidMethod(executor_obj, executor_cl, executor_add_id, row_num_start,
- row_num_end);
+ *input_place_ptrs = reinterpret_cast<int64_t>(places_address);
+ env->CallNonvirtualVoidMethod(executor_obj, executor_cl, executor_add_id, is_single_place,
+ row_num_start, row_num_end);
return JniUtil::GetJniExceptionMsg(env);
}
- Status merge(const AggregateJavaUdafData& rhs) {
+ Status merge(const AggregateJavaUdafData& rhs, int64_t place) {
JNIEnv* env = nullptr;
RETURN_NOT_OK_STATUS_WITH_WARN(JniUtil::GetJNIEnv(&env), "Java-Udaf merge function");
serialize_data = rhs.serialize_data;
long len = serialize_data.length();
jbyteArray arr = env->NewByteArray(len);
env->SetByteArrayRegion(arr, 0, len, reinterpret_cast<jbyte*>(serialize_data.data()));
- env->CallNonvirtualVoidMethod(executor_obj, executor_cl, executor_merge_id, arr);
+ env->CallNonvirtualVoidMethod(executor_obj, executor_cl, executor_merge_id, place, arr);
return JniUtil::GetJniExceptionMsg(env);
}
- Status write(BufferWritable& buf) {
+ Status write(BufferWritable& buf, int64_t place) {
JNIEnv* env = nullptr;
RETURN_NOT_OK_STATUS_WITH_WARN(JniUtil::GetJNIEnv(&env), "Java-Udaf write function");
// TODO: Here get a byte[] from FE serialize, and then allocate the same length bytes to
// save it in BE, Because i'm not sure there is a way to use the byte[] not allocate again.
- jbyteArray arr = (jbyteArray)(env->CallNonvirtualObjectMethod(executor_obj, executor_cl,
- executor_serialize_id));
+ jbyteArray arr = (jbyteArray)(env->CallNonvirtualObjectMethod(
+ executor_obj, executor_cl, executor_serialize_id, place));
int len = env->GetArrayLength(arr);
serialize_data.resize(len);
env->GetByteArrayRegion(arr, 0, len, reinterpret_cast<jbyte*>(serialize_data.data()));
@@ -176,7 +180,14 @@ public:
void read(BufferReadable& buf) { read_binary(serialize_data, buf); }
- Status get(IColumn& to, const DataTypePtr& result_type) const {
+ Status destroy() {
+ JNIEnv* env = nullptr;
+ RETURN_NOT_OK_STATUS_WITH_WARN(JniUtil::GetJNIEnv(&env), "Java-Udaf destroy function");
+ env->CallNonvirtualVoidMethod(executor_obj, executor_cl, executor_destroy_id);
+ return JniUtil::GetJniExceptionMsg(env);
+ }
+
+ Status get(IColumn& to, const DataTypePtr& result_type, int64_t place) const {
to.insert_default();
JNIEnv* env = nullptr;
RETURN_NOT_OK_STATUS_WITH_WARN(JniUtil::GetJNIEnv(&env), "Java-Udaf get value function");
@@ -187,34 +198,34 @@ public:
auto& data_col = nullable.get_nested_column();
#ifndef EVALUATE_JAVA_UDAF
-#define EVALUATE_JAVA_UDAF \
- if (data_col.is_column_string()) { \
- const ColumnString* str_col = check_and_get_column<ColumnString>(data_col); \
- ColumnString::Chars& chars = const_cast<ColumnString::Chars&>(str_col->get_chars()); \
- ColumnString::Offsets& offsets = \
- const_cast<ColumnString::Offsets&>(str_col->get_offsets()); \
- int increase_buffer_size = 0; \
- *output_value_buffer = reinterpret_cast<int64_t>(chars.data()); \
- *output_offsets_ptr = reinterpret_cast<int64_t>(offsets.data()); \
- *output_intermediate_state_ptr = chars.size(); \
- jboolean res = env->CallNonvirtualBooleanMethod(executor_obj, executor_cl, \
- executor_result_id, to.size() - 1); \
- while (res != JNI_TRUE) { \
- int32_t buffer_size = JniUtil::IncreaseReservedBufferSize(increase_buffer_size); \
- increase_buffer_size++; \
- chars.reserve(chars.size() + buffer_size); \
- chars.resize(chars.size() + buffer_size); \
- *output_intermediate_state_ptr = chars.size(); \
- res = env->CallNonvirtualBooleanMethod(executor_obj, executor_cl, executor_result_id, \
- to.size() - 1); \
- } \
- } else if (data_col.is_numeric() || data_col.is_column_decimal()) { \
- *output_value_buffer = reinterpret_cast<int64_t>(data_col.get_raw_data().data); \
- env->CallNonvirtualBooleanMethod(executor_obj, executor_cl, executor_result_id, \
- to.size() - 1); \
- } else { \
- return Status::InvalidArgument(strings::Substitute( \
- "Java UDAF doesn't support return type is $0 now !", result_type->get_name())); \
+#define EVALUATE_JAVA_UDAF \
+ if (data_col.is_column_string()) { \
+ const ColumnString* str_col = check_and_get_column<ColumnString>(data_col); \
+ ColumnString::Chars& chars = const_cast<ColumnString::Chars&>(str_col->get_chars()); \
+ ColumnString::Offsets& offsets = \
+ const_cast<ColumnString::Offsets&>(str_col->get_offsets()); \
+ int increase_buffer_size = 0; \
+ *output_value_buffer = reinterpret_cast<int64_t>(chars.data()); \
+ *output_offsets_ptr = reinterpret_cast<int64_t>(offsets.data()); \
+ *output_intermediate_state_ptr = chars.size(); \
+ jboolean res = env->CallNonvirtualBooleanMethod(executor_obj, executor_cl, \
+ executor_result_id, to.size() - 1, place); \
+ while (res != JNI_TRUE) { \
+ int32_t buffer_size = JniUtil::IncreaseReservedBufferSize(increase_buffer_size); \
+ increase_buffer_size++; \
+ chars.reserve(chars.size() + buffer_size); \
+ chars.resize(chars.size() + buffer_size); \
+ *output_intermediate_state_ptr = chars.size(); \
+ res = env->CallNonvirtualBooleanMethod(executor_obj, executor_cl, executor_result_id, \
+ to.size() - 1, place); \
+ } \
+ } else if (data_col.is_numeric() || data_col.is_column_decimal()) { \
+ *output_value_buffer = reinterpret_cast<int64_t>(data_col.get_raw_data().data); \
+ env->CallNonvirtualBooleanMethod(executor_obj, executor_cl, executor_result_id, \
+ to.size() - 1, place); \
+ } else { \
+ return Status::InvalidArgument(strings::Substitute( \
+ "Java UDAF doesn't support return type is $0 now !", result_type->get_name())); \
}
#endif
EVALUATE_JAVA_UDAF;
@@ -224,7 +235,7 @@ public:
auto& data_col = to;
EVALUATE_JAVA_UDAF;
env->CallNonvirtualBooleanMethod(executor_obj, executor_cl, executor_result_id,
- to.size() - 1);
+ to.size() - 1, place);
}
return JniUtil::GetJniExceptionMsg(env);
}
@@ -250,10 +261,14 @@ private:
register_id("serialize", UDAF_EXECUTOR_SERIALIZE_SIGNATURE, executor_serialize_id));
RETURN_IF_ERROR(
register_id("getValue", UDAF_EXECUTOR_RESULT_SIGNATURE, executor_result_id));
+ RETURN_IF_ERROR(
+ register_id("destroy", UDAF_EXECUTOR_DESTROY_SIGNATURE, executor_destroy_id));
return Status::OK();
}
private:
+ // TODO: too many variables are hold, it's causing a lot of memory waste
+ // it's time to refactor it.
jclass executor_cl;
jobject executor_obj;
jmethodID executor_ctor_id;
@@ -263,10 +278,12 @@ private:
jmethodID executor_serialize_id;
jmethodID executor_result_id;
jmethodID executor_close_id;
+ jmethodID executor_destroy_id;
std::unique_ptr<int64_t[]> input_values_buffer_ptr;
std::unique_ptr<int64_t[]> input_nulls_buffer_ptr;
std::unique_ptr<int64_t[]> input_offsets_ptrs;
+ std::unique_ptr<int64_t> input_place_ptrs;
std::unique_ptr<int64_t> output_value_buffer;
std::unique_ptr<int64_t> output_null_value;
std::unique_ptr<int64_t> output_offsets_ptr;
@@ -283,7 +300,9 @@ public:
const DataTypePtr& return_type)
: IAggregateFunctionDataHelper(argument_types, parameters),
_fn(fn),
- _return_type(return_type) {}
+ _return_type(return_type),
+ _first_created(true),
+ _exec_place(nullptr) {}
~AggregateJavaUdaf() = default;
static AggregateFunctionPtr create(const TFunction& fn, const DataTypes& argument_types,
@@ -292,54 +311,87 @@ public:
}
void create(AggregateDataPtr __restrict place) const override {
- new (place) Data(argument_types.size());
- Status status = Status::OK();
- RETURN_IF_STATUS_ERROR(status, data(place).init_udaf(_fn));
+ if (_first_created) {
+ new (place) Data(argument_types.size());
+ Status status = Status::OK();
+ RETURN_IF_STATUS_ERROR(status, this->data(place).init_udaf(_fn));
+ _first_created = false;
+ _exec_place = place;
+ }
+ }
+
+ // To avoid multiple times JNI call, Here will destroy all data at once
+ void destroy(AggregateDataPtr __restrict place) const noexcept override {
+ if (place == _exec_place) {
+ this->data(_exec_place).destroy();
+ this->data(_exec_place).~Data();
+ }
}
String get_name() const override { return _fn.name.function_name; }
DataTypePtr get_return_type() const override { return _return_type; }
- // TODO: here calling add operator maybe only hava done one row, this performance may be poorly
- // so it's possible to maintain a hashtable in FE, the key is place address, value is the object
- // then we can calling add_bacth function and calculate the whole batch at once,
- // and avoid calling jni multiple times.
void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num,
Arena*) const override {
- this->data(place).add(columns, row_num, row_num + 1, argument_types);
+ LOG(WARNING) << " shouldn't going add function, there maybe some error about function "
+ << _fn.name.function_name;
+ }
+
+ void add_batch(size_t batch_size, AggregateDataPtr* places, size_t place_offset,
+ const IColumn** columns, Arena* arena) const override {
+ int64_t places_address[batch_size];
+ for (size_t i = 0; i < batch_size; ++i) {
+ places_address[i] = reinterpret_cast<int64_t>(places[i]);
+ }
+ this->data(_exec_place).add(places_address, false, columns, 0, batch_size, argument_types);
}
// TODO: Here we calling method by jni, And if we get a thrown from FE,
// But can't let user known the error, only return directly and output error to log file.
void add_batch_single_place(size_t batch_size, AggregateDataPtr place, const IColumn** columns,
Arena* arena) const override {
- this->data(place).add(columns, 0, batch_size, argument_types);
+ int64_t places_address[1];
+ places_address[0] = reinterpret_cast<int64_t>(place);
+ this->data(_exec_place).add(places_address, true, columns, 0, batch_size, argument_types);
}
- void reset(AggregateDataPtr place) const override {}
+ // TODO: reset function should be implement also in struct data
+ void reset(AggregateDataPtr place) const override {
+ LOG(WARNING) << " shouldn't going reset function, there maybe some error about function "
+ << _fn.name.function_name;
+ }
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
Arena*) const override {
- this->data(place).merge(this->data(rhs));
+ this->data(_exec_place).merge(this->data(rhs), reinterpret_cast<int64_t>(place));
}
void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override {
- this->data(const_cast<AggregateDataPtr&>(place)).write(buf);
+ this->data(const_cast<AggregateDataPtr&>(_exec_place))
+ .write(buf, reinterpret_cast<int64_t>(place));
}
+ // during merge-finalized phase, for deserialize and merge firstly,
+ // will call create --- deserialize --- merge --- destory for each rows ,
+ // so need doing new (place), to create Data and read to buf, then call merge ,
+ // and during destory about deserialize, because haven't done init_udaf,
+ // so it's can't call ~Data, only to change _destory_deserialize flag.
void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
Arena*) const override {
+ new (place) Data(argument_types.size());
this->data(place).read(buf);
}
void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
- this->data(place).get(to, _return_type);
+ this->data(_exec_place).get(to, _return_type, reinterpret_cast<int64_t>(place));
}
private:
TFunction _fn;
DataTypePtr _return_type;
+ mutable bool _first_created;
+ mutable AggregateDataPtr _exec_place;
};
} // namespace doris::vectorized
diff --git a/fe/java-udf/src/main/java/org/apache/doris/udf/UdafExecutor.java b/fe/java-udf/src/main/java/org/apache/doris/udf/UdafExecutor.java
index c3632a15ad..f944cfe2d1 100644
--- a/fe/java-udf/src/main/java/org/apache/doris/udf/UdafExecutor.java
+++ b/fe/java-udf/src/main/java/org/apache/doris/udf/UdafExecutor.java
@@ -57,21 +57,22 @@ public class UdafExecutor {
public static final String UDAF_DESERIALIZE_FUNCTION = "deserialize";
public static final String UDAF_MERGE_FUNCTION = "merge";
public static final String UDAF_RESULT_FUNCTION = "getValue";
- private static final Logger LOG = Logger.getLogger(UdfExecutor.class);
+ private static final Logger LOG = Logger.getLogger(UdafExecutor.class);
private static final TBinaryProtocol.Factory PROTOCOL_FACTORY = new TBinaryProtocol.Factory();
private final long inputBufferPtrs;
private final long inputNullsPtrs;
private final long inputOffsetsPtrs;
+ private final long inputPlacesPtr;
private final long outputBufferPtr;
private final long outputNullPtr;
private final long outputOffsetsPtr;
private final long outputIntermediateStatePtr;
private Object udaf;
private HashMap<String, Method> allMethods;
+ private HashMap<Long, Object> stateObjMap;
private URLClassLoader classLoader;
private JavaUdfDataType[] argTypes;
private JavaUdfDataType retType;
- private Object stateObj;
/**
* Constructor to create an object.
@@ -91,17 +92,18 @@ public class UdafExecutor {
inputBufferPtrs = request.input_buffer_ptrs;
inputNullsPtrs = request.input_nulls_ptrs;
inputOffsetsPtrs = request.input_offsets_ptrs;
+ inputPlacesPtr = request.input_places_ptr;
outputBufferPtr = request.output_buffer_ptr;
outputNullPtr = request.output_null_ptr;
outputOffsetsPtr = request.output_offsets_ptr;
outputIntermediateStatePtr = request.output_intermediate_state_ptr;
allMethods = new HashMap<>();
+ stateObjMap = new HashMap<>();
String className = request.fn.aggregate_fn.symbol;
String jarFile = request.location;
Type funcRetType = UdfUtils.fromThrift(request.fn.ret_type, 0).first;
init(jarFile, className, funcRetType, parameterTypes);
- stateObj = create();
}
/**
@@ -110,7 +112,6 @@ public class UdafExecutor {
public void close() {
if (classLoader != null) {
try {
- destroy();
classLoader.close();
} catch (Exception e) {
// Log and ignore.
@@ -132,24 +133,23 @@ public class UdafExecutor {
/**
* invoke add function, add row in loop [rowStart, rowEnd).
*/
- public void add(long rowStart, long rowEnd) throws UdfRuntimeException {
+ public void add(boolean isSinglePlace, long rowStart, long rowEnd) throws UdfRuntimeException {
try {
- Object[] inputArgs = new Object[argTypes.length + 1];
- inputArgs[0] = stateObj;
- for (long row = rowStart; row < rowEnd; ++row) {
- Object[] inputObjects = allocateInputObjects(row);
- for (int i = 0; i < argTypes.length; ++i) {
- if (UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputNullsPtrs, i)) == -1
- || UdfUtils.UNSAFE.getByte(null,
- UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputNullsPtrs, i)) + row)
- == 0) {
+ long idx = rowStart;
+ do {
+ Long curPlace = UdfUtils.UNSAFE.getLong(null, UdfUtils.UNSAFE.getLong(null, inputPlacesPtr) + 8L * idx);
+ Object[] inputArgs = new Object[argTypes.length + 1];
+ stateObjMap.putIfAbsent(curPlace, createAggState());
+ inputArgs[0] = stateObjMap.get(curPlace);
+ do {
+ Object[] inputObjects = allocateInputObjects(idx);
+ for (int i = 0; i < argTypes.length; ++i) {
inputArgs[i + 1] = inputObjects[i];
- } else {
- inputArgs[i + 1] = null;
}
- }
- allMethods.get(UDAF_ADD_FUNCTION).invoke(udaf, inputArgs);
- }
+ allMethods.get(UDAF_ADD_FUNCTION).invoke(udaf, inputArgs);
+ idx++;
+ } while (isSinglePlace && idx < rowEnd);
+ } while (idx < rowEnd);
} catch (Exception e) {
throw new UdfRuntimeException("UDAF failed to add: ", e);
}
@@ -158,7 +158,7 @@ public class UdafExecutor {
/**
* invoke user create function to get obj.
*/
- public Object create() throws UdfRuntimeException {
+ public Object createAggState() throws UdfRuntimeException {
try {
return allMethods.get(UDAF_CREATE_FUNCTION).invoke(udaf, null);
} catch (Exception e) {
@@ -167,11 +167,14 @@ public class UdafExecutor {
}
/**
- * invoke destroy before colse.
+ * invoke destroy before colse. Here we destroy all data at once
*/
public void destroy() throws UdfRuntimeException {
try {
- allMethods.get(UDAF_DESTROY_FUNCTION).invoke(udaf, stateObj);
+ for (Object obj : stateObjMap.values()) {
+ allMethods.get(UDAF_DESTROY_FUNCTION).invoke(udaf, obj);
+ }
+ stateObjMap.clear();
} catch (Exception e) {
throw new UdfRuntimeException("UDAF failed to destroy: ", e);
}
@@ -180,11 +183,11 @@ public class UdafExecutor {
/**
* invoke serialize function and return byte[] to backends.
*/
- public byte[] serialize() throws UdfRuntimeException {
+ public byte[] serialize(long place) throws UdfRuntimeException {
try {
Object[] args = new Object[2];
ByteArrayOutputStream baos = new ByteArrayOutputStream();
- args[0] = stateObj;
+ args[0] = stateObjMap.get((Long) place);
args[1] = new DataOutputStream(baos);
allMethods.get(UDAF_SERIALIZE_FUNCTION).invoke(udaf, args);
return baos.toByteArray();
@@ -197,15 +200,17 @@ public class UdafExecutor {
* invoke merge function and it's have done deserialze.
* here call deserialize first, and call merge.
*/
- public void merge(byte[] data) throws UdfRuntimeException {
+ public void merge(long place, byte[] data) throws UdfRuntimeException {
try {
Object[] args = new Object[2];
ByteArrayInputStream bins = new ByteArrayInputStream(data);
- args[0] = create();
+ args[0] = createAggState();
args[1] = new DataInputStream(bins);
allMethods.get(UDAF_DESERIALIZE_FUNCTION).invoke(udaf, args);
args[1] = args[0];
- args[0] = stateObj;
+ Long curPlace = place;
+ stateObjMap.putIfAbsent(curPlace, createAggState());
+ args[0] = stateObjMap.get(curPlace);
allMethods.get(UDAF_MERGE_FUNCTION).invoke(udaf, args);
} catch (Exception e) {
throw new UdfRuntimeException("UDAF failed to merge: ", e);
@@ -215,9 +220,10 @@ public class UdafExecutor {
/**
* invoke getValue to return finally result.
*/
- public boolean getValue(long row) throws UdfRuntimeException {
+ public boolean getValue(long row, long place) throws UdfRuntimeException {
try {
- return storeUdfResult(allMethods.get(UDAF_RESULT_FUNCTION).invoke(udaf, stateObj), row);
+ return storeUdfResult(allMethods.get(UDAF_RESULT_FUNCTION).invoke(udaf, stateObjMap.get((Long) place)),
+ row);
} catch (Exception e) {
throw new UdfRuntimeException("UDAF failed to result", e);
}
@@ -353,6 +359,12 @@ public class UdafExecutor {
Object[] inputObjects = new Object[argTypes.length];
for (int i = 0; i < argTypes.length; ++i) {
+ // skip the input column of current row is null
+ if (UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputNullsPtrs, i)) != -1
+ && UdfUtils.UNSAFE.getByte(null, UdfUtils.getAddressAtOffset(inputNullsPtrs, i) + row) == 1) {
+ inputObjects[i] = null;
+ continue;
+ }
switch (argTypes[i]) {
case BOOLEAN:
inputObjects[i] = UdfUtils.UNSAFE.getBoolean(null,
diff --git a/gensrc/thrift/Types.thrift b/gensrc/thrift/Types.thrift
index 3913f30376..ecf3268de0 100644
--- a/gensrc/thrift/Types.thrift
+++ b/gensrc/thrift/Types.thrift
@@ -385,6 +385,10 @@ struct TJavaUdfExecutorCtorParams {
9: optional i64 output_intermediate_state_ptr
10: optional i64 batch_size_ptr
+
+ // this is used to pass place or places to FE, which could help us call jni
+ // only once and can process a batch size data in JAVA-Udaf
+ 11: optional i64 input_places_ptr
}
// Contains all interesting statistics from a single 'memory pool' in the JVM.
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org