You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@datasketches.apache.org by jm...@apache.org on 2023/02/10 08:26:31 UTC

[datasketches-cpp] 02/02: allow creation of frequent_items_sketch using py::object in python

This is an automated email from the ASF dual-hosted git repository.

jmalkin pushed a commit to branch fi_equals
in repository https://gitbox.apache.org/repos/asf/datasketches-cpp.git

commit 71f22835bd80f687c34915eb5bf7a09216b80539
Author: Jon Malkin <78...@users.noreply.github.com>
AuthorDate: Fri Feb 10 00:26:13 2023 -0800

    allow creation of frequent_items_sketch using py::object in python
---
 python/src/fi_wrapper.cpp | 125 ++++++++++++++++++++++++++++++++++++----------
 1 file changed, 99 insertions(+), 26 deletions(-)

diff --git a/python/src/fi_wrapper.cpp b/python/src/fi_wrapper.cpp
index bdb49a4..2abeb84 100644
--- a/python/src/fi_wrapper.cpp
+++ b/python/src/fi_wrapper.cpp
@@ -17,45 +17,66 @@
  * under the License.
  */
 
-#include <pybind11/pybind11.h>
 
+#include "py_serde.hpp"
 #include "frequent_items_sketch.hpp"
 
+#include <pybind11/pybind11.h>
+
+#include <ostream>
+
 namespace py = pybind11;
 
-template<typename T>
+namespace pybind11 {
+static std::ostream& operator<<(std::ostream& os, const py::object& obj) {
+  os << std::string(py::str(obj));
+  return os;
+}
+}
+
+// forward declarations
+// std::string and arithmetic types, where we don't need a separate serde
+template<typename T, typename W, typename H, typename E, typename std::enable_if<std::is_arithmetic<T>::value || std::is_same<std::string, T>::value, bool>::type = 0>
+void add_serialization(py::class_<datasketches::frequent_items_sketch<T, W, H, E>>& clazz);
+
+// py::object and other types where the caller must provide a serde
+template<typename T, typename W, typename H, typename E, typename std::enable_if<!std::is_arithmetic<T>::value && !std::is_same<std::string, T>::value, bool>::type = 0>
+void add_serialization(py::class_<datasketches::frequent_items_sketch<T, W, H, E>>& clazz);
+
+
+template<typename T, typename W, typename H, typename E>
 void bind_fi_sketch(py::module &m, const char* name) {
   using namespace datasketches;
 
-  py::class_<frequent_items_sketch<T>>(m, name)
-    .def(py::init<uint8_t>(), py::arg("lg_max_k"))
-    .def("__str__", &frequent_items_sketch<T>::to_string, py::arg("print_items")=false,
+  auto fi_class = py::class_<frequent_items_sketch<T, W, H, E>>(m, name)
+    .def("__str__", &frequent_items_sketch<T, W, H, E>::to_string, py::arg("print_items")=false,
          "Produces a string summary of the sketch")
-    .def("to_string", &frequent_items_sketch<T>::to_string, py::arg("print_items")=false,
+    .def("to_string", &frequent_items_sketch<T, W, H, E>::to_string, py::arg("print_items")=false,
          "Produces a string summary of the sketch")
-    .def("update", (void (frequent_items_sketch<T>::*)(const T&, uint64_t)) &frequent_items_sketch<T>::update, py::arg("item"), py::arg("weight")=1,
+        .def(py::init<uint8_t>(), py::arg("lg_max_k"))
+    .def("update", (void (frequent_items_sketch<T, W, H, E>::*)(const T&, uint64_t)) &frequent_items_sketch<T, W, H, E>::update, py::arg("item"), py::arg("weight")=1,
          "Updates the sketch with the given string and, optionally, a weight")
-    .def("merge", (void (frequent_items_sketch<T>::*)(const frequent_items_sketch<T>&)) &frequent_items_sketch<T>::merge,
+    .def("merge", (void (frequent_items_sketch<T, W, H, E>::*)(const frequent_items_sketch<T, W, H, E>&)) &frequent_items_sketch<T, W, H, E>::merge,
          "Merges the given sketch into this one")
-    .def("is_empty", &frequent_items_sketch<T>::is_empty,
+    .def("is_empty", &frequent_items_sketch<T, W, H, E>::is_empty,
          "Returns True if the sketch is empty, otherwise False")
-    .def("get_num_active_items", &frequent_items_sketch<T>::get_num_active_items,
+    .def("get_num_active_items", &frequent_items_sketch<T, W, H, E>::get_num_active_items,
          "Returns the number of active items in the sketch")
-    .def("get_total_weight", &frequent_items_sketch<T>::get_total_weight,
+    .def("get_total_weight", &frequent_items_sketch<T, W, H, E>::get_total_weight,
          "Returns the sum of the weights (frequencies) in the stream seen so far by the sketch")
-    .def("get_estimate", &frequent_items_sketch<T>::get_estimate, py::arg("item"),
+    .def("get_estimate", &frequent_items_sketch<T, W, H, E>::get_estimate, py::arg("item"),
          "Returns the estimate of the weight (frequency) of the given item.\n"
          "Note: The true frequency of a item would be the sum of the counts as a result of the "
          "two update functions.")
-    .def("get_lower_bound", &frequent_items_sketch<T>::get_lower_bound, py::arg("item"),
+    .def("get_lower_bound", &frequent_items_sketch<T, W, H, E>::get_lower_bound, py::arg("item"),
          "Returns the guaranteed lower bound weight (frequency) of the given item.")
-    .def("get_upper_bound", &frequent_items_sketch<T>::get_upper_bound, py::arg("item"),
+    .def("get_upper_bound", &frequent_items_sketch<T, W, H, E>::get_upper_bound, py::arg("item"),
          "Returns the guaranteed upper bound weight (frequency) of the given item.")
-    .def("get_sketch_epsilon", (double (frequent_items_sketch<T>::*)(void) const) &frequent_items_sketch<T>::get_epsilon,
+    .def("get_sketch_epsilon", (double (frequent_items_sketch<T, W, H, E>::*)(void) const) &frequent_items_sketch<T, W, H, E>::get_epsilon,
          "Returns the epsilon value used by the sketch to compute error")
     .def(
         "get_frequent_items",
-        [](const frequent_items_sketch<T>& sk, frequent_items_error_type err_type, uint64_t threshold) {
+        [](const frequent_items_sketch<T, W, H, E>& sk, frequent_items_error_type err_type, uint64_t threshold) {
           if (threshold == 0) threshold = sk.get_maximum_error();
           py::list list;
           auto rows = sk.get_frequent_items(err_type, threshold);
@@ -73,37 +94,88 @@ void bind_fi_sketch(py::module &m, const char* name) {
     )
     .def_static(
         "get_epsilon_for_lg_size",
-        [](uint8_t lg_max_map_size) { return frequent_items_sketch<T>::get_epsilon(lg_max_map_size); },
+        [](uint8_t lg_max_map_size) { return frequent_items_sketch<T, W, H, E>::get_epsilon(lg_max_map_size); },
         py::arg("lg_max_map_size"),
         "Returns the epsilon value used to compute a priori error for a given log2(max_map_size)"
     )
     .def_static(
         "get_apriori_error",
-        &frequent_items_sketch<T>::get_apriori_error,
+        &frequent_items_sketch<T, W, H, E>::get_apriori_error,
         py::arg("lg_max_map_size"), py::arg("estimated_total_weight"),
         "Returns the estimated a priori error given the max_map_size for the sketch and the estimated_total_stream_weight."
-    )
-    .def(
+    );
+
+    // serialization may need a caller-provided serde depending on teh sketch type, so
+    // we use a separate method to handle that appropriately based on type T.
+    add_serialization(fi_class);
+}
+
+// std::string or arithmetic types, for which we have a built-in serde
+template<typename T, typename W, typename H, typename E, typename std::enable_if<std::is_arithmetic<T>::value || std::is_same<std::string, T>::value, bool>::type>
+void add_serialization(py::class_<datasketches::frequent_items_sketch<T, W, H, E>>& clazz) {
+    using namespace datasketches;
+    clazz.def(
         "get_serialized_size_bytes",
-        [](const frequent_items_sketch<T>& sk) { return sk.get_serialized_size_bytes(); },
+        [](const frequent_items_sketch<T, W, H, E>& sk) { return sk.get_serialized_size_bytes(); },
         "Computes the size needed to serialize the current state of the sketch. This can be expensive since every item needs to be looked at."
     )
     .def(
         "serialize",
-        [](const frequent_items_sketch<T>& sk) {
+        [](const frequent_items_sketch<T, W, H, E>& sk) {
           auto bytes = sk.serialize();
           return py::bytes(reinterpret_cast<const char*>(bytes.data()), bytes.size());
         },
-        "Serializes the sketch into a bytes object"
+        "Serializes the sketch into a bytes object."
     )
     .def_static(
         "deserialize",
-        [](const std::string& bytes) { return frequent_items_sketch<T>::deserialize(bytes.data(), bytes.size()); },
+        [](const std::string& bytes) { return frequent_items_sketch<T, W, H, E>::deserialize(bytes.data(), bytes.size()); },
         py::arg("bytes"),
-        "Reads a bytes object and returns the corresponding frequent_strings_sketch"
+        "Reads a bytes object and returns the corresponding frequent_strings_sketch."
+    );
+}
+
+// py::object or any other type that requires a provided serde
+template<typename T, typename W, typename H, typename E, typename std::enable_if<!std::is_arithmetic<T>::value && !std::is_same<std::string, T>::value, bool>::type>
+void add_serialization(py::class_<datasketches::frequent_items_sketch<T, W, H, E>>& clazz) {
+    using namespace datasketches;
+    clazz.def(
+        "get_serialized_size_bytes",
+        [](const frequent_items_sketch<T, W, H, E>& sk, py_object_serde& serde) { return sk.get_serialized_size_bytes(serde); },
+        py::arg("serde"),
+        "Computes the size needed to serialize the current state of the sketch using the provided serde. This can be expensive since every item needs to be looked at."
+    )
+    .def(
+        "serialize",
+        [](const frequent_items_sketch<T, W, H, E>& sk, py_object_serde& serde) {
+          auto bytes = sk.serialize(0, serde);
+          return py::bytes(reinterpret_cast<const char*>(bytes.data()), bytes.size());
+        }, py::arg("serde"),
+        "Serializes the sketch into a bytes object using the provided serde."
+    )
+    .def_static(
+        "deserialize",
+        [](const std::string& bytes, py_object_serde& serde) {
+          return frequent_items_sketch<T, W, H, E>::deserialize(bytes.data(), bytes.size(), serde);
+        }, py::arg("bytes"), py::arg("serde"),
+        "Reads a bytes object using the provided serde and returns the corresponding frequent_strings_sketch."
     );
 }
 
+// calls class __hash__ method
+struct py_hash_caller {
+  size_t operator()(const py::object& a) {
+    return py::hash(a);
+  }
+};
+
+// calls class __eq__ method
+struct py_equal_caller {
+  bool operator()(const py::object& a, const py::object& b) {
+    return a.equal(b);
+  }
+};
+
 void init_fi(py::module &m) {
   using namespace datasketches;
 
@@ -112,5 +184,6 @@ void init_fi(py::module &m) {
     .value("NO_FALSE_NEGATIVES", NO_FALSE_NEGATIVES)
     .export_values();
 
-  bind_fi_sketch<std::string>(m, "frequent_strings_sketch");
+  bind_fi_sketch<std::string, uint64_t, std::hash<std::string>, std::equal_to<std::string>>(m, "frequent_strings_sketch");
+  bind_fi_sketch<py::object, uint64_t, py_hash_caller, py_equal_caller>(m, "frequent_items_sketch"); 
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@datasketches.apache.org
For additional commands, e-mail: commits-help@datasketches.apache.org