You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@datasketches.apache.org by jm...@apache.org on 2020/02/14 00:00:23 UTC
[incubator-datasketches-cpp] branch sampling updated: [WIP] add var
opt union tests, fixing bugs uncovered along the way
This is an automated email from the ASF dual-hosted git repository.
jmalkin pushed a commit to branch sampling
in repository https://gitbox.apache.org/repos/asf/incubator-datasketches-cpp.git
The following commit(s) were added to refs/heads/sampling by this push:
new 2f719b8 [WIP] add var opt union tests, fixing bugs uncovered along the way
2f719b8 is described below
commit 2f719b8a891e6465fb5fe0f390ccd886daf4df5d
Author: Jon Malkin <jm...@users.noreply.github.com>
AuthorDate: Thu Feb 13 16:00:03 2020 -0800
[WIP] add var opt union tests, fixing bugs uncovered along the way
---
sampling/include/var_opt_sketch.hpp | 30 +--
sampling/include/var_opt_sketch_impl.hpp | 58 +++--
sampling/include/var_opt_union.hpp | 27 +-
sampling/include/var_opt_union_impl.hpp | 261 ++++++++++++++++++-
sampling/test/CMakeLists.txt | 3 +-
sampling/test/var_opt_sketch_test.cpp | 1 -
sampling/test/var_opt_union_test.cpp | 418 +++++++++++++++++++++++++++++++
7 files changed, 736 insertions(+), 62 deletions(-)
diff --git a/sampling/include/var_opt_sketch.hpp b/sampling/include/var_opt_sketch.hpp
index 79cfd48..7230537 100644
--- a/sampling/include/var_opt_sketch.hpp
+++ b/sampling/include/var_opt_sketch.hpp
@@ -106,7 +106,7 @@ class var_opt_sketch {
static const uint8_t PREAMBLE_LONGS_WARMUP = 3;
static const uint8_t PREAMBLE_LONGS_FULL = 4;
static const uint8_t SER_VER = 2;
- static const uint8_t FAMILY = 13;
+ static const uint8_t FAMILY_ID = 13;
static const uint8_t EMPTY_FLAG_MASK = 4;
static const uint8_t GADGET_FLAG_MASK = 128;
@@ -206,34 +206,6 @@ class var_opt_sketch {
static double next_double_exclude_zero();
};
-/*
-template<typename T, typename S, typename A>
-class var_opt_sketch<T, S, A>::const_iterator: public std::iterator<std::input_iterator_tag, T> {
-public:
- friend class var_opt_sketch<T, S, A>;
- const_iterator(const const_iterator& other);
- const_iterator& operator++();
- const_iterator& operator++(int);
- bool operator==(const const_iterator& other) const;
- bool operator!=(const const_iterator& other) const;
- const std::pair<const T&, const double> operator*() const;
-private:
- const T* items;
- const double* weights;
- const bool* marks;
- const uint32_t h_count;
- const uint32_t r_count;
- const double total_wt_r;
- const double r_item_wt;
- double cum_weight; // used for weight correction in R
- uint32_t final_idx;
- uint32_t index;
- const bool get_mark() const;
- const_iterator(const T* items, const double* weights, const bool* marks_,
- const uint32_t h_count, const uint32_t r_count, const double total_wt_r, bool use_end=false);
-};
-*/
-
template<typename T, typename S, typename A>
class var_opt_sketch<T, S, A>::const_iterator: public std::iterator<std::input_iterator_tag, T> {
public:
diff --git a/sampling/include/var_opt_sketch_impl.hpp b/sampling/include/var_opt_sketch_impl.hpp
index 2e83671..4fea686 100644
--- a/sampling/include/var_opt_sketch_impl.hpp
+++ b/sampling/include/var_opt_sketch_impl.hpp
@@ -59,9 +59,20 @@ var_opt_sketch<T,S,A>::var_opt_sketch(const var_opt_sketch& other) :
marks_(nullptr)
{
data_ = A().allocate(curr_items_alloc_);
- std::copy(&other.data_[0], &other.data_[curr_items_alloc_], data_);
+ if (other.filled_data_) {
+ // copy everything
+ for (size_t i = 0; i < curr_items_alloc_; ++i)
+ data_[i] = other.data_[i];
+ } else {
+ // skip gap or anything unused at the end
+ for (size_t i = 0; i < h_; ++i)
+ data_[i] = other.data_[i];
+ for (size_t i = h_ + 1; i < h_ + r_ + 1; ++i)
+ data_[i] = other.data_[i];
+ }
weights_ = AllocDouble().allocate(curr_items_alloc_);
+ // doubles so can successfully copy regardless of the internal state
std::copy(&other.weights_[0], &other.weights_[curr_items_alloc_], weights_);
if (other.marks_ != nullptr) {
@@ -87,9 +98,20 @@ var_opt_sketch<T,S,A>::var_opt_sketch(const var_opt_sketch& other, bool as_sketc
marks_(nullptr)
{
data_ = A().allocate(curr_items_alloc_);
- std::copy(&other.data_[0], &other.data_[curr_items_alloc_], data_);
+ if (other.filled_data_) {
+ // copy everything
+ for (size_t i = 0; i < curr_items_alloc_; ++i)
+ data_[i] = other.data_[i];
+ } else {
+ // skip gap or anything unused at the end
+ for (size_t i = 0; i < h_; ++i)
+ data_[i] = other.data_[i];
+ for (size_t i = h_ + 1; i < h_ + r_ + 1; ++i)
+ data_[i] = other.data_[i];
+ }
weights_ = AllocDouble().allocate(curr_items_alloc_);
+ // doubles so can successfully copy regardless of the internal state
std::copy(&other.weights_[0], &other.weights_[curr_items_alloc_], weights_);
if (!as_sketch && other.marks_ != nullptr) {
@@ -240,6 +262,8 @@ var_opt_sketch<T,S,A>::var_opt_sketch(uint32_t k, resize_factor rf, bool is_gadg
throw std::invalid_argument("Possible corruption: deserializing in full mode but r = 0 or invalid R weight. "
"Found r = " + std::to_string(r_) + ", R region weight = " + std::to_string(total_wt_r_));
}
+ } else {
+ total_wt_r_ = 0.0;
}
allocate_data_arrays(curr_items_alloc_, is_gadget);
@@ -255,15 +279,15 @@ var_opt_sketch<T,S,A>::var_opt_sketch(uint32_t k, resize_factor rf, bool is_gadg
std::fill(&weights_[h_], &weights_[curr_items_alloc_], -1.0);
// read the first h_ marks as packed bytes iff we have a gadget
+ num_marks_in_h_ = 0;
if (is_gadget) {
- uint32_t num_bytes = (h_ >> 3) + ((h_ & 0x7) > 0 ? 1 : 0);
-
uint8_t val = 0;
- for (int i = 0; i < num_bytes; ++i) {
- if ((i & 0x7) == 0x0) { // should trigger on first iteration
+ for (int i = 0; i < h_; ++i) {
+ if ((i & 0x7) == 0x0) { // should trigger on first iteration
is.read((char*)&val, sizeof(val));
}
marks_[i] = ((val >> (i & 0x7)) & 0x1) == 1;
+ num_marks_in_h_ += (marks_[i] ? 1 : 0);
}
}
@@ -292,6 +316,8 @@ var_opt_sketch<T,S,A>::var_opt_sketch(uint32_t k, resize_factor rf, bool is_gadg
throw std::invalid_argument("Possible corruption: deserializing in full mode but r = 0 or invalid R weight. "
"Found r = " + std::to_string(r_) + ", R region weight = " + std::to_string(total_wt_r_));
}
+ } else {
+ total_wt_r_ = 0.0;
}
allocate_data_arrays(curr_items_alloc_, is_gadget);
@@ -304,17 +330,17 @@ var_opt_sketch<T,S,A>::var_opt_sketch(uint32_t k, resize_factor rf, bool is_gadg
}
}
std::fill(&weights_[h_], &weights_[curr_items_alloc_], -1.0);
-
+
// read the first h_ marks as packed bytes iff we have a gadget
+ num_marks_in_h_ = 0;
if (is_gadget) {
- uint32_t num_bytes = (h_ >> 3) + ((h_ & 0x7) > 0 ? 1 : 0);
-
uint8_t val = 0;
- for (int i = 0; i < num_bytes; ++i) {
+ for (int i = 0; i < h_; ++i) {
if ((i & 0x7) == 0x0) { // should trigger on first iteration
ptr += copy_from_mem(ptr, &val, sizeof(val));
}
marks_[i] = ((val >> (i & 0x7)) & 0x1) == 1;
+ num_marks_in_h_ += (marks_[i] ? 1 : 0);
}
}
@@ -410,7 +436,7 @@ std::vector<uint8_t, AllocU8<A>> var_opt_sketch<T,S,A>::serialize(unsigned heade
// first prelong
uint8_t ser_ver(SER_VER);
- uint8_t family(FAMILY);
+ uint8_t family(FAMILY_ID);
ptr += copy_to_mem(&first_byte, ptr, sizeof(uint8_t));
ptr += copy_to_mem(&ser_ver, ptr, sizeof(uint8_t));
ptr += copy_to_mem(&family, ptr, sizeof(uint8_t));
@@ -479,7 +505,7 @@ void var_opt_sketch<T,S,A>::serialize(std::ostream& os) const {
// first prelong
uint8_t ser_ver(SER_VER);
- uint8_t family(FAMILY);
+ uint8_t family(FAMILY_ID);
os.write((char*)&first_byte, sizeof(uint8_t));
os.write((char*)&ser_ver, sizeof(uint8_t));
os.write((char*)&family, sizeof(uint8_t));
@@ -900,7 +926,7 @@ void var_opt_sketch<T,S,A>::grow_data_arrays() {
for (int i = 0; i < prev_size; ++i) {
A().construct(&tmp_data[i], std::move(data_[i]));
A().destroy(data_ + i);
- tmp_weights[i] = std::move(weights_[i]); // primitive double, but for consistency
+ tmp_weights[i] = weights_[i];
}
A().deallocate(data_, prev_size);
@@ -912,7 +938,7 @@ void var_opt_sketch<T,S,A>::grow_data_arrays() {
if (marks_ != nullptr) {
bool* tmp_marks = AllocBool().allocate(curr_items_alloc_);
for (int i = 0; i < prev_size; ++i) {
- tmp_marks[i] = std::move(marks_ + i); // primitive bool, again for consisntency
+ tmp_marks[i] = marks_[i];
}
AllocBool().deallocate(marks_, prev_size);
marks_ = std::move(tmp_marks);
@@ -1236,7 +1262,7 @@ void var_opt_sketch<T,S,A>::check_preamble_longs(uint8_t preamble_longs, uint8_t
template<typename T, typename S, typename A>
void var_opt_sketch<T,S,A>::check_family_and_serialization_version(uint8_t family_id, uint8_t ser_ver) {
- if (family_id == FAMILY) {
+ if (family_id == FAMILY_ID) {
if (ser_ver != SER_VER) {
throw std::invalid_argument("Possible corruption: VarOpt serialization version must be "
+ std::to_string(SER_VER) + ". Found: " + std::to_string(ser_ver));
@@ -1246,7 +1272,7 @@ void var_opt_sketch<T,S,A>::check_family_and_serialization_version(uint8_t famil
// TODO: extend to handle reservoir sampling
throw std::invalid_argument("Possible corruption: VarOpt family id must be "
- + std::to_string(FAMILY) + ". Found: " + std::to_string(family_id));
+ + std::to_string(FAMILY_ID) + ". Found: " + std::to_string(family_id));
}
template<typename T, typename S, typename A>
diff --git a/sampling/include/var_opt_union.hpp b/sampling/include/var_opt_union.hpp
index a1a24ff..240e119 100644
--- a/sampling/include/var_opt_union.hpp
+++ b/sampling/include/var_opt_union.hpp
@@ -38,11 +38,13 @@ template <typename T, typename S = serde<T>, typename A = std::allocator<T>>
class var_opt_union {
public:
+ static const uint32_t MAX_K = ((uint32_t) 1 << 31) - 2;
+
explicit var_opt_union(uint32_t max_k);
var_opt_union(const var_opt_union& other);
var_opt_union(var_opt_union&& other) noexcept;
- //static var_opt_union deserialize(std::istream& is);
- //static var_opt_union deserialize(const void* bytes, size_t size);
+ static var_opt_union deserialize(std::istream& is);
+ static var_opt_union deserialize(const void* bytes, size_t size);
~var_opt_union();
@@ -50,7 +52,7 @@ public:
var_opt_union& operator=(var_opt_union&& other);
void update(var_opt_sketch<T,S,A>& sk);
- //void update(var_opt_sketch<T,S,A>>& sk);
+ //void update(var_opt_sketch<T,S,A>&& sk);
void reset();
@@ -61,8 +63,9 @@ public:
*/
var_opt_sketch<T,S,A> get_result() const;
- //std::vector<uint8_t, AllocU8<A>> serialize(unsigned header_size_bytes = 0) const;
- //void serialize(std::ostream& os) const;
+ size_t get_serialized_size_bytes() const;
+ void serialize(std::ostream& os) const;
+ std::vector<uint8_t, AllocU8<A>> serialize(unsigned header_size_bytes = 0) const;
std::ostream& to_stream(std::ostream& os) const;
std::string to_string() const;
@@ -71,6 +74,12 @@ public:
private:
typedef typename std::allocator_traits<A>::template rebind_alloc<var_opt_sketch<T,S,A>> AllocSketch;
+ static const uint8_t PREAMBLE_LONGS_EMPTY = 1;
+ static const uint8_t PREAMBLE_LONGS_NON_EMPTY = 4;
+ static const uint8_t SER_VER = 2;
+ static const uint8_t FAMILY_ID = 14;
+ static const uint8_t EMPTY_FLAG_MASK = 4;
+
uint64_t n_; // cumulative over all input sketches
// outer tau is the largest tau of any input sketch
@@ -79,10 +88,13 @@ private:
// total cardinality of the same R-zones, or zero if no input sketch was in estimation mode
uint64_t outer_tau_denom_;
- const uint32_t max_k_;
+ uint32_t max_k_;
var_opt_sketch<T,S,A> gadget_;
+ var_opt_union(uint64_t n, double outer_tau_numer, uint64_t outer_tau_denom,
+ uint32_t max_k, var_opt_sketch<T,S,A>&& gadget);
+
/*
IMPORTANT NOTE: the "gadget" in the union object appears to be a varopt sketch,
but in fact is NOT because it doesn't satisfy the mathematical definition
@@ -147,6 +159,9 @@ private:
bool detect_and_handle_subcase_of_pseudo_exact(var_opt_sketch<T,S,A>& sk) const;
void mark_moving_gadget_coercer(var_opt_sketch<T,S,A>& sk) const;
void migrate_marked_items_by_decreasing_k(var_opt_sketch<T,S,A>& sk) const;
+
+ static void check_preamble_longs(uint8_t preamble_longs, uint8_t flags);
+ static void check_family_and_serialization_version(uint8_t family_id, uint8_t ser_ver);
};
}
diff --git a/sampling/include/var_opt_union_impl.hpp b/sampling/include/var_opt_union_impl.hpp
index 77845a0..dd87ca8 100644
--- a/sampling/include/var_opt_union_impl.hpp
+++ b/sampling/include/var_opt_union_impl.hpp
@@ -42,7 +42,7 @@ var_opt_union<T,S,A>::var_opt_union(const var_opt_union& other) :
outer_tau_numer_(other.outer_tau_numer_),
outer_tau_denom_(other.outer_tau_denom_),
max_k_(other.max_k_),
- gadget_(other.gadget)
+ gadget_(other.gadget_)
{}
template<typename T, typename S, typename A>
@@ -50,10 +50,19 @@ var_opt_union<T,S,A>::var_opt_union(var_opt_union&& other) noexcept :
n_(other.n_),
outer_tau_numer_(other.outer_tau_numer_),
outer_tau_denom_(other.outer_tau_denom_),
- max_k_(other.max_k_)
-{
- gadget_ = std::move(other.gadget_);
-}
+ max_k_(other.max_k_),
+ gadget_(std::move(other.gadget_))
+{}
+
+template<typename T, typename S, typename A>
+var_opt_union<T,S,A>::var_opt_union(uint64_t n, double outer_tau_numer, uint64_t outer_tau_denom,
+ uint32_t max_k, var_opt_sketch<T,S,A>&& gadget) :
+ n_(n),
+ outer_tau_numer_(outer_tau_numer),
+ outer_tau_denom_(outer_tau_denom),
+ max_k_(max_k),
+ gadget_(gadget)
+{}
template<typename T, typename S, typename A>
var_opt_union<T,S,A>::~var_opt_union() {}
@@ -63,7 +72,7 @@ var_opt_union<T,S,A>& var_opt_union<T,S,A>::operator=(const var_opt_union& other
var_opt_union<T,S,A> union_copy(other);
std::swap(n_, union_copy.n_);
std::swap(outer_tau_numer_, union_copy.outer_tau_numer_);
- std::swap(outer_tau_numer_, union_copy.outer_tau_denom_);
+ std::swap(outer_tau_denom_, union_copy.outer_tau_denom_);
std::swap(max_k_, union_copy.max_k_);
std::swap(gadget_, union_copy.gadget_);
return *this;
@@ -73,16 +82,216 @@ template<typename T, typename S, typename A>
var_opt_union<T,S,A>& var_opt_union<T,S,A>::operator=(var_opt_union&& other) {
std::swap(n_, other.n_);
std::swap(outer_tau_numer_, other.outer_tau_numer_);
- std::swap(outer_tau_numer_, other.outer_tau_denom_);
+ std::swap(outer_tau_denom_, other.outer_tau_denom_);
std::swap(max_k_, other.max_k_);
std::swap(gadget_, other.gadget_);
return *this;
}
+/*
+ * An empty union requires 8 bytes.
+ *
+ * <pre>
+ * Long || Start Byte Adr:
+ * Adr:
+ * || 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 |
+ * 0 || Preamble_Longs | SerVer | FamID | Flags |---------Max Res. Size (K)---------|
+ * </pre>
+ *
+ * A non-empty sketch requires 24 bytes of preamble for an under-full sample; once there are
+ * at least k items the sketch uses 32 bytes of preamble.
+ *
+ * The count of items seen is limited to 48 bits (~256 trillion) even though there are adjacent
+ * unused preamble bits. The acceptance probability for an item is a double in the range [0,1),
+ * limiting us to 53 bits of randomness due to details of the IEEE floating point format. To
+ * ensure meaningful probabilities as the items seen count approaches capacity, we intentionally
+ * use slightly fewer bits.
+ *
+ * Following the header are weights for the heavy items, then marks in the event this is a gadget.
+ * The serialized items come last.
+ *
+ * <pre>
+ * Long || Start Byte Adr:
+ * Adr:
+ * || 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 |
+ * 0 || Preamble_Longs | SerVer | FamID | Flags |---------Max Res. Size (K)---------|
+ *
+ * || 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 |
+ * 1 ||---------------------------Items Seen Count (N)--------------------------------|
+ *
+ * || 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 |
+ * 2 ||------------------------Outer Tau Numerator (double)---------------------------|
+ *
+ * || 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 |
+ * 3 ||----------------------Outer Tau Denominator (uint64_t)-------------------------|
+ * </pre>
+ */
+
+template<typename T, typename S, typename A>
+var_opt_union<T,S,A> var_opt_union<T,S,A>::deserialize(std::istream& is) {
+ uint8_t preamble_longs;
+ is.read((char*)&preamble_longs, sizeof(preamble_longs));
+ uint8_t serial_version;
+ is.read((char*)&serial_version, sizeof(serial_version));
+ uint8_t family_id;
+ is.read((char*)&family_id, sizeof(family_id));
+ uint8_t flags;
+ is.read((char*)&flags, sizeof(flags));
+ uint32_t max_k;
+ is.read((char*)&max_k, sizeof(max_k));
+
+ check_preamble_longs(preamble_longs, flags);
+ check_family_and_serialization_version(family_id, serial_version);
+
+ if (max_k == 0 || max_k > MAX_K) {
+ throw std::invalid_argument("k must be at least 1 and less than 2^31 - 1");
+ }
+
+ bool is_empty = flags & EMPTY_FLAG_MASK;
+
+ if (is_empty) {
+ return var_opt_union<T,S,A>(max_k);
+ }
+
+ uint64_t items_seen;
+ is.read((char*)&items_seen, sizeof(items_seen));
+ double outer_tau_numer;
+ is.read((char*)&outer_tau_numer, sizeof(outer_tau_numer));
+ uint64_t outer_tau_denom;
+ is.read((char*)&outer_tau_denom, sizeof(outer_tau_denom));
+
+ var_opt_sketch<T,S,A> gadget = var_opt_sketch<T,S,A>::deserialize(is);
+
+ return var_opt_union<T,S,A>(items_seen, outer_tau_numer, outer_tau_denom, max_k, std::move(gadget));
+}
+
+template<typename T, typename S, typename A>
+var_opt_union<T,S,A> var_opt_union<T,S,A>::deserialize(const void* bytes, size_t size) {
+ const char* ptr = static_cast<const char*>(bytes);
+ uint8_t preamble_longs;
+ ptr += copy_from_mem(ptr, &preamble_longs, sizeof(preamble_longs));
+ uint8_t serial_version;
+ ptr += copy_from_mem(ptr, &serial_version, sizeof(serial_version));
+ uint8_t family_id;
+ ptr += copy_from_mem(ptr, &family_id, sizeof(family_id));
+ uint8_t flags;
+ ptr += copy_from_mem(ptr, &flags, sizeof(flags));
+ uint32_t max_k;
+ ptr += copy_from_mem(ptr, &max_k, sizeof(max_k));
+
+ check_preamble_longs(preamble_longs, flags);
+ check_family_and_serialization_version(family_id, serial_version);
+
+ if (max_k == 0 || max_k > MAX_K) {
+ throw std::invalid_argument("k must be at least 1 and less than 2^31 - 1");
+ }
+
+ bool is_empty = flags & EMPTY_FLAG_MASK;
+
+ if (is_empty) {
+ return var_opt_union<T,S,A>(max_k);
+ }
+
+ uint64_t items_seen;
+ ptr += copy_from_mem(ptr, &items_seen, sizeof(items_seen));
+ double outer_tau_numer;
+ ptr += copy_from_mem(ptr, &outer_tau_numer, sizeof(outer_tau_numer));
+ uint64_t outer_tau_denom;
+ ptr += copy_from_mem(ptr, &outer_tau_denom, sizeof(outer_tau_denom));
+
+ size_t gadget_size = size - (PREAMBLE_LONGS_NON_EMPTY << 3);
+ var_opt_sketch<T,S,A> gadget = var_opt_sketch<T,S,A>::deserialize(ptr, gadget_size);
+
+ return var_opt_union<T,S,A>(items_seen, outer_tau_numer, outer_tau_denom, max_k, std::move(gadget));
+}
+
+template<typename T, typename S, typename A>
+size_t var_opt_union<T,S,A>::get_serialized_size_bytes() const {
+ if (n_ == 0) {
+ return PREAMBLE_LONGS_EMPTY << 3;
+ } else {
+ return (PREAMBLE_LONGS_NON_EMPTY << 3) + gadget_.get_serialized_size_bytes();
+ }
+}
+
+template<typename T, typename S, typename A>
+void var_opt_union<T,S,A>::serialize(std::ostream& os) const {
+ bool empty = (n_ == 0);
+
+ uint8_t serialization_version(SER_VER);
+ uint8_t family_id(FAMILY_ID);
+
+ uint8_t preamble_longs;
+ uint8_t flags;
+ if (empty) {
+ preamble_longs = PREAMBLE_LONGS_EMPTY;
+ flags = EMPTY_FLAG_MASK;
+ } else {
+ preamble_longs = PREAMBLE_LONGS_NON_EMPTY;
+ flags = 0;
+ }
+
+ os.write((char*) &preamble_longs, sizeof(uint8_t));
+ os.write((char*) &serialization_version, sizeof(uint8_t));
+ os.write((char*) &family_id, sizeof(uint8_t));
+ os.write((char*) &flags, sizeof(uint8_t));
+ os.write((char*) &max_k_, sizeof(uint32_t));
+
+ if (!empty) {
+ os.write((char*) &n_, sizeof(uint64_t));
+ os.write((char*) &outer_tau_numer_, sizeof(double));
+ os.write((char*) &outer_tau_denom_, sizeof(uint64_t));
+ gadget_.serialize(os);
+ }
+}
+
+template<typename T, typename S, typename A>
+std::vector<uint8_t, AllocU8<A>> var_opt_union<T,S,A>::serialize(unsigned header_size_bytes) const {
+ const size_t size = header_size_bytes + get_serialized_size_bytes();
+ std::vector<uint8_t, AllocU8<A>> bytes(size);
+ uint8_t* ptr = bytes.data() + header_size_bytes;
+
+ bool empty = n_ == 0;
+
+ uint8_t serialization_version(SER_VER);
+ uint8_t family_id(FAMILY_ID);
+
+ uint8_t preamble_longs;
+ uint8_t flags;
+
+ if (empty) {
+ preamble_longs = PREAMBLE_LONGS_EMPTY;
+ flags = EMPTY_FLAG_MASK;
+ } else {
+ preamble_longs = PREAMBLE_LONGS_NON_EMPTY;
+ flags = 0;
+ }
+
+ // first prelong
+ ptr += copy_to_mem(&preamble_longs, ptr, sizeof(uint8_t));
+ ptr += copy_to_mem(&serialization_version, ptr, sizeof(uint8_t));
+ ptr += copy_to_mem(&family_id, ptr, sizeof(uint8_t));
+ ptr += copy_to_mem(&flags, ptr, sizeof(uint8_t));
+ ptr += copy_to_mem(&max_k_, ptr, sizeof(uint32_t));
+
+ if (!empty) {
+ ptr += copy_to_mem(&n_, ptr, sizeof(uint64_t));
+ ptr += copy_to_mem(&outer_tau_numer_, ptr, sizeof(double));
+ ptr += copy_to_mem(&outer_tau_denom_, ptr, sizeof(uint64_t));
+
+ auto gadget_bytes = gadget_.serialize();
+ ptr += copy_to_mem(gadget_bytes.data(), ptr, gadget_bytes.size() * sizeof(uint8_t));
+ }
+
+ return bytes;
+}
+
template<typename T, typename S, typename A>
void var_opt_union<T,S,A>::reset() {
- if (gadget_ != nullptr)
- gadget_.reset();
+ n_ = 0;
+ outer_tau_numer_ = 0.0;
+ outer_tau_denom_ = 0;
+ gadget_.reset();
}
template<typename T, typename S, typename A>
@@ -104,6 +313,7 @@ std::string var_opt_union<T,S,A>::to_string() const {
return ss.str();
}
+
template<typename T, typename S, typename A>
void var_opt_union<T,S,A>::update(var_opt_sketch<T,S,A>& sk) {
merge_into(sk);
@@ -361,6 +571,39 @@ void var_opt_union<T,S,A>::migrate_marked_items_by_decreasing_k(var_opt_sketch<T
gcopy.strip_marks();
}
+template<typename T, typename S, typename A>
+void var_opt_union<T,S,A>::check_preamble_longs(uint8_t preamble_longs, uint8_t flags) {
+ bool is_empty(flags & EMPTY_FLAG_MASK);
+
+ if (is_empty) {
+ if (preamble_longs != PREAMBLE_LONGS_EMPTY) {
+ throw std::invalid_argument("Possible corruption: Preamble longs must be "
+ + std::to_string(PREAMBLE_LONGS_EMPTY) + " for an empty sketch. Found: "
+ + std::to_string(preamble_longs));
+ }
+ } else {
+ if (preamble_longs != PREAMBLE_LONGS_NON_EMPTY) {
+ throw std::invalid_argument("Possible corruption: Preamble longs must be "
+ + std::to_string(PREAMBLE_LONGS_NON_EMPTY)
+ + " for a non-empty sketch. Found: " + std::to_string(preamble_longs));
+ }
+ }
+}
+
+template<typename T, typename S, typename A>
+void var_opt_union<T,S,A>::check_family_and_serialization_version(uint8_t family_id, uint8_t ser_ver) {
+ if (family_id == FAMILY_ID) {
+ if (ser_ver != SER_VER) {
+ throw std::invalid_argument("Possible corruption: VarOpt Union serialization version must be "
+ + std::to_string(SER_VER) + ". Found: " + std::to_string(ser_ver));
+ }
+ return;
+ }
+ // TODO: extend to handle reservoir sampling
+
+ throw std::invalid_argument("Possible corruption: VarOpt Union family id must be "
+ + std::to_string(FAMILY_ID) + ". Found: " + std::to_string(family_id));
+}
} // namespace datasketches
diff --git a/sampling/test/CMakeLists.txt b/sampling/test/CMakeLists.txt
index e291b25..00327e0 100644
--- a/sampling/test/CMakeLists.txt
+++ b/sampling/test/CMakeLists.txt
@@ -37,7 +37,8 @@ add_test(
target_sources(sampling_test
PRIVATE
- var_opt_sketch_test.cpp
+ var_opt_sketch_test.cpp
+ var_opt_union_test.cpp
)
target_include_directories(sampling_test
diff --git a/sampling/test/var_opt_sketch_test.cpp b/sampling/test/var_opt_sketch_test.cpp
index 524454e..a2819f3 100644
--- a/sampling/test/var_opt_sketch_test.cpp
+++ b/sampling/test/var_opt_sketch_test.cpp
@@ -18,7 +18,6 @@
*/
#include <var_opt_sketch.hpp>
-#include <var_opt_union.hpp>
#include <cppunit/TestFixture.h>
#include <cppunit/extensions/HelperMacros.h>
diff --git a/sampling/test/var_opt_union_test.cpp b/sampling/test/var_opt_union_test.cpp
new file mode 100644
index 0000000..f0a519f
--- /dev/null
+++ b/sampling/test/var_opt_union_test.cpp
@@ -0,0 +1,418 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <var_opt_sketch.hpp>
+#include <var_opt_union.hpp>
+
+#include <cppunit/TestFixture.h>
+#include <cppunit/extensions/HelperMacros.h>
+
+#include <vector>
+#include <string>
+#include <sstream>
+#include <iostream>
+#include <cmath>
+#include <random>
+
+#ifdef TEST_BINARY_INPUT_PATH
+static std::string testBinaryInputPath = TEST_BINARY_INPUT_PATH;
+#else
+static std::string testBinaryInputPath = "test/";
+#endif
+
+namespace datasketches {
+
+class var_opt_union_test: public CppUnit::TestFixture {
+
+ static constexpr double EPS = 1e-13;
+ CPPUNIT_TEST_SUITE(var_opt_union_test);
+ CPPUNIT_TEST(bad_prelongs);
+ CPPUNIT_TEST(bad_ser_ver);
+ CPPUNIT_TEST(bad_family);
+ CPPUNIT_TEST(invalid_k);
+ CPPUNIT_TEST(empty_union);
+ CPPUNIT_TEST(two_exact_sketches);
+ CPPUNIT_TEST(heavy_sampling_sketch);
+ CPPUNIT_TEST(identical_sampling_sketches);
+ CPPUNIT_TEST(small_sampling_sketch);
+ CPPUNIT_TEST(serialize_empty);
+ CPPUNIT_TEST(serialize_exact);
+ CPPUNIT_TEST(serialize_sampling);
+ // CPPUNIT_TEST(deserialize_exact_from_java);
+ // CPPUNIT_TEST(deserialize_sampling_from_java);
+ CPPUNIT_TEST_SUITE_END();
+
+ var_opt_sketch<int> create_unweighted_sketch(uint32_t k, uint64_t n) {
+ var_opt_sketch<int> sk(k);
+ for (uint64_t i = 0; i < n; ++i) {
+ sk.update(i, 1.0);
+ }
+ return sk;
+ }
+
+ template<typename T, typename S, typename A>
+ void check_if_equal(var_opt_sketch<T,S,A>& sk1, var_opt_sketch<T,S,A>& sk2) {
+ CPPUNIT_ASSERT_EQUAL_MESSAGE("sketches have different values of k",
+ sk1.get_k(), sk2.get_k());
+ CPPUNIT_ASSERT_EQUAL_MESSAGE("sketches have different values of n",
+ sk1.get_n(), sk2.get_n());
+ CPPUNIT_ASSERT_EQUAL_MESSAGE("sketches have different sample counts",
+ sk1.get_num_samples(), sk2.get_num_samples());
+
+ auto it1 = sk1.begin();
+ auto it2 = sk2.begin();
+ size_t i = 0;
+
+ while ((it1 != sk1.end()) && (it2 != sk2.end())) {
+ const std::pair<const T&, const double> p1 = *it1;
+ const std::pair<const T&, const double> p2 = *it2;
+ CPPUNIT_ASSERT_EQUAL_MESSAGE("data values differ at sample " + std::to_string(i),
+ p1.first, p2.first);
+ CPPUNIT_ASSERT_EQUAL_MESSAGE("weight values differ at sample " + std::to_string(i),
+ p1.second, p2.second);
+ ++i;
+ ++it1;
+ ++it2;
+ }
+
+ CPPUNIT_ASSERT_MESSAGE("iterators did not end at the same time",
+ (it1 == sk1.end()) && (it2 == sk2.end()));
+ }
+
+ void bad_prelongs() {
+ var_opt_sketch<int> sk = create_unweighted_sketch(32, 33);
+ var_opt_union<int> u(32);
+ u.update(sk);
+ std::vector<uint8_t> bytes = u.serialize();
+
+ bytes[0] = 0; // corrupt the preamble longs byte to be too small
+ CPPUNIT_ASSERT_THROW_MESSAGE("deserialize(bytes) failed to catch bad preamble longs",
+ var_opt_union<int>::deserialize(bytes.data(), bytes.size()),
+ std::invalid_argument);
+
+ // create a stringstream to check the same
+ std::stringstream ss;
+ std::string str(bytes.begin(), bytes.end());
+ ss.str(str);
+ CPPUNIT_ASSERT_THROW_MESSAGE("deserialize(stream) failed to catch bad serialization version",
+ var_opt_union<int>::deserialize(ss),
+ std::invalid_argument);
+ }
+
+ void bad_ser_ver() {
+ var_opt_sketch<int> sk = create_unweighted_sketch(16, 16);
+ var_opt_union<int> u(32);
+ u.update(sk);
+ std::vector<uint8_t> bytes = u.serialize();
+ bytes[1] = 0; // corrupt the serialization version byte
+
+ CPPUNIT_ASSERT_THROW_MESSAGE("deserialize(bytes) failed to catch bad serialization version",
+ var_opt_union<int>::deserialize(bytes.data(), bytes.size()),
+ std::invalid_argument);
+
+ // create a stringstream to check the same
+ std::stringstream ss;
+ std::string str(bytes.begin(), bytes.end());
+ ss.str(str);
+ CPPUNIT_ASSERT_THROW_MESSAGE("deserialize(stream) failed to catch bad serialization version",
+ var_opt_union<int>::deserialize(ss),
+ std::invalid_argument);
+ }
+
+ void invalid_k() {
+ CPPUNIT_ASSERT_THROW_MESSAGE("constructor failed to catch invalid k = 0",
+ var_opt_union<int> sk(0),
+ std::invalid_argument);
+
+ CPPUNIT_ASSERT_THROW_MESSAGE("constructor failed to catch invalid k < 0 (aka >= 2^31)",
+ var_opt_union<std::string> sk(1<<31),
+ std::invalid_argument);
+ }
+
+ void bad_family() {
+ var_opt_sketch<int> sk = create_unweighted_sketch(16, 16);
+ var_opt_union<int> u(15);
+ u.update(sk);
+ std::vector<uint8_t> bytes = u.serialize();
+ bytes[2] = 0; // corrupt the family byte
+
+ CPPUNIT_ASSERT_THROW_MESSAGE("deserialize(bytes) failed to catch bad family id",
+ var_opt_union<int>::deserialize(bytes.data(), bytes.size()),
+ std::invalid_argument);
+
+ std::stringstream ss;
+ std::string str(bytes.begin(), bytes.end());
+ ss.str(str);
+ CPPUNIT_ASSERT_THROW_MESSAGE("deserialize(stream) failed to catch bad family id",
+ var_opt_union<int>::deserialize(ss),
+ std::invalid_argument);
+ }
+
+ void empty_union() {
+ uint32_t k = 2048;
+ var_opt_sketch<std::string> sk(k);
+ var_opt_union<std::string> u(k);
+ u.update(sk);
+
+ var_opt_sketch<std::string> result = u.get_result();
+ CPPUNIT_ASSERT(result.is_empty());
+ CPPUNIT_ASSERT_EQUAL((uint64_t) 0, result.get_n());
+ CPPUNIT_ASSERT_EQUAL((uint32_t) 0, result.get_num_samples());
+ CPPUNIT_ASSERT_EQUAL(k, result.get_k());
+ }
+
+ void two_exact_sketches() {
+ uint64_t n = 4; // 2n < k
+ uint32_t k = 10;
+ var_opt_sketch<int> sk1(k), sk2(k);
+
+ for (int i = 1; i <= n; ++i) {
+ sk1.update(i, i);
+ sk2.update(-i, i);
+ }
+
+ var_opt_union<int> u(k);
+ u.update(sk1);
+ u.update(sk2);
+
+ var_opt_sketch<int> result = u.get_result();
+ CPPUNIT_ASSERT_EQUAL(2 * n, result.get_n());
+ CPPUNIT_ASSERT_EQUAL(k, result.get_k());
+ }
+
+void heavy_sampling_sketch() {
+ uint64_t n1 = 20;
+ uint32_t k1 = 10;
+ uint64_t n2 = 6;
+ uint32_t k2 = 5;
+ var_opt_sketch<int64_t> sk1(k1), sk2(k2);
+
+ for (int i = 1; i <= n1; ++i) {
+ sk1.update(i, i);
+ }
+
+ for (int i = 1; i < n2; ++i) { // we'll add a very heavy one later
+ sk2.update(-i, i + 1000.0);
+ }
+ sk2.update(-n2, 1000000.0);
+
+ var_opt_union<int64_t> u(k1);
+ u.update(sk1);
+ u.update(sk2);
+
+ var_opt_sketch<int64_t> result = u.get_result();
+ CPPUNIT_ASSERT_EQUAL(n1 + n2, result.get_n());
+ CPPUNIT_ASSERT_EQUAL(k2, result.get_k()); // heavy enough the result pulls back to k2
+
+ u.reset();
+ result = u.get_result();
+ CPPUNIT_ASSERT_EQUAL((uint64_t) 0, result.get_n());
+ CPPUNIT_ASSERT_EQUAL(k1, result.get_k()); // union reset so empty result reflects max_k
+}
+
+void identical_sampling_sketches() {
+ uint32_t k = 20;
+ uint64_t n = 50;
+ var_opt_sketch<int> sk = create_unweighted_sketch(k, n);
+
+ var_opt_union<int> u(k);
+ u.update(sk);
+ u.update(sk);
+
+ var_opt_sketch<int> result = u.get_result();
+ double expected_wt = 2.0 * n;
+ subset_summary ss = result.estimate_subset_sum([](int x){return true;});
+ CPPUNIT_ASSERT_EQUAL(2 * n, result.get_n());
+ CPPUNIT_ASSERT_DOUBLES_EQUAL(expected_wt, ss.total_sketch_weight, EPS);
+
+ // add another sketch, such that sketch_tau < outer_tau
+ sk = create_unweighted_sketch(k, k + 1); // tau = (k + 1) / k
+ u.update(sk);
+ result = u.get_result();
+ expected_wt = (2.0 * n) + k + 1;
+ ss = result.estimate_subset_sum([](int x){return true;});
+ CPPUNIT_ASSERT_EQUAL((2 * n) + k + 1, result.get_n());
+ CPPUNIT_ASSERT_DOUBLES_EQUAL(expected_wt, ss.total_sketch_weight, EPS);
+}
+
+void small_sampling_sketch() {
+ uint32_t k_small = 16;
+ uint32_t k_max = 128;
+ uint64_t n1 = 32;
+ uint64_t n2 = 64;
+
+ var_opt_sketch<float> sk(k_small);
+ for (int i = 0; i < n1; ++i) { sk.update(i); }
+ sk.update(-1, n1 * n1); // add a heavy item
+
+ var_opt_union<float> u(k_max);
+ u.update(sk);
+
+ // another one, but different n to get a different per-item weight
+ var_opt_sketch<float> sk2(k_small);
+ for (int i = 0; i < n2; ++i) { sk2.update(i); }
+ u.update(sk2);
+
+ // should trigger migrate_marked_items_by_decreasing_k()
+ var_opt_sketch<float> result = u.get_result();
+ CPPUNIT_ASSERT_EQUAL(n1 + n2 + 1, result.get_n());
+
+ double expected_wt = 1.0 * (n1 + n2); // n1 + n2 light items, ignore the heavy one
+ subset_summary ss = result.estimate_subset_sum([](float x){return x >= 0;});
+ CPPUNIT_ASSERT_DOUBLES_EQUAL(expected_wt, ss.estimate, EPS);
+ CPPUNIT_ASSERT_DOUBLES_EQUAL(expected_wt + (n1 * n1), ss.total_sketch_weight, EPS);
+ CPPUNIT_ASSERT_LESS(k_max, result.get_k());
+}
+
+void serialize_empty() {
+ var_opt_union<std::string> u(100);
+
+ std::vector<uint8_t> bytes = u.serialize();
+
+ var_opt_union<std::string> u_from_bytes = var_opt_union<std::string>::deserialize(bytes.data(), bytes.size());
+ var_opt_sketch<std::string> sk1 = u.get_result();
+ var_opt_sketch<std::string> sk2 = u_from_bytes.get_result();
+ check_if_equal(sk1, sk2);
+
+ std::string str(bytes.begin(), bytes.end());
+ std::stringstream ss;
+ ss.str(str);
+
+ var_opt_union<std::string> u_from_stream = var_opt_union<std::string>::deserialize(ss);
+ sk2 = u_from_stream.get_result();
+ check_if_equal(sk1, sk2);
+
+ ss.seekg(0); // didn't put anything so only reset read position
+ u.serialize(ss);
+ u_from_stream = var_opt_union<std::string>::deserialize(ss);
+ sk2 = u_from_stream.get_result();
+ check_if_equal(sk1, sk2);
+}
+
+void serialize_exact() {
+ uint32_t k = 100;
+ var_opt_union<int> u(k);
+ var_opt_sketch<int> sk = create_unweighted_sketch(k, k / 2);
+ u.update(sk);
+
+ std::vector<uint8_t> bytes = u.serialize();
+
+ var_opt_union<int> u_from_bytes = var_opt_union<int>::deserialize(bytes.data(), bytes.size());
+ var_opt_sketch<int> sk1 = u.get_result();
+ var_opt_sketch<int> sk2 = u_from_bytes.get_result();
+ check_if_equal(sk1, sk2);
+
+ std::string str(bytes.begin(), bytes.end());
+ std::stringstream ss;
+ ss.str(str);
+
+ var_opt_union<int> u_from_stream = var_opt_union<int>::deserialize(ss);
+ sk2 = u_from_stream.get_result();
+ check_if_equal(sk1, sk2);
+
+ ss.seekg(0); // didn't put anything so only reset read position
+ u.serialize(ss);
+ u_from_stream = var_opt_union<int>::deserialize(ss);
+ sk2 = u_from_stream.get_result();
+ check_if_equal(sk1, sk2);
+}
+
+void serialize_sampling() {
+ uint32_t k = 100;
+ var_opt_union<int> u(k);
+ var_opt_sketch<int> sk = create_unweighted_sketch(k, 2 * k);
+ u.update(sk);
+
+ std::vector<uint8_t> bytes = u.serialize();
+
+ var_opt_union<int> u_from_bytes = var_opt_union<int>::deserialize(bytes.data(), bytes.size());
+ var_opt_sketch<int> sk1 = u.get_result();
+ var_opt_sketch<int> sk2 = u_from_bytes.get_result();
+ check_if_equal(sk1, sk2);
+
+ std::string str(bytes.begin(), bytes.end());
+ std::stringstream ss;
+ ss.str(str);
+
+ var_opt_union<int> u_from_stream = var_opt_union<int>::deserialize(ss);
+ sk2 = u_from_stream.get_result();
+ check_if_equal(sk1, sk2);
+
+ ss.seekg(0); // didn't put anything so only reset read position
+ u.serialize(ss);
+ u_from_stream = var_opt_union<int>::deserialize(ss);
+ sk2 = u_from_stream.get_result();
+ check_if_equal(sk1, sk2);
+}
+
+/**********************************************************/
+
+
+ void test_union() {
+ var_opt_union<int> u(10);
+
+ var_opt_sketch<int> sk = create_unweighted_sketch(9, 100);
+ u.update(sk);
+ std::cout << u.to_string() << std::endl;
+
+ auto vec = u.serialize();
+ std::cout << vec.size() << "\t" << vec.capacity() << "\t" << vec.empty() << std::endl;
+ }
+
+ void deserialize_exact_from_java() {
+ std::ifstream is;
+ is.exceptions(std::ios::failbit | std::ios::badbit);
+ is.open(testBinaryInputPath + "varopt_string_exact.bin", std::ios::binary);
+ var_opt_sketch<std::string> sketch = var_opt_sketch<std::string>::deserialize(is);
+ CPPUNIT_ASSERT(!sketch.is_empty());
+ CPPUNIT_ASSERT_EQUAL((uint32_t) 1024, sketch.get_k());
+ CPPUNIT_ASSERT_EQUAL((uint64_t) 200, sketch.get_n());
+ CPPUNIT_ASSERT_EQUAL((uint32_t) 200, sketch.get_num_samples());
+ subset_summary ss = sketch.estimate_subset_sum([](std::string x){ return true; });
+
+ double tgt_wt = 0.0;
+ for (int i = 1; i <= 200; ++i) { tgt_wt += 1000.0 / i; }
+ CPPUNIT_ASSERT_DOUBLES_EQUAL(tgt_wt, ss.total_sketch_weight, EPS);
+ }
+
+ void deserialize_sampling_from_java() {
+ std::ifstream is;
+ is.exceptions(std::ios::failbit | std::ios::badbit);
+ is.open(testBinaryInputPath + "varopt_long_sampling.bin", std::ios::binary);
+ var_opt_sketch<int64_t> sketch = var_opt_sketch<int64_t>::deserialize(is);
+ CPPUNIT_ASSERT(!sketch.is_empty());
+ CPPUNIT_ASSERT_EQUAL((uint32_t) 1024, sketch.get_k());
+ CPPUNIT_ASSERT_EQUAL((uint64_t) 2003, sketch.get_n());
+ CPPUNIT_ASSERT_EQUAL(sketch.get_k(), sketch.get_num_samples());
+ subset_summary ss = sketch.estimate_subset_sum([](int64_t x){ return true; });
+ CPPUNIT_ASSERT_DOUBLES_EQUAL(332000.0, ss.estimate, EPS);
+ CPPUNIT_ASSERT_DOUBLES_EQUAL(332000.0, ss.total_sketch_weight, EPS);
+
+ ss = sketch.estimate_subset_sum([](int64_t x){ return x < 0; });
+ CPPUNIT_ASSERT_DOUBLES_EQUAL(330000.0, ss.estimate, 0.0);
+
+ ss = sketch.estimate_subset_sum([](int64_t x){ return x >= 0; });
+ CPPUNIT_ASSERT_DOUBLES_EQUAL(2000.0, ss.estimate, EPS);
+ }
+
+};
+
+CPPUNIT_TEST_SUITE_REGISTRATION(var_opt_union_test);
+
+} /* namespace datasketches */
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@datasketches.apache.org
For additional commands, e-mail: commits-help@datasketches.apache.org