You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/08/10 05:34:40 UTC
[tvm] branch main updated: [TVMScript] Printer Registry (#12237)
This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new fae79bbc3e [TVMScript] Printer Registry (#12237)
fae79bbc3e is described below
commit fae79bbc3e499f3b9f26c9a13743896f948b723d
Author: Lite Ye <ye...@gmail.com>
AuthorDate: Wed Aug 10 01:34:35 2022 -0400
[TVMScript] Printer Registry (#12237)
This PR:
- Adds the registry of printing function (traced_object_layered_functor.cc)
Compared to the prototype version, this:
- Consolidates the implementation into a single class, since this class is only for the TVMScript printer.
- Deduces the TObjectRef when calling set_dispatch.
Tracking issue: https://github.com/apache/tvm/issues/11912
Co-authored-by: Greg Bonik <gb...@octoml.ai>
---
include/tvm/script/printer/traced_object.h | 2 +
include/tvm/script/printer/traced_object_functor.h | 183 +++++++++++++++++++++
src/script/printer/traced_object_functor.cc | 75 +++++++++
...tvmscript_printer_traced_object_functor_test.cc | 171 +++++++++++++++++++
4 files changed, 431 insertions(+)
diff --git a/include/tvm/script/printer/traced_object.h b/include/tvm/script/printer/traced_object.h
index 6f04b66cec..4c09b0a41b 100644
--- a/include/tvm/script/printer/traced_object.h
+++ b/include/tvm/script/printer/traced_object.h
@@ -86,6 +86,8 @@ class TracedObject {
using ObjectType = typename RefT::ContainerType;
public:
+ using ObjectRefType = RefT;
+
// Don't use this direcly. For convenience, call MakeTraced() instead.
explicit TracedObject(const RefT& object_ref, ObjectPath path)
: ref_(object_ref), path_(std::move(path)) {}
diff --git a/include/tvm/script/printer/traced_object_functor.h b/include/tvm/script/printer/traced_object_functor.h
new file mode 100644
index 0000000000..05fbbf79f2
--- /dev/null
+++ b/include/tvm/script/printer/traced_object_functor.h
@@ -0,0 +1,183 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_SCRIPT_PRINTER_TRACED_OBJECT_FUNCTOR_H_
+#define TVM_SCRIPT_PRINTER_TRACED_OBJECT_FUNCTOR_H_
+
+#include <tvm/node/node.h>
+#include <tvm/runtime/logging.h>
+#include <tvm/runtime/packed_func.h>
+#include <tvm/script/printer/traced_object.h>
+
+#include <string>
+#include <type_traits>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace script {
+namespace printer {
+
+namespace {
+
+namespace detail {
+/*!
+ * \brief Helper template class to extract the type of first argument of a function
+ * \tparam FType The function type.
+ */
+template <typename FType>
+struct FirstArgTypeGetter;
+
+template <typename R, typename ArgOne, typename... OtherArgs>
+struct FirstArgTypeGetter<R(ArgOne, OtherArgs...)> {
+ using T = ArgOne;
+};
+
+/*!
+ * \brief Template alias for the type of first argument of a function
+ * \tparam FType The function type.
+ *
+ * The name of public functions are in snake case to be consistent with
+ * tvm/node/functor.h
+ */
+template <typename FType>
+using FirstArgType = typename detail::FirstArgTypeGetter<
+ typename tvm::runtime::detail::function_signature<FType>::FType>::T;
+} // namespace detail
+
+} // namespace
+
+/*
+ * This type alias and the following free functions are created to reduce the binary bloat
+ * from template and also hide implementation details from this header
+ */
+using DispatchTable = std::unordered_map<std::string, std::vector<runtime::PackedFunc>>;
+
+/*!
+ * \brief Get function from dispatch table.
+ * \param dispatch_table The dispatch table.
+ * \param token The dispatch token.
+ * \param type_index The type index of the Object type to be dispatched.
+ *
+ * \return The dispatch function.
+ */
+const runtime::PackedFunc& GetDispatchFunction(const DispatchTable& dispatch_table,
+ const String& token, uint32_t type_index);
+
+/*!
+ * \brief Set function in dispatch table.
+ * \param dispatch_table The dispatch table.
+ * \param token The dispatch token.
+ * \param type_index The type index of the Object type to be dispatched.
+ * \param f The dispatch function.
+ */
+void SetDispatchFunction(DispatchTable* dispatch_table, const String& token, uint32_t type_index,
+ runtime::PackedFunc f);
+
+constexpr const char* kDefaultDispatchToken = "";
+
+/*!
+ * \brief Dynamic dispatch functor based on TracedObject.
+ *
+ * This functor dispatches based on the type of object ref inside the input TracedObject,
+ * and the input dispatch token.
+ */
+template <typename R, typename... Args>
+class TracedObjectFunctor {
+ private:
+ using TSelf = TracedObjectFunctor<R, Args...>;
+
+ template <class TObjectRef, class TCallable>
+ using IsDispatchFunction =
+ typename std::is_convertible<TCallable, std::function<R(TracedObject<TObjectRef>, Args...)>>;
+
+ public:
+ /*!
+ * \brief Call the dispatch function.
+ * \param token The dispatch token.
+ * \param traced_object The traced object.
+ * \param args Other args.
+ *
+ * \return The return value of the dispatch function
+ *
+ * If the TObjectRef isn't registered with the token, it will try to find
+ * dispatch function for TObjectRef with kDefaultDispatchToken.
+ */
+ template <class TObjectRef>
+ R operator()(const String& token, TracedObject<TObjectRef> traced_object, Args... args) const {
+ const runtime::PackedFunc& dispatch_function =
+ GetDispatchFunction(dispatch_table_, token, traced_object.Get()->type_index());
+ return dispatch_function(traced_object.Get(), traced_object.GetPath(), args...);
+ }
+
+ /*!
+ * \brief Set the dispatch function
+ * \param token The dispatch token.
+ * \param type_index The TVM object type index for this dispatch function.
+ * \param f The dispatch function.
+ *
+ * This takes a type-erased packed function as input. It should be used
+ * through FFI boundary, for example, registering dispatch function from Python.
+ */
+ TSelf& set_dispatch(String token, uint32_t type_index, runtime::PackedFunc f) {
+ SetDispatchFunction(&dispatch_table_, token, type_index, std::move(f));
+ return *this;
+ }
+
+ /*!
+ * \brief Set the dispatch function
+ * \param token The dispatch token.
+ * \param f The dispatch function.
+ *
+ * The diaptch function should have signature `R(TracedObject<TObjectRef>, Args...)`.
+ */
+ template <typename TCallable,
+ typename TObjectRef = typename detail::FirstArgType<TCallable>::ObjectRefType,
+ typename = std::enable_if_t<IsDispatchFunction<TObjectRef, TCallable>::value>>
+ TSelf& set_dispatch(String token, TCallable f) {
+ return set_dispatch(
+ token, //
+ TObjectRef::ContainerType::RuntimeTypeIndex(), //
+ runtime::TypedPackedFunc<R(TObjectRef, ObjectPath, Args...)>(
+ [f = std::move(f)](TObjectRef object, ObjectPath path, Args... args) -> R {
+ return f(MakeTraced(object, path), args...);
+ }));
+ }
+ /*!
+ * \brief Set the default dispatch function
+ * \param f The dispatch function.
+ *
+ * Default dispatch function will be used if there is no function registered
+ * with the requested dispatch token.
+ *
+ * Default dispatch function has an empty string as dispatch token.
+ */
+ template <typename TCallable>
+ TSelf& set_dispatch(TCallable&& f) {
+ return set_dispatch(kDefaultDispatchToken, std::forward<TCallable>(f));
+ }
+
+ private:
+ DispatchTable dispatch_table_;
+};
+
+} // namespace printer
+} // namespace script
+} // namespace tvm
+#endif // TVM_SCRIPT_PRINTER_TRACED_OBJECT_FUNCTOR_H_
diff --git a/src/script/printer/traced_object_functor.cc b/src/script/printer/traced_object_functor.cc
new file mode 100644
index 0000000000..a018099a1d
--- /dev/null
+++ b/src/script/printer/traced_object_functor.cc
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <tvm/script/printer/traced_object_functor.h>
+
+namespace tvm {
+namespace script {
+namespace printer {
+
+const runtime::PackedFunc* GetDispatchFunctionForToken(const DispatchTable& table,
+ const String& token, uint32_t type_index) {
+ auto it = table.find(token);
+ if (it == table.end()) {
+ return nullptr;
+ }
+ const std::vector<runtime::PackedFunc>& tab = it->second;
+ if (type_index >= tab.size()) {
+ return nullptr;
+ }
+ const PackedFunc* f = &tab[type_index];
+ if (f->defined()) {
+ return f;
+ } else {
+ return nullptr;
+ }
+}
+
+const runtime::PackedFunc& GetDispatchFunction(const DispatchTable& dispatch_table,
+ const String& token, uint32_t type_index) {
+ if (const runtime::PackedFunc* pf =
+ GetDispatchFunctionForToken(dispatch_table, token, type_index)) {
+ return *pf;
+ } else if (const runtime::PackedFunc* pf =
+ GetDispatchFunctionForToken(dispatch_table, kDefaultDispatchToken, type_index)) {
+ // Fallback to function with the default dispatch token
+ return *pf;
+ } else {
+ ICHECK(false) << "ObjectFunctor calls un-registered function on type: "
+ << runtime::Object::TypeIndex2Key(type_index) << " (token: " << token << ")";
+ throw;
+ }
+}
+
+void SetDispatchFunction(DispatchTable* dispatch_table, const String& token, uint32_t type_index,
+ runtime::PackedFunc f) {
+ std::vector<runtime::PackedFunc>* table = &(*dispatch_table)[token];
+ if (table->size() <= type_index) {
+ table->resize(type_index + 1, nullptr);
+ }
+ runtime::PackedFunc& slot = (*table)[type_index];
+ if (slot != nullptr) {
+ ICHECK(false) << "Dispatch for type is already registered: "
+ << runtime::Object::TypeIndex2Key(type_index);
+ }
+ slot = f;
+}
+} // namespace printer
+} // namespace script
+} // namespace tvm
diff --git a/tests/cpp/tvmscript_printer_traced_object_functor_test.cc b/tests/cpp/tvmscript_printer_traced_object_functor_test.cc
new file mode 100644
index 0000000000..3fd52d44aa
--- /dev/null
+++ b/tests/cpp/tvmscript_printer_traced_object_functor_test.cc
@@ -0,0 +1,171 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <dmlc/logging.h>
+#include <gtest/gtest.h>
+#include <tvm/node/object_path.h>
+#include <tvm/runtime/packed_func.h>
+#include <tvm/script/printer/traced_object.h>
+#include <tvm/script/printer/traced_object_functor.h>
+
+using namespace tvm;
+using namespace tvm::script::printer;
+
+namespace {
+
+class FooObjectNode : public Object {
+ public:
+ void VisitAttrs(AttrVisitor* v) {}
+
+ static constexpr const char* _type_key = "test.FooObject";
+ TVM_DECLARE_FINAL_OBJECT_INFO(FooObjectNode, Object);
+};
+
+class FooObject : public ObjectRef {
+ public:
+ FooObject() { this->data_ = make_object<FooObjectNode>(); }
+ TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(FooObject, ObjectRef, FooObjectNode);
+};
+
+TVM_REGISTER_NODE_TYPE(FooObjectNode);
+
+class BarObjectNode : public Object {
+ public:
+ void VisitAttrs(AttrVisitor* v) {}
+
+ static constexpr const char* _type_key = "test.BarObject";
+ TVM_DECLARE_FINAL_OBJECT_INFO(BarObjectNode, Object);
+};
+
+class BarObject : public ObjectRef {
+ public:
+ BarObject() { this->data_ = make_object<BarObjectNode>(); }
+ TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(BarObject, ObjectRef, BarObjectNode);
+};
+
+TVM_REGISTER_NODE_TYPE(BarObjectNode);
+
+String ComputeFoo(TracedObject<FooObject> foo) { return "Foo"; }
+
+} // anonymous namespace
+
+TEST(TracedObjectFunctorTest, NormalRegistration) {
+ TracedObjectFunctor<String> functor;
+ ObjectPath path = ObjectPath::Root();
+
+ functor.set_dispatch([](TracedObject<FooObject> o) -> String { return "Foo"; });
+ functor.set_dispatch([](TracedObject<BarObject> o) -> String { return "Bar"; });
+
+ ICHECK_EQ(functor("", MakeTraced(FooObject(), path)), "Foo");
+ ICHECK_EQ(functor("", MakeTraced(BarObject(), path)), "Bar");
+}
+
+TEST(TracedObjectFunctorTest, RegistrationWithFunction) {
+ TracedObjectFunctor<String> functor;
+ ObjectPath path = ObjectPath::Root();
+
+ functor.set_dispatch([](TracedObject<FooObject> o) -> String { return "FooLambda"; });
+ functor.set_dispatch("tir", ComputeFoo);
+
+ ICHECK_EQ(functor("", MakeTraced(FooObject(), path)), "FooLambda");
+ ICHECK_EQ(functor("tir", MakeTraced(FooObject(), path)), "Foo");
+}
+
+TEST(TracedObjectFunctorTest, RegistrationWithDispatchToken) {
+ TracedObjectFunctor<String> functor;
+ ObjectPath path = ObjectPath::Root();
+
+ functor.set_dispatch([](TracedObject<FooObject> o) -> String { return "Foo"; });
+ functor.set_dispatch("tir", [](TracedObject<FooObject> o) -> String { return "Foo tir"; });
+ functor.set_dispatch("relax", [](TracedObject<FooObject> o) -> String { return "Foo relax"; });
+
+ ICHECK_EQ(functor("", MakeTraced(FooObject(), path)), "Foo");
+ ICHECK_EQ(functor("tir", MakeTraced(FooObject(), path)), "Foo tir");
+ ICHECK_EQ(functor("relax", MakeTraced(FooObject(), path)), "Foo relax");
+ ICHECK_EQ(functor("xyz", MakeTraced(FooObject(), path)), "Foo");
+}
+
+TEST(TracedObjectFunctorTest, RegistrationWithPackedFunc) {
+ TracedObjectFunctor<String> functor;
+ ObjectPath path = ObjectPath::Root();
+
+ auto f_default = [](runtime::TVMArgs, runtime::TVMRetValue* ret) { *ret = String("default"); };
+ auto f_tir = [](runtime::TVMArgs, runtime::TVMRetValue* ret) { *ret = String("tir"); };
+
+ functor.set_dispatch("", FooObjectNode::RuntimeTypeIndex(), runtime::PackedFunc(f_default));
+ functor.set_dispatch("tir", FooObjectNode::RuntimeTypeIndex(), runtime::PackedFunc(f_tir));
+
+ ICHECK_EQ(functor("", MakeTraced(FooObject(), path)), "default");
+ ICHECK_EQ(functor("tir", MakeTraced(FooObject(), path)), "tir");
+}
+
+TEST(TracedObjectFunctorTest, ExtraArg) {
+ TracedObjectFunctor<int, int> functor;
+ ObjectPath path = ObjectPath::Root();
+
+ functor.set_dispatch([](TracedObject<FooObject> o, int x) { return x; });
+ functor.set_dispatch([](TracedObject<BarObject> o, int x) { return x + 1; });
+
+ ICHECK_EQ(functor("", MakeTraced(FooObject(), path), 2), 2);
+ ICHECK_EQ(functor("", MakeTraced(BarObject(), path), 2), 3);
+ ICHECK_EQ(functor("tir", MakeTraced(BarObject(), path), 2), 3);
+}
+
+TEST(TracedObjectFunctorTest, CallWithUnregisteredType) {
+ TracedObjectFunctor<int, int> functor;
+ ObjectPath path = ObjectPath::Root();
+
+ bool failed = false;
+ try {
+ ICHECK_EQ(functor("", MakeTraced(FooObject(), path), 2), 2);
+ } catch (...) {
+ failed = true;
+ }
+ ASSERT_EQ(failed, true);
+}
+
+TEST(TracedObjectFunctorTest, DuplicateRegistration_WithoutToken) {
+ TracedObjectFunctor<int, int> functor;
+ ObjectPath path = ObjectPath::Root();
+
+ functor.set_dispatch([](TracedObject<FooObject> o, int x) { return x; });
+
+ bool failed = false;
+ try {
+ functor.set_dispatch([](TracedObject<FooObject> o, int x) { return x; });
+ } catch (...) {
+ failed = true;
+ }
+ ASSERT_EQ(failed, true);
+}
+
+TEST(TracedObjectFunctorTest, DuplicateRegistration_WithToken) {
+ TracedObjectFunctor<int, int> functor;
+ ObjectPath path = ObjectPath::Root();
+
+ functor.set_dispatch("tir", [](TracedObject<FooObject> o, int x) { return x; });
+
+ bool failed = false;
+ try {
+ functor.set_dispatch("tir", [](TracedObject<FooObject> o, int x) { return x; });
+ } catch (...) {
+ failed = true;
+ }
+ ASSERT_EQ(failed, true);
+}