You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/08/10 05:34:40 UTC

[tvm] branch main updated: [TVMScript] Printer Registry (#12237)

This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new fae79bbc3e [TVMScript] Printer Registry (#12237)
fae79bbc3e is described below

commit fae79bbc3e499f3b9f26c9a13743896f948b723d
Author: Lite Ye <ye...@gmail.com>
AuthorDate: Wed Aug 10 01:34:35 2022 -0400

    [TVMScript] Printer Registry (#12237)
    
    This PR:
    
    - Adds the registry of printing function (traced_object_layered_functor.cc)
    
    Compared to the prototype version, this:
    - Consolidates the implementation into a single class, since this class is only for the TVMScript printer.
    - Deduces the TObjectRef when calling set_dispatch.
    
    Tracking issue: https://github.com/apache/tvm/issues/11912
    
    Co-authored-by: Greg Bonik <gb...@octoml.ai>
---
 include/tvm/script/printer/traced_object.h         |   2 +
 include/tvm/script/printer/traced_object_functor.h | 183 +++++++++++++++++++++
 src/script/printer/traced_object_functor.cc        |  75 +++++++++
 ...tvmscript_printer_traced_object_functor_test.cc | 171 +++++++++++++++++++
 4 files changed, 431 insertions(+)

diff --git a/include/tvm/script/printer/traced_object.h b/include/tvm/script/printer/traced_object.h
index 6f04b66cec..4c09b0a41b 100644
--- a/include/tvm/script/printer/traced_object.h
+++ b/include/tvm/script/printer/traced_object.h
@@ -86,6 +86,8 @@ class TracedObject {
   using ObjectType = typename RefT::ContainerType;
 
  public:
+  using ObjectRefType = RefT;
+
   // Don't use this direcly. For convenience, call MakeTraced() instead.
   explicit TracedObject(const RefT& object_ref, ObjectPath path)
       : ref_(object_ref), path_(std::move(path)) {}
diff --git a/include/tvm/script/printer/traced_object_functor.h b/include/tvm/script/printer/traced_object_functor.h
new file mode 100644
index 0000000000..05fbbf79f2
--- /dev/null
+++ b/include/tvm/script/printer/traced_object_functor.h
@@ -0,0 +1,183 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_SCRIPT_PRINTER_TRACED_OBJECT_FUNCTOR_H_
+#define TVM_SCRIPT_PRINTER_TRACED_OBJECT_FUNCTOR_H_
+
+#include <tvm/node/node.h>
+#include <tvm/runtime/logging.h>
+#include <tvm/runtime/packed_func.h>
+#include <tvm/script/printer/traced_object.h>
+
+#include <string>
+#include <type_traits>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace script {
+namespace printer {
+
+namespace {
+
+namespace detail {
+/*!
+ * \brief Helper template class to extract the type of first argument of a function
+ * \tparam FType The function type.
+ */
+template <typename FType>
+struct FirstArgTypeGetter;
+
+template <typename R, typename ArgOne, typename... OtherArgs>
+struct FirstArgTypeGetter<R(ArgOne, OtherArgs...)> {
+  using T = ArgOne;
+};
+
+/*!
+ * \brief Template alias for the type of first argument of a function
+ * \tparam FType The function type.
+ *
+ * The name of public functions are in snake case to be consistent with
+ * tvm/node/functor.h
+ */
+template <typename FType>
+using FirstArgType = typename detail::FirstArgTypeGetter<
+    typename tvm::runtime::detail::function_signature<FType>::FType>::T;
+}  // namespace detail
+
+}  // namespace
+
+/*
+ * This type alias and the following free functions are created to reduce the binary bloat
+ * from template and also hide implementation details from this header
+ */
+using DispatchTable = std::unordered_map<std::string, std::vector<runtime::PackedFunc>>;
+
+/*!
+ * \brief Get function from dispatch table.
+ * \param dispatch_table The dispatch table.
+ * \param token The dispatch token.
+ * \param type_index The type index of the Object type to be dispatched.
+ *
+ * \return The dispatch function.
+ */
+const runtime::PackedFunc& GetDispatchFunction(const DispatchTable& dispatch_table,
+                                               const String& token, uint32_t type_index);
+
+/*!
+ * \brief Set function in dispatch table.
+ * \param dispatch_table The dispatch table.
+ * \param token The dispatch token.
+ * \param type_index The type index of the Object type to be dispatched.
+ * \param f The dispatch function.
+ */
+void SetDispatchFunction(DispatchTable* dispatch_table, const String& token, uint32_t type_index,
+                         runtime::PackedFunc f);
+
+constexpr const char* kDefaultDispatchToken = "";
+
+/*!
+ * \brief Dynamic dispatch functor based on TracedObject.
+ *
+ * This functor dispatches based on the type of object ref inside the input TracedObject,
+ * and the input dispatch token.
+ */
+template <typename R, typename... Args>
+class TracedObjectFunctor {
+ private:
+  using TSelf = TracedObjectFunctor<R, Args...>;
+
+  template <class TObjectRef, class TCallable>
+  using IsDispatchFunction =
+      typename std::is_convertible<TCallable, std::function<R(TracedObject<TObjectRef>, Args...)>>;
+
+ public:
+  /*!
+   * \brief Call the dispatch function.
+   * \param token The dispatch token.
+   * \param traced_object The traced object.
+   * \param args Other args.
+   *
+   * \return The return value of the dispatch function
+   *
+   * If the TObjectRef isn't registered with the token, it will try to find
+   * dispatch function for TObjectRef with kDefaultDispatchToken.
+   */
+  template <class TObjectRef>
+  R operator()(const String& token, TracedObject<TObjectRef> traced_object, Args... args) const {
+    const runtime::PackedFunc& dispatch_function =
+        GetDispatchFunction(dispatch_table_, token, traced_object.Get()->type_index());
+    return dispatch_function(traced_object.Get(), traced_object.GetPath(), args...);
+  }
+
+  /*!
+   * \brief Set the dispatch function
+   * \param token The dispatch token.
+   * \param type_index The TVM object type index for this dispatch function.
+   * \param f The dispatch function.
+   *
+   * This takes a type-erased packed function as input. It should be used
+   * through FFI boundary, for example, registering dispatch function from Python.
+   */
+  TSelf& set_dispatch(String token, uint32_t type_index, runtime::PackedFunc f) {
+    SetDispatchFunction(&dispatch_table_, token, type_index, std::move(f));
+    return *this;
+  }
+
+  /*!
+   * \brief Set the dispatch function
+   * \param token The dispatch token.
+   * \param f The dispatch function.
+   *
+   * The diaptch function should have signature `R(TracedObject<TObjectRef>, Args...)`.
+   */
+  template <typename TCallable,
+            typename TObjectRef = typename detail::FirstArgType<TCallable>::ObjectRefType,
+            typename = std::enable_if_t<IsDispatchFunction<TObjectRef, TCallable>::value>>
+  TSelf& set_dispatch(String token, TCallable f) {
+    return set_dispatch(
+        token,                                          //
+        TObjectRef::ContainerType::RuntimeTypeIndex(),  //
+        runtime::TypedPackedFunc<R(TObjectRef, ObjectPath, Args...)>(
+            [f = std::move(f)](TObjectRef object, ObjectPath path, Args... args) -> R {
+              return f(MakeTraced(object, path), args...);
+            }));
+  }
+  /*!
+   * \brief Set the default dispatch function
+   * \param f The dispatch function.
+   *
+   * Default dispatch function will be used if there is no function registered
+   * with the requested dispatch token.
+   *
+   * Default dispatch function has an empty string as dispatch token.
+   */
+  template <typename TCallable>
+  TSelf& set_dispatch(TCallable&& f) {
+    return set_dispatch(kDefaultDispatchToken, std::forward<TCallable>(f));
+  }
+
+ private:
+  DispatchTable dispatch_table_;
+};
+
+}  // namespace printer
+}  // namespace script
+}  // namespace tvm
+#endif  // TVM_SCRIPT_PRINTER_TRACED_OBJECT_FUNCTOR_H_
diff --git a/src/script/printer/traced_object_functor.cc b/src/script/printer/traced_object_functor.cc
new file mode 100644
index 0000000000..a018099a1d
--- /dev/null
+++ b/src/script/printer/traced_object_functor.cc
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <tvm/script/printer/traced_object_functor.h>
+
+namespace tvm {
+namespace script {
+namespace printer {
+
+const runtime::PackedFunc* GetDispatchFunctionForToken(const DispatchTable& table,
+                                                       const String& token, uint32_t type_index) {
+  auto it = table.find(token);
+  if (it == table.end()) {
+    return nullptr;
+  }
+  const std::vector<runtime::PackedFunc>& tab = it->second;
+  if (type_index >= tab.size()) {
+    return nullptr;
+  }
+  const PackedFunc* f = &tab[type_index];
+  if (f->defined()) {
+    return f;
+  } else {
+    return nullptr;
+  }
+}
+
+const runtime::PackedFunc& GetDispatchFunction(const DispatchTable& dispatch_table,
+                                               const String& token, uint32_t type_index) {
+  if (const runtime::PackedFunc* pf =
+          GetDispatchFunctionForToken(dispatch_table, token, type_index)) {
+    return *pf;
+  } else if (const runtime::PackedFunc* pf =
+                 GetDispatchFunctionForToken(dispatch_table, kDefaultDispatchToken, type_index)) {
+    // Fallback to function with the default dispatch token
+    return *pf;
+  } else {
+    ICHECK(false) << "ObjectFunctor calls un-registered function on type: "
+                  << runtime::Object::TypeIndex2Key(type_index) << " (token: " << token << ")";
+    throw;
+  }
+}
+
+void SetDispatchFunction(DispatchTable* dispatch_table, const String& token, uint32_t type_index,
+                         runtime::PackedFunc f) {
+  std::vector<runtime::PackedFunc>* table = &(*dispatch_table)[token];
+  if (table->size() <= type_index) {
+    table->resize(type_index + 1, nullptr);
+  }
+  runtime::PackedFunc& slot = (*table)[type_index];
+  if (slot != nullptr) {
+    ICHECK(false) << "Dispatch for type is already registered: "
+                  << runtime::Object::TypeIndex2Key(type_index);
+  }
+  slot = f;
+}
+}  // namespace printer
+}  // namespace script
+}  // namespace tvm
diff --git a/tests/cpp/tvmscript_printer_traced_object_functor_test.cc b/tests/cpp/tvmscript_printer_traced_object_functor_test.cc
new file mode 100644
index 0000000000..3fd52d44aa
--- /dev/null
+++ b/tests/cpp/tvmscript_printer_traced_object_functor_test.cc
@@ -0,0 +1,171 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <dmlc/logging.h>
+#include <gtest/gtest.h>
+#include <tvm/node/object_path.h>
+#include <tvm/runtime/packed_func.h>
+#include <tvm/script/printer/traced_object.h>
+#include <tvm/script/printer/traced_object_functor.h>
+
+using namespace tvm;
+using namespace tvm::script::printer;
+
+namespace {
+
+class FooObjectNode : public Object {
+ public:
+  void VisitAttrs(AttrVisitor* v) {}
+
+  static constexpr const char* _type_key = "test.FooObject";
+  TVM_DECLARE_FINAL_OBJECT_INFO(FooObjectNode, Object);
+};
+
+class FooObject : public ObjectRef {
+ public:
+  FooObject() { this->data_ = make_object<FooObjectNode>(); }
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(FooObject, ObjectRef, FooObjectNode);
+};
+
+TVM_REGISTER_NODE_TYPE(FooObjectNode);
+
+class BarObjectNode : public Object {
+ public:
+  void VisitAttrs(AttrVisitor* v) {}
+
+  static constexpr const char* _type_key = "test.BarObject";
+  TVM_DECLARE_FINAL_OBJECT_INFO(BarObjectNode, Object);
+};
+
+class BarObject : public ObjectRef {
+ public:
+  BarObject() { this->data_ = make_object<BarObjectNode>(); }
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(BarObject, ObjectRef, BarObjectNode);
+};
+
+TVM_REGISTER_NODE_TYPE(BarObjectNode);
+
+String ComputeFoo(TracedObject<FooObject> foo) { return "Foo"; }
+
+}  // anonymous namespace
+
+TEST(TracedObjectFunctorTest, NormalRegistration) {
+  TracedObjectFunctor<String> functor;
+  ObjectPath path = ObjectPath::Root();
+
+  functor.set_dispatch([](TracedObject<FooObject> o) -> String { return "Foo"; });
+  functor.set_dispatch([](TracedObject<BarObject> o) -> String { return "Bar"; });
+
+  ICHECK_EQ(functor("", MakeTraced(FooObject(), path)), "Foo");
+  ICHECK_EQ(functor("", MakeTraced(BarObject(), path)), "Bar");
+}
+
+TEST(TracedObjectFunctorTest, RegistrationWithFunction) {
+  TracedObjectFunctor<String> functor;
+  ObjectPath path = ObjectPath::Root();
+
+  functor.set_dispatch([](TracedObject<FooObject> o) -> String { return "FooLambda"; });
+  functor.set_dispatch("tir", ComputeFoo);
+
+  ICHECK_EQ(functor("", MakeTraced(FooObject(), path)), "FooLambda");
+  ICHECK_EQ(functor("tir", MakeTraced(FooObject(), path)), "Foo");
+}
+
+TEST(TracedObjectFunctorTest, RegistrationWithDispatchToken) {
+  TracedObjectFunctor<String> functor;
+  ObjectPath path = ObjectPath::Root();
+
+  functor.set_dispatch([](TracedObject<FooObject> o) -> String { return "Foo"; });
+  functor.set_dispatch("tir", [](TracedObject<FooObject> o) -> String { return "Foo tir"; });
+  functor.set_dispatch("relax", [](TracedObject<FooObject> o) -> String { return "Foo relax"; });
+
+  ICHECK_EQ(functor("", MakeTraced(FooObject(), path)), "Foo");
+  ICHECK_EQ(functor("tir", MakeTraced(FooObject(), path)), "Foo tir");
+  ICHECK_EQ(functor("relax", MakeTraced(FooObject(), path)), "Foo relax");
+  ICHECK_EQ(functor("xyz", MakeTraced(FooObject(), path)), "Foo");
+}
+
+TEST(TracedObjectFunctorTest, RegistrationWithPackedFunc) {
+  TracedObjectFunctor<String> functor;
+  ObjectPath path = ObjectPath::Root();
+
+  auto f_default = [](runtime::TVMArgs, runtime::TVMRetValue* ret) { *ret = String("default"); };
+  auto f_tir = [](runtime::TVMArgs, runtime::TVMRetValue* ret) { *ret = String("tir"); };
+
+  functor.set_dispatch("", FooObjectNode::RuntimeTypeIndex(), runtime::PackedFunc(f_default));
+  functor.set_dispatch("tir", FooObjectNode::RuntimeTypeIndex(), runtime::PackedFunc(f_tir));
+
+  ICHECK_EQ(functor("", MakeTraced(FooObject(), path)), "default");
+  ICHECK_EQ(functor("tir", MakeTraced(FooObject(), path)), "tir");
+}
+
+TEST(TracedObjectFunctorTest, ExtraArg) {
+  TracedObjectFunctor<int, int> functor;
+  ObjectPath path = ObjectPath::Root();
+
+  functor.set_dispatch([](TracedObject<FooObject> o, int x) { return x; });
+  functor.set_dispatch([](TracedObject<BarObject> o, int x) { return x + 1; });
+
+  ICHECK_EQ(functor("", MakeTraced(FooObject(), path), 2), 2);
+  ICHECK_EQ(functor("", MakeTraced(BarObject(), path), 2), 3);
+  ICHECK_EQ(functor("tir", MakeTraced(BarObject(), path), 2), 3);
+}
+
+TEST(TracedObjectFunctorTest, CallWithUnregisteredType) {
+  TracedObjectFunctor<int, int> functor;
+  ObjectPath path = ObjectPath::Root();
+
+  bool failed = false;
+  try {
+    ICHECK_EQ(functor("", MakeTraced(FooObject(), path), 2), 2);
+  } catch (...) {
+    failed = true;
+  }
+  ASSERT_EQ(failed, true);
+}
+
+TEST(TracedObjectFunctorTest, DuplicateRegistration_WithoutToken) {
+  TracedObjectFunctor<int, int> functor;
+  ObjectPath path = ObjectPath::Root();
+
+  functor.set_dispatch([](TracedObject<FooObject> o, int x) { return x; });
+
+  bool failed = false;
+  try {
+    functor.set_dispatch([](TracedObject<FooObject> o, int x) { return x; });
+  } catch (...) {
+    failed = true;
+  }
+  ASSERT_EQ(failed, true);
+}
+
+TEST(TracedObjectFunctorTest, DuplicateRegistration_WithToken) {
+  TracedObjectFunctor<int, int> functor;
+  ObjectPath path = ObjectPath::Root();
+
+  functor.set_dispatch("tir", [](TracedObject<FooObject> o, int x) { return x; });
+
+  bool failed = false;
+  try {
+    functor.set_dispatch("tir", [](TracedObject<FooObject> o, int x) { return x; });
+  } catch (...) {
+    failed = true;
+  }
+  ASSERT_EQ(failed, true);
+}