You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@doris.apache.org by yi...@apache.org on 2022/04/21 09:39:07 UTC
[incubator-doris] branch master updated: [UDF] support RPC udaf part 1: support create RPC udaf in fe (#8510)
This is an automated email from the ASF dual-hosted git repository.
yiguolei pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-doris.git
The following commit(s) were added to refs/heads/master by this push:
new ae680b4248 [UDF] support RPC udaf part 1: support create RPC udaf in fe (#8510)
ae680b4248 is described below
commit ae680b4248826aa2266811240fc0d526fa41ac16
Author: Zhengguo Yang <ya...@gmail.com>
AuthorDate: Thu Apr 21 17:38:58 2022 +0800
[UDF] support RPC udaf part 1: support create RPC udaf in fe (#8510)
---
be/src/exec/partitioned_aggregation_node.cc | 8 +-
be/src/exprs/CMakeLists.txt | 3 +-
be/src/exprs/{agg_fn.cc => agg_fn.cpp} | 114 +++--
be/src/exprs/agg_fn.h | 69 +--
be/src/exprs/expr_context.h | 2 +-
be/src/exprs/new_agg_fn_evaluator.cc | 23 +-
be/src/exprs/new_agg_fn_evaluator_ir.cc | 31 --
.../function_rpc.cpp => exprs/rpc_fn.cpp} | 402 +++++++++++----
be/src/exprs/rpc_fn.h | 136 ++++++
be/src/exprs/rpc_fn_call.cpp | 273 +----------
be/src/exprs/rpc_fn_call.h | 20 +-
be/src/udf/udf.cpp | 14 -
be/src/udf/udf_internal.h | 4 -
be/src/vec/core/block.cpp | 12 +-
be/src/vec/core/block.h | 2 +-
be/src/vec/exprs/vectorized_fn_call.cpp | 4 +-
be/src/vec/functions/function_rpc.cpp | 540 +--------------------
be/src/vec/functions/function_rpc.h | 25 +-
contrib/udf/CMakeLists.txt | 16 -
.../udf/native-user-defined-function.md | 2 -
.../Data Definition/create-function.md | 11 +-
.../udf/native-user-defined-function.md | 2 -
.../Data Definition/create-function.md | 56 ++-
.../apache/doris/analysis/CreateFunctionStmt.java | 86 ++--
.../apache/doris/catalog/AggregateFunction.java | 4 +
.../java/org/apache/doris/common/util/URI.java | 5 +
gensrc/proto/function_service.proto | 3 +-
gensrc/proto/types.proto | 12 +-
.../cpp_function_service_demo.cpp | 36 +-
.../org/apache/doris/udf/FunctionServiceImpl.java | 2 +-
.../remote-udf-python-demo/function_server_demo.py | 2 +-
31 files changed, 755 insertions(+), 1164 deletions(-)
diff --git a/be/src/exec/partitioned_aggregation_node.cc b/be/src/exec/partitioned_aggregation_node.cc
index 80b8d49764..a99b0da23b 100644
--- a/be/src/exec/partitioned_aggregation_node.cc
+++ b/be/src/exec/partitioned_aggregation_node.cc
@@ -174,10 +174,10 @@ Status PartitionedAggregationNode::init(const TPlanNode& tnode, RuntimeState* st
SlotDescriptor* intermediate_slot_desc = intermediate_tuple_desc_->slots()[j];
SlotDescriptor* output_slot_desc = output_tuple_desc_->slots()[j];
AggFn* agg_fn;
- RETURN_IF_ERROR(AggFn::Create(tnode.agg_node.aggregate_functions[i], row_desc,
+ RETURN_IF_ERROR(AggFn::create(tnode.agg_node.aggregate_functions[i], row_desc,
*intermediate_slot_desc, *output_slot_desc, state, &agg_fn));
agg_fns_.push_back(agg_fn);
- needs_serialize_ |= agg_fn->SupportsSerialize();
+ needs_serialize_ |= agg_fn->supports_serialize();
}
return Status::OK();
}
@@ -719,7 +719,7 @@ Status PartitionedAggregationNode::close(RuntimeState* state) {
}
Expr::close(grouping_exprs_);
Expr::close(build_exprs_);
- AggFn::Close(agg_fns_);
+ AggFn::close(agg_fns_);
return ExecNode::close(state);
}
@@ -1105,7 +1105,7 @@ void PartitionedAggregationNode::DebugString(int indentation_level, stringstream
<< "intermediate_tuple_id=" << intermediate_tuple_id_
<< " output_tuple_id=" << output_tuple_id_ << " needs_finalize=" << needs_finalize_
<< " grouping_exprs=" << Expr::debug_string(grouping_exprs_)
- << " agg_exprs=" << AggFn::DebugString(agg_fns_);
+ << " agg_exprs=" << AggFn::debug_string(agg_fns_);
ExecNode::debug_string(indentation_level, out);
*out << ")";
}
diff --git a/be/src/exprs/CMakeLists.txt b/be/src/exprs/CMakeLists.txt
index 7ac1af85df..ff0a8037bb 100644
--- a/be/src/exprs/CMakeLists.txt
+++ b/be/src/exprs/CMakeLists.txt
@@ -52,6 +52,7 @@ add_library(Exprs
math_functions.cpp
null_literal.cpp
scalar_fn_call.cpp
+ rpc_fn.cpp
rpc_fn_call.cpp
slot_ref.cpp
string_functions.cpp
@@ -64,7 +65,7 @@ add_library(Exprs
json_functions.cpp
operators.cpp
hll_hash_function.cpp
- agg_fn.cc
+ agg_fn.cpp
new_agg_fn_evaluator.cc
bitmap_function.cpp
hll_function.cpp
diff --git a/be/src/exprs/agg_fn.cc b/be/src/exprs/agg_fn.cpp
similarity index 58%
rename from be/src/exprs/agg_fn.cc
rename to be/src/exprs/agg_fn.cpp
index ca6d41f967..a04ef1ba86 100644
--- a/be/src/exprs/agg_fn.cc
+++ b/be/src/exprs/agg_fn.cpp
@@ -21,6 +21,7 @@
#include "exprs/agg_fn.h"
#include "exprs/anyval_util.h"
+#include "exprs/rpc_fn.h"
#include "runtime/descriptors.h"
#include "runtime/runtime_state.h"
#include "runtime/user_function_cache.h"
@@ -67,7 +68,7 @@ AggFn::AggFn(const TExprNode& tnode, const SlotDescriptor& intermediate_slot_des
}
}
-Status AggFn::Init(const RowDescriptor& row_desc, RuntimeState* state) {
+Status AggFn::init(const RowDescriptor& row_desc, RuntimeState* state) {
// TODO chenhao , calling expr's prepare in NewAggFnEvaluator create
// Initialize all children (i.e. input exprs to this aggregate expr).
//for (Expr* input_expr : children()) {
@@ -89,45 +90,74 @@ Status AggFn::Init(const RowDescriptor& row_desc, RuntimeState* state) {
ss << "Function " << _fn.name.function_name << " is not implemented.";
return Status::InternalError(ss.str());
}
-
- RETURN_IF_ERROR(UserFunctionCache::instance()->get_function_ptr(
- _fn.id, aggregate_fn.init_fn_symbol, _fn.hdfs_location, _fn.checksum, &init_fn_,
- &_cache_entry));
- RETURN_IF_ERROR(UserFunctionCache::instance()->get_function_ptr(
- _fn.id, aggregate_fn.update_fn_symbol, _fn.hdfs_location, _fn.checksum, &update_fn_,
- &_cache_entry));
-
- // Merge() is not defined for purely analytic function.
- if (!aggregate_fn.is_analytic_only_fn) {
+ if (_fn.binary_type == TFunctionBinaryType::NATIVE ||
+ _fn.binary_type == TFunctionBinaryType::BUILTIN ||
+ _fn.binary_type == TFunctionBinaryType::HIVE) {
RETURN_IF_ERROR(UserFunctionCache::instance()->get_function_ptr(
- _fn.id, aggregate_fn.merge_fn_symbol, _fn.hdfs_location, _fn.checksum, &merge_fn_,
+ _fn.id, aggregate_fn.init_fn_symbol, _fn.hdfs_location, _fn.checksum, &_init_fn,
&_cache_entry));
- }
- // Serialize(), GetValue(), Remove() and Finalize() are optional
- if (!aggregate_fn.serialize_fn_symbol.empty()) {
RETURN_IF_ERROR(UserFunctionCache::instance()->get_function_ptr(
- _fn.id, aggregate_fn.serialize_fn_symbol, _fn.hdfs_location, _fn.checksum,
- &serialize_fn_, &_cache_entry));
- }
- if (!aggregate_fn.get_value_fn_symbol.empty()) {
- RETURN_IF_ERROR(UserFunctionCache::instance()->get_function_ptr(
- _fn.id, aggregate_fn.get_value_fn_symbol, _fn.hdfs_location, _fn.checksum,
- &get_value_fn_, &_cache_entry));
- }
- if (!aggregate_fn.remove_fn_symbol.empty()) {
- RETURN_IF_ERROR(UserFunctionCache::instance()->get_function_ptr(
- _fn.id, aggregate_fn.remove_fn_symbol, _fn.hdfs_location, _fn.checksum, &remove_fn_,
+ _fn.id, aggregate_fn.update_fn_symbol, _fn.hdfs_location, _fn.checksum, &_update_fn,
&_cache_entry));
- }
- if (!aggregate_fn.finalize_fn_symbol.empty()) {
- RETURN_IF_ERROR(UserFunctionCache::instance()->get_function_ptr(
- _fn.id, _fn.aggregate_fn.finalize_fn_symbol, _fn.hdfs_location, _fn.checksum,
- &finalize_fn_, &_cache_entry));
+
+ // Merge() is not defined for purely analytic function.
+ if (!aggregate_fn.is_analytic_only_fn) {
+ RETURN_IF_ERROR(UserFunctionCache::instance()->get_function_ptr(
+ _fn.id, aggregate_fn.merge_fn_symbol, _fn.hdfs_location, _fn.checksum,
+ &_merge_fn, &_cache_entry));
+ }
+ // Serialize(), GetValue(), Remove() and Finalize() are optional
+ if (!aggregate_fn.serialize_fn_symbol.empty()) {
+ RETURN_IF_ERROR(UserFunctionCache::instance()->get_function_ptr(
+ _fn.id, aggregate_fn.serialize_fn_symbol, _fn.hdfs_location, _fn.checksum,
+ &_serialize_fn, &_cache_entry));
+ }
+ if (!aggregate_fn.get_value_fn_symbol.empty()) {
+ RETURN_IF_ERROR(UserFunctionCache::instance()->get_function_ptr(
+ _fn.id, aggregate_fn.get_value_fn_symbol, _fn.hdfs_location, _fn.checksum,
+ &_get_value_fn, &_cache_entry));
+ }
+ if (!aggregate_fn.remove_fn_symbol.empty()) {
+ RETURN_IF_ERROR(UserFunctionCache::instance()->get_function_ptr(
+ _fn.id, aggregate_fn.remove_fn_symbol, _fn.hdfs_location, _fn.checksum,
+ &_remove_fn, &_cache_entry));
+ }
+ if (!aggregate_fn.finalize_fn_symbol.empty()) {
+ RETURN_IF_ERROR(UserFunctionCache::instance()->get_function_ptr(
+ _fn.id, _fn.aggregate_fn.finalize_fn_symbol, _fn.hdfs_location, _fn.checksum,
+ &_finalize_fn, &_cache_entry));
+ }
+ } else if (_fn.binary_type == TFunctionBinaryType::RPC) {
+ _rpc_init = std::make_unique<RPCFn>(state, _fn, RPCFn::AggregationStep::INIT, true);
+ _rpc_update = std::make_unique<RPCFn>(state, _fn, RPCFn::AggregationStep::UPDATE, true);
+
+ // Merge() is not defined for purely analytic function.
+ if (!aggregate_fn.is_analytic_only_fn) {
+ _rpc_merge = std::make_unique<RPCFn>(state, _fn, RPCFn::AggregationStep::MERGE, true);
+ }
+ // Serialize(), GetValue(), Remove() and Finalize() are optional
+ if (!aggregate_fn.serialize_fn_symbol.empty()) {
+ _rpc_serialize =
+ std::make_unique<RPCFn>(state, _fn, RPCFn::AggregationStep::SERIALIZE, true);
+ }
+ if (!aggregate_fn.get_value_fn_symbol.empty()) {
+ _rpc_get_value =
+ std::make_unique<RPCFn>(state, _fn, RPCFn::AggregationStep::GET_VALUE, true);
+ }
+ if (!aggregate_fn.remove_fn_symbol.empty()) {
+ _rpc_remove = std::make_unique<RPCFn>(state, _fn, RPCFn::AggregationStep::REMOVE, true);
+ }
+ if (!aggregate_fn.finalize_fn_symbol.empty()) {
+ _rpc_finalize =
+ std::make_unique<RPCFn>(state, _fn, RPCFn::AggregationStep::FINALIZE, true);
+ }
+ } else {
+ return Status::NotSupported(fmt::format("Not supported BinaryType: {}", _fn.binary_type));
}
return Status::OK();
}
-Status AggFn::Create(const TExpr& texpr, const RowDescriptor& row_desc,
+Status AggFn::create(const TExpr& texpr, const RowDescriptor& row_desc,
const SlotDescriptor& intermediate_slot_desc,
const SlotDescriptor& output_slot_desc, RuntimeState* state, AggFn** agg_fn) {
*agg_fn = nullptr;
@@ -140,9 +170,9 @@ Status AggFn::Create(const TExpr& texpr, const RowDescriptor& row_desc,
}
AggFn* new_agg_fn = pool->add(new AggFn(texpr_node, intermediate_slot_desc, output_slot_desc));
RETURN_IF_ERROR(Expr::create_tree(texpr, pool, new_agg_fn));
- Status status = new_agg_fn->Init(row_desc, state);
+ Status status = new_agg_fn->init(row_desc, state);
if (UNLIKELY(!status.ok())) {
- new_agg_fn->Close();
+ new_agg_fn->close();
return status;
}
for (Expr* input_expr : new_agg_fn->children()) {
@@ -153,24 +183,24 @@ Status AggFn::Create(const TExpr& texpr, const RowDescriptor& row_desc,
return Status::OK();
}
-FunctionContext::TypeDesc AggFn::GetIntermediateTypeDesc() const {
+FunctionContext::TypeDesc AggFn::get_intermediate_type_desc() const {
return AnyValUtil::column_type_to_type_desc(intermediate_slot_desc_.type());
}
-FunctionContext::TypeDesc AggFn::GetOutputTypeDesc() const {
+FunctionContext::TypeDesc AggFn::get_output_type_desc() const {
return AnyValUtil::column_type_to_type_desc(output_slot_desc_.type());
}
-void AggFn::Close() {
+void AggFn::close() {
// This also closes all the input expressions.
Expr::close();
}
-void AggFn::Close(const std::vector<AggFn*>& exprs) {
- for (AggFn* expr : exprs) expr->Close();
+void AggFn::close(const std::vector<AggFn*>& exprs) {
+ for (AggFn* expr : exprs) expr->close();
}
-std::string AggFn::DebugString() const {
+std::string AggFn::debug_string() const {
std::stringstream out;
out << "AggFn(op=" << agg_op_;
for (Expr* input_expr : children()) {
@@ -180,11 +210,11 @@ std::string AggFn::DebugString() const {
return out.str();
}
-std::string AggFn::DebugString(const std::vector<AggFn*>& agg_fns) {
+std::string AggFn::debug_string(const std::vector<AggFn*>& agg_fns) {
std::stringstream out;
out << "[";
for (int i = 0; i < agg_fns.size(); ++i) {
- out << (i == 0 ? "" : " ") << agg_fns[i]->DebugString();
+ out << (i == 0 ? "" : " ") << agg_fns[i]->debug_string();
}
out << "]";
return out.str();
diff --git a/be/src/exprs/agg_fn.h b/be/src/exprs/agg_fn.h
index 87083c9484..28342f974a 100644
--- a/be/src/exprs/agg_fn.h
+++ b/be/src/exprs/agg_fn.h
@@ -35,6 +35,7 @@ class RuntimeState;
class Tuple;
class TupleRow;
class TExprNode;
+class RPCFn;
/// --- AggFn overview
///
@@ -52,33 +53,33 @@ class TExprNode;
/// AggFnEvaluator is the interface for evaluating aggregate functions against input
/// tuple rows. It invokes the following functions at different phases of the aggregation:
///
-/// init_fn_ : An initialization function that initializes the aggregate value.
+/// _init_fn : An initialization function that initializes the aggregate value.
///
-/// update_fn_ : An update function that processes the arguments for each row in the
+/// _update_fn : An update function that processes the arguments for each row in the
/// query result set and accumulates an intermediate result. For example,
/// this function might increment a counter, append to a string buffer or
/// add the input to a cumulative sum.
///
-/// merge_fn_ : A merge function that combines multiple intermediate results into a
+/// _merge_fn : A merge function that combines multiple intermediate results into a
/// single value.
///
-/// serialize_fn_: A serialization function that flattens any intermediate values
+/// _serialize_fn: A serialization function that flattens any intermediate values
/// containing pointers, and frees any memory allocated during the init,
/// update and merge phases.
///
-/// finalize_fn_ : A finalize function that either passes through the combined result
+/// _finalize_fn : A finalize function that either passes through the combined result
/// unchanged, or does one final transformation. Also frees the resources
/// allocated during init, update and merge phases.
///
-/// get_value_fn_: Used by AnalyticEval node to obtain the current intermediate value.
+/// _get_value_fn: Used by AnalyticEval node to obtain the current intermediate value.
///
-/// remove_fn_ : Used by AnalyticEval node to undo the update to the intermediate value
+/// _remove_fn : Used by AnalyticEval node to undo the update to the intermediate value
/// by an input row as it falls out of a sliding window.
///
class AggFn : public Expr {
public:
/// Override the base class' implementation.
- virtual bool IsAggFn() const { return true; }
+ virtual bool is_agg_fn() const { return true; }
/// Enum for some built-in aggregation ops.
enum AggregationOp {
@@ -99,7 +100,7 @@ public:
/// the row descriptor of the input tuple row; 'intermediate_slot_desc' is the slot
/// descriptor of the intermediate value; 'output_slot_desc' is the slot descriptor
/// of the output value. On failure, returns error status and sets 'agg_fn' to nullptr.
- static Status Create(const TExpr& texpr, const RowDescriptor& row_desc,
+ static Status create(const TExpr& texpr, const RowDescriptor& row_desc,
const SlotDescriptor& intermediate_slot_desc,
const SlotDescriptor& output_slot_desc, RuntimeState* state,
AggFn** agg_fn) WARN_UNUSED_RESULT;
@@ -115,25 +116,25 @@ public:
const SlotDescriptor& intermediate_slot_desc() const { return intermediate_slot_desc_; }
// Output type is the same as Expr::type().
const SlotDescriptor& output_slot_desc() const { return output_slot_desc_; }
- void* remove_fn() const { return remove_fn_; }
- void* merge_or_update_fn() const { return is_merge_ ? merge_fn_ : update_fn_; }
- void* serialize_fn() const { return serialize_fn_; }
- void* get_value_fn() const { return get_value_fn_; }
- void* finalize_fn() const { return finalize_fn_; }
- bool SupportsRemove() const { return remove_fn_ != nullptr; }
- bool SupportsSerialize() const { return serialize_fn_ != nullptr; }
- FunctionContext::TypeDesc GetIntermediateTypeDesc() const;
- FunctionContext::TypeDesc GetOutputTypeDesc() const;
+ void* remove_fn() const { return _remove_fn; }
+ void* merge_or_update_fn() const { return is_merge_ ? _merge_fn : _update_fn; }
+ void* serialize_fn() const { return _serialize_fn; }
+ void* get_value_fn() const { return _get_value_fn; }
+ void* finalize_fn() const { return _finalize_fn; }
+ bool supports_remove() const { return _remove_fn != nullptr; }
+ bool supports_serialize() const { return _serialize_fn != nullptr; }
+ FunctionContext::TypeDesc get_intermediate_type_desc() const;
+ FunctionContext::TypeDesc get_output_type_desc() const;
const std::vector<FunctionContext::TypeDesc>& arg_type_descs() const { return arg_type_descs_; }
/// Releases all cache entries to libCache for all nodes in the expr tree.
- virtual void Close();
- static void Close(const std::vector<AggFn*>& exprs);
+ virtual void close();
+ static void close(const std::vector<AggFn*>& exprs);
Expr* clone(ObjectPool* pool) const { return nullptr; }
- virtual std::string DebugString() const;
- static std::string DebugString(const std::vector<AggFn*>& exprs);
+ virtual std::string debug_string() const;
+ static std::string debug_string(const std::vector<AggFn*>& exprs);
const int get_vararg_start_idx() const { return _vararg_start_idx; }
@@ -158,22 +159,30 @@ private:
AggregationOp agg_op_;
/// Function pointers for the different phases of the aggregate function.
- void* init_fn_ = nullptr;
- void* update_fn_ = nullptr;
- void* remove_fn_ = nullptr;
- void* merge_fn_ = nullptr;
- void* serialize_fn_ = nullptr;
- void* get_value_fn_ = nullptr;
- void* finalize_fn_ = nullptr;
+ void* _init_fn = nullptr;
+ void* _update_fn = nullptr;
+ void* _remove_fn = nullptr;
+ void* _merge_fn = nullptr;
+ void* _serialize_fn = nullptr;
+ void* _get_value_fn = nullptr;
+ void* _finalize_fn = nullptr;
int _vararg_start_idx;
+ std::unique_ptr<RPCFn> _rpc_init;
+ std::unique_ptr<RPCFn> _rpc_update;
+ std::unique_ptr<RPCFn> _rpc_remove;
+ std::unique_ptr<RPCFn> _rpc_merge;
+ std::unique_ptr<RPCFn> _rpc_serialize;
+ std::unique_ptr<RPCFn> _rpc_get_value;
+ std::unique_ptr<RPCFn> _rpc_finalize;
+
AggFn(const TExprNode& node, const SlotDescriptor& intermediate_slot_desc,
const SlotDescriptor& output_slot_desc);
/// Initializes the AggFn and its input expressions. May load the UDAF from LibCache
/// if necessary.
- virtual Status Init(const RowDescriptor& desc, RuntimeState* state) WARN_UNUSED_RESULT;
+ virtual Status init(const RowDescriptor& desc, RuntimeState* state) WARN_UNUSED_RESULT;
};
} // namespace doris
diff --git a/be/src/exprs/expr_context.h b/be/src/exprs/expr_context.h
index 1b6edc6c31..3a79a1be54 100644
--- a/be/src/exprs/expr_context.h
+++ b/be/src/exprs/expr_context.h
@@ -156,7 +156,7 @@ public:
private:
friend class Expr;
friend class ScalarFnCall;
- friend class RPCFnCall;
+ friend class RPCFn;
friend class InPredicate;
friend class RuntimePredicateWrapper;
friend class BloomFilterPredicate;
diff --git a/be/src/exprs/new_agg_fn_evaluator.cc b/be/src/exprs/new_agg_fn_evaluator.cc
index a81205240b..0c09a78ed5 100644
--- a/be/src/exprs/new_agg_fn_evaluator.cc
+++ b/be/src/exprs/new_agg_fn_evaluator.cc
@@ -116,11 +116,10 @@ Status NewAggFnEvaluator::Create(const AggFn& agg_fn, RuntimeState* state, Objec
*result = nullptr;
// Create a new AggFn evaluator.
- NewAggFnEvaluator* agg_fn_eval =
- pool->add(new NewAggFnEvaluator(agg_fn, mem_pool, false));
+ NewAggFnEvaluator* agg_fn_eval = pool->add(new NewAggFnEvaluator(agg_fn, mem_pool, false));
agg_fn_eval->agg_fn_ctx_.reset(FunctionContextImpl::create_context(
- state, mem_pool, agg_fn.GetIntermediateTypeDesc(), agg_fn.GetOutputTypeDesc(),
+ state, mem_pool, agg_fn.get_intermediate_type_desc(), agg_fn.get_output_type_desc(),
agg_fn.arg_type_descs(), 0, false));
Status status;
@@ -284,7 +283,7 @@ void NewAggFnEvaluator::SetDstSlot(const AnyVal* src, const SlotDescriptor& dst_
// This function would be replaced in codegen.
void NewAggFnEvaluator::Init(Tuple* dst) {
DCHECK(opened_);
- DCHECK(agg_fn_.init_fn_ != nullptr);
+ DCHECK(agg_fn_._init_fn != nullptr);
for (ExprContext* input_eval : input_evals_) {
DCHECK(input_eval->opened());
}
@@ -301,7 +300,7 @@ void NewAggFnEvaluator::Init(Tuple* dst) {
sv->ptr = reinterpret_cast<uint8_t*>(slot);
sv->len = type.len;
}
- reinterpret_cast<InitFn>(agg_fn_.init_fn_)(agg_fn_ctx_.get(), staging_intermediate_val_);
+ reinterpret_cast<InitFn>(agg_fn_._init_fn)(agg_fn_ctx_.get(), staging_intermediate_val_);
SetDstSlot(staging_intermediate_val_, slot_desc, dst);
agg_fn_ctx_->impl()->set_num_updates(0);
agg_fn_ctx_->impl()->set_num_removes(0);
@@ -519,12 +518,12 @@ void NewAggFnEvaluator::Update(const TupleRow* row, Tuple* dst, void* fn) {
}
void NewAggFnEvaluator::Merge(Tuple* src, Tuple* dst) {
- DCHECK(agg_fn_.merge_fn_ != nullptr);
+ DCHECK(agg_fn_._merge_fn != nullptr);
const SlotDescriptor& slot_desc = intermediate_slot_desc();
SetAnyVal(slot_desc, dst, staging_intermediate_val_);
SetAnyVal(slot_desc, src, staging_merge_input_val_);
// The merge fn always takes one input argument.
- reinterpret_cast<UpdateFn1>(agg_fn_.merge_fn_)(agg_fn_ctx_.get(), *staging_merge_input_val_,
+ reinterpret_cast<UpdateFn1>(agg_fn_._merge_fn)(agg_fn_ctx_.get(), *staging_merge_input_val_,
staging_intermediate_val_);
SetDstSlot(staging_intermediate_val_, slot_desc, dst);
}
@@ -650,13 +649,3 @@ void NewAggFnEvaluator::ShallowClone(ObjectPool* pool, MemPool* mem_pool,
cloned_evals->push_back(cloned_eval);
}
}
-
-//
-//void NewAggFnEvaluator::FreeLocalAllocations() {
-// ExprContext::FreeLocalAllocations(input_evals_);
-// agg_fn_ctx_->impl()->FreeLocalAllocations();
-//}
-
-//void NewAggFnEvaluator::FreeLocalAllocations(const vector<NewAggFnEvaluator*>& evals) {
-// for (NewAggFnEvaluator* eval : evals) eval->FreeLocalAllocations();
-//}
diff --git a/be/src/exprs/new_agg_fn_evaluator_ir.cc b/be/src/exprs/new_agg_fn_evaluator_ir.cc
deleted file mode 100644
index 21014e7f40..0000000000
--- a/be/src/exprs/new_agg_fn_evaluator_ir.cc
+++ /dev/null
@@ -1,31 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements. See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership. The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License. You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied. See the License for the
-// specific language governing permissions and limitations
-// under the License.
-// This file is copied from
-// https://github.com/apache/impala/blob/branch-2.10.0/be/src/exprs/agg-fn-evaluator-ir.cc
-// and modified by Doris
-
-#include "exprs/new_agg_fn_evaluator.h"
-
-using namespace doris;
-
-FunctionContext* NewAggFnEvaluator::agg_fn_ctx() const {
- return agg_fn_ctx_.get();
-}
-
-ExprContext* const* NewAggFnEvaluator::input_evals() const {
- return input_evals_.data();
-}
diff --git a/be/src/vec/functions/function_rpc.cpp b/be/src/exprs/rpc_fn.cpp
similarity index 55%
copy from be/src/vec/functions/function_rpc.cpp
copy to be/src/exprs/rpc_fn.cpp
index 9b2e11d08a..63ddc93dfc 100644
--- a/be/src/vec/functions/function_rpc.cpp
+++ b/be/src/exprs/rpc_fn.cpp
@@ -15,19 +15,18 @@
// specific language governing permissions and limitations
// under the License.
-#include "vec/functions/function_rpc.h"
+#include "exprs/rpc_fn.h"
#include <fmt/format.h>
-#include <memory>
-
-#include "gen_cpp/function_service.pb.h"
-#include "runtime/exec_env.h"
+#include "runtime/fragment_mgr.h"
#include "runtime/user_function_cache.h"
#include "service/brpc.h"
#include "util/brpc_client_cache.h"
+#include "vec/columns/column.h"
#include "vec/columns/column_vector.h"
#include "vec/core/block.h"
+#include "vec/core/column_numbers.h"
#include "vec/data_types/data_type_bitmap.h"
#include "vec/data_types/data_type_date.h"
#include "vec/data_types/data_type_date_time.h"
@@ -36,65 +35,270 @@
#include "vec/data_types/data_type_number.h"
#include "vec/data_types/data_type_string.h"
-namespace doris::vectorized {
-RPCFnCall::RPCFnCall(const std::string& symbol, const std::string& server,
- const DataTypes& argument_types, const DataTypePtr& return_type)
- : _symbol(symbol),
- _server(server),
- _name(fmt::format("{}/{}", server, symbol)),
- _argument_types(argument_types),
- _return_type(return_type) {}
-Status RPCFnCall::prepare(FunctionContext* context, FunctionContext::FunctionStateScope scope) {
- _client = ExecEnv::GetInstance()->brpc_function_client_cache()->get_client(_server);
+namespace doris {
+
+RPCFn::RPCFn(RuntimeState* state, const TFunction& fn, int fn_ctx_id, bool is_agg)
+ : _state(state), _fn(fn), _fn_ctx_id(fn_ctx_id), _is_agg(is_agg) {
+ _client = ExecEnv::GetInstance()->brpc_function_client_cache()->get_client(_server_addr);
+ if (!_is_agg) {
+ _function_name = _fn.scalar_fn.symbol;
+ _server_addr = _fn.hdfs_location;
+ _signature = fmt::format("{}: [{}/{}]", _fn.name.function_name, _fn.hdfs_location,
+ _fn.scalar_fn.symbol);
+ }
+}
+
+RPCFn::RPCFn(const TFunction& fn, bool is_agg) : RPCFn(nullptr, fn, -1, is_agg) {}
+
+RPCFn::RPCFn(RuntimeState* state, const TFunction& fn, AggregationStep step, bool is_agg)
+ : RPCFn(nullptr, fn, -1, is_agg) {
+ _step = step;
+ DCHECK(is_agg) << "Only used for agg fns";
+ switch (_step) {
+ case INIT: {
+ _function_name = _fn.aggregate_fn.init_fn_symbol;
+ _server_addr = _fn.hdfs_location;
+ _signature = fmt::format("{}: [{}/{}]", _fn.name.function_name, _fn.hdfs_location,
+ _fn.aggregate_fn.init_fn_symbol);
+ break;
+ }
+ case UPDATE: {
+ _function_name = _fn.aggregate_fn.init_fn_symbol;
+ break;
+ }
+ case MERGE: {
+ _function_name = _fn.aggregate_fn.merge_fn_symbol;
+ break;
+ }
+ case SERIALIZE: {
+ _function_name = _fn.aggregate_fn.serialize_fn_symbol;
+ break;
+ }
+ case GET_VALUE: {
+ _function_name = _fn.aggregate_fn.get_value_fn_symbol;
+ break;
+ }
+ case FINALIZE: {
+ _function_name = _fn.aggregate_fn.finalize_fn_symbol;
+ break;
+ }
+ case REMOVE: {
+ _function_name = _fn.aggregate_fn.remove_fn_symbol;
+ break;
+ }
+
+ default:
+ CHECK(false) << "invalid AggregationStep: " << _step;
+ break;
+ }
+ _server_addr = _fn.hdfs_location;
+ _signature = fmt::format("{}: [{}/{}]", _fn.name.function_name, _server_addr, _function_name);
+}
+
+Status RPCFn::call_internal(ExprContext* context, TupleRow* row, PFunctionCallResponse* response,
+ const std::vector<Expr*>& exprs) {
+ FunctionContext* fn_ctx = context->fn_context(_fn_ctx_id);
+ PFunctionCallRequest request;
+ request.set_function_name(_function_name);
+ for (int i = 0; i < exprs.size(); ++i) {
+ PValues* arg = request.add_args();
+ void* src_slot = context->get_value(exprs[i], row);
+ PGenericType* ptype = arg->mutable_type();
+ if (src_slot == nullptr) {
+ arg->set_has_null(true);
+ arg->add_null_map(true);
+ } else {
+ arg->set_has_null(false);
+ }
+ switch (exprs[i]->type().type) {
+ case TYPE_BOOLEAN: {
+ ptype->set_id(PGenericType::BOOLEAN);
+ arg->add_bool_value(*(bool*)src_slot);
+ break;
+ }
+ case TYPE_TINYINT: {
+ ptype->set_id(PGenericType::INT8);
+ arg->add_int32_value(*(int8_t*)src_slot);
+ break;
+ }
+ case TYPE_SMALLINT: {
+ ptype->set_id(PGenericType::INT16);
+ arg->add_int32_value(*(int16_t*)src_slot);
+ break;
+ }
+ case TYPE_INT: {
+ ptype->set_id(PGenericType::INT32);
+ arg->add_int32_value(*(int*)src_slot);
+ break;
+ }
+ case TYPE_BIGINT: {
+ ptype->set_id(PGenericType::INT64);
+ arg->add_int64_value(*(int64_t*)src_slot);
+ break;
+ }
+ case TYPE_LARGEINT: {
+ ptype->set_id(PGenericType::INT128);
+ char buffer[sizeof(__int128)];
+ memcpy(buffer, src_slot, sizeof(__int128));
+ arg->add_bytes_value(buffer, sizeof(__int128));
+ break;
+ }
+ case TYPE_DOUBLE: {
+ ptype->set_id(PGenericType::DOUBLE);
+ arg->add_double_value(*(double*)src_slot);
+ break;
+ }
+ case TYPE_FLOAT: {
+ ptype->set_id(PGenericType::FLOAT);
+ arg->add_float_value(*(float*)src_slot);
+ break;
+ }
+ case TYPE_VARCHAR:
+ case TYPE_STRING:
+ case TYPE_CHAR: {
+ ptype->set_id(PGenericType::STRING);
+ StringValue value = *reinterpret_cast<StringValue*>(src_slot);
+ arg->add_string_value(value.ptr, value.len);
+ break;
+ }
+ case TYPE_HLL: {
+ ptype->set_id(PGenericType::HLL);
+ StringValue value = *reinterpret_cast<StringValue*>(src_slot);
+ arg->add_string_value(value.ptr, value.len);
+ break;
+ }
+ case TYPE_OBJECT: {
+ ptype->set_id(PGenericType::BITMAP);
+ StringValue value = *reinterpret_cast<StringValue*>(src_slot);
+ arg->add_string_value(value.ptr, value.len);
+ break;
+ }
+ case TYPE_DECIMALV2: {
+ ptype->set_id(PGenericType::DECIMAL128);
+ ptype->mutable_decimal_type()->set_precision(exprs[i]->type().precision);
+ ptype->mutable_decimal_type()->set_scale(exprs[i]->type().scale);
+ char buffer[sizeof(__int128)];
+ memcpy(buffer, src_slot, sizeof(__int128));
+ arg->add_bytes_value(buffer, sizeof(__int128));
+ break;
+ }
+ case TYPE_DATE: {
+ ptype->set_id(PGenericType::DATE);
+ const auto* time_val = (const DateTimeValue*)(src_slot);
+ PDateTime* date_time = arg->add_datetime_value();
+ date_time->set_day(time_val->day());
+ date_time->set_month(time_val->month());
+ date_time->set_year(time_val->year());
+ break;
+ }
+ case TYPE_DATETIME: {
+ ptype->set_id(PGenericType::DATETIME);
+ const auto* time_val = (const DateTimeValue*)(src_slot);
+ PDateTime* date_time = arg->add_datetime_value();
+ date_time->set_day(time_val->day());
+ date_time->set_month(time_val->month());
+ date_time->set_year(time_val->year());
+ date_time->set_hour(time_val->hour());
+ date_time->set_minute(time_val->minute());
+ date_time->set_second(time_val->second());
+ date_time->set_microsecond(time_val->microsecond());
+ break;
+ }
+ case TYPE_TIME: {
+ ptype->set_id(PGenericType::DATETIME);
+ const auto* time_val = (const DateTimeValue*)(src_slot);
+ PDateTime* date_time = arg->add_datetime_value();
+ date_time->set_hour(time_val->hour());
+ date_time->set_minute(time_val->minute());
+ date_time->set_second(time_val->second());
+ date_time->set_microsecond(time_val->microsecond());
+ break;
+ }
+ default: {
+ std::string error_msg =
+ fmt::format("data time not supported: {}", exprs[i]->type().type);
+ fn_ctx->set_error(error_msg.c_str());
+ cancel(error_msg);
+ break;
+ }
+ }
+ }
- if (_client == nullptr) {
- return Status::InternalError("rpc env init error");
+ brpc::Controller cntl;
+ _client->fn_call(&cntl, &request, response, nullptr);
+ if (cntl.Failed()) {
+ std::string error_msg =
+ fmt::format("call rpc function {} failed: {}", _signature, cntl.ErrorText());
+ fn_ctx->set_error(error_msg.c_str());
+ cancel(error_msg);
+ return Status::InternalError(error_msg);
+ }
+ if (!response->has_status() || response->result_size() == 0) {
+ std::string error_msg =
+ fmt::format("call rpc function {} failed: status or result is not set: {}",
+ _signature, response->status().DebugString());
+ fn_ctx->set_error(error_msg.c_str());
+ cancel(error_msg);
+ return Status::InternalError(error_msg);
+ }
+ if (response->status().status_code() != 0) {
+ std::string error_msg = fmt::format("call rpc function {} failed: {}", _signature,
+ response->status().DebugString());
+ fn_ctx->set_error(error_msg.c_str());
+ cancel(error_msg);
+ return Status::InternalError(error_msg);
}
return Status::OK();
}
+void RPCFn::cancel(const std::string& msg) {
+ _state->exec_env()->fragment_mgr()->cancel(_state->fragment_instance_id(),
+ PPlanFragmentCancelReason::CALL_RPC_ERROR, msg);
+}
+
template <bool nullable>
-void convert_col_to_pvalue(const ColumnPtr& column, const DataTypePtr& data_type, PValues* arg,
+void convert_col_to_pvalue(const vectorized::ColumnPtr& column,
+ const vectorized::DataTypePtr& data_type, PValues* arg,
size_t row_count) {
PGenericType* ptype = arg->mutable_type();
switch (data_type->get_type_id()) {
- case TypeIndex::UInt8: {
+ case vectorized::TypeIndex::UInt8: {
ptype->set_id(PGenericType::UINT8);
auto* values = arg->mutable_bool_value();
values->Reserve(row_count);
- const auto* col = check_and_get_column<ColumnUInt8>(column);
+ const auto* col = vectorized::check_and_get_column<vectorized::ColumnUInt8>(column);
auto& data = col->get_data();
values->Add(data.begin(), data.begin() + row_count);
break;
}
- case TypeIndex::UInt16: {
+ case vectorized::TypeIndex::UInt16: {
ptype->set_id(PGenericType::UINT16);
auto* values = arg->mutable_uint32_value();
values->Reserve(row_count);
- const auto* col = check_and_get_column<ColumnUInt16>(column);
+ const auto* col = vectorized::check_and_get_column<vectorized::ColumnUInt16>(column);
auto& data = col->get_data();
values->Add(data.begin(), data.begin() + row_count);
break;
}
- case TypeIndex::UInt32: {
+ case vectorized::TypeIndex::UInt32: {
ptype->set_id(PGenericType::UINT32);
auto* values = arg->mutable_uint32_value();
values->Reserve(row_count);
- const auto* col = check_and_get_column<ColumnUInt32>(column);
+ const auto* col = vectorized::check_and_get_column<vectorized::ColumnUInt32>(column);
auto& data = col->get_data();
values->Add(data.begin(), data.begin() + row_count);
break;
}
- case TypeIndex::UInt64: {
+ case vectorized::TypeIndex::UInt64: {
ptype->set_id(PGenericType::UINT64);
auto* values = arg->mutable_uint64_value();
values->Reserve(row_count);
- const auto* col = check_and_get_column<ColumnUInt64>(column);
+ const auto* col = vectorized::check_and_get_column<vectorized::ColumnUInt64>(column);
auto& data = col->get_data();
values->Add(data.begin(), data.begin() + row_count);
break;
}
- case TypeIndex::UInt128: {
+ case vectorized::TypeIndex::UInt128: {
ptype->set_id(PGenericType::UINT128);
arg->mutable_bytes_value()->Reserve(row_count);
for (size_t row_num = 0; row_num < row_count; ++row_num) {
@@ -112,43 +316,43 @@ void convert_col_to_pvalue(const ColumnPtr& column, const DataTypePtr& data_type
}
break;
}
- case TypeIndex::Int8: {
+ case vectorized::TypeIndex::Int8: {
ptype->set_id(PGenericType::INT8);
auto* values = arg->mutable_int32_value();
values->Reserve(row_count);
- const auto* col = check_and_get_column<ColumnInt8>(column);
+ const auto* col = vectorized::check_and_get_column<vectorized::ColumnInt8>(column);
auto& data = col->get_data();
values->Add(data.begin(), data.begin() + row_count);
break;
}
- case TypeIndex::Int16: {
+ case vectorized::TypeIndex::Int16: {
ptype->set_id(PGenericType::INT16);
auto* values = arg->mutable_int32_value();
values->Reserve(row_count);
- const auto* col = check_and_get_column<ColumnInt16>(column);
+ const auto* col = vectorized::check_and_get_column<vectorized::ColumnInt16>(column);
auto& data = col->get_data();
values->Add(data.begin(), data.begin() + row_count);
break;
}
- case TypeIndex::Int32: {
+ case vectorized::TypeIndex::Int32: {
ptype->set_id(PGenericType::INT32);
auto* values = arg->mutable_int32_value();
values->Reserve(row_count);
- const auto* col = check_and_get_column<ColumnInt32>(column);
+ const auto* col = vectorized::check_and_get_column<vectorized::ColumnInt32>(column);
auto& data = col->get_data();
values->Add(data.begin(), data.begin() + row_count);
break;
}
- case TypeIndex::Int64: {
+ case vectorized::TypeIndex::Int64: {
ptype->set_id(PGenericType::INT64);
auto* values = arg->mutable_int64_value();
values->Reserve(row_count);
- const auto* col = check_and_get_column<ColumnInt64>(column);
+ const auto* col = vectorized::check_and_get_column<vectorized::ColumnInt64>(column);
auto& data = col->get_data();
values->Add(data.begin(), data.begin() + row_count);
break;
}
- case TypeIndex::Int128: {
+ case vectorized::TypeIndex::Int128: {
ptype->set_id(PGenericType::INT128);
arg->mutable_bytes_value()->Reserve(row_count);
for (size_t row_num = 0; row_num < row_count; ++row_num) {
@@ -166,28 +370,29 @@ void convert_col_to_pvalue(const ColumnPtr& column, const DataTypePtr& data_type
}
break;
}
- case TypeIndex::Float32: {
+ case vectorized::TypeIndex::Float32: {
ptype->set_id(PGenericType::FLOAT);
auto* values = arg->mutable_float_value();
values->Reserve(row_count);
- const auto* col = check_and_get_column<ColumnFloat32>(column);
+ const auto* col = vectorized::check_and_get_column<vectorized::ColumnFloat32>(column);
auto& data = col->get_data();
values->Add(data.begin(), data.begin() + row_count);
break;
}
- case TypeIndex::Float64: {
+ case vectorized::TypeIndex::Float64: {
ptype->set_id(PGenericType::DOUBLE);
auto* values = arg->mutable_double_value();
values->Reserve(row_count);
- const auto* col = check_and_get_column<ColumnFloat64>(column);
+ const auto* col = vectorized::check_and_get_column<vectorized::ColumnFloat64>(column);
auto& data = col->get_data();
values->Add(data.begin(), data.begin() + row_count);
break;
}
- case TypeIndex::Decimal128: {
+ case vectorized::TypeIndex::Decimal128: {
ptype->set_id(PGenericType::DECIMAL128);
- auto dec_type = std::reinterpret_pointer_cast<const DataTypeDecimal<Decimal128>>(data_type);
+ auto dec_type = std::reinterpret_pointer_cast<
+ const vectorized::DataTypeDecimal<vectorized::Decimal128>>(data_type);
ptype->mutable_decimal_type()->set_precision(dec_type->get_precision());
ptype->mutable_decimal_type()->set_scale(dec_type->get_scale());
arg->mutable_bytes_value()->Reserve(row_count);
@@ -206,7 +411,7 @@ void convert_col_to_pvalue(const ColumnPtr& column, const DataTypePtr& data_type
}
break;
}
- case TypeIndex::String: {
+ case vectorized::TypeIndex::String: {
ptype->set_id(PGenericType::STRING);
arg->mutable_bytes_value()->Reserve(row_count);
for (size_t row_num = 0; row_num < row_count; ++row_num) {
@@ -224,23 +429,22 @@ void convert_col_to_pvalue(const ColumnPtr& column, const DataTypePtr& data_type
}
break;
}
- case TypeIndex::Date: {
+ case vectorized::TypeIndex::Date: {
ptype->set_id(PGenericType::DATE);
arg->mutable_datetime_value()->Reserve(row_count);
for (size_t row_num = 0; row_num < row_count; ++row_num) {
PDateTime* date_time = arg->add_datetime_value();
if constexpr (nullable) {
if (!column->is_null_at(row_num)) {
- VecDateTimeValue v =
- binary_cast<vectorized::Int64, vectorized::VecDateTimeValue>(
- column->get_int(row_num));
+ vectorized::VecDateTimeValue v =
+ vectorized::VecDateTimeValue::create_from_olap_date(column->get_int(row_num));
date_time->set_day(v.day());
date_time->set_month(v.month());
date_time->set_year(v.year());
}
} else {
- VecDateTimeValue v = binary_cast<vectorized::Int64, vectorized::VecDateTimeValue>(
- column->get_int(row_num));
+ vectorized::VecDateTimeValue v =
+ vectorized::VecDateTimeValue::create_from_olap_date(column->get_int(row_num));
date_time->set_day(v.day());
date_time->set_month(v.month());
date_time->set_year(v.year());
@@ -248,16 +452,15 @@ void convert_col_to_pvalue(const ColumnPtr& column, const DataTypePtr& data_type
}
break;
}
- case TypeIndex::DateTime: {
+ case vectorized::TypeIndex::DateTime: {
ptype->set_id(PGenericType::DATETIME);
arg->mutable_datetime_value()->Reserve(row_count);
for (size_t row_num = 0; row_num < row_count; ++row_num) {
PDateTime* date_time = arg->add_datetime_value();
if constexpr (nullable) {
if (!column->is_null_at(row_num)) {
- VecDateTimeValue v =
- binary_cast<vectorized::Int64, vectorized::VecDateTimeValue>(
- column->get_int(row_num));
+ vectorized::VecDateTimeValue v =
+ vectorized::VecDateTimeValue::create_from_olap_datetime(column->get_int(row_num));
date_time->set_day(v.day());
date_time->set_month(v.month());
date_time->set_year(v.year());
@@ -266,8 +469,8 @@ void convert_col_to_pvalue(const ColumnPtr& column, const DataTypePtr& data_type
date_time->set_second(v.second());
}
} else {
- VecDateTimeValue v = binary_cast<vectorized::Int64, vectorized::VecDateTimeValue>(
- column->get_int(row_num));
+ vectorized::VecDateTimeValue v =
+ vectorized::VecDateTimeValue::create_from_olap_datetime(column->get_int(row_num));
date_time->set_day(v.day());
date_time->set_month(v.month());
date_time->set_year(v.year());
@@ -278,7 +481,7 @@ void convert_col_to_pvalue(const ColumnPtr& column, const DataTypePtr& data_type
}
break;
}
- case TypeIndex::BitMap: {
+ case vectorized::TypeIndex::BitMap: {
ptype->set_id(PGenericType::BITMAP);
arg->mutable_bytes_value()->Reserve(row_count);
for (size_t row_num = 0; row_num < row_count; ++row_num) {
@@ -296,7 +499,7 @@ void convert_col_to_pvalue(const ColumnPtr& column, const DataTypePtr& data_type
}
break;
}
- case TypeIndex::HLL: {
+ case vectorized::TypeIndex::HLL: {
ptype->set_id(PGenericType::HLL);
arg->mutable_bytes_value()->Reserve(row_count);
for (size_t row_num = 0; row_num < row_count; ++row_num) {
@@ -321,12 +524,14 @@ void convert_col_to_pvalue(const ColumnPtr& column, const DataTypePtr& data_type
}
}
-void convert_nullable_col_to_pvalue(const ColumnPtr& column, const DataTypePtr& data_type,
- const ColumnUInt8& null_col, PValues* arg, size_t row_count) {
+void convert_nullable_col_to_pvalue(const vectorized::ColumnPtr& column,
+ const vectorized::DataTypePtr& data_type,
+ const vectorized::ColumnUInt8& null_col, PValues* arg,
+ size_t row_count) {
if (column->has_null(row_count)) {
auto* null_map = arg->mutable_null_map();
null_map->Reserve(row_count);
- const auto* col = check_and_get_column<ColumnUInt8>(null_col);
+ const auto* col = vectorized::check_and_get_column<vectorized::ColumnUInt8>(null_col);
auto& data = col->get_data();
null_map->Add(data.begin(), data.begin() + row_count);
convert_col_to_pvalue<true>(column, data_type, arg, row_count);
@@ -335,18 +540,20 @@ void convert_nullable_col_to_pvalue(const ColumnPtr& column, const DataTypePtr&
}
}
-void convert_block_to_proto(Block& block, const ColumnNumbers& arguments, size_t input_rows_count,
- PFunctionCallRequest* request) {
+void convert_block_to_proto(vectorized::Block& block, const vectorized::ColumnNumbers& arguments,
+ size_t input_rows_count, PFunctionCallRequest* request) {
size_t row_count = std::min(block.rows(), input_rows_count);
for (size_t col_idx : arguments) {
PValues* arg = request->add_args();
- ColumnWithTypeAndName& column = block.get_by_position(col_idx);
+ vectorized::ColumnWithTypeAndName& column = block.get_by_position(col_idx);
arg->set_has_null(column.column->has_null(row_count));
auto col = column.column->convert_to_full_column_if_const();
- if (auto* nullable = check_and_get_column<const ColumnNullable>(*col)) {
+ if (auto* nullable =
+ vectorized::check_and_get_column<const vectorized::ColumnNullable>(*col)) {
auto data_col = nullable->get_nested_column_ptr();
auto& null_col = nullable->get_null_map_column();
- auto data_type = std::reinterpret_pointer_cast<const DataTypeNullable>(column.type);
+ auto data_type =
+ std::reinterpret_pointer_cast<const vectorized::DataTypeNullable>(column.type);
convert_nullable_col_to_pvalue(data_col->convert_to_full_column_if_const(),
data_type->get_nested_type(), null_col, arg, row_count);
} else {
@@ -356,12 +563,12 @@ void convert_block_to_proto(Block& block, const ColumnNumbers& arguments, size_t
}
template <bool nullable>
-void convert_to_column(MutableColumnPtr& column, const PValues& result) {
+void convert_to_column(vectorized::MutableColumnPtr& column, const PValues& result) {
switch (result.type().id()) {
case PGenericType::UINT8: {
column->reserve(result.uint32_value_size());
column->resize(result.uint32_value_size());
- auto& data = reinterpret_cast<ColumnUInt8*>(column.get())->get_data();
+ auto& data = reinterpret_cast<vectorized::ColumnUInt8*>(column.get())->get_data();
for (int i = 0; i < result.uint32_value_size(); ++i) {
data[i] = result.uint32_value(i);
}
@@ -370,7 +577,7 @@ void convert_to_column(MutableColumnPtr& column, const PValues& result) {
case PGenericType::UINT16: {
column->reserve(result.uint32_value_size());
column->resize(result.uint32_value_size());
- auto& data = reinterpret_cast<ColumnUInt16*>(column.get())->get_data();
+ auto& data = reinterpret_cast<vectorized::ColumnUInt16*>(column.get())->get_data();
for (int i = 0; i < result.uint32_value_size(); ++i) {
data[i] = result.uint32_value(i);
}
@@ -379,7 +586,7 @@ void convert_to_column(MutableColumnPtr& column, const PValues& result) {
case PGenericType::UINT32: {
column->reserve(result.uint32_value_size());
column->resize(result.uint32_value_size());
- auto& data = reinterpret_cast<ColumnUInt32*>(column.get())->get_data();
+ auto& data = reinterpret_cast<vectorized::ColumnUInt32*>(column.get())->get_data();
for (int i = 0; i < result.uint32_value_size(); ++i) {
data[i] = result.uint32_value(i);
}
@@ -388,7 +595,7 @@ void convert_to_column(MutableColumnPtr& column, const PValues& result) {
case PGenericType::UINT64: {
column->reserve(result.uint64_value_size());
column->resize(result.uint64_value_size());
- auto& data = reinterpret_cast<ColumnUInt64*>(column.get())->get_data();
+ auto& data = reinterpret_cast<vectorized::ColumnUInt64*>(column.get())->get_data();
for (int i = 0; i < result.uint64_value_size(); ++i) {
data[i] = result.uint64_value(i);
}
@@ -397,7 +604,7 @@ void convert_to_column(MutableColumnPtr& column, const PValues& result) {
case PGenericType::INT8: {
column->reserve(result.int32_value_size());
column->resize(result.int32_value_size());
- auto& data = reinterpret_cast<ColumnInt16*>(column.get())->get_data();
+ auto& data = reinterpret_cast<vectorized::ColumnInt16*>(column.get())->get_data();
for (int i = 0; i < result.int32_value_size(); ++i) {
data[i] = result.int32_value(i);
}
@@ -406,7 +613,7 @@ void convert_to_column(MutableColumnPtr& column, const PValues& result) {
case PGenericType::INT16: {
column->reserve(result.int32_value_size());
column->resize(result.int32_value_size());
- auto& data = reinterpret_cast<ColumnInt16*>(column.get())->get_data();
+ auto& data = reinterpret_cast<vectorized::ColumnInt16*>(column.get())->get_data();
for (int i = 0; i < result.int32_value_size(); ++i) {
data[i] = result.int32_value(i);
}
@@ -415,7 +622,7 @@ void convert_to_column(MutableColumnPtr& column, const PValues& result) {
case PGenericType::INT32: {
column->reserve(result.int32_value_size());
column->resize(result.int32_value_size());
- auto& data = reinterpret_cast<ColumnInt32*>(column.get())->get_data();
+ auto& data = reinterpret_cast<vectorized::ColumnInt32*>(column.get())->get_data();
for (int i = 0; i < result.int32_value_size(); ++i) {
data[i] = result.int32_value(i);
}
@@ -424,7 +631,7 @@ void convert_to_column(MutableColumnPtr& column, const PValues& result) {
case PGenericType::INT64: {
column->reserve(result.int64_value_size());
column->resize(result.int64_value_size());
- auto& data = reinterpret_cast<ColumnInt64*>(column.get())->get_data();
+ auto& data = reinterpret_cast<vectorized::ColumnInt64*>(column.get())->get_data();
for (int i = 0; i < result.int64_value_size(); ++i) {
data[i] = result.int64_value(i);
}
@@ -434,19 +641,19 @@ void convert_to_column(MutableColumnPtr& column, const PValues& result) {
case PGenericType::DATETIME: {
column->reserve(result.datetime_value_size());
column->resize(result.datetime_value_size());
- auto& data = reinterpret_cast<ColumnInt64*>(column.get())->get_data();
+ auto& data = reinterpret_cast<vectorized::ColumnInt64*>(column.get())->get_data();
for (int i = 0; i < result.datetime_value_size(); ++i) {
- VecDateTimeValue v;
+ vectorized::VecDateTimeValue v;
PDateTime pv = result.datetime_value(i);
v.set_time(pv.year(), pv.month(), pv.day(), pv.hour(), pv.minute(), pv.minute());
- data[i] = binary_cast<VecDateTimeValue, Int64>(v);
+ data[i] = binary_cast<vectorized::VecDateTimeValue, vectorized::Int64>(v);
}
break;
}
case PGenericType::FLOAT: {
column->reserve(result.float_value_size());
column->resize(result.float_value_size());
- auto& data = reinterpret_cast<ColumnFloat32*>(column.get())->get_data();
+ auto& data = reinterpret_cast<vectorized::ColumnFloat32*>(column.get())->get_data();
for (int i = 0; i < result.float_value_size(); ++i) {
data[i] = result.float_value(i);
}
@@ -455,7 +662,7 @@ void convert_to_column(MutableColumnPtr& column, const PValues& result) {
case PGenericType::DOUBLE: {
column->reserve(result.double_value_size());
column->resize(result.double_value_size());
- auto& data = reinterpret_cast<ColumnFloat64*>(column.get())->get_data();
+ auto& data = reinterpret_cast<vectorized::ColumnFloat64*>(column.get())->get_data();
for (int i = 0; i < result.double_value_size(); ++i) {
data[i] = result.double_value(i);
}
@@ -464,7 +671,7 @@ void convert_to_column(MutableColumnPtr& column, const PValues& result) {
case PGenericType::INT128: {
column->reserve(result.bytes_value_size());
column->resize(result.bytes_value_size());
- auto& data = reinterpret_cast<ColumnInt128*>(column.get())->get_data();
+ auto& data = reinterpret_cast<vectorized::ColumnInt128*>(column.get())->get_data();
for (int i = 0; i < result.bytes_value_size(); ++i) {
data[i] = *(int128_t*)(result.bytes_value(i).c_str());
}
@@ -480,7 +687,7 @@ void convert_to_column(MutableColumnPtr& column, const PValues& result) {
case PGenericType::DECIMAL128: {
column->reserve(result.bytes_value_size());
column->resize(result.bytes_value_size());
- auto& data = reinterpret_cast<ColumnDecimal128*>(column.get())->get_data();
+ auto& data = reinterpret_cast<vectorized::ColumnDecimal128*>(column.get())->get_data();
for (int i = 0; i < result.bytes_value_size(); ++i) {
data[i] = *(int128_t*)(result.bytes_value(i).c_str());
}
@@ -507,13 +714,14 @@ void convert_to_column(MutableColumnPtr& column, const PValues& result) {
}
}
-void convert_to_block(Block& block, const PValues& result, size_t pos) {
+void convert_to_block(vectorized::Block& block, const PValues& result, size_t pos) {
auto data_type = block.get_data_type(pos);
if (data_type->is_nullable()) {
- auto null_type = std::reinterpret_pointer_cast<const DataTypeNullable>(data_type);
+ auto null_type =
+ std::reinterpret_pointer_cast<const vectorized::DataTypeNullable>(data_type);
auto data_col = null_type->get_nested_type()->create_column();
convert_to_column<true>(data_col, result);
- auto null_col = ColumnUInt8::create(data_col->size(), 0);
+ auto null_col = vectorized::ColumnUInt8::create(data_col->size(), 0);
auto& null_map_data = null_col->get_data();
null_col->reserve(data_col->size());
null_col->resize(data_col->size());
@@ -526,37 +734,37 @@ void convert_to_block(Block& block, const PValues& result, size_t pos) {
null_map_data[i] = false;
}
}
- block.replace_by_position(pos,
- ColumnNullable::create(std::move(data_col), std::move(null_col)));
+ block.replace_by_position(
+ pos, vectorized::ColumnNullable::create(std::move(data_col), std::move(null_col)));
} else {
auto column = data_type->create_column();
convert_to_column<false>(column, result);
block.replace_by_position(pos, std::move(column));
}
}
-
-Status RPCFnCall::execute(FunctionContext* context, Block& block, const ColumnNumbers& arguments,
- size_t result, size_t input_rows_count, bool dry_run) {
+Status RPCFn::vec_call(FunctionContext* context, vectorized::Block& block,
+ const vectorized::ColumnNumbers& arguments, size_t result,
+ size_t input_rows_count) {
PFunctionCallRequest request;
PFunctionCallResponse response;
- request.set_function_name(_symbol);
+ request.set_function_name(_function_name);
convert_block_to_proto(block, arguments, input_rows_count, &request);
brpc::Controller cntl;
_client->fn_call(&cntl, &request, &response, nullptr);
if (cntl.Failed()) {
return Status::InternalError(
- fmt::format("call to rpc function {} failed: {}", _symbol, cntl.ErrorText())
+ fmt::format("call to rpc function {} failed: {}", _signature, cntl.ErrorText())
.c_str());
}
- if (!response.has_status() || !response.has_result()) {
- return Status::InternalError(
- fmt::format("call rpc function {} failed: status or result is not set.", _symbol));
+ if (!response.has_status() || response.result_size() == 0) {
+ return Status::InternalError(fmt::format(
+ "call rpc function {} failed: status or result is not set.", _signature));
}
if (response.status().status_code() != 0) {
- return Status::InternalError(fmt::format("call to rpc function {} failed: {}", _symbol,
+ return Status::InternalError(fmt::format("call to rpc function {} failed: {}", _signature,
response.status().DebugString()));
}
- convert_to_block(block, response.result(), result);
+ convert_to_block(block, response.result(0), result);
return Status::OK();
}
-} // namespace doris::vectorized
+} // namespace doris
diff --git a/be/src/exprs/rpc_fn.h b/be/src/exprs/rpc_fn.h
new file mode 100644
index 0000000000..154f158640
--- /dev/null
+++ b/be/src/exprs/rpc_fn.h
@@ -0,0 +1,136 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+#include <string>
+
+#include "common/status.h"
+#include "exprs/expr.h"
+#include "exprs/expr_context.h"
+#include "gen_cpp/function_service.pb.h"
+#include "runtime/runtime_state.h"
+#include "udf/udf.h"
+
+namespace doris {
+namespace vectorized {
+class Block;
+} // namespace vectorized
+
+class RPCFn {
+public:
+ enum AggregationStep {
+ INIT = 0,
+ UPDATE = 1,
+ MERGE = 2,
+ REMOVE = 3,
+ SERIALIZE = 4,
+ GET_VALUE = 5,
+ FINALIZE = 6,
+ INVALID = 999,
+ };
+
+ RPCFn(RuntimeState* state, const TFunction& fn, int fn_ctx_id, bool is_agg);
+ RPCFn(const TFunction& fn, bool is_agg);
+ RPCFn(RuntimeState* state, const TFunction& fn, AggregationStep step, bool is_agg);
+ ~RPCFn() {}
+ template <typename T>
+ T call(ExprContext* context, TupleRow* row, const std::vector<Expr*>& exprs);
+ Status vec_call(FunctionContext* context, vectorized::Block& block,
+ const std::vector<size_t>& arguments, size_t result, size_t input_rows_count);
+ bool avliable() { return _client != nullptr; }
+
+private:
+ Status call_internal(ExprContext* context, TupleRow* row, PFunctionCallResponse* response,
+ const std::vector<Expr*>& exprs);
+ void cancel(const std::string& msg);
+
+ std::shared_ptr<PFunctionService_Stub> _client;
+ RuntimeState* _state;
+ std::string _function_name;
+ std::string _server_addr;
+ std::string _signature;
+ TFunction _fn;
+ int _fn_ctx_id;
+ bool _is_agg;
+ AggregationStep _step = AggregationStep::INVALID;
+};
+
+template <typename T>
+T RPCFn::call(ExprContext* context, TupleRow* row, const std::vector<Expr*>& exprs) {
+ PFunctionCallResponse response;
+ Status st = call_internal(context, row, &response, exprs);
+ WARN_IF_ERROR(st, "call rpc udf error");
+ if (!st.ok() || (response.result(0).has_null() && response.result(0).null_map(0))) {
+ return T::null();
+ }
+ T res_val;
+ // TODO(yangzhg) deal with udtf and udaf
+ const PValues& result = response.result(0);
+ if constexpr (std::is_same_v<T, TinyIntVal>) {
+ DCHECK(result.type().id() == PGenericType::INT8);
+ res_val.val = static_cast<int8_t>(result.int32_value(0));
+ } else if constexpr (std::is_same_v<T, SmallIntVal>) {
+ DCHECK(result.type().id() == PGenericType::INT16);
+ res_val.val = static_cast<int16_t>(result.int32_value(0));
+ } else if constexpr (std::is_same_v<T, IntVal>) {
+ DCHECK(result.type().id() == PGenericType::INT32);
+ res_val.val = result.int32_value(0);
+ } else if constexpr (std::is_same_v<T, BigIntVal>) {
+ DCHECK(result.type().id() == PGenericType::INT64);
+ res_val.val = result.int64_value(0);
+ } else if constexpr (std::is_same_v<T, FloatVal>) {
+ DCHECK(result.type().id() == PGenericType::FLOAT);
+ res_val.val = result.float_value(0);
+ } else if constexpr (std::is_same_v<T, DoubleVal>) {
+ DCHECK(result.type().id() == PGenericType::DOUBLE);
+ res_val.val = result.double_value(0);
+ } else if constexpr (std::is_same_v<T, StringVal>) {
+ DCHECK(result.type().id() == PGenericType::STRING);
+ auto* fn_ctx = context->fn_context(_fn_ctx_id);
+ StringVal val(fn_ctx, result.string_value(0).size());
+ res_val = val.copy_from(fn_ctx,
+ reinterpret_cast<const uint8_t*>(result.string_value(0).c_str()),
+ result.string_value(0).size());
+ } else if constexpr (std::is_same_v<T, LargeIntVal>) {
+ DCHECK(result.type().id() == PGenericType::INT128);
+ memcpy(&(res_val.val), result.bytes_value(0).data(), sizeof(__int128_t));
+ } else if constexpr (std::is_same_v<T, DateTimeVal>) {
+ DCHECK(result.type().id() == PGenericType::DATE ||
+ result.type().id() == PGenericType::DATETIME);
+ DateTimeValue value;
+ value.set_time(result.datetime_value(0).year(), result.datetime_value(0).month(),
+ result.datetime_value(0).day(), result.datetime_value(0).hour(),
+ result.datetime_value(0).minute(), result.datetime_value(0).second(),
+ result.datetime_value(0).microsecond());
+ if (result.type().id() == PGenericType::DATE) {
+ value.set_type(TimeType::TIME_DATE);
+ } else if (result.type().id() == PGenericType::DATETIME) {
+ if (result.datetime_value(0).has_year()) {
+ value.set_type(TimeType::TIME_DATETIME);
+ } else
+ value.set_type(TimeType::TIME_TIME);
+ }
+ value.to_datetime_val(&res_val);
+ } else if constexpr (std::is_same_v<T, DecimalV2Val>) {
+ DCHECK(result.type().id() == PGenericType::DECIMAL128);
+ memcpy(&(res_val.val), result.bytes_value(0).data(), sizeof(__int128_t));
+ }
+ return res_val;
+}
+} // namespace doris
\ No newline at end of file
diff --git a/be/src/exprs/rpc_fn_call.cpp b/be/src/exprs/rpc_fn_call.cpp
index b8939af7d4..8d779f0ce8 100644
--- a/be/src/exprs/rpc_fn_call.cpp
+++ b/be/src/exprs/rpc_fn_call.cpp
@@ -19,20 +19,20 @@
#include "exprs/anyval_util.h"
#include "exprs/expr_context.h"
+#include "exprs/rpc_fn.h"
#include "fmt/format.h"
-#include "gen_cpp/function_service.pb.h"
-#include "runtime/fragment_mgr.h"
+#include "rpc_fn.h"
#include "runtime/runtime_state.h"
#include "runtime/user_function_cache.h"
-#include "service/brpc.h"
-#include "util/brpc_client_cache.h"
namespace doris {
-RPCFnCall::RPCFnCall(const TExprNode& node) : Expr(node), _fn_context_index(-1) {
+RPCFnCall::RPCFnCall(const TExprNode& node) : Expr(node), _tnode(node) {
DCHECK_EQ(_fn.binary_type, TFunctionBinaryType::RPC);
}
+RPCFnCall::~RPCFnCall() {}
+
Status RPCFnCall::prepare(RuntimeState* state, const RowDescriptor& desc, ExprContext* context) {
RETURN_IF_ERROR(Expr::prepare(state, desc, context));
DCHECK(!_fn.scalar_fn.symbol.empty());
@@ -44,16 +44,12 @@ Status RPCFnCall::prepare(RuntimeState* state, const RowDescriptor& desc, ExprCo
arg_types.push_back(AnyValUtil::column_type_to_type_desc(_children[i]->type()));
char_arg = char_arg || (_children[i]->type().type == TYPE_CHAR);
}
- _fn_context_index = context->register_func(state, return_type, arg_types, 0);
-
- // _fn.scalar_fn.symbol
- _rpc_function_symbol = _fn.scalar_fn.symbol;
-
- _client = state->exec_env()->brpc_function_client_cache()->get_client(_fn.hdfs_location);
+ int id = context->register_func(state, return_type, arg_types, 0);
- if (_client == nullptr) {
+ _rpc_fn = std::make_unique<RPCFn>(state, _fn, id, false);
+ if (!_rpc_fn->avliable()) {
return Status::InternalError(
- fmt::format("rpc env init error: {}/{}", _fn.hdfs_location, _rpc_function_symbol));
+ fmt::format("rpc env init error: {}/{}", _fn.hdfs_location, _fn.scalar_fn.symbol));
}
return Status::OK();
}
@@ -61,7 +57,6 @@ Status RPCFnCall::prepare(RuntimeState* state, const RowDescriptor& desc, ExprCo
Status RPCFnCall::open(RuntimeState* state, ExprContext* ctx,
FunctionContext::FunctionStateScope scope) {
RETURN_IF_ERROR(Expr::open(state, ctx, scope));
- _state = state;
return Status::OK();
}
@@ -70,276 +65,50 @@ void RPCFnCall::close(RuntimeState* state, ExprContext* context,
Expr::close(state, context, scope);
}
-Status RPCFnCall::call_rpc(ExprContext* context, TupleRow* row, PFunctionCallResponse* response) {
- PFunctionCallRequest request;
- request.set_function_name(_rpc_function_symbol);
- for (int i = 0; i < _children.size(); ++i) {
- PValues* arg = request.add_args();
- void* src_slot = context->get_value(_children[i], row);
- PGenericType* ptype = arg->mutable_type();
- if (src_slot == nullptr) {
- arg->set_has_null(true);
- arg->add_null_map(true);
- } else {
- arg->set_has_null(false);
- }
- switch (_children[i]->type().type) {
- case TYPE_BOOLEAN: {
- ptype->set_id(PGenericType::BOOLEAN);
- arg->add_bool_value(*(bool*)src_slot);
- break;
- }
- case TYPE_TINYINT: {
- ptype->set_id(PGenericType::INT8);
- arg->add_int32_value(*(int8_t*)src_slot);
- break;
- }
- case TYPE_SMALLINT: {
- ptype->set_id(PGenericType::INT16);
- arg->add_int32_value(*(int16_t*)src_slot);
- break;
- }
- case TYPE_INT: {
- ptype->set_id(PGenericType::INT32);
- arg->add_int32_value(*(int*)src_slot);
- break;
- }
- case TYPE_BIGINT: {
- ptype->set_id(PGenericType::INT64);
- arg->add_int64_value(*(int64_t*)src_slot);
- break;
- }
- case TYPE_LARGEINT: {
- ptype->set_id(PGenericType::INT128);
- char buffer[sizeof(__int128)];
- memcpy(buffer, src_slot, sizeof(__int128));
- arg->add_bytes_value(buffer, sizeof(__int128));
- break;
- }
- case TYPE_DOUBLE: {
- ptype->set_id(PGenericType::DOUBLE);
- arg->add_double_value(*(double*)src_slot);
- break;
- }
- case TYPE_FLOAT: {
- ptype->set_id(PGenericType::FLOAT);
- arg->add_float_value(*(float*)src_slot);
- break;
- }
- case TYPE_VARCHAR:
- case TYPE_STRING:
- case TYPE_CHAR: {
- ptype->set_id(PGenericType::STRING);
- StringValue value = *reinterpret_cast<StringValue*>(src_slot);
- arg->add_string_value(value.ptr, value.len);
- break;
- }
- case TYPE_HLL: {
- ptype->set_id(PGenericType::HLL);
- StringValue value = *reinterpret_cast<StringValue*>(src_slot);
- arg->add_string_value(value.ptr, value.len);
- break;
- }
- case TYPE_OBJECT: {
- ptype->set_id(PGenericType::BITMAP);
- StringValue value = *reinterpret_cast<StringValue*>(src_slot);
- arg->add_string_value(value.ptr, value.len);
- break;
- }
- case TYPE_DECIMALV2: {
- ptype->set_id(PGenericType::DECIMAL128);
- ptype->mutable_decimal_type()->set_precision(_children[i]->type().precision);
- ptype->mutable_decimal_type()->set_scale(_children[i]->type().scale);
- char buffer[sizeof(__int128)];
- memcpy(buffer, src_slot, sizeof(__int128));
- arg->add_bytes_value(buffer, sizeof(__int128));
- break;
- }
- case TYPE_DATE: {
- ptype->set_id(PGenericType::DATE);
- const auto* time_val = (const DateTimeValue*)(src_slot);
- PDateTime* date_time = arg->add_datetime_value();
- date_time->set_day(time_val->day());
- date_time->set_month(time_val->month());
- date_time->set_year(time_val->year());
- break;
- }
- case TYPE_DATETIME: {
- ptype->set_id(PGenericType::DATETIME);
- const auto* time_val = (const DateTimeValue*)(src_slot);
- PDateTime* date_time = arg->add_datetime_value();
- date_time->set_day(time_val->day());
- date_time->set_month(time_val->month());
- date_time->set_year(time_val->year());
- date_time->set_hour(time_val->hour());
- date_time->set_minute(time_val->minute());
- date_time->set_second(time_val->second());
- date_time->set_microsecond(time_val->microsecond());
- break;
- }
- case TYPE_TIME: {
- ptype->set_id(PGenericType::DATETIME);
- const auto* time_val = (const DateTimeValue*)(src_slot);
- PDateTime* date_time = arg->add_datetime_value();
- date_time->set_hour(time_val->hour());
- date_time->set_minute(time_val->minute());
- date_time->set_second(time_val->second());
- date_time->set_microsecond(time_val->microsecond());
- break;
- }
- default: {
- FunctionContext* fn_ctx = context->fn_context(_fn_context_index);
- std::string error_msg =
- fmt::format("data time not supported: {}", _children[i]->type().type);
- fn_ctx->set_error(error_msg.c_str());
- cancel(error_msg);
- break;
- }
- }
- }
-
- brpc::Controller cntl;
- _client->fn_call(&cntl, &request, response, nullptr);
- if (cntl.Failed()) {
- FunctionContext* fn_ctx = context->fn_context(_fn_context_index);
- std::string error_msg = fmt::format("call rpc function {} failed: {}", _rpc_function_symbol,
- cntl.ErrorText());
- fn_ctx->set_error(error_msg.c_str());
- cancel(error_msg);
- return Status::InternalError(error_msg);
- }
- if (!response->has_status() || !response->has_result()) {
- FunctionContext* fn_ctx = context->fn_context(_fn_context_index);
- std::string error_msg =
- fmt::format("call rpc function {} failed: status or result is not set: {}",
- _rpc_function_symbol, response->status().DebugString());
- fn_ctx->set_error(error_msg.c_str());
- cancel(error_msg);
- return Status::InternalError(error_msg);
- }
- if (response->status().status_code() != 0) {
- FunctionContext* fn_ctx = context->fn_context(_fn_context_index);
- std::string error_msg = fmt::format("call rpc function {} failed: {}", _rpc_function_symbol,
- response->status().DebugString());
- fn_ctx->set_error(error_msg.c_str());
- cancel(error_msg);
- return Status::InternalError(error_msg);
- }
- return Status::OK();
-}
-
-template <typename T>
-T RPCFnCall::interpret_eval(ExprContext* context, TupleRow* row) {
- PFunctionCallResponse response;
- Status st = call_rpc(context, row, &response);
- WARN_IF_ERROR(st, "call rpc udf error");
- if (!st.ok() || (response.result().has_null() && response.result().null_map(0))) {
- return T::null();
- }
- T res_val;
- // TODO(yangzhg) deal with udtf and udaf
- const PValues& result = response.result();
- if constexpr (std::is_same_v<T, TinyIntVal>) {
- DCHECK(result.type().id() == PGenericType::INT8);
- res_val.val = static_cast<int8_t>(result.int32_value(0));
- } else if constexpr (std::is_same_v<T, SmallIntVal>) {
- DCHECK(result.type().id() == PGenericType::INT16);
- res_val.val = static_cast<int16_t>(result.int32_value(0));
- } else if constexpr (std::is_same_v<T, IntVal>) {
- DCHECK(result.type().id() == PGenericType::INT32);
- res_val.val = result.int32_value(0);
- } else if constexpr (std::is_same_v<T, BigIntVal>) {
- DCHECK(result.type().id() == PGenericType::INT64);
- res_val.val = result.int64_value(0);
- } else if constexpr (std::is_same_v<T, FloatVal>) {
- DCHECK(result.type().id() == PGenericType::FLOAT);
- res_val.val = result.float_value(0);
- } else if constexpr (std::is_same_v<T, DoubleVal>) {
- DCHECK(result.type().id() == PGenericType::DOUBLE);
- res_val.val = result.double_value(0);
- } else if constexpr (std::is_same_v<T, StringVal>) {
- DCHECK(result.type().id() == PGenericType::STRING);
- FunctionContext* fn_ctx = context->fn_context(_fn_context_index);
- StringVal val(fn_ctx, result.string_value(0).size());
- res_val = val.copy_from(fn_ctx,
- reinterpret_cast<const uint8_t*>(result.string_value(0).c_str()),
- result.string_value(0).size());
- } else if constexpr (std::is_same_v<T, LargeIntVal>) {
- DCHECK(result.type().id() == PGenericType::INT128);
- memcpy(&(res_val.val), result.bytes_value(0).data(), sizeof(__int128_t));
- } else if constexpr (std::is_same_v<T, DateTimeVal>) {
- DCHECK(result.type().id() == PGenericType::DATE ||
- result.type().id() == PGenericType::DATETIME);
- DateTimeValue value;
- value.set_time(result.datetime_value(0).year(), result.datetime_value(0).month(),
- result.datetime_value(0).day(), result.datetime_value(0).hour(),
- result.datetime_value(0).minute(), result.datetime_value(0).second(),
- result.datetime_value(0).microsecond());
- if (result.type().id() == PGenericType::DATE) {
- value.set_type(TimeType::TIME_DATE);
- } else if (result.type().id() == PGenericType::DATETIME) {
- if (result.datetime_value(0).has_year()) {
- value.set_type(TimeType::TIME_DATETIME);
- } else
- value.set_type(TimeType::TIME_TIME);
- }
- value.to_datetime_val(&res_val);
- } else if constexpr (std::is_same_v<T, DecimalV2Val>) {
- DCHECK(result.type().id() == PGenericType::DECIMAL128);
- memcpy(&(res_val.val), result.bytes_value(0).data(), sizeof(__int128_t));
- }
- return res_val;
-} // namespace doris
-
doris_udf::IntVal RPCFnCall::get_int_val(ExprContext* context, TupleRow* row) {
- return interpret_eval<IntVal>(context, row);
+ return _rpc_fn->call<IntVal>(context, row, _children);
}
doris_udf::BooleanVal RPCFnCall::get_boolean_val(ExprContext* context, TupleRow* row) {
- return interpret_eval<BooleanVal>(context, row);
+ return _rpc_fn->call<BooleanVal>(context, row, _children);
}
doris_udf::TinyIntVal RPCFnCall::get_tiny_int_val(ExprContext* context, TupleRow* row) {
- return interpret_eval<TinyIntVal>(context, row);
+ return _rpc_fn->call<TinyIntVal>(context, row, _children);
}
doris_udf::SmallIntVal RPCFnCall::get_small_int_val(ExprContext* context, TupleRow* row) {
- return interpret_eval<SmallIntVal>(context, row);
+ return _rpc_fn->call<SmallIntVal>(context, row, _children);
}
doris_udf::BigIntVal RPCFnCall::get_big_int_val(ExprContext* context, TupleRow* row) {
- return interpret_eval<BigIntVal>(context, row);
+ return _rpc_fn->call<BigIntVal>(context, row, _children);
}
doris_udf::FloatVal RPCFnCall::get_float_val(ExprContext* context, TupleRow* row) {
- return interpret_eval<FloatVal>(context, row);
+ return _rpc_fn->call<FloatVal>(context, row, _children);
}
doris_udf::DoubleVal RPCFnCall::get_double_val(ExprContext* context, TupleRow* row) {
- return interpret_eval<DoubleVal>(context, row);
+ return _rpc_fn->call<DoubleVal>(context, row, _children);
}
doris_udf::StringVal RPCFnCall::get_string_val(ExprContext* context, TupleRow* row) {
- return interpret_eval<StringVal>(context, row);
+ return _rpc_fn->call<StringVal>(context, row, _children);
}
doris_udf::LargeIntVal RPCFnCall::get_large_int_val(ExprContext* context, TupleRow* row) {
- return interpret_eval<LargeIntVal>(context, row);
+ return _rpc_fn->call<LargeIntVal>(context, row, _children);
}
doris_udf::DateTimeVal RPCFnCall::get_datetime_val(ExprContext* context, TupleRow* row) {
- return interpret_eval<DateTimeVal>(context, row);
+ return _rpc_fn->call<DateTimeVal>(context, row, _children);
}
doris_udf::DecimalV2Val RPCFnCall::get_decimalv2_val(ExprContext* context, TupleRow* row) {
- return interpret_eval<DecimalV2Val>(context, row);
+ return _rpc_fn->call<DecimalV2Val>(context, row, _children);
}
doris_udf::CollectionVal RPCFnCall::get_array_val(ExprContext* context, TupleRow* row) {
- return interpret_eval<CollectionVal>(context, row);
+ return _rpc_fn->call<CollectionVal>(context, row, _children);
}
-void RPCFnCall::cancel(const std::string& msg) {
- _state->exec_env()->fragment_mgr()->cancel(_state->fragment_instance_id(),
- PPlanFragmentCancelReason::CALL_RPC_ERROR, msg);
-}
-
} // namespace doris
diff --git a/be/src/exprs/rpc_fn_call.h b/be/src/exprs/rpc_fn_call.h
index b534c0c68b..d63fb2db0e 100644
--- a/be/src/exprs/rpc_fn_call.h
+++ b/be/src/exprs/rpc_fn_call.h
@@ -23,13 +23,12 @@
namespace doris {
class TExprNode;
-class PFunctionService_Stub;
-class PFunctionCallResponse;
+class RPCFn;
class RPCFnCall : public Expr {
public:
RPCFnCall(const TExprNode& node);
- ~RPCFnCall() = default;
+ ~RPCFnCall();
virtual Status prepare(RuntimeState* state, const RowDescriptor& desc,
ExprContext* context) override;
@@ -37,7 +36,9 @@ public:
FunctionContext::FunctionStateScope scope) override;
virtual void close(RuntimeState* state, ExprContext* context,
FunctionContext::FunctionStateScope scope) override;
- virtual Expr* clone(ObjectPool* pool) const override { return pool->add(new RPCFnCall(*this)); }
+ virtual Expr* clone(ObjectPool* pool) const override {
+ return pool->add(new RPCFnCall(_tnode));
+ }
virtual doris_udf::BooleanVal get_boolean_val(ExprContext* context, TupleRow*) override;
virtual doris_udf::TinyIntVal get_tiny_int_val(ExprContext* context, TupleRow*) override;
@@ -53,14 +54,7 @@ public:
virtual doris_udf::CollectionVal get_array_val(ExprContext* context, TupleRow*) override;
private:
- Status call_rpc(ExprContext* context, TupleRow* row, PFunctionCallResponse* response);
- template <typename RETURN_TYPE>
- RETURN_TYPE interpret_eval(ExprContext* context, TupleRow* row);
- void cancel(const std::string& msg);
-
- std::shared_ptr<PFunctionService_Stub> _client = nullptr;
- int _fn_context_index;
- std::string _rpc_function_symbol;
- RuntimeState* _state;
+ std::unique_ptr<RPCFn> _rpc_fn;
+ const TExprNode& _tnode;
};
} // namespace doris
diff --git a/be/src/udf/udf.cpp b/be/src/udf/udf.cpp
index f3c11ee6cc..b343155cff 100644
--- a/be/src/udf/udf.cpp
+++ b/be/src/udf/udf.cpp
@@ -201,20 +201,6 @@ FunctionContext* FunctionContextImpl::clone(MemPool* pool) {
return new_context;
}
-// TODO: to be implemented
-void FunctionContextImpl::serialize(PFunctionContext* pcontext) const {
- // pcontext->set_string_result(_string_result);
- // pcontext->set_num_updates(_num_updates);
- // pcontext->set_num_removes(_num_removes);
- // pcontext->set_num_warnings(_num_warnings);
- // pcontext->set_error_msg(_error_msg);
- // PUniqueId* query_id = pcontext->mutable_query_id();
- // query_id->set_hi(_context->query_id().hi);
- // query_id->set_lo(_context->query_id().lo);
-}
-
-void FunctionContextImpl::derialize(const PFunctionContext& pcontext) {}
-
} // namespace doris
namespace doris_udf {
diff --git a/be/src/udf/udf_internal.h b/be/src/udf/udf_internal.h
index 903ac069dc..708e45dda6 100644
--- a/be/src/udf/udf_internal.h
+++ b/be/src/udf/udf_internal.h
@@ -36,7 +36,6 @@ class FreePool;
class MemPool;
class RuntimeState;
struct ColumnPtrWrapper;
-class PFunctionContext;
// This class actually implements the interface of FunctionContext. This is split to
// hide the details from the external header.
@@ -111,9 +110,6 @@ public:
const doris_udf::FunctionContext::TypeDesc& get_return_type() const { return _return_type; }
- void serialize(PFunctionContext* pcontext) const;
- void derialize(const PFunctionContext& pcontext);
-
private:
friend class doris_udf::FunctionContext;
friend class ExprContext;
diff --git a/be/src/vec/core/block.cpp b/be/src/vec/core/block.cpp
index e304a1973e..36be1d61c3 100644
--- a/be/src/vec/core/block.cpp
+++ b/be/src/vec/core/block.cpp
@@ -351,9 +351,6 @@ std::string Block::dump_names() const {
}
std::string Block::dump_data(size_t begin, size_t row_limit) const {
- if (rows() == 0) {
- return "empty block.";
- }
std::vector<std::string> headers;
std::vector<size_t> headers_size;
for (auto it = data.begin(); it != data.end(); ++it) {
@@ -379,6 +376,9 @@ std::string Block::dump_data(size_t begin, size_t row_limit) const {
out << std::setw(1) << "|" << std::endl;
// header bottom line
line();
+ if (rows() == 0) {
+ return out.str();
+ }
// content
for (size_t row_num = begin; row_num < rows() && row_num < row_limit + begin; ++row_num) {
for (size_t i = 0; i < columns(); ++i) {
@@ -875,9 +875,6 @@ Block MutableBlock::to_block(int start_column, int end_column) {
}
std::string MutableBlock::dump_data(size_t row_limit) const {
- if (rows() == 0) {
- return "empty block.";
- }
std::vector<std::string> headers;
std::vector<size_t> headers_size;
for (size_t i = 0; i < columns(); ++i) {
@@ -903,6 +900,9 @@ std::string MutableBlock::dump_data(size_t row_limit) const {
out << std::setw(1) << "|" << std::endl;
// header bottom line
line();
+ if (rows() == 0) {
+ return out.str();
+ }
// content
for (size_t row_num = 0; row_num < rows() && row_num < row_limit; ++row_num) {
for (size_t i = 0; i < columns(); ++i) {
diff --git a/be/src/vec/core/block.h b/be/src/vec/core/block.h
index ee032900a3..94dd5ec8e9 100644
--- a/be/src/vec/core/block.h
+++ b/be/src/vec/core/block.h
@@ -259,7 +259,7 @@ public:
std::unique_ptr<Block> create_same_struct_block(size_t size) const;
- /** Compares (*this) n-th row and rhs m-th row.
+ /** Compares (*this) n-th row and rhs m-th row.
* Returns negative number, 0, or positive number (*this) n-th row is less, equal, greater than rhs m-th row respectively.
* Is used in sortings.
*
diff --git a/be/src/vec/exprs/vectorized_fn_call.cpp b/be/src/vec/exprs/vectorized_fn_call.cpp
index 379b892f03..9afacc748b 100644
--- a/be/src/vec/exprs/vectorized_fn_call.cpp
+++ b/be/src/vec/exprs/vectorized_fn_call.cpp
@@ -20,6 +20,7 @@
#include <string_view>
#include "exprs/anyval_util.h"
+#include "exprs/rpc_fn.h"
#include "fmt/format.h"
#include "fmt/ranges.h"
#include "udf/udf_internal.h"
@@ -45,8 +46,7 @@ doris::Status VectorizedFnCall::prepare(doris::RuntimeState* state,
child_expr_name.emplace_back(child->expr_name());
}
if (_fn.binary_type == TFunctionBinaryType::RPC) {
- _function = RPCFnCall::create(_fn.name.function_name, _fn.hdfs_location, argument_template,
- _data_type);
+ _function = FunctionRPC::create(_fn, argument_template, _data_type);
} else if (_fn.binary_type == TFunctionBinaryType::JAVA_UDF) {
#ifdef LIBJVM
_function = JavaFunctionCall::create(_fn, argument_template, _data_type);
diff --git a/be/src/vec/functions/function_rpc.cpp b/be/src/vec/functions/function_rpc.cpp
index 9b2e11d08a..97d31710ca 100644
--- a/be/src/vec/functions/function_rpc.cpp
+++ b/be/src/vec/functions/function_rpc.cpp
@@ -21,542 +21,24 @@
#include <memory>
-#include "gen_cpp/function_service.pb.h"
-#include "runtime/exec_env.h"
-#include "runtime/user_function_cache.h"
-#include "service/brpc.h"
-#include "util/brpc_client_cache.h"
-#include "vec/columns/column_vector.h"
-#include "vec/core/block.h"
-#include "vec/data_types/data_type_bitmap.h"
-#include "vec/data_types/data_type_date.h"
-#include "vec/data_types/data_type_date_time.h"
-#include "vec/data_types/data_type_decimal.h"
-#include "vec/data_types/data_type_nullable.h"
-#include "vec/data_types/data_type_number.h"
-#include "vec/data_types/data_type_string.h"
+#include "exprs/rpc_fn.h"
namespace doris::vectorized {
-RPCFnCall::RPCFnCall(const std::string& symbol, const std::string& server,
- const DataTypes& argument_types, const DataTypePtr& return_type)
- : _symbol(symbol),
- _server(server),
- _name(fmt::format("{}/{}", server, symbol)),
- _argument_types(argument_types),
- _return_type(return_type) {}
-Status RPCFnCall::prepare(FunctionContext* context, FunctionContext::FunctionStateScope scope) {
- _client = ExecEnv::GetInstance()->brpc_function_client_cache()->get_client(_server);
+FunctionRPC::FunctionRPC(const TFunction& fn, const DataTypes& argument_types,
+ const DataTypePtr& return_type)
+ : _argument_types(argument_types), _return_type(return_type), _tfn(fn) {}
- if (_client == nullptr) {
+Status FunctionRPC::prepare(FunctionContext* context, FunctionContext::FunctionStateScope scope) {
+ _fn = std::make_unique<RPCFn>(_tfn, false);
+
+ if (!_fn->avliable()) {
return Status::InternalError("rpc env init error");
}
return Status::OK();
}
-template <bool nullable>
-void convert_col_to_pvalue(const ColumnPtr& column, const DataTypePtr& data_type, PValues* arg,
- size_t row_count) {
- PGenericType* ptype = arg->mutable_type();
- switch (data_type->get_type_id()) {
- case TypeIndex::UInt8: {
- ptype->set_id(PGenericType::UINT8);
- auto* values = arg->mutable_bool_value();
- values->Reserve(row_count);
- const auto* col = check_and_get_column<ColumnUInt8>(column);
- auto& data = col->get_data();
- values->Add(data.begin(), data.begin() + row_count);
- break;
- }
- case TypeIndex::UInt16: {
- ptype->set_id(PGenericType::UINT16);
- auto* values = arg->mutable_uint32_value();
- values->Reserve(row_count);
- const auto* col = check_and_get_column<ColumnUInt16>(column);
- auto& data = col->get_data();
- values->Add(data.begin(), data.begin() + row_count);
- break;
- }
- case TypeIndex::UInt32: {
- ptype->set_id(PGenericType::UINT32);
- auto* values = arg->mutable_uint32_value();
- values->Reserve(row_count);
- const auto* col = check_and_get_column<ColumnUInt32>(column);
- auto& data = col->get_data();
- values->Add(data.begin(), data.begin() + row_count);
- break;
- }
- case TypeIndex::UInt64: {
- ptype->set_id(PGenericType::UINT64);
- auto* values = arg->mutable_uint64_value();
- values->Reserve(row_count);
- const auto* col = check_and_get_column<ColumnUInt64>(column);
- auto& data = col->get_data();
- values->Add(data.begin(), data.begin() + row_count);
- break;
- }
- case TypeIndex::UInt128: {
- ptype->set_id(PGenericType::UINT128);
- arg->mutable_bytes_value()->Reserve(row_count);
- for (size_t row_num = 0; row_num < row_count; ++row_num) {
- if constexpr (nullable) {
- if (column->is_null_at(row_num)) {
- arg->add_bytes_value(nullptr);
- } else {
- StringRef data = column->get_data_at(row_num);
- arg->add_bytes_value(data.data, data.size);
- }
- } else {
- StringRef data = column->get_data_at(row_num);
- arg->add_bytes_value(data.data, data.size);
- }
- }
- break;
- }
- case TypeIndex::Int8: {
- ptype->set_id(PGenericType::INT8);
- auto* values = arg->mutable_int32_value();
- values->Reserve(row_count);
- const auto* col = check_and_get_column<ColumnInt8>(column);
- auto& data = col->get_data();
- values->Add(data.begin(), data.begin() + row_count);
- break;
- }
- case TypeIndex::Int16: {
- ptype->set_id(PGenericType::INT16);
- auto* values = arg->mutable_int32_value();
- values->Reserve(row_count);
- const auto* col = check_and_get_column<ColumnInt16>(column);
- auto& data = col->get_data();
- values->Add(data.begin(), data.begin() + row_count);
- break;
- }
- case TypeIndex::Int32: {
- ptype->set_id(PGenericType::INT32);
- auto* values = arg->mutable_int32_value();
- values->Reserve(row_count);
- const auto* col = check_and_get_column<ColumnInt32>(column);
- auto& data = col->get_data();
- values->Add(data.begin(), data.begin() + row_count);
- break;
- }
- case TypeIndex::Int64: {
- ptype->set_id(PGenericType::INT64);
- auto* values = arg->mutable_int64_value();
- values->Reserve(row_count);
- const auto* col = check_and_get_column<ColumnInt64>(column);
- auto& data = col->get_data();
- values->Add(data.begin(), data.begin() + row_count);
- break;
- }
- case TypeIndex::Int128: {
- ptype->set_id(PGenericType::INT128);
- arg->mutable_bytes_value()->Reserve(row_count);
- for (size_t row_num = 0; row_num < row_count; ++row_num) {
- if constexpr (nullable) {
- if (column->is_null_at(row_num)) {
- arg->add_bytes_value(nullptr);
- } else {
- StringRef data = column->get_data_at(row_num);
- arg->add_bytes_value(data.data, data.size);
- }
- } else {
- StringRef data = column->get_data_at(row_num);
- arg->add_bytes_value(data.data, data.size);
- }
- }
- break;
- }
- case TypeIndex::Float32: {
- ptype->set_id(PGenericType::FLOAT);
- auto* values = arg->mutable_float_value();
- values->Reserve(row_count);
- const auto* col = check_and_get_column<ColumnFloat32>(column);
- auto& data = col->get_data();
- values->Add(data.begin(), data.begin() + row_count);
- break;
- }
-
- case TypeIndex::Float64: {
- ptype->set_id(PGenericType::DOUBLE);
- auto* values = arg->mutable_double_value();
- values->Reserve(row_count);
- const auto* col = check_and_get_column<ColumnFloat64>(column);
- auto& data = col->get_data();
- values->Add(data.begin(), data.begin() + row_count);
- break;
- }
- case TypeIndex::Decimal128: {
- ptype->set_id(PGenericType::DECIMAL128);
- auto dec_type = std::reinterpret_pointer_cast<const DataTypeDecimal<Decimal128>>(data_type);
- ptype->mutable_decimal_type()->set_precision(dec_type->get_precision());
- ptype->mutable_decimal_type()->set_scale(dec_type->get_scale());
- arg->mutable_bytes_value()->Reserve(row_count);
- for (size_t row_num = 0; row_num < row_count; ++row_num) {
- if constexpr (nullable) {
- if (column->is_null_at(row_num)) {
- arg->add_bytes_value(nullptr);
- } else {
- StringRef data = column->get_data_at(row_num);
- arg->add_bytes_value(data.data, data.size);
- }
- } else {
- StringRef data = column->get_data_at(row_num);
- arg->add_bytes_value(data.data, data.size);
- }
- }
- break;
- }
- case TypeIndex::String: {
- ptype->set_id(PGenericType::STRING);
- arg->mutable_bytes_value()->Reserve(row_count);
- for (size_t row_num = 0; row_num < row_count; ++row_num) {
- if constexpr (nullable) {
- if (column->is_null_at(row_num)) {
- arg->add_string_value(nullptr);
- } else {
- StringRef data = column->get_data_at(row_num);
- arg->add_string_value(data.to_string());
- }
- } else {
- StringRef data = column->get_data_at(row_num);
- arg->add_string_value(data.to_string());
- }
- }
- break;
- }
- case TypeIndex::Date: {
- ptype->set_id(PGenericType::DATE);
- arg->mutable_datetime_value()->Reserve(row_count);
- for (size_t row_num = 0; row_num < row_count; ++row_num) {
- PDateTime* date_time = arg->add_datetime_value();
- if constexpr (nullable) {
- if (!column->is_null_at(row_num)) {
- VecDateTimeValue v =
- binary_cast<vectorized::Int64, vectorized::VecDateTimeValue>(
- column->get_int(row_num));
- date_time->set_day(v.day());
- date_time->set_month(v.month());
- date_time->set_year(v.year());
- }
- } else {
- VecDateTimeValue v = binary_cast<vectorized::Int64, vectorized::VecDateTimeValue>(
- column->get_int(row_num));
- date_time->set_day(v.day());
- date_time->set_month(v.month());
- date_time->set_year(v.year());
- }
- }
- break;
- }
- case TypeIndex::DateTime: {
- ptype->set_id(PGenericType::DATETIME);
- arg->mutable_datetime_value()->Reserve(row_count);
- for (size_t row_num = 0; row_num < row_count; ++row_num) {
- PDateTime* date_time = arg->add_datetime_value();
- if constexpr (nullable) {
- if (!column->is_null_at(row_num)) {
- VecDateTimeValue v =
- binary_cast<vectorized::Int64, vectorized::VecDateTimeValue>(
- column->get_int(row_num));
- date_time->set_day(v.day());
- date_time->set_month(v.month());
- date_time->set_year(v.year());
- date_time->set_hour(v.hour());
- date_time->set_minute(v.minute());
- date_time->set_second(v.second());
- }
- } else {
- VecDateTimeValue v = binary_cast<vectorized::Int64, vectorized::VecDateTimeValue>(
- column->get_int(row_num));
- date_time->set_day(v.day());
- date_time->set_month(v.month());
- date_time->set_year(v.year());
- date_time->set_hour(v.hour());
- date_time->set_minute(v.minute());
- date_time->set_second(v.second());
- }
- }
- break;
- }
- case TypeIndex::BitMap: {
- ptype->set_id(PGenericType::BITMAP);
- arg->mutable_bytes_value()->Reserve(row_count);
- for (size_t row_num = 0; row_num < row_count; ++row_num) {
- if constexpr (nullable) {
- if (column->is_null_at(row_num)) {
- arg->add_bytes_value(nullptr);
- } else {
- StringRef data = column->get_data_at(row_num);
- arg->add_bytes_value(data.data, data.size);
- }
- } else {
- StringRef data = column->get_data_at(row_num);
- arg->add_bytes_value(data.data, data.size);
- }
- }
- break;
- }
- case TypeIndex::HLL: {
- ptype->set_id(PGenericType::HLL);
- arg->mutable_bytes_value()->Reserve(row_count);
- for (size_t row_num = 0; row_num < row_count; ++row_num) {
- if constexpr (nullable) {
- if (column->is_null_at(row_num)) {
- arg->add_bytes_value(nullptr);
- } else {
- StringRef data = column->get_data_at(row_num);
- arg->add_bytes_value(data.data, data.size);
- }
- } else {
- StringRef data = column->get_data_at(row_num);
- arg->add_bytes_value(data.data, data.size);
- }
- }
- break;
- }
- default:
- LOG(INFO) << "unknown type: " << data_type->get_name();
- ptype->set_id(PGenericType::UNKNOWN);
- break;
- }
-}
-
-void convert_nullable_col_to_pvalue(const ColumnPtr& column, const DataTypePtr& data_type,
- const ColumnUInt8& null_col, PValues* arg, size_t row_count) {
- if (column->has_null(row_count)) {
- auto* null_map = arg->mutable_null_map();
- null_map->Reserve(row_count);
- const auto* col = check_and_get_column<ColumnUInt8>(null_col);
- auto& data = col->get_data();
- null_map->Add(data.begin(), data.begin() + row_count);
- convert_col_to_pvalue<true>(column, data_type, arg, row_count);
- } else {
- convert_col_to_pvalue<false>(column, data_type, arg, row_count);
- }
-}
-
-void convert_block_to_proto(Block& block, const ColumnNumbers& arguments, size_t input_rows_count,
- PFunctionCallRequest* request) {
- size_t row_count = std::min(block.rows(), input_rows_count);
- for (size_t col_idx : arguments) {
- PValues* arg = request->add_args();
- ColumnWithTypeAndName& column = block.get_by_position(col_idx);
- arg->set_has_null(column.column->has_null(row_count));
- auto col = column.column->convert_to_full_column_if_const();
- if (auto* nullable = check_and_get_column<const ColumnNullable>(*col)) {
- auto data_col = nullable->get_nested_column_ptr();
- auto& null_col = nullable->get_null_map_column();
- auto data_type = std::reinterpret_pointer_cast<const DataTypeNullable>(column.type);
- convert_nullable_col_to_pvalue(data_col->convert_to_full_column_if_const(),
- data_type->get_nested_type(), null_col, arg, row_count);
- } else {
- convert_col_to_pvalue<false>(col, column.type, arg, row_count);
- }
- }
-}
-
-template <bool nullable>
-void convert_to_column(MutableColumnPtr& column, const PValues& result) {
- switch (result.type().id()) {
- case PGenericType::UINT8: {
- column->reserve(result.uint32_value_size());
- column->resize(result.uint32_value_size());
- auto& data = reinterpret_cast<ColumnUInt8*>(column.get())->get_data();
- for (int i = 0; i < result.uint32_value_size(); ++i) {
- data[i] = result.uint32_value(i);
- }
- break;
- }
- case PGenericType::UINT16: {
- column->reserve(result.uint32_value_size());
- column->resize(result.uint32_value_size());
- auto& data = reinterpret_cast<ColumnUInt16*>(column.get())->get_data();
- for (int i = 0; i < result.uint32_value_size(); ++i) {
- data[i] = result.uint32_value(i);
- }
- break;
- }
- case PGenericType::UINT32: {
- column->reserve(result.uint32_value_size());
- column->resize(result.uint32_value_size());
- auto& data = reinterpret_cast<ColumnUInt32*>(column.get())->get_data();
- for (int i = 0; i < result.uint32_value_size(); ++i) {
- data[i] = result.uint32_value(i);
- }
- break;
- }
- case PGenericType::UINT64: {
- column->reserve(result.uint64_value_size());
- column->resize(result.uint64_value_size());
- auto& data = reinterpret_cast<ColumnUInt64*>(column.get())->get_data();
- for (int i = 0; i < result.uint64_value_size(); ++i) {
- data[i] = result.uint64_value(i);
- }
- break;
- }
- case PGenericType::INT8: {
- column->reserve(result.int32_value_size());
- column->resize(result.int32_value_size());
- auto& data = reinterpret_cast<ColumnInt16*>(column.get())->get_data();
- for (int i = 0; i < result.int32_value_size(); ++i) {
- data[i] = result.int32_value(i);
- }
- break;
- }
- case PGenericType::INT16: {
- column->reserve(result.int32_value_size());
- column->resize(result.int32_value_size());
- auto& data = reinterpret_cast<ColumnInt16*>(column.get())->get_data();
- for (int i = 0; i < result.int32_value_size(); ++i) {
- data[i] = result.int32_value(i);
- }
- break;
- }
- case PGenericType::INT32: {
- column->reserve(result.int32_value_size());
- column->resize(result.int32_value_size());
- auto& data = reinterpret_cast<ColumnInt32*>(column.get())->get_data();
- for (int i = 0; i < result.int32_value_size(); ++i) {
- data[i] = result.int32_value(i);
- }
- break;
- }
- case PGenericType::INT64: {
- column->reserve(result.int64_value_size());
- column->resize(result.int64_value_size());
- auto& data = reinterpret_cast<ColumnInt64*>(column.get())->get_data();
- for (int i = 0; i < result.int64_value_size(); ++i) {
- data[i] = result.int64_value(i);
- }
- break;
- }
- case PGenericType::DATE:
- case PGenericType::DATETIME: {
- column->reserve(result.datetime_value_size());
- column->resize(result.datetime_value_size());
- auto& data = reinterpret_cast<ColumnInt64*>(column.get())->get_data();
- for (int i = 0; i < result.datetime_value_size(); ++i) {
- VecDateTimeValue v;
- PDateTime pv = result.datetime_value(i);
- v.set_time(pv.year(), pv.month(), pv.day(), pv.hour(), pv.minute(), pv.minute());
- data[i] = binary_cast<VecDateTimeValue, Int64>(v);
- }
- break;
- }
- case PGenericType::FLOAT: {
- column->reserve(result.float_value_size());
- column->resize(result.float_value_size());
- auto& data = reinterpret_cast<ColumnFloat32*>(column.get())->get_data();
- for (int i = 0; i < result.float_value_size(); ++i) {
- data[i] = result.float_value(i);
- }
- break;
- }
- case PGenericType::DOUBLE: {
- column->reserve(result.double_value_size());
- column->resize(result.double_value_size());
- auto& data = reinterpret_cast<ColumnFloat64*>(column.get())->get_data();
- for (int i = 0; i < result.double_value_size(); ++i) {
- data[i] = result.double_value(i);
- }
- break;
- }
- case PGenericType::INT128: {
- column->reserve(result.bytes_value_size());
- column->resize(result.bytes_value_size());
- auto& data = reinterpret_cast<ColumnInt128*>(column.get())->get_data();
- for (int i = 0; i < result.bytes_value_size(); ++i) {
- data[i] = *(int128_t*)(result.bytes_value(i).c_str());
- }
- break;
- }
- case PGenericType::STRING: {
- column->reserve(result.string_value_size());
- for (int i = 0; i < result.string_value_size(); ++i) {
- column->insert_data(result.string_value(i).c_str(), result.string_value(i).size());
- }
- break;
- }
- case PGenericType::DECIMAL128: {
- column->reserve(result.bytes_value_size());
- column->resize(result.bytes_value_size());
- auto& data = reinterpret_cast<ColumnDecimal128*>(column.get())->get_data();
- for (int i = 0; i < result.bytes_value_size(); ++i) {
- data[i] = *(int128_t*)(result.bytes_value(i).c_str());
- }
- break;
- }
- case PGenericType::BITMAP: {
- column->reserve(result.bytes_value_size());
- for (int i = 0; i < result.bytes_value_size(); ++i) {
- column->insert_data(result.bytes_value(i).c_str(), result.bytes_value(i).size());
- }
- break;
- }
- case PGenericType::HLL: {
- column->reserve(result.bytes_value_size());
- for (int i = 0; i < result.bytes_value_size(); ++i) {
- column->insert_data(result.bytes_value(i).c_str(), result.bytes_value(i).size());
- }
- break;
- }
- default: {
- LOG(WARNING) << "unknown PGenericType: " << result.type().DebugString();
- break;
- }
- }
-}
-
-void convert_to_block(Block& block, const PValues& result, size_t pos) {
- auto data_type = block.get_data_type(pos);
- if (data_type->is_nullable()) {
- auto null_type = std::reinterpret_pointer_cast<const DataTypeNullable>(data_type);
- auto data_col = null_type->get_nested_type()->create_column();
- convert_to_column<true>(data_col, result);
- auto null_col = ColumnUInt8::create(data_col->size(), 0);
- auto& null_map_data = null_col->get_data();
- null_col->reserve(data_col->size());
- null_col->resize(data_col->size());
- if (result.has_null()) {
- for (int i = 0; i < data_col->size(); ++i) {
- null_map_data[i] = result.null_map(i);
- }
- } else {
- for (int i = 0; i < data_col->size(); ++i) {
- null_map_data[i] = false;
- }
- }
- block.replace_by_position(pos,
- ColumnNullable::create(std::move(data_col), std::move(null_col)));
- } else {
- auto column = data_type->create_column();
- convert_to_column<false>(column, result);
- block.replace_by_position(pos, std::move(column));
- }
-}
-
-Status RPCFnCall::execute(FunctionContext* context, Block& block, const ColumnNumbers& arguments,
- size_t result, size_t input_rows_count, bool dry_run) {
- PFunctionCallRequest request;
- PFunctionCallResponse response;
- request.set_function_name(_symbol);
- convert_block_to_proto(block, arguments, input_rows_count, &request);
- brpc::Controller cntl;
- _client->fn_call(&cntl, &request, &response, nullptr);
- if (cntl.Failed()) {
- return Status::InternalError(
- fmt::format("call to rpc function {} failed: {}", _symbol, cntl.ErrorText())
- .c_str());
- }
- if (!response.has_status() || !response.has_result()) {
- return Status::InternalError(
- fmt::format("call rpc function {} failed: status or result is not set.", _symbol));
- }
- if (response.status().status_code() != 0) {
- return Status::InternalError(fmt::format("call to rpc function {} failed: {}", _symbol,
- response.status().DebugString()));
- }
- convert_to_block(block, response.result(), result);
- return Status::OK();
+Status FunctionRPC::execute(FunctionContext* context, Block& block, const ColumnNumbers& arguments,
+ size_t result, size_t input_rows_count, bool dry_run) {
+ return _fn->vec_call(context, block, arguments, result, input_rows_count);
}
} // namespace doris::vectorized
diff --git a/be/src/vec/functions/function_rpc.h b/be/src/vec/functions/function_rpc.h
index 43bfe3acc2..a4037958dd 100644
--- a/be/src/vec/functions/function_rpc.h
+++ b/be/src/vec/functions/function_rpc.h
@@ -20,25 +20,28 @@
#include "vec/functions/function.h"
namespace doris {
-class PFunctionService_Stub;
+class RPCFn;
namespace vectorized {
-class RPCFnCall : public IFunctionBase {
+class FunctionRPC : public IFunctionBase {
public:
- RPCFnCall(const std::string& symbol, const std::string& server, const DataTypes& argument_types,
- const DataTypePtr& return_type);
- static FunctionBasePtr create(const std::string& symbol, const std::string& server,
- const ColumnsWithTypeAndName& argument_types,
+ FunctionRPC(const TFunction& fn, const DataTypes& argument_types,
+ const DataTypePtr& return_type);
+
+ static FunctionBasePtr create(const TFunction& fn, const ColumnsWithTypeAndName& argument_types,
const DataTypePtr& return_type) {
DataTypes data_types(argument_types.size());
for (size_t i = 0; i < argument_types.size(); ++i) {
data_types[i] = argument_types[i].type;
}
- return std::make_shared<RPCFnCall>(symbol, server, data_types, return_type);
+ return std::make_shared<FunctionRPC>(fn, data_types, return_type);
}
/// Get the main function name.
- String get_name() const override { return _name; };
+ String get_name() const override {
+ return fmt::format("{}: [{}/{}]", _tfn.name.function_name, _tfn.hdfs_location,
+ _tfn.scalar_fn.symbol);
+ };
const DataTypes& get_argument_types() const override { return _argument_types; };
const DataTypePtr& get_return_type() const override { return _return_type; };
@@ -58,12 +61,10 @@ public:
bool is_deterministic_in_scope_of_query() const override { return false; }
private:
- std::string _symbol;
- std::string _server;
- std::string _name;
DataTypes _argument_types;
DataTypePtr _return_type;
- std::shared_ptr<PFunctionService_Stub> _client = nullptr;
+ TFunction _tfn;
+ std::unique_ptr<RPCFn> _fn;
};
} // namespace vectorized
diff --git a/contrib/udf/CMakeLists.txt b/contrib/udf/CMakeLists.txt
index 6b347f8c6c..66fd4f32cc 100644
--- a/contrib/udf/CMakeLists.txt
+++ b/contrib/udf/CMakeLists.txt
@@ -34,22 +34,6 @@ set(BUILD_DIR "${CMAKE_CURRENT_BINARY_DIR}")
set(SRC_DIR "${BASE_DIR}/src/")
set(OUTPUT_DIR "${BASE_DIR}/output")
-# Check gcc
-if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
- if (CMAKE_CXX_COMPILER_VERSION VERSION_LESS "7.3.0")
- message(FATAL_ERROR "Need GCC version at least 7.3.0")
- endif()
-
- if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER "7.3.0")
- message(STATUS "GCC version is greater than 7.3.0, disable -Werror. Be careful with compile warnings.")
- else()
- # -Werror: compile warnings should be errors when using the toolchain compiler.
- set(CXX_GCC_FLAGS "${CXX_GCC_FLAGS} -Werror")
- endif()
-elseif (NOT APPLE)
- message(FATAL_ERROR "Compiler should be GNU")
-endif()
-
# Just for clang-tidy: -Wno-expansion-to-defined -Wno-deprecated-declaration
SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC -g -ggdb -std=c++11 -Wall -Werror -Wno-unused-variable -Wno-expansion-to-defined -Wno-deprecated-declarations -O3")
diff --git a/docs/en/extending-doris/udf/native-user-defined-function.md b/docs/en/extending-doris/udf/native-user-defined-function.md
index 32311bf57a..c32f17549c 100644
--- a/docs/en/extending-doris/udf/native-user-defined-function.md
+++ b/docs/en/extending-doris/udf/native-user-defined-function.md
@@ -34,8 +34,6 @@ There are two types of analysis requirements that UDF can meet: UDF and UDAF. UD
This document mainly describes how to write a custom UDF function and how to use it in Doris.
-If users use the UDF function and extend Doris' function analysis, and want to contribute their own UDF functions back to the Doris community for other users, please see the document [Contribute UDF](./contribute_udf.md).
-
## Writing UDF functions
Before using UDF, users need to write their own UDF functions under Doris' UDF framework. In the `contrib/udf/src/udf_samples/udf_sample.h|cpp` file is a simple UDF Demo.
diff --git a/docs/en/sql-reference/sql-statements/Data Definition/create-function.md b/docs/en/sql-reference/sql-statements/Data Definition/create-function.md
index ca581ddf58..11f6bf9681 100644
--- a/docs/en/sql-reference/sql-statements/Data Definition/create-function.md
+++ b/docs/en/sql-reference/sql-statements/Data Definition/create-function.md
@@ -79,6 +79,8 @@ CREATE [AGGREGATE] [ALIAS] FUNCTION function_name
> "prepare_fn": Function signature of the prepare function for finding the entry from the dynamic library. This option is optional for custom functions
>
> "close_fn": Function signature of the close function for finding the entry from the dynamic library. This option is optional for custom functions
+> "type": Function type, RPC for remote udf, NATIVE for c++ native udf
+
This statement creates a custom function. Executing this command requires that the user have `ADMIN` privileges.
@@ -138,6 +140,13 @@ If the `function_name` contains the database name, the custom function will be c
CREATE ALIAS FUNCTION string(ALL, INT) WITH PARAMETER(col, length)
AS CAST(col AS varchar(length));
```
-
+6. Create a remote UDF
+ ```
+ CREATE FUNCTION rpc_add(INT, INT) RETURNS INT PROPERTIES (
+ "SYMBOL"="add_int",
+ "OBJECT_FILE"="127.0.0.1:9999",
+ "TYPE"="RPC"
+ );
+ ```
## keyword
CREATE,FUNCTION
diff --git a/docs/zh-CN/extending-doris/udf/native-user-defined-function.md b/docs/zh-CN/extending-doris/udf/native-user-defined-function.md
index 8f7e56c869..fff1ddbd5d 100644
--- a/docs/zh-CN/extending-doris/udf/native-user-defined-function.md
+++ b/docs/zh-CN/extending-doris/udf/native-user-defined-function.md
@@ -35,8 +35,6 @@ UDF 能满足的分析需求分为两种:UDF 和 UDAF。本文中的 UDF 指
这篇文档主要讲述了,如何编写自定义的 UDF 函数,以及如何在 Doris 中使用它。
-如果用户使用 UDF 功能并扩展了 Doris 的函数分析,并且希望将自己实现的 UDF 函数贡献回 Doris 社区给其他用户使用,这时候请看文档 [Contribute UDF](./contribute_udf.md)。
-
## 编写 UDF 函数
在使用UDF之前,用户需要先在 Doris 的 UDF 框架下,编写自己的UDF函数。在`contrib/udf/src/udf_samples/udf_sample.h|cpp`文件中是一个简单的 UDF Demo。
diff --git a/docs/zh-CN/sql-reference/sql-statements/Data Definition/create-function.md b/docs/zh-CN/sql-reference/sql-statements/Data Definition/create-function.md
index f6b2c7b990..902462664a 100644
--- a/docs/zh-CN/sql-reference/sql-statements/Data Definition/create-function.md
+++ b/docs/zh-CN/sql-reference/sql-statements/Data Definition/create-function.md
@@ -79,6 +79,7 @@ CREATE [AGGREGATE] [ALIAS] FUNCTION function_name
> "prepare_fn": 自定义函数的prepare函数的函数签名,用于从动态库里面找到prepare函数入口。此选项对于自定义函数是可选项
>
> "close_fn": 自定义函数的close函数的函数签名,用于从动态库里面找到close函数入口。此选项对于自定义函数是可选项
+> "type": 自定义函数的类型,如果是远程函数就是则填 RPC,C++的原生 UDF 填 NATIVE, 默认 NATIVE
此语句创建一个自定义函数。执行此命令需要用户拥有 `ADMIN` 权限。
@@ -89,35 +90,35 @@ CREATE [AGGREGATE] [ALIAS] FUNCTION function_name
1. 创建一个自定义标量函数
- ```
- CREATE FUNCTION my_add(INT, INT) RETURNS INT PROPERTIES (
- "symbol" = "_ZN9doris_udf6AddUdfEPNS_15FunctionContextERKNS_6IntValES4_",
- "object_file" = "http://host:port/libmyadd.so"
- );
- ```
+ ```
+ CREATE FUNCTION my_add(INT, INT) RETURNS INT PROPERTIES (
+ "symbol" = "_ZN9doris_udf6AddUdfEPNS_15FunctionContextERKNS_6IntValES4_",
+ "object_file" = "http://host:port/libmyadd.so"
+ );
+ ```
2. 创建一个有prepare/close函数的自定义标量函数
- ```
- CREATE FUNCTION my_add(INT, INT) RETURNS INT PROPERTIES (
- "symbol" = "_ZN9doris_udf6AddUdfEPNS_15FunctionContextERKNS_6IntValES4_",
- "prepare_fn" = "_ZN9doris_udf14AddUdf_prepareEPNS_15FunctionContextENS0_18FunctionStateScopeE",
- "close_fn" = "_ZN9doris_udf12AddUdf_closeEPNS_15FunctionContextENS0_18FunctionStateScopeE",
- "object_file" = "http://host:port/libmyadd.so"
- );
- ```
+ ```
+ CREATE FUNCTION my_add(INT, INT) RETURNS INT PROPERTIES (
+ "symbol" = "_ZN9doris_udf6AddUdfEPNS_15FunctionContextERKNS_6IntValES4_",
+ "prepare_fn" = "_ZN9doris_udf14AddUdf_prepareEPNS_15FunctionContextENS0_18FunctionStateScopeE",
+ "close_fn" = "_ZN9doris_udf12AddUdf_closeEPNS_15FunctionContextENS0_18FunctionStateScopeE",
+ "object_file" = "http://host:port/libmyadd.so"
+ );
+ ```
3. 创建一个自定义聚合函数
- ```
- CREATE AGGREGATE FUNCTION my_count (BIGINT) RETURNS BIGINT PROPERTIES (
- "init_fn"="_ZN9doris_udf9CountInitEPNS_15FunctionContextEPNS_9BigIntValE",
- "update_fn"="_ZN9doris_udf11CountUpdateEPNS_15FunctionContextERKNS_6IntValEPNS_9BigIntValE",
- "merge_fn"="_ZN9doris_udf10CountMergeEPNS_15FunctionContextERKNS_9BigIntValEPS2_",
- "finalize_fn"="_ZN9doris_udf13CountFinalizeEPNS_15FunctionContextERKNS_9BigIntValE",
- "object_file"="http://host:port/libudasample.so"
- );
- ```
+ ```
+ CREATE AGGREGATE FUNCTION my_count (BIGINT) RETURNS BIGINT PROPERTIES (
+ "init_fn"="_ZN9doris_udf9CountInitEPNS_15FunctionContextEPNS_9BigIntValE",
+ "update_fn"="_ZN9doris_udf11CountUpdateEPNS_15FunctionContextERKNS_6IntValEPNS_9BigIntValE",
+ "merge_fn"="_ZN9doris_udf10CountMergeEPNS_15FunctionContextERKNS_9BigIntValEPS2_",
+ "finalize_fn"="_ZN9doris_udf13CountFinalizeEPNS_15FunctionContextERKNS_9BigIntValE",
+ "object_file"="http://host:port/libudasample.so"
+ );
+ ```
4. 创建一个变长参数的标量函数
@@ -139,7 +140,14 @@ CREATE [AGGREGATE] [ALIAS] FUNCTION function_name
CREATE ALIAS FUNCTION string(ALL, INT) WITH PARAMETER(col, length)
AS CAST(col AS varchar(length));
```
-
+6. 创建一个远程自动函数
+ ```
+ CREATE FUNCTION rpc_add(INT, INT) RETURNS INT PROPERTIES (
+ "SYMBOL"="add_int",
+ "OBJECT_FILE"="127.0.0.1:9999",
+ "TYPE"="RPC"
+ );
+ ```
## keyword
CREATE,FUNCTION
diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/CreateFunctionStmt.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/CreateFunctionStmt.java
index d2516954f7..2446fb8249 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/analysis/CreateFunctionStmt.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/CreateFunctionStmt.java
@@ -236,9 +236,6 @@ public class CreateFunctionStmt extends DdlStmt {
}
private void analyzeUda() throws AnalysisException {
- if (binaryType == TFunctionBinaryType.RPC) {
- throw new AnalysisException("RPC UDAF is not supported.");
- }
AggregateFunction.AggregateFunctionBuilder builder = AggregateFunction.AggregateFunctionBuilder.createUdfBuilder();
builder.name(functionName).argsType(argsDef.getArgTypes()).retType(returnType.getType()).
@@ -255,10 +252,31 @@ public class CreateFunctionStmt extends DdlStmt {
if (mergeFnSymbol == null) {
throw new AnalysisException("No 'merge_fn' in properties");
}
+ String serializeFnSymbol = properties.get(SERIALIZE_KEY);
+ String finalizeFnSymbol = properties.get(FINALIZE_KEY);
+ String getValueFnSymbol = properties.get(GET_VALUE_KEY);
+ String removeFnSymbol = properties.get(REMOVE_KEY);
+ if (binaryType == TFunctionBinaryType.RPC && !userFile.contains("://")) {
+ checkRPCUdf(initFnSymbol);
+ checkRPCUdf(updateFnSymbol);
+ checkRPCUdf(mergeFnSymbol);
+ if (serializeFnSymbol != null) {
+ checkRPCUdf(serializeFnSymbol);
+ }
+ if (finalizeFnSymbol != null) {
+ checkRPCUdf(finalizeFnSymbol);
+ }
+ if (getValueFnSymbol != null) {
+ checkRPCUdf(getValueFnSymbol);
+ }
+ if (removeFnSymbol != null) {
+ checkRPCUdf(removeFnSymbol);
+ }
+ }
function = builder.initFnSymbol(initFnSymbol)
.updateFnSymbol(updateFnSymbol).mergeFnSymbol(mergeFnSymbol)
- .serializeFnSymbol(properties.get(SERIALIZE_KEY)).finalizeFnSymbol(properties.get(FINALIZE_KEY))
- .getValueFnSymbol(properties.get(GET_VALUE_KEY)).removeFnSymbol(properties.get(REMOVE_KEY))
+ .serializeFnSymbol(serializeFnSymbol).finalizeFnSymbol(finalizeFnSymbol)
+ .getValueFnSymbol(getValueFnSymbol).removeFnSymbol(removeFnSymbol)
.build();
function.setChecksum(checksum);
}
@@ -274,33 +292,9 @@ public class CreateFunctionStmt extends DdlStmt {
// the format for load balance can ref https://github.com/apache/incubator-brpc/blob/master/docs/en/client.md#connect-to-a-cluster
if (binaryType == TFunctionBinaryType.RPC && !userFile.contains("://")) {
if (StringUtils.isNotBlank(prepareFnSymbol) || StringUtils.isNotBlank(closeFnSymbol)) {
- throw new AnalysisException(" prepare and close in RPC UDF are not supported.");
- }
- String[] url = userFile.split(":");
- if (url.length != 2) {
- throw new AnalysisException("function server address invalid.");
- }
- String host = url[0];
- int port = Integer.valueOf(url[1]);
- ManagedChannel channel = NettyChannelBuilder.forAddress(host, port)
- .flowControlWindow(Config.grpc_max_message_size_bytes)
- .maxInboundMessageSize(Config.grpc_max_message_size_bytes)
- .enableRetry().maxRetryAttempts(3)
- .usePlaintext().build();
- PFunctionServiceGrpc.PFunctionServiceBlockingStub stub = PFunctionServiceGrpc.newBlockingStub(channel);
- FunctionService.PCheckFunctionRequest.Builder builder = FunctionService.PCheckFunctionRequest.newBuilder();
- builder.getFunctionBuilder().setFunctionName(symbol);
- for (Type arg : argsDef.getArgTypes()) {
- builder.getFunctionBuilder().addInputs(convertToPParameterType(arg));
- }
- builder.getFunctionBuilder().setOutput(convertToPParameterType(returnType.getType()));
- FunctionService.PCheckFunctionResponse response = stub.checkFn(builder.build());
- if (response == null || !response.hasStatus()) {
- throw new AnalysisException("cannot access function server");
- }
- if (response.getStatus().getStatusCode() != 0) {
- throw new AnalysisException("check function [" + symbol + "] failed: " + response.getStatus());
+ throw new AnalysisException("prepare and close in RPC UDF are not supported.");
}
+ checkRPCUdf(symbol);
} else if (binaryType == TFunctionBinaryType.JAVA_UDF) {
analyzeJavaUdf(symbol);
}
@@ -399,6 +393,36 @@ public class CreateFunctionStmt extends DdlStmt {
}
}
+ private void checkRPCUdf(String symbol) throws AnalysisException {
+ // TODO(yangzhg) support check function in FE when function service behind load balancer
+ // the format for load balance can ref https://github.com/apache/incubator-brpc/blob/master/docs/en/client.md#connect-to-a-cluster
+ String[] url = userFile.split(":");
+ if (url.length != 2) {
+ throw new AnalysisException("function server address invalid.");
+ }
+ String host = url[0];
+ int port = Integer.valueOf(url[1]);
+ ManagedChannel channel = NettyChannelBuilder.forAddress(host, port)
+ .flowControlWindow(Config.grpc_max_message_size_bytes)
+ .maxInboundMessageSize(Config.grpc_max_message_size_bytes)
+ .enableRetry().maxRetryAttempts(3)
+ .usePlaintext().build();
+ PFunctionServiceGrpc.PFunctionServiceBlockingStub stub = PFunctionServiceGrpc.newBlockingStub(channel);
+ FunctionService.PCheckFunctionRequest.Builder builder = FunctionService.PCheckFunctionRequest.newBuilder();
+ builder.getFunctionBuilder().setFunctionName(symbol);
+ for (Type arg : argsDef.getArgTypes()) {
+ builder.getFunctionBuilder().addInputs(convertToPParameterType(arg));
+ }
+ builder.getFunctionBuilder().setOutput(convertToPParameterType(returnType.getType()));
+ FunctionService.PCheckFunctionResponse response = stub.checkFn(builder.build());
+ if (response == null || !response.hasStatus()) {
+ throw new AnalysisException("cannot access function server");
+ }
+ if (response.getStatus().getStatusCode() != 0) {
+ throw new AnalysisException("check function [" + symbol + "] failed: " + response.getStatus());
+ }
+ }
+
private Types.PGenericType convertToPParameterType(Type arg) throws AnalysisException {
Types.PGenericType.Builder typeBuilder = Types.PGenericType.newBuilder();
switch (arg.getPrimitiveType()) {
diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/AggregateFunction.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/AggregateFunction.java
index 8f35805523..1820fba43b 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/catalog/AggregateFunction.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/AggregateFunction.java
@@ -392,6 +392,10 @@ public class AggregateFunction extends Function {
this.removeFnSymbol = symbol;
return this;
}
+ public AggregateFunctionBuilder binaryType(TFunctionBinaryType binaryType) {
+ this.binaryType = binaryType;
+ return this;
+ }
public AggregateFunction build() {
AggregateFunction fn = new AggregateFunction(name, argTypes, retType, hasVarArgs, intermediateType,
diff --git a/fe/fe-core/src/main/java/org/apache/doris/common/util/URI.java b/fe/fe-core/src/main/java/org/apache/doris/common/util/URI.java
index b0c0f19f4d..4f4409e2a8 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/common/util/URI.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/common/util/URI.java
@@ -201,4 +201,9 @@ public class URI {
throw new AnalysisException("Invalid host port: " + hostPort);
}
}
+
+ @Override
+ public String toString() {
+ return location;
+ }
}
diff --git a/gensrc/proto/function_service.proto b/gensrc/proto/function_service.proto
index 561be9f887..298957b556 100644
--- a/gensrc/proto/function_service.proto
+++ b/gensrc/proto/function_service.proto
@@ -35,8 +35,9 @@ message PFunctionCallRequest {
}
message PFunctionCallResponse {
- optional PValues result = 1;
+ repeated PValues result = 1;
optional PStatus status = 2;
+ optional PRequestContext context = 3;
}
message PCheckFunctionRequest {
diff --git a/gensrc/proto/types.proto b/gensrc/proto/types.proto
index e1f8445620..a3ac3dc190 100644
--- a/gensrc/proto/types.proto
+++ b/gensrc/proto/types.proto
@@ -193,17 +193,7 @@ message PFunction {
}
message PFunctionContext {
- optional string version = 1 [default = "V2_0"];
- repeated PValue staging_input_vals = 2;
- repeated PValue constant_args = 3;
- optional string error_msg = 4;
- optional PUniqueId query_id = 5;
- optional bytes thread_local_fn_state = 6;
- optional bytes fragment_local_fn_state = 7;
- optional string string_result = 8;
- optional int64 num_updates = 9;
- optional int64 num_removes = 10;
- optional int64 num_warnings = 11;
+ optional bytes data = 1;
}
message PHandShakeRequest {
diff --git a/samples/doris-demo/remote-udf-cpp-demo/cpp_function_service_demo.cpp b/samples/doris-demo/remote-udf-cpp-demo/cpp_function_service_demo.cpp
index 3e0193f1ed..7e141394d0 100644
--- a/samples/doris-demo/remote-udf-cpp-demo/cpp_function_service_demo.cpp
+++ b/samples/doris-demo/remote-udf-cpp-demo/cpp_function_service_demo.cpp
@@ -32,20 +32,21 @@ public:
::google::protobuf::Closure* done) override {
brpc::ClosureGuard closure_guard(done);
std::string fun_name = request->function_name();
+ auto* result = response->add_result();
if (fun_name == "int32_add") {
- response->mutable_result()->mutable_type()->set_id(PGenericType::INT32);
+ result->mutable_type()->set_id(PGenericType::INT32);
for (size_t i = 0; i < request->args(0).int32_value_size(); ++i) {
- response->mutable_result()->add_int32_value(request->args(0).int32_value(i) +
- request->args(1).int32_value(i));
+ result->add_int32_value(request->args(0).int32_value(i) +
+ request->args(1).int32_value(i));
}
} else if (fun_name == "int64_add") {
- response->mutable_result()->mutable_type()->set_id(PGenericType::INT64);
+ result->mutable_type()->set_id(PGenericType::INT64);
for (size_t i = 0; i < request->args(0).int64_value_size(); ++i) {
- response->mutable_result()->add_int64_value(request->args(0).int64_value(i) +
- request->args(1).int64_value(i));
+ result->add_int64_value(request->args(0).int64_value(i) +
+ request->args(1).int64_value(i));
}
} else if (fun_name == "int128_add") {
- response->mutable_result()->mutable_type()->set_id(PGenericType::INT128);
+ result->mutable_type()->set_id(PGenericType::INT128);
for (size_t i = 0; i < request->args(0).bytes_value_size(); ++i) {
__int128 v1;
memcpy(&v1, request->args(0).bytes_value(i).data(), sizeof(__int128));
@@ -54,26 +55,25 @@ public:
__int128 v = v1 + v2;
char buffer[sizeof(__int128)];
memcpy(buffer, &v, sizeof(__int128));
- response->mutable_result()->add_bytes_value(buffer, sizeof(__int128));
+ result->add_bytes_value(buffer, sizeof(__int128));
}
} else if (fun_name == "float_add") {
- response->mutable_result()->mutable_type()->set_id(PGenericType::FLOAT);
+ result->mutable_type()->set_id(PGenericType::FLOAT);
for (size_t i = 0; i < request->args(0).float_value_size(); ++i) {
- response->mutable_result()->add_float_value(request->args(0).float_value(i) +
- request->args(1).float_value(i));
+ result->add_float_value(request->args(0).float_value(i) +
+ request->args(1).float_value(i));
}
} else if (fun_name == "double_add") {
- response->mutable_result()->mutable_type()->set_id(PGenericType::DOUBLE);
+ result->mutable_type()->set_id(PGenericType::DOUBLE);
for (size_t i = 0; i < request->args(0).double_value_size(); ++i) {
- response->mutable_result()->add_double_value(request->args(0).double_value(i) +
- request->args(1).double_value(i));
+ result->add_double_value(request->args(0).double_value(i) +
+ request->args(1).double_value(i));
}
} else if (fun_name == "str_add") {
- response->mutable_result()->mutable_type()->set_id(PGenericType::STRING);
+ result->mutable_type()->set_id(PGenericType::STRING);
for (size_t i = 0; i < request->args(0).string_value_size(); ++i) {
- response->mutable_result()->add_string_value(request->args(0).string_value(i) +
- " + " +
- request->args(1).string_value(i));
+ result->add_string_value(request->args(0).string_value(i) + " + " +
+ request->args(1).string_value(i));
}
}
response->mutable_status()->set_status_code(0);
diff --git a/samples/doris-demo/remote-udf-java-demo/src/main/java/org/apache/doris/udf/FunctionServiceImpl.java b/samples/doris-demo/remote-udf-java-demo/src/main/java/org/apache/doris/udf/FunctionServiceImpl.java
index 40558193ad..fded80505f 100644
--- a/samples/doris-demo/remote-udf-java-demo/src/main/java/org/apache/doris/udf/FunctionServiceImpl.java
+++ b/samples/doris-demo/remote-udf-java-demo/src/main/java/org/apache/doris/udf/FunctionServiceImpl.java
@@ -45,7 +45,7 @@ public class FunctionServiceImpl extends PFunctionServiceGrpc.PFunctionServiceIm
if ("add_int".equals(functionName)) {
res = FunctionService.PFunctionCallResponse.newBuilder()
.setStatus(Types.PStatus.newBuilder().setStatusCode(0).build())
- .setResult(Types.PValues.newBuilder().setHasNull(false)
+ .addResult(Types.PValues.newBuilder().setHasNull(false)
.addAllInt32Value(IntStream.range(0, Math.min(request.getArgs(0)
.getInt32ValueCount(), request.getArgs(1).getInt32ValueCount()))
.mapToObj(i -> request.getArgs(0).getInt32Value(i) + request.getArgs(1)
diff --git a/samples/doris-demo/remote-udf-python-demo/function_server_demo.py b/samples/doris-demo/remote-udf-python-demo/function_server_demo.py
index 60d6d939c9..d1f2160013 100644
--- a/samples/doris-demo/remote-udf-python-demo/function_server_demo.py
+++ b/samples/doris-demo/remote-udf-python-demo/function_server_demo.py
@@ -43,7 +43,7 @@ class FunctionServerDemo(function_service_pb2_grpc.PFunctionServiceServicer):
result_type.id = types_pb2.PGenericType.INT32
result.type.CopyFrom(result_type)
result.int32_value.extend([x + y for x, y in zip(request.args[0].int32_value, request.args[1].int32_value)])
- response.result.CopyFrom(result)
+ response.result.append(result)
return response
def check_fn(self, request, context):
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org