You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@impala.apache.org by st...@apache.org on 2021/12/15 08:38:39 UTC

[impala] 03/04: IMPALA-10956: datasketches UDFs: memory leak and merge overhead

This is an automated email from the ASF dual-hosted git repository.

stigahuang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/impala.git

commit 151dce86ca8810252c0331cb321b6e77e09327cb
Author: AlexanderSaydakov <al...@apache.org>
AuthorDate: Fri Sep 24 14:05:43 2021 -0700

    IMPALA-10956: datasketches UDFs: memory leak and merge overhead
    
    - call destructors of sketch and union objects
    - avoid overhead of constructing union and getting result from it every time
    
    Change-Id: I8dd0e6736f4266f74f5f265f58d40a4e4707287f
    Reviewed-on: http://gerrit.cloudera.org:8080/17869
    Reviewed-by: Impala Public Jenkins <im...@cloudera.com>
    Tested-by: Impala Public Jenkins <im...@cloudera.com>
---
 be/src/exprs/aggregate-functions-ir.cc | 468 +++++++++++++++++++--------------
 1 file changed, 273 insertions(+), 195 deletions(-)

diff --git a/be/src/exprs/aggregate-functions-ir.cc b/be/src/exprs/aggregate-functions-ir.cc
index 45998e4..a658e39 100644
--- a/be/src/exprs/aggregate-functions-ir.cc
+++ b/be/src/exprs/aggregate-functions-ir.cc
@@ -1674,18 +1674,20 @@ StringVal SerializeDsThetaIntersection(
   return StringVal::null();
 }
 
+// This is for functions with different state during update and merge.
+enum agg_phase { UPDATE, MERGE };
+using agg_state = std::pair<agg_phase, void*>;
+
 void AggregateFunctions::DsHllInit(FunctionContext* ctx, StringVal* dst) {
-  AllocBuffer(ctx, dst, sizeof(datasketches::hll_sketch));
+  AllocBuffer(ctx, dst, sizeof(agg_state));
   if (UNLIKELY(dst->is_null)) {
     DCHECK(!ctx->impl()->state()->GetQueryStatus().ok());
     return;
   }
-  // Note, that hll_sketch will always have the same size regardless of the amount of data
-  // it keeps track. This is because it's a wrapper class that holds all the inserted data
-  // on heap. Here, we put only the wrapper class into a StringVal.
-  datasketches::hll_sketch* sketch_ptr =
-      reinterpret_cast<datasketches::hll_sketch*>(dst->ptr);
-  *sketch_ptr = datasketches::hll_sketch(DS_SKETCH_CONFIG, DS_HLL_TYPE);
+  auto agg_state_ptr = reinterpret_cast<agg_state*>(dst->ptr);
+  agg_state_ptr->first = agg_phase::UPDATE;
+  agg_state_ptr->second = new (ctx->Allocate<datasketches::hll_sketch>())
+      datasketches::hll_sketch(DS_SKETCH_CONFIG, DS_HLL_TYPE);
 }
 
 template <typename T>
@@ -1693,9 +1695,10 @@ void AggregateFunctions::DsHllUpdate(FunctionContext* ctx, const T& src,
     StringVal* dst) {
   if (src.is_null) return;
   DCHECK(!dst->is_null);
-  DCHECK_EQ(dst->len, sizeof(datasketches::hll_sketch));
-  datasketches::hll_sketch* sketch_ptr =
-      reinterpret_cast<datasketches::hll_sketch*>(dst->ptr);
+  DCHECK_EQ(dst->len, sizeof(agg_state));
+  auto agg_state_ptr = reinterpret_cast<agg_state*>(dst->ptr);
+  DCHECK_EQ(agg_state_ptr->first, agg_phase::UPDATE);
+  auto sketch_ptr = reinterpret_cast<datasketches::hll_sketch*>(agg_state_ptr->second);
   sketch_ptr->update(src.val);
 }
 
@@ -1705,19 +1708,29 @@ void AggregateFunctions::DsHllUpdate(
     FunctionContext* ctx, const StringVal& src, StringVal* dst) {
   if (src.is_null || src.len == 0) return;
   DCHECK(!dst->is_null);
-  DCHECK_EQ(dst->len, sizeof(datasketches::hll_sketch));
-  datasketches::hll_sketch* sketch_ptr =
-      reinterpret_cast<datasketches::hll_sketch*>(dst->ptr);
+  DCHECK_EQ(dst->len, sizeof(agg_state));
+  auto agg_state_ptr = reinterpret_cast<agg_state*>(dst->ptr);
+  DCHECK_EQ(agg_state_ptr->first, agg_phase::UPDATE);
+  auto sketch_ptr = reinterpret_cast<datasketches::hll_sketch*>(agg_state_ptr->second);
   sketch_ptr->update(reinterpret_cast<char*>(src.ptr), src.len);
 }
 
 StringVal AggregateFunctions::DsHllSerialize(FunctionContext* ctx,
     const StringVal& src) {
   DCHECK(!src.is_null);
-  DCHECK_EQ(src.len, sizeof(datasketches::hll_sketch));
-  datasketches::hll_sketch* sketch_ptr =
-      reinterpret_cast<datasketches::hll_sketch*>(src.ptr);
-  StringVal dst = SerializeCompactDsHllSketch(ctx, *sketch_ptr);
+  DCHECK_EQ(src.len, sizeof(agg_state));
+  StringVal dst;
+  auto agg_state_ptr = reinterpret_cast<agg_state*>(src.ptr);
+  if (agg_state_ptr->first == agg_phase::UPDATE) { // the agg state is a sketch
+    auto sketch_ptr = reinterpret_cast<datasketches::hll_sketch*>(agg_state_ptr->second);
+    dst = SerializeCompactDsHllSketch(ctx, *sketch_ptr);
+    sketch_ptr->~hll_sketch_alloc();
+  } else { // the agg state is a union
+    auto union_ptr = reinterpret_cast<datasketches::hll_union*>(agg_state_ptr->second);
+    dst = SerializeDsHllUnion(ctx, *union_ptr);
+    union_ptr->~hll_union_alloc();
+  }
+  ctx->Free(reinterpret_cast<uint8_t*>(agg_state_ptr->second));
   ctx->Free(src.ptr);
   return dst;
 }
@@ -1726,26 +1739,44 @@ void AggregateFunctions::DsHllMerge(
     FunctionContext* ctx, const StringVal& src, StringVal* dst) {
   DCHECK(!src.is_null);
   DCHECK(!dst->is_null);
-  DCHECK_EQ(dst->len, sizeof(datasketches::hll_sketch));
-  datasketches::hll_sketch src_sketch =
-      datasketches::hll_sketch::deserialize((void*)src.ptr, src.len);
-
-  datasketches::hll_sketch* dst_sketch_ptr =
-      reinterpret_cast<datasketches::hll_sketch*>(dst->ptr);
+  DCHECK_EQ(dst->len, sizeof(agg_state));
+  auto agg_state_ptr = reinterpret_cast<agg_state*>(dst->ptr);
+  if (agg_state_ptr->first == agg_phase::MERGE) { // was already switched to union
+    auto dst_union_ptr =
+        reinterpret_cast<datasketches::hll_union*>(agg_state_ptr->second);
+    dst_union_ptr->update(datasketches::hll_sketch::deserialize(src.ptr, src.len));
+  } else { // must be the first call. the state is still a sketch
+    auto dst_sketch_ptr =
+        reinterpret_cast<datasketches::hll_sketch*>(agg_state_ptr->second);
 
-  datasketches::hll_union union_sketch(DS_SKETCH_CONFIG);
-  union_sketch.update(src_sketch);
-  union_sketch.update(*dst_sketch_ptr);
+    datasketches::hll_union u(DS_SKETCH_CONFIG);
+    u.update(*dst_sketch_ptr);
+    u.update(datasketches::hll_sketch::deserialize(src.ptr, src.len));
 
-  *dst_sketch_ptr = union_sketch.get_result(DS_HLL_TYPE);
+    // swich to union
+    dst_sketch_ptr->~hll_sketch_alloc();
+    ctx->Free(reinterpret_cast<uint8_t*>(agg_state_ptr->second));
+    agg_state_ptr->second = new (ctx->Allocate<datasketches::hll_union>())
+        datasketches::hll_union(std::move(u));
+    agg_state_ptr->first = agg_phase::MERGE;
+  }
 }
 
 BigIntVal AggregateFunctions::DsHllFinalize(FunctionContext* ctx, const StringVal& src) {
   DCHECK(!src.is_null);
-  DCHECK_EQ(src.len, sizeof(datasketches::hll_sketch));
-  datasketches::hll_sketch* sketch_ptr =
-      reinterpret_cast<datasketches::hll_sketch*>(src.ptr);
-  BigIntVal estimate = sketch_ptr->get_estimate();
+  DCHECK_EQ(src.len, sizeof(agg_state));
+  BigIntVal estimate;
+  auto agg_state_ptr = reinterpret_cast<agg_state*>(src.ptr);
+  if (agg_state_ptr->first == agg_phase::UPDATE) { // the agg state is a sketch
+    auto sketch_ptr = reinterpret_cast<datasketches::hll_sketch*>(agg_state_ptr->second);
+    estimate = sketch_ptr->get_estimate();
+    sketch_ptr->~hll_sketch_alloc();
+  } else { // the agg state is a union
+    auto union_ptr = reinterpret_cast<datasketches::hll_union*>(agg_state_ptr->second);
+    estimate = union_ptr->get_result().get_estimate();
+    union_ptr->~hll_union_alloc();
+  }
+  ctx->Free(reinterpret_cast<uint8_t*>(agg_state_ptr->second));
   ctx->Free(src.ptr);
   return (estimate == 0) ? BigIntVal::null() : estimate;
 }
@@ -1753,15 +1784,26 @@ BigIntVal AggregateFunctions::DsHllFinalize(FunctionContext* ctx, const StringVa
 StringVal AggregateFunctions::DsHllFinalizeSketch(FunctionContext* ctx,
     const StringVal& src) {
   DCHECK(!src.is_null);
-  DCHECK_EQ(src.len, sizeof(datasketches::hll_sketch));
-  datasketches::hll_sketch* sketch_ptr =
-      reinterpret_cast<datasketches::hll_sketch*>(src.ptr);
-  StringVal result_str = StringVal::null();
-  if (sketch_ptr->get_estimate() > 0.0) {
-    result_str = SerializeCompactDsHllSketch(ctx, *sketch_ptr);
+  DCHECK_EQ(src.len, sizeof(agg_state));
+  StringVal result = StringVal::null();
+  auto agg_state_ptr = reinterpret_cast<agg_state*>(src.ptr);
+  if (agg_state_ptr->first == agg_phase::UPDATE) { // the agg state is a sketch
+    auto sketch_ptr = reinterpret_cast<datasketches::hll_sketch*>(agg_state_ptr->second);
+    if (!sketch_ptr->is_empty()) {
+      result = SerializeCompactDsHllSketch(ctx, *sketch_ptr);
+    }
+    sketch_ptr->~hll_sketch_alloc();
+  } else { // the agg state is a union
+    auto union_ptr = reinterpret_cast<datasketches::hll_union*>(agg_state_ptr->second);
+    auto sketch = union_ptr->get_result(DS_HLL_TYPE);
+    if (!sketch.is_empty()) {
+      result = SerializeCompactDsHllSketch(ctx, sketch);
+    }
+    union_ptr->~hll_union_alloc();
   }
+  ctx->Free(reinterpret_cast<uint8_t*>(agg_state_ptr->second));
   ctx->Free(src.ptr);
-  return result_str;
+  return result;
 }
 
 void AggregateFunctions::DsHllUnionInit(FunctionContext* ctx, StringVal* slot) {
@@ -1770,9 +1812,8 @@ void AggregateFunctions::DsHllUnionInit(FunctionContext* ctx, StringVal* slot) {
     DCHECK(!ctx->impl()->state()->GetQueryStatus().ok());
     return;
   }
-  datasketches::hll_union* union_ptr =
-      reinterpret_cast<datasketches::hll_union*>(slot->ptr);
-  *union_ptr = datasketches::hll_union(DS_SKETCH_CONFIG);
+  auto union_ptr = reinterpret_cast<datasketches::hll_union*>(slot->ptr);
+  new (union_ptr) datasketches::hll_union(DS_SKETCH_CONFIG);
 }
 
 void AggregateFunctions::DsHllUnionUpdate(FunctionContext* ctx, const StringVal& src,
@@ -1795,6 +1836,7 @@ StringVal AggregateFunctions::DsHllUnionSerialize(FunctionContext* ctx,
   datasketches::hll_union* union_ptr =
       reinterpret_cast<datasketches::hll_union*>(src.ptr);
   StringVal dst = SerializeDsHllUnion(ctx, *union_ptr);
+  union_ptr->~hll_union_alloc();
   ctx->Free(src.ptr);
   return dst;
 }
@@ -1817,39 +1859,37 @@ StringVal AggregateFunctions::DsHllUnionFinalize(FunctionContext* ctx,
     const StringVal& src) {
   DCHECK(!src.is_null);
   DCHECK_EQ(src.len, sizeof(datasketches::hll_union));
-  datasketches::hll_union* union_ptr =
-      reinterpret_cast<datasketches::hll_union*>(src.ptr);
-  datasketches::hll_sketch sketch = union_ptr->get_result(DS_HLL_TYPE);
-  if (sketch.get_estimate() == 0.0) {
-    ctx->Free(src.ptr);
-    return StringVal::null();
+  auto union_ptr = reinterpret_cast<datasketches::hll_union*>(src.ptr);
+  auto sketch = union_ptr->get_result(DS_HLL_TYPE);
+  StringVal result = StringVal::null();
+  if (!sketch.is_empty()) {
+    result = SerializeCompactDsHllSketch(ctx, sketch);
   }
-  StringVal result = SerializeCompactDsHllSketch(ctx, sketch);
+  union_ptr->~hll_union_alloc();
   ctx->Free(src.ptr);
   return result;
 }
 
 void AggregateFunctions::DsCpcInit(FunctionContext* ctx, StringVal* dst) {
-  AllocBuffer(ctx, dst, sizeof(datasketches::cpc_sketch));
+  AllocBuffer(ctx, dst, sizeof(agg_state));
   if (UNLIKELY(dst->is_null)) {
     DCHECK(!ctx->impl()->state()->GetQueryStatus().ok());
     return;
   }
-  // Note, that cpc_sketch will always have the same size regardless of the amount of data
-  // it keeps track. This is because it's a wrapper class that holds all the inserted data
-  // on heap. Here, we put only the wrapper class into a StringVal.
-  datasketches::cpc_sketch* sketch_ptr =
-      reinterpret_cast<datasketches::cpc_sketch*>(dst->ptr);
-  *sketch_ptr = datasketches::cpc_sketch(DS_CPC_SKETCH_CONFIG);
+  auto agg_state_ptr = reinterpret_cast<agg_state*>(dst->ptr);
+  agg_state_ptr->first = agg_phase::UPDATE;
+  agg_state_ptr->second = new (ctx->Allocate<datasketches::cpc_sketch>())
+      datasketches::cpc_sketch(DS_CPC_SKETCH_CONFIG);
 }
 
 template <typename T>
 void AggregateFunctions::DsCpcUpdate(FunctionContext* ctx, const T& src, StringVal* dst) {
   if (src.is_null) return;
   DCHECK(!dst->is_null);
-  DCHECK_EQ(dst->len, sizeof(datasketches::cpc_sketch));
-  datasketches::cpc_sketch* sketch_ptr =
-      reinterpret_cast<datasketches::cpc_sketch*>(dst->ptr);
+  DCHECK_EQ(dst->len, sizeof(agg_state));
+  auto agg_state_ptr = reinterpret_cast<agg_state*>(dst->ptr);
+  DCHECK_EQ(agg_state_ptr->first, agg_phase::UPDATE);
+  auto sketch_ptr = reinterpret_cast<datasketches::cpc_sketch*>(agg_state_ptr->second);
   sketch_ptr->update(src.val);
 }
 
@@ -1859,18 +1899,28 @@ void AggregateFunctions::DsCpcUpdate(
     FunctionContext* ctx, const StringVal& src, StringVal* dst) {
   if (src.is_null || src.len == 0) return;
   DCHECK(!dst->is_null);
-  DCHECK_EQ(dst->len, sizeof(datasketches::cpc_sketch));
-  datasketches::cpc_sketch* sketch_ptr =
-      reinterpret_cast<datasketches::cpc_sketch*>(dst->ptr);
+  DCHECK_EQ(dst->len, sizeof(agg_state));
+  auto agg_state_ptr = reinterpret_cast<agg_state*>(dst->ptr);
+  DCHECK(agg_state_ptr->first == agg_phase::UPDATE);
+  auto sketch_ptr = reinterpret_cast<datasketches::cpc_sketch*>(agg_state_ptr->second);
   sketch_ptr->update(reinterpret_cast<char*>(src.ptr), src.len);
 }
 
 StringVal AggregateFunctions::DsCpcSerialize(FunctionContext* ctx, const StringVal& src) {
   DCHECK(!src.is_null);
-  DCHECK_EQ(src.len, sizeof(datasketches::cpc_sketch));
-  datasketches::cpc_sketch* sketch_ptr =
-      reinterpret_cast<datasketches::cpc_sketch*>(src.ptr);
-  StringVal dst = SerializeDsSketch(ctx, *sketch_ptr);
+  DCHECK_EQ(src.len, sizeof(agg_state));
+  StringVal dst;
+  auto agg_state_ptr = reinterpret_cast<agg_state*>(src.ptr);
+  if (agg_state_ptr->first == agg_phase::UPDATE) { // the agg state is a sketch
+    auto sketch_ptr = reinterpret_cast<datasketches::cpc_sketch*>(agg_state_ptr->second);
+    dst = SerializeDsSketch(ctx, *sketch_ptr);
+    sketch_ptr->~cpc_sketch_alloc();
+  } else { // the agg state is a union
+    auto union_ptr = reinterpret_cast<datasketches::cpc_union*>(agg_state_ptr->second);
+    dst = SerializeDsCpcUnion(ctx, *union_ptr);
+    union_ptr->~cpc_union_alloc();
+  }
+  ctx->Free(reinterpret_cast<uint8_t*>(agg_state_ptr->second));
   ctx->Free(src.ptr);
   return dst;
 }
@@ -1879,27 +1929,49 @@ void AggregateFunctions::DsCpcMerge(
     FunctionContext* ctx, const StringVal& src, StringVal* dst) {
   DCHECK(!src.is_null);
   DCHECK(!dst->is_null);
-  DCHECK_EQ(dst->len, sizeof(datasketches::cpc_sketch));
+  DCHECK_EQ(dst->len, sizeof(agg_state));
+  auto agg_state_ptr = reinterpret_cast<agg_state*>(dst->ptr);
+  if (agg_state_ptr->first == agg_phase::MERGE) { // was already switched to union
+    auto dst_union_ptr =
+        reinterpret_cast<datasketches::cpc_union*>(agg_state_ptr->second);
+    dst_union_ptr->update(datasketches::cpc_sketch::deserialize(src.ptr, src.len));
+  } else { // must be the first call. the state is still a sketch
+    auto dst_sketch_ptr =
+        reinterpret_cast<datasketches::cpc_sketch*>(agg_state_ptr->second);
+
+    datasketches::cpc_union u(DS_CPC_SKETCH_CONFIG);
+    u.update(*dst_sketch_ptr);
+    try {
+      u.update(datasketches::cpc_sketch::deserialize(src.ptr, src.len));
+    } catch (const std::exception& e) {
+      LogSketchDeserializationError(ctx, e);
+      return;
+    }
 
-  datasketches::cpc_union union_sketch(DS_CPC_SKETCH_CONFIG);
-  try {
-    union_sketch.update(datasketches::cpc_sketch::deserialize(src.ptr, src.len));
-  } catch (const std::exception& e) {
-    LogSketchDeserializationError(ctx, e);
-    return;
+    // switch to union
+    dst_sketch_ptr->~cpc_sketch_alloc();
+    ctx->Free(reinterpret_cast<uint8_t*>(agg_state_ptr->second));
+    agg_state_ptr->second = new (ctx->Allocate<datasketches::cpc_union>())
+        datasketches::cpc_union(std::move(u));
+    agg_state_ptr->first = agg_phase::MERGE;
   }
-  datasketches::cpc_sketch* dst_sketch_ptr =
-      reinterpret_cast<datasketches::cpc_sketch*>(dst->ptr);
-  union_sketch.update(*dst_sketch_ptr);
-  *dst_sketch_ptr = union_sketch.get_result();
 }
 
 BigIntVal AggregateFunctions::DsCpcFinalize(FunctionContext* ctx, const StringVal& src) {
   DCHECK(!src.is_null);
-  DCHECK_EQ(src.len, sizeof(datasketches::cpc_sketch));
-  datasketches::cpc_sketch* sketch_ptr =
-      reinterpret_cast<datasketches::cpc_sketch*>(src.ptr);
-  BigIntVal estimate = sketch_ptr->get_estimate();
+  DCHECK_EQ(src.len, sizeof(agg_state));
+  BigIntVal estimate;
+  auto agg_state_ptr = reinterpret_cast<agg_state*>(src.ptr);
+  if (agg_state_ptr->first == agg_phase::UPDATE) { // the agg state is a sketch
+    auto sketch_ptr = reinterpret_cast<datasketches::cpc_sketch*>(agg_state_ptr->second);
+    estimate = sketch_ptr->get_estimate();
+    sketch_ptr->~cpc_sketch_alloc();
+  } else { // the agg state is a union
+    auto union_ptr = reinterpret_cast<datasketches::cpc_union*>(agg_state_ptr->second);
+    estimate = union_ptr->get_result().get_estimate();
+    union_ptr->~cpc_union_alloc();
+  }
+  ctx->Free(reinterpret_cast<uint8_t*>(agg_state_ptr->second));
   ctx->Free(src.ptr);
   return (estimate == 0) ? BigIntVal::null() : estimate;
 }
@@ -1907,15 +1979,26 @@ BigIntVal AggregateFunctions::DsCpcFinalize(FunctionContext* ctx, const StringVa
 StringVal AggregateFunctions::DsCpcFinalizeSketch(
     FunctionContext* ctx, const StringVal& src) {
   DCHECK(!src.is_null);
-  DCHECK_EQ(src.len, sizeof(datasketches::cpc_sketch));
-  datasketches::cpc_sketch* sketch_ptr =
-      reinterpret_cast<datasketches::cpc_sketch*>(src.ptr);
-  StringVal result_str = StringVal::null();
-  if (sketch_ptr->get_estimate() > 0.0) {
-    result_str = SerializeDsSketch(ctx, *sketch_ptr);
+  DCHECK_EQ(src.len, sizeof(agg_state));
+  StringVal result = StringVal::null();
+  auto agg_state_ptr = reinterpret_cast<agg_state*>(src.ptr);
+  if (agg_state_ptr->first == agg_phase::UPDATE) { // the agg state is a sketch
+    auto sketch_ptr = reinterpret_cast<datasketches::cpc_sketch*>(agg_state_ptr->second);
+    if (!sketch_ptr->is_empty()) {
+      result = SerializeDsSketch(ctx, *sketch_ptr);
+    }
+    sketch_ptr->~cpc_sketch_alloc();
+  } else { // the agg state is a union
+    auto union_ptr = reinterpret_cast<datasketches::cpc_union*>(agg_state_ptr->second);
+    auto sketch = union_ptr->get_result();
+    if (!sketch.is_empty()) {
+      result = SerializeDsSketch(ctx, sketch);
+    }
+    union_ptr->~cpc_union_alloc();
   }
+  ctx->Free(reinterpret_cast<uint8_t*>(agg_state_ptr->second));
   ctx->Free(src.ptr);
-  return result_str;
+  return result;
 }
 
 void AggregateFunctions::DsCpcUnionInit(FunctionContext* ctx, StringVal* slot) {
@@ -1924,9 +2007,8 @@ void AggregateFunctions::DsCpcUnionInit(FunctionContext* ctx, StringVal* slot) {
     DCHECK(!ctx->impl()->state()->GetQueryStatus().ok());
     return;
   }
-  datasketches::cpc_union* union_ptr =
-      reinterpret_cast<datasketches::cpc_union*>(slot->ptr);
-  *union_ptr = datasketches::cpc_union(DS_SKETCH_CONFIG);
+  auto union_ptr = reinterpret_cast<datasketches::cpc_union*>(slot->ptr);
+  new (union_ptr) datasketches::cpc_union(DS_CPC_SKETCH_CONFIG);
 }
 
 void AggregateFunctions::DsCpcUnionUpdate(
@@ -1946,9 +2028,9 @@ StringVal AggregateFunctions::DsCpcUnionSerialize(
     FunctionContext* ctx, const StringVal& src) {
   DCHECK(!src.is_null);
   DCHECK_EQ(src.len, sizeof(datasketches::cpc_union));
-  datasketches::cpc_union* union_ptr =
-      reinterpret_cast<datasketches::cpc_union*>(src.ptr);
+  auto union_ptr = reinterpret_cast<datasketches::cpc_union*>(src.ptr);
   StringVal dst = SerializeDsCpcUnion(ctx, *union_ptr);
+  union_ptr->~cpc_union_alloc();
   ctx->Free(src.ptr);
   return dst;
 }
@@ -1971,32 +2053,28 @@ StringVal AggregateFunctions::DsCpcUnionFinalize(
     FunctionContext* ctx, const StringVal& src) {
   DCHECK(!src.is_null);
   DCHECK_EQ(src.len, sizeof(datasketches::cpc_union));
-  datasketches::cpc_union* union_ptr =
-      reinterpret_cast<datasketches::cpc_union*>(src.ptr);
-  datasketches::cpc_sketch sketch = union_ptr->get_result();
-  if (sketch.get_estimate() == 0.0) {
-    ctx->Free(src.ptr);
-    return StringVal::null();
+  auto union_ptr = reinterpret_cast<datasketches::cpc_union*>(src.ptr);
+  auto sketch = union_ptr->get_result();
+  StringVal result = StringVal::null();
+  if (!sketch.is_empty()) {
+    result = SerializeDsSketch(ctx, sketch);
   }
-  StringVal result = SerializeDsSketch(ctx, sketch);
+  union_ptr->~cpc_union_alloc();
   ctx->Free(src.ptr);
   return result;
 }
 
 void AggregateFunctions::DsThetaInit(FunctionContext* ctx, StringVal* dst) {
-  AllocBuffer(ctx, dst, sizeof(datasketches::update_theta_sketch));
+  AllocBuffer(ctx, dst, sizeof(agg_state));
   if (UNLIKELY(dst->is_null)) {
     DCHECK(!ctx->impl()->state()->GetQueryStatus().ok());
     return;
   }
-  // Note, that update_theta_sketch will always have the same size regardless of the
-  // amount of data it keeps track. This is because it's a wrapper class that holds all
-  // the inserted data on heap. Here, we put only the wrapper class into a StringVal.
-  datasketches::update_theta_sketch* sketch_ptr =
-      reinterpret_cast<datasketches::update_theta_sketch*>(dst->ptr);
-  datasketches::update_theta_sketch sketch =
-      datasketches::update_theta_sketch::builder().build();
-  std::uninitialized_fill_n(sketch_ptr, 1, sketch);
+  auto agg_state_ptr = reinterpret_cast<agg_state*>(dst->ptr);
+  agg_state_ptr->first = agg_phase::UPDATE;
+  agg_state_ptr->second = new (ctx->Allocate<datasketches::update_theta_sketch>())
+      datasketches::update_theta_sketch(
+          datasketches::update_theta_sketch::builder().build());
 }
 
 template <typename T>
@@ -2004,9 +2082,11 @@ void AggregateFunctions::DsThetaUpdate(
     FunctionContext* ctx, const T& src, StringVal* dst) {
   if (src.is_null) return;
   DCHECK(!dst->is_null);
-  DCHECK_EQ(dst->len, sizeof(datasketches::update_theta_sketch));
-  datasketches::update_theta_sketch* sketch_ptr =
-      reinterpret_cast<datasketches::update_theta_sketch*>(dst->ptr);
+  DCHECK_EQ(dst->len, sizeof(agg_state));
+  auto agg_state_ptr = reinterpret_cast<agg_state*>(dst->ptr);
+  DCHECK_EQ(agg_state_ptr->first, agg_phase::UPDATE);
+  auto sketch_ptr =
+      reinterpret_cast<datasketches::update_theta_sketch*>(agg_state_ptr->second);
   sketch_ptr->update(src.val);
 }
 
@@ -2016,25 +2096,31 @@ void AggregateFunctions::DsThetaUpdate(
     FunctionContext* ctx, const StringVal& src, StringVal* dst) {
   if (src.is_null || src.len == 0) return;
   DCHECK(!dst->is_null);
-  DCHECK_EQ(dst->len, sizeof(datasketches::update_theta_sketch));
-  datasketches::update_theta_sketch* sketch_ptr =
-      reinterpret_cast<datasketches::update_theta_sketch*>(dst->ptr);
+  DCHECK_EQ(dst->len, sizeof(agg_state));
+  auto agg_state_ptr = reinterpret_cast<agg_state*>(dst->ptr);
+  DCHECK_EQ(agg_state_ptr->first, agg_phase::UPDATE);
+  auto sketch_ptr =
+      reinterpret_cast<datasketches::update_theta_sketch*>(agg_state_ptr->second);
   sketch_ptr->update(reinterpret_cast<char*>(src.ptr), src.len);
 }
 
 StringVal AggregateFunctions::DsThetaSerialize(
     FunctionContext* ctx, const StringVal& src) {
   DCHECK(!src.is_null);
-  DCHECK(src.len == sizeof(datasketches::update_theta_sketch)
-      || src.len == sizeof(datasketches::theta_union));
+  DCHECK_EQ(src.len, sizeof(agg_state));
   StringVal dst;
-  if (src.len == sizeof(datasketches::update_theta_sketch)) {
-    auto sketch_ptr = reinterpret_cast<datasketches::update_theta_sketch*>(src.ptr);
+  auto agg_state_ptr = reinterpret_cast<agg_state*>(src.ptr);
+  if (agg_state_ptr->first == agg_phase::UPDATE) { // the agg state is a sketch
+    auto sketch_ptr =
+        reinterpret_cast<datasketches::update_theta_sketch*>(agg_state_ptr->second);
     dst = SerializeDsSketch(ctx, sketch_ptr->compact());
-  } else {
-    auto union_ptr = reinterpret_cast<datasketches::theta_union*>(src.ptr);
+    sketch_ptr->~update_theta_sketch_alloc();
+  } else { // the agg state is a union
+    auto union_ptr = reinterpret_cast<datasketches::theta_union*>(agg_state_ptr->second);
     dst = SerializeDsThetaUnion(ctx, *union_ptr);
+    union_ptr->~theta_union_alloc();
   }
+  ctx->Free(reinterpret_cast<uint8_t*>(agg_state_ptr->second));
   ctx->Free(src.ptr);
   return dst;
 }
@@ -2043,12 +2129,13 @@ void AggregateFunctions::DsThetaMerge(
     FunctionContext* ctx, const StringVal& src, StringVal* dst) {
   DCHECK(!src.is_null);
   DCHECK(!dst->is_null);
-  DCHECK(dst->len == sizeof(datasketches::update_theta_sketch)
-      or dst->len == sizeof(datasketches::theta_union));
+  DCHECK_EQ(dst->len, sizeof(agg_state));
+  auto agg_state_ptr = reinterpret_cast<agg_state*>(dst->ptr);
 
   // Note, 'src' is a serialized compact_theta_sketch.
-  if (dst->len == sizeof(datasketches::theta_union)) {
-    auto dst_union_ptr = reinterpret_cast<datasketches::theta_union*>(dst->ptr);
+  if (agg_state_ptr->first == agg_phase::MERGE) { // was already switched to union
+    auto dst_union_ptr =
+        reinterpret_cast<datasketches::theta_union*>(agg_state_ptr->second);
     try {
       dst_union_ptr->update(datasketches::compact_theta_sketch::deserialize(src.ptr,
           src.len));
@@ -2056,47 +2143,45 @@ void AggregateFunctions::DsThetaMerge(
       LogSketchDeserializationError(ctx, e);
       return;
     }
-  } else if (dst->len == sizeof(datasketches::update_theta_sketch)) {
-    auto dst_sketch_ptr = reinterpret_cast<datasketches::update_theta_sketch*>(dst->ptr);
+  } else { // must be the first call. the state is still a sketch
+    auto dst_sketch_ptr =
+        reinterpret_cast<datasketches::update_theta_sketch*>(agg_state_ptr->second);
 
-    datasketches::theta_union union_sketch = datasketches::theta_union::builder().build();
-    union_sketch.update(*dst_sketch_ptr);
+    auto u = datasketches::theta_union::builder().build();
+    u.update(*dst_sketch_ptr);
     try {
-      union_sketch.update(datasketches::compact_theta_sketch::deserialize(src.ptr,
-          src.len));
+      u.update(datasketches::compact_theta_sketch::deserialize(src.ptr, src.len));
     } catch (const std::exception& e) {
       LogSketchDeserializationError(ctx, e);
       return;
     }
 
-    // theta_union.get_result() returns a compact sketch, does not support updating, and
-    // is inconsistent with the initial underlying type of dst. This is different from
-    // the HLL sketch. Here use theta_union as the underlying type of dst.
-    ctx->Free(dst->ptr);
-    AllocBuffer(ctx, dst, sizeof(datasketches::theta_union));
-    if (UNLIKELY(dst->is_null)) {
-      DCHECK(!ctx->impl()->state()->GetQueryStatus().ok());
-      return;
-    }
-    datasketches::theta_union* union_ptr =
-        reinterpret_cast<datasketches::theta_union*>(dst->ptr);
-    std::uninitialized_fill_n(union_ptr, 1, union_sketch);
+    // switch to union
+    dst_sketch_ptr->~update_theta_sketch_alloc();
+    ctx->Free(reinterpret_cast<uint8_t*>(agg_state_ptr->second));
+    agg_state_ptr->second = new (ctx->Allocate<datasketches::theta_union>())
+        datasketches::theta_union(std::move(u));
+    agg_state_ptr->first = agg_phase::MERGE;
   }
 }
 
 BigIntVal AggregateFunctions::DsThetaFinalize(
     FunctionContext* ctx, const StringVal& src) {
   DCHECK(!src.is_null);
-  DCHECK(src.len == sizeof(datasketches::update_theta_sketch)
-      or src.len == sizeof(datasketches::theta_union));
+  DCHECK_EQ(src.len, sizeof(agg_state));
   BigIntVal estimate;
-  if (src.len == sizeof(datasketches::update_theta_sketch)) {
-    auto sketch_ptr = reinterpret_cast<datasketches::update_theta_sketch*>(src.ptr);
+  auto agg_state_ptr = reinterpret_cast<agg_state*>(src.ptr);
+  if (agg_state_ptr->first == agg_phase::UPDATE) { // the agg state is a sketch
+    auto sketch_ptr =
+        reinterpret_cast<datasketches::update_theta_sketch*>(agg_state_ptr->second);
     estimate = sketch_ptr->get_estimate();
-  } else {
-    auto union_ptr = reinterpret_cast<datasketches::theta_union*>(src.ptr);
+    sketch_ptr->~update_theta_sketch_alloc();
+  } else { // the agg state is a union
+    auto union_ptr = reinterpret_cast<datasketches::theta_union*>(agg_state_ptr->second);
     estimate = union_ptr->get_result().get_estimate();
+    union_ptr->~theta_union_alloc();
   }
+  ctx->Free(reinterpret_cast<uint8_t*>(agg_state_ptr->second));
   ctx->Free(src.ptr);
   return estimate;
 }
@@ -2104,16 +2189,20 @@ BigIntVal AggregateFunctions::DsThetaFinalize(
 StringVal AggregateFunctions::DsThetaFinalizeSketch(
     FunctionContext* ctx, const StringVal& src) {
   DCHECK(!src.is_null);
-  DCHECK(src.len == sizeof(datasketches::update_theta_sketch)
-      or src.len == sizeof(datasketches::theta_union));
+  DCHECK_EQ(src.len, sizeof(agg_state));
   StringVal result;
-  if (src.len == sizeof(datasketches::update_theta_sketch)) {
-    auto sketch_ptr = reinterpret_cast<datasketches::update_theta_sketch*>(src.ptr);
+  auto agg_state_ptr = reinterpret_cast<agg_state*>(src.ptr);
+  if (agg_state_ptr->first == agg_phase::UPDATE) { // the agg state is a sketch
+    auto sketch_ptr =
+        reinterpret_cast<datasketches::update_theta_sketch*>(agg_state_ptr->second);
     result = SerializeDsSketch(ctx, sketch_ptr->compact());
-  } else {
-    auto union_ptr = reinterpret_cast<datasketches::theta_union*>(src.ptr);
+    sketch_ptr->~update_theta_sketch_alloc();
+  } else { // the agg state is a union
+    auto union_ptr = reinterpret_cast<datasketches::theta_union*>(agg_state_ptr->second);
     result = SerializeDsThetaUnion(ctx, *union_ptr);
+    union_ptr->~theta_union_alloc();
   }
+  ctx->Free(reinterpret_cast<uint8_t*>(agg_state_ptr->second));
   ctx->Free(src.ptr);
   return result;
 }
@@ -2124,10 +2213,8 @@ void AggregateFunctions::DsThetaUnionInit(FunctionContext* ctx, StringVal* dst)
     DCHECK(!ctx->impl()->state()->GetQueryStatus().ok());
     return;
   }
-  datasketches::theta_union* union_ptr =
-      reinterpret_cast<datasketches::theta_union*>(dst->ptr);
-  datasketches::theta_union union_sketch = datasketches::theta_union::builder().build();
-  std::uninitialized_fill_n(union_ptr, 1, union_sketch);
+  auto union_ptr = reinterpret_cast<datasketches::theta_union*>(dst->ptr);
+  new (union_ptr) datasketches::theta_union(datasketches::theta_union::builder().build());
 }
 
 void AggregateFunctions::DsThetaUnionUpdate(
@@ -2147,9 +2234,9 @@ StringVal AggregateFunctions::DsThetaUnionSerialize(
     FunctionContext* ctx, const StringVal& src) {
   DCHECK(!src.is_null);
   DCHECK_EQ(src.len, sizeof(datasketches::theta_union));
-  datasketches::theta_union* union_ptr =
-      reinterpret_cast<datasketches::theta_union*>(src.ptr);
+  auto union_ptr = reinterpret_cast<datasketches::theta_union*>(src.ptr);
   StringVal dst = SerializeDsThetaUnion(ctx, *union_ptr);
+  union_ptr->~theta_union_alloc();
   ctx->Free(src.ptr);
   return dst;
 }
@@ -2172,14 +2259,13 @@ StringVal AggregateFunctions::DsThetaUnionFinalize(
     FunctionContext* ctx, const StringVal& src) {
   DCHECK(!src.is_null);
   DCHECK_EQ(src.len, sizeof(datasketches::theta_union));
-  datasketches::theta_union* union_ptr =
-      reinterpret_cast<datasketches::theta_union*>(src.ptr);
+  auto union_ptr = reinterpret_cast<datasketches::theta_union*>(src.ptr);
   auto sketch = union_ptr->get_result();
-  if (sketch.is_empty()) {
-    ctx->Free(src.ptr);
-    return StringVal::null();
+  StringVal result = StringVal::null();
+  if (!sketch.is_empty()) {
+    result = SerializeDsSketch(ctx, sketch);
   }
-  StringVal result = SerializeDsSketch(ctx, sketch);
+  union_ptr->~theta_union_alloc();
   ctx->Free(src.ptr);
   return result;
 }
@@ -2190,9 +2276,8 @@ void AggregateFunctions::DsThetaIntersectInit(FunctionContext* ctx, StringVal* s
     DCHECK(!ctx->impl()->state()->GetQueryStatus().ok());
     return;
   }
-  datasketches::theta_intersection* intersection_ptr =
-      reinterpret_cast<datasketches::theta_intersection*>(slot->ptr);
-  *intersection_ptr = datasketches::theta_intersection();
+  auto intersection_ptr = reinterpret_cast<datasketches::theta_intersection*>(slot->ptr);
+  new (intersection_ptr) datasketches::theta_intersection();
 }
 
 void AggregateFunctions::DsThetaIntersectUpdate(
@@ -2212,9 +2297,9 @@ StringVal AggregateFunctions::DsThetaIntersectSerialize(
     FunctionContext* ctx, const StringVal& src) {
   DCHECK(!src.is_null);
   DCHECK_EQ(src.len, sizeof(datasketches::theta_intersection));
-  datasketches::theta_intersection* intersection_ptr =
-      reinterpret_cast<datasketches::theta_intersection*>(src.ptr);
+  auto intersection_ptr = reinterpret_cast<datasketches::theta_intersection*>(src.ptr);
   StringVal dst = SerializeDsThetaIntersection(ctx, *intersection_ptr);
+  intersection_ptr->~theta_intersection_alloc();
   ctx->Free(src.ptr);
   return dst;
 }
@@ -2238,14 +2323,12 @@ StringVal AggregateFunctions::DsThetaIntersectFinalize(
     FunctionContext* ctx, const StringVal& src) {
   DCHECK(!src.is_null);
   DCHECK_EQ(src.len, sizeof(datasketches::theta_intersection));
-  datasketches::theta_intersection* intersection_ptr =
-      reinterpret_cast<datasketches::theta_intersection*>(src.ptr);
-  if (!intersection_ptr->has_result()) {
-    ctx->Free(src.ptr);
-    return StringVal::null();
+  auto intersection_ptr = reinterpret_cast<datasketches::theta_intersection*>(src.ptr);
+  StringVal result = StringVal::null();
+  if (intersection_ptr->has_result()) {
+    result = SerializeDsSketch(ctx, intersection_ptr->get_result());
   }
-  auto sketch = intersection_ptr->get_result();
-  StringVal result = SerializeDsSketch(ctx, sketch);
+  intersection_ptr->~theta_intersection_alloc();
   ctx->Free(src.ptr);
   return result;
 }
@@ -2259,18 +2342,17 @@ void AggregateFunctions::DsKllInitHelper(FunctionContext* ctx, StringVal* slot)
   // Note, that kll_sketch will always have the same size regardless of the amount of
   // data it keeps track of. This is because it's a wrapper class that holds all the
   // inserted data on heap. Here, we put only the wrapper class into a StringVal.
-  datasketches::kll_sketch<float>* sketch_ptr =
-      reinterpret_cast<datasketches::kll_sketch<float>*>(slot->ptr);
-  *sketch_ptr = datasketches::kll_sketch<float>();
+  auto sketch_ptr = reinterpret_cast<datasketches::kll_sketch<float>*>(slot->ptr);
+  new (sketch_ptr) datasketches::kll_sketch<float>();
 }
 
 StringVal AggregateFunctions::DsKllSerializeHelper(FunctionContext* ctx,
     const StringVal& src) {
   DCHECK(!src.is_null);
   DCHECK_EQ(src.len, sizeof(datasketches::kll_sketch<float>));
-  datasketches::kll_sketch<float>* sketch_ptr =
-      reinterpret_cast<datasketches::kll_sketch<float>*>(src.ptr);
+  auto sketch_ptr = reinterpret_cast<datasketches::kll_sketch<float>*>(src.ptr);
   StringVal dst = SerializeDsSketch(ctx, *sketch_ptr);
+  sketch_ptr->~kll_sketch();
   ctx->Free(src.ptr);
   return dst;
 }
@@ -2281,8 +2363,7 @@ void AggregateFunctions::DsKllMergeHelper(FunctionContext* ctx, const StringVal&
   DCHECK(!dst->is_null);
   DCHECK_EQ(dst->len, sizeof(datasketches::kll_sketch<float>));
 
-  datasketches::kll_sketch<float>* dst_sketch_ptr =
-      reinterpret_cast<datasketches::kll_sketch<float>*>(dst->ptr);
+  auto dst_sketch_ptr = reinterpret_cast<datasketches::kll_sketch<float>*>(dst->ptr);
   try {
     dst_sketch_ptr->merge(datasketches::kll_sketch<float>::deserialize(src.ptr, src.len));
   } catch (const std::exception& e) {
@@ -2296,13 +2377,12 @@ StringVal AggregateFunctions::DsKllFinalizeHelper(FunctionContext* ctx,
     const StringVal& src) {
   DCHECK(!src.is_null);
   DCHECK_EQ(src.len, sizeof(datasketches::kll_sketch<float>));
-  datasketches::kll_sketch<float>* sketch_ptr =
-      reinterpret_cast<datasketches::kll_sketch<float>*>(src.ptr);
-  if (sketch_ptr->get_n() == 0) {
-    ctx->Free(src.ptr);
-    return StringVal::null();
+  auto sketch_ptr = reinterpret_cast<datasketches::kll_sketch<float>*>(src.ptr);
+  StringVal dst = StringVal::null();
+  if (!sketch_ptr->is_empty()) {
+    dst = SerializeDsSketch(ctx, *sketch_ptr);
   }
-  StringVal dst = SerializeDsSketch(ctx, *sketch_ptr);
+  sketch_ptr->~kll_sketch();
   ctx->Free(src.ptr);
   return dst;
 }
@@ -2316,8 +2396,7 @@ void AggregateFunctions::DsKllUpdate(FunctionContext* ctx, const FloatVal& src,
   if (src.is_null) return;
   DCHECK(!dst->is_null);
   DCHECK_EQ(dst->len, sizeof(datasketches::kll_sketch<float>));
-  datasketches::kll_sketch<float>* sketch_ptr =
-      reinterpret_cast<datasketches::kll_sketch<float>*>(dst->ptr);
+  auto sketch_ptr = reinterpret_cast<datasketches::kll_sketch<float>*>(dst->ptr);
   sketch_ptr->update(src.val);
 }
 
@@ -2348,8 +2427,7 @@ void AggregateFunctions::DsKllUnionUpdate(FunctionContext* ctx, const StringVal&
   if (src.is_null) return;
   DCHECK(!dst->is_null);
   DCHECK_EQ(dst->len, sizeof(datasketches::kll_sketch<float>));
-  datasketches::kll_sketch<float>* dst_sketch =
-      reinterpret_cast<datasketches::kll_sketch<float>*>(dst->ptr);
+  auto dst_sketch = reinterpret_cast<datasketches::kll_sketch<float>*>(dst->ptr);
   try {
     dst_sketch->merge(datasketches::kll_sketch<float>::deserialize(src.ptr, src.len));
   } catch (const std::exception& e) {