You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/09/17 20:54:09 UTC
[tvm] branch main updated: [TVMScript] Add more helper functions to the printer infra (#12829)
This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 1ecf084eec [TVMScript] Add more helper functions to the printer infra (#12829)
1ecf084eec is described below
commit 1ecf084eecaff167967df1a8c998de72e1198c24
Author: Lite Ye <ye...@gmail.com>
AuthorDate: Sat Sep 17 16:54:01 2022 -0400
[TVMScript] Add more helper functions to the printer infra (#12829)
This PR is split from https://github.com/apache/tvm/pull/12492, to make the necessary updates to the printer infra for future PRs of TIR printer.
Tracking issue: https://github.com/apache/tvm/issues/11912
Co-authored-by: Greg Bonik <gb...@octoml.ai>
---
include/tvm/script/printer/doc.h | 64 +++++++++++++++
include/tvm/script/printer/traced_object_functor.h | 37 +--------
include/tvm/script/printer/var_table.h | 11 +++
src/script/printer/doc.cc | 30 +++++--
src/script/printer/ir_docsifier.cc | 2 +-
src/script/printer/utils.h | 93 ++++++++++++++++++++++
src/script/printer/var_table.cc | 3 +-
tests/cpp/tvmscript_printer_irdocsifier_test.cc | 13 ++-
...tvmscript_printer_traced_object_functor_test.cc | 37 +++++----
9 files changed, 228 insertions(+), 62 deletions(-)
diff --git a/include/tvm/script/printer/doc.h b/include/tvm/script/printer/doc.h
index 72f343354b..1ee7fd6a7f 100644
--- a/include/tvm/script/printer/doc.h
+++ b/include/tvm/script/printer/doc.h
@@ -22,6 +22,7 @@
#include <tvm/ir/expr.h>
#include <tvm/node/node.h>
#include <tvm/runtime/data_type.h>
+#include <tvm/script/printer/traced_object.h>
namespace tvm {
namespace script {
@@ -87,6 +88,15 @@ class ExprDocNode : public DocNode {
*/
ExprDoc Attr(String attr) const;
+ /*!
+ * \brief Create a doc representing attribute access on the current ExprDoc
+ * \param attr The attribute to access.
+ *
+ * The ObjectPath of attr will be pushed to the source_path of the returned
+ * doc.
+ */
+ ExprDoc Attr(TracedObject<String> attr) const;
+
/*!
* \brief Create a doc representing index access on the current ExprDoc
* \param indices The indices to access.
@@ -242,6 +252,7 @@ class LiteralDocNode : public ExprDocNode {
class LiteralDoc : public ExprDoc {
protected:
explicit LiteralDoc(ObjectRef value);
+ LiteralDoc(ObjectRef value, ObjectPath object_path);
public:
/*!
@@ -249,30 +260,83 @@ class LiteralDoc : public ExprDoc {
*/
static LiteralDoc None() { return LiteralDoc(ObjectRef(nullptr)); }
+ /*!
+ * \brief Create a LiteralDoc to represent None/null/empty value.
+ * \param object_path The source path of the returned Doc.
+ */
+ static LiteralDoc None(ObjectPath object_path) {
+ return LiteralDoc(ObjectRef(nullptr), object_path);
+ }
+
/*!
* \brief Create a LiteralDoc to represent integer.
* \param v The integer value.
*/
static LiteralDoc Int(int v) { return LiteralDoc(IntImm(DataType::Int(64), v)); }
+ /*!
+ * \brief Create a LiteralDoc to represent integer.
+ * \param v The integer value.
+ *
+ * The ObjectPath of v will be pushed to the source_path of the returned doc.
+ */
+ static LiteralDoc Int(const TracedObject<IntImm>& v) { return LiteralDoc(v.Get(), v.GetPath()); }
+
+ /*!
+ * \brief Create a LiteralDoc to represent integer.
+ * \param v The integer value.
+ *
+ * The ObjectPath of v will be pushed to the source_path of the returned doc.
+ */
+ static LiteralDoc Int(const TracedBasicValue<int>& v) {
+ return LiteralDoc(IntImm(DataType::Int(64), v.Get()), v.GetPath());
+ }
/*!
* \brief Create a LiteralDoc to represent boolean.
* \param v The boolean value.
*/
static LiteralDoc Boolean(bool v) { return LiteralDoc(IntImm(DataType::Bool(), v)); }
+ /*!
+ * \brief Create a LiteralDoc to represent boolean.
+ * \param v The boolean value.
+ *
+ * The ObjectPath of v will be pushed to the source_path of the returned doc.
+ */
+ static LiteralDoc Boolean(const TracedBasicValue<bool>& v) {
+ return LiteralDoc(IntImm(DataType::Bool(), v.Get()), v.GetPath());
+ }
+
/*!
* \brief Create a LiteralDoc to represent float.
* \param v The float value.
*/
static LiteralDoc Float(double v) { return LiteralDoc(FloatImm(DataType::Float(64), v)); }
+ /*!
+ * \brief Create a LiteralDoc to represent float.
+ * \param v The float value.
+ *
+ * The ObjectPath of v will be pushed to the source_path of the returned doc.
+ */
+ static LiteralDoc Float(const TracedObject<FloatImm>& v) {
+ return LiteralDoc(v.Get(), v.GetPath());
+ }
+
/*!
* \brief Create a LiteralDoc to represent string.
* \param v The string value.
*/
static LiteralDoc Str(const String& v) { return LiteralDoc(v); }
+ /*!
+ * \brief Create a LiteralDoc to represent string.
+ * \param v The string value.
+ *
+ * The ObjectPath of v will be pushed to the source_path of the returned doc.
+ */
+ static LiteralDoc Str(const TracedObject<String>& v) { return LiteralDoc(v.Get(), v.GetPath()); }
+
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(LiteralDoc, ExprDoc, LiteralDocNode);
};
diff --git a/include/tvm/script/printer/traced_object_functor.h b/include/tvm/script/printer/traced_object_functor.h
index 6caaf8a6e0..8f72d139a5 100644
--- a/include/tvm/script/printer/traced_object_functor.h
+++ b/include/tvm/script/printer/traced_object_functor.h
@@ -34,35 +34,6 @@ namespace tvm {
namespace script {
namespace printer {
-namespace {
-
-namespace detail {
-/*!
- * \brief Helper template class to extract the type of first argument of a function
- * \tparam FType The function type.
- */
-template <typename FType>
-struct FirstArgTypeGetter;
-
-template <typename R, typename ArgOne, typename... OtherArgs>
-struct FirstArgTypeGetter<R(ArgOne, OtherArgs...)> {
- using T = ArgOne;
-};
-
-/*!
- * \brief Template alias for the type of first argument of a function
- * \tparam FType The function type.
- *
- * The name of public functions are in snake case to be consistent with
- * tvm/node/functor.h
- */
-template <typename FType>
-using FirstArgType = typename detail::FirstArgTypeGetter<
- typename tvm::runtime::detail::function_signature<FType>::FType>::T;
-} // namespace detail
-
-} // namespace
-
/*
* This type alias and the following free functions are created to reduce the binary bloat
* from template and also hide implementation details from this header
@@ -156,8 +127,7 @@ class TracedObjectFunctor {
*
* The diaptch function should have signature `R(TracedObject<TObjectRef>, Args...)`.
*/
- template <typename TCallable,
- typename TObjectRef = typename detail::FirstArgType<TCallable>::ObjectRefType,
+ template <typename TObjectRef, typename TCallable,
typename = std::enable_if_t<IsDispatchFunction<TObjectRef, TCallable>::value>>
TSelf& set_dispatch(String token, TCallable f) {
return set_dispatch(
@@ -177,9 +147,10 @@ class TracedObjectFunctor {
*
* Default dispatch function has an empty string as dispatch token.
*/
- template <typename TCallable>
+ template <typename TObjectRef, typename TCallable,
+ typename = std::enable_if_t<IsDispatchFunction<TObjectRef, TCallable>::value>>
TSelf& set_dispatch(TCallable&& f) {
- return set_dispatch(kDefaultDispatchToken, std::forward<TCallable>(f));
+ return set_dispatch<TObjectRef>(kDefaultDispatchToken, std::forward<TCallable>(f));
}
/*!
diff --git a/include/tvm/script/printer/var_table.h b/include/tvm/script/printer/var_table.h
index 9300a976c5..2cd9335213 100644
--- a/include/tvm/script/printer/var_table.h
+++ b/include/tvm/script/printer/var_table.h
@@ -103,6 +103,17 @@ class VarTableNode : public Object {
*/
Optional<ExprDoc> GetVarDoc(const ObjectRef& obj, const ObjectPath& object_path) const;
+ /*!
+ * \brief Get the doc for variable.
+ * \param obj The traced variable object.
+ *
+ * \return The doc for variable, if it exists in the table. Otherwise it returns NullOpt.
+ */
+ template <typename TObjectRef>
+ Optional<ExprDoc> GetVarDoc(const TracedObject<TObjectRef> obj) const {
+ return GetVarDoc(obj.Get(), obj.GetPath());
+ }
+
/*!
* \brief Check if a variable exists in the table.
* \param obj The variable object.
diff --git a/src/script/printer/doc.cc b/src/script/printer/doc.cc
index d6f5ff35ab..f3b431bd62 100644
--- a/src/script/printer/doc.cc
+++ b/src/script/printer/doc.cc
@@ -27,6 +27,12 @@ namespace printer {
ExprDoc ExprDocNode::Attr(String attr) const { return AttrAccessDoc(GetRef<ExprDoc>(this), attr); }
+ExprDoc ExprDocNode::Attr(TracedObject<String> attr) const {
+ auto doc = AttrAccessDoc(GetRef<ExprDoc>(this), attr.Get());
+ doc->source_paths.push_back(attr.GetPath());
+ return doc;
+}
+
ExprDoc ExprDocNode::operator[](Array<Doc> indices) const {
return IndexDoc(GetRef<ExprDoc>(this), indices);
}
@@ -54,6 +60,13 @@ LiteralDoc::LiteralDoc(ObjectRef value) {
this->data_ = std::move(n);
}
+LiteralDoc::LiteralDoc(ObjectRef value, ObjectPath object_path) {
+ ObjectPtr<LiteralDocNode> n = make_object<LiteralDocNode>();
+ n->value = value;
+ n->source_paths.push_back(object_path);
+ this->data_ = std::move(n);
+}
+
IdDoc::IdDoc(String name) {
ObjectPtr<IdDocNode> n = make_object<IdDocNode>();
n->name = name;
@@ -225,7 +238,8 @@ TVM_REGISTER_GLOBAL("script.printer.DocSetSourcePaths")
});
TVM_REGISTER_NODE_TYPE(ExprDocNode);
-TVM_REGISTER_GLOBAL("script.printer.ExprDocAttr").set_body_method<ExprDoc>(&ExprDocNode::Attr);
+TVM_REGISTER_GLOBAL("script.printer.ExprDocAttr")
+ .set_body_method<ExprDoc, ExprDocNode, ExprDoc, String>(&ExprDocNode::Attr);
TVM_REGISTER_GLOBAL("script.printer.ExprDocIndex")
.set_body_method<ExprDoc>(&ExprDocNode::operator[]);
TVM_REGISTER_GLOBAL("script.printer.ExprDocCall")
@@ -242,11 +256,15 @@ TVM_REGISTER_GLOBAL("script.printer.StmtBlockDoc").set_body_typed([](Array<StmtD
});
TVM_REGISTER_NODE_TYPE(LiteralDocNode);
-TVM_REGISTER_GLOBAL("script.printer.LiteralDocNone").set_body_typed(LiteralDoc::None);
-TVM_REGISTER_GLOBAL("script.printer.LiteralDocInt").set_body_typed(LiteralDoc::Int);
-TVM_REGISTER_GLOBAL("script.printer.LiteralDocBoolean").set_body_typed(LiteralDoc::Boolean);
-TVM_REGISTER_GLOBAL("script.printer.LiteralDocFloat").set_body_typed(LiteralDoc::Float);
-TVM_REGISTER_GLOBAL("script.printer.LiteralDocStr").set_body_typed(LiteralDoc::Str);
+TVM_REGISTER_GLOBAL("script.printer.LiteralDocNone").set_body_typed<LiteralDoc()>(LiteralDoc::None);
+TVM_REGISTER_GLOBAL("script.printer.LiteralDocInt")
+ .set_body_typed<LiteralDoc(int)>(LiteralDoc::Int);
+TVM_REGISTER_GLOBAL("script.printer.LiteralDocBoolean")
+ .set_body_typed<LiteralDoc(bool)>(LiteralDoc::Boolean);
+TVM_REGISTER_GLOBAL("script.printer.LiteralDocFloat")
+ .set_body_typed<LiteralDoc(double)>(LiteralDoc::Float);
+TVM_REGISTER_GLOBAL("script.printer.LiteralDocStr")
+ .set_body_typed<LiteralDoc(const String&)>(LiteralDoc::Str);
TVM_REGISTER_NODE_TYPE(IdDocNode);
TVM_REGISTER_GLOBAL("script.printer.IdDoc").set_body_typed([](String name) { return IdDoc(name); });
diff --git a/src/script/printer/ir_docsifier.cc b/src/script/printer/ir_docsifier.cc
index b72ed48db6..7f032ec502 100644
--- a/src/script/printer/ir_docsifier.cc
+++ b/src/script/printer/ir_docsifier.cc
@@ -61,7 +61,7 @@ RootNodeContainer::RootNodeContainer(ObjectRef root_node) {
// });
// \endcode
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
- .set_dispatch([](TracedObject<RootNodeContainer> obj, IRDocsifier p) -> Doc {
+ .set_dispatch<RootNodeContainer>([](TracedObject<RootNodeContainer> obj, IRDocsifier p) -> Doc {
String top_dispatch_token = p->dispatch_tokens.back();
ICHECK_NE(top_dispatch_token, "");
ICHECK(false) << "Printing IR " << top_dispatch_token << " is not implemented.";
diff --git a/src/script/printer/utils.h b/src/script/printer/utils.h
new file mode 100644
index 0000000000..abe7ce5e9a
--- /dev/null
+++ b/src/script/printer/utils.h
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_SCRIPT_PRINTER_UTILS_H_
+#define TVM_SCRIPT_PRINTER_UTILS_H_
+
+#include <tvm/script/printer/doc.h>
+#include <tvm/script/printer/ir_docsifier.h>
+
+#include <utility>
+
+namespace tvm {
+namespace script {
+namespace printer {
+
+template <typename DocType, typename NodeType>
+Array<DocType> AsDocArray(const TracedArray<NodeType>& refs, const IRDocsifier& ir_docsifier) {
+ Array<DocType> result;
+ for (auto ref : refs) {
+ result.push_back(ir_docsifier->AsExprDoc(ref));
+ }
+ return result;
+}
+
+template <typename DocType, typename NodeType>
+Array<DocType> AsDocArray(std::initializer_list<NodeType>&& refs, const IRDocsifier& ir_docsifier) {
+ Array<DocType> result;
+ for (auto& ref : refs) {
+ result.push_back(ir_docsifier->AsExprDoc(ref));
+ }
+ return result;
+}
+
+template <typename RefType>
+Array<ExprDoc> AsExprDocArray(const TracedArray<RefType>& refs, const IRDocsifier& ir_docsifier) {
+ return AsDocArray<ExprDoc>(refs, ir_docsifier);
+}
+
+template <typename RefType>
+Array<ExprDoc> AsExprDocArray(std::initializer_list<RefType>&& refs,
+ const IRDocsifier& ir_docsifier) {
+ return AsDocArray<ExprDoc>(std::move(refs), ir_docsifier);
+}
+
+inline DictDoc AsDictDoc(const TracedMap<String, ObjectRef>& dict,
+ const IRDocsifier& ir_docsifier) {
+ Array<ExprDoc> keys;
+ Array<ExprDoc> values;
+
+ for (auto p : dict) {
+ keys.push_back(LiteralDoc::Str(p.first));
+ values.push_back(ir_docsifier->AsExprDoc(p.second));
+ }
+
+ auto doc = DictDoc(keys, values);
+ doc->source_paths.push_back(dict.GetPath());
+ return doc;
+}
+
+template <typename T>
+inline ListDoc AsListDoc(const TracedArray<T>& arr, const IRDocsifier& ir_docsifier) {
+ auto ret = ListDoc(AsExprDocArray(arr, ir_docsifier));
+ ret->source_paths.push_back(arr.GetPath());
+ return ret;
+}
+
+template <typename T>
+inline TupleDoc AsTupleDoc(const TracedArray<T>& arr, const IRDocsifier& ir_docsifier) {
+ auto ret = TupleDoc(AsExprDocArray(arr, ir_docsifier));
+ ret->source_paths.push_back(arr.GetPath());
+ return ret;
+}
+
+} // namespace printer
+} // namespace script
+} // namespace tvm
+
+#endif // TVM_SCRIPT_PRINTER_UTILS_H_
diff --git a/src/script/printer/var_table.cc b/src/script/printer/var_table.cc
index 49ba93f9bc..62d8b2f66c 100644
--- a/src/script/printer/var_table.cc
+++ b/src/script/printer/var_table.cc
@@ -99,7 +99,8 @@ TVM_REGISTER_GLOBAL("script.printer.VarTableDefineByDoc")
obj, [f = std::move(factory)]() { return f(); }, frame);
});
TVM_REGISTER_GLOBAL("script.printer.VarTableGetVarDoc")
- .set_body_method<VarTable>(&VarTableNode::GetVarDoc);
+ .set_body_method<VarTable, VarTableNode, Optional<ExprDoc>, const ObjectRef&,
+ const ObjectPath&>(&VarTableNode::GetVarDoc);
TVM_REGISTER_GLOBAL("script.printer.VarTableIsVarDefined")
.set_body_method<VarTable>(&VarTableNode::IsVarDefined);
diff --git a/tests/cpp/tvmscript_printer_irdocsifier_test.cc b/tests/cpp/tvmscript_printer_irdocsifier_test.cc
index fcdb5ed04e..8c68399df2 100644
--- a/tests/cpp/tvmscript_printer_irdocsifier_test.cc
+++ b/tests/cpp/tvmscript_printer_irdocsifier_test.cc
@@ -45,14 +45,19 @@ class TestObject : public ObjectRef {
TVM_REGISTER_NODE_TYPE(TestObjectNode);
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
- .set_dispatch([](TracedObject<TestObject> obj, IRDocsifier p) { return IdDoc("x"); });
+ .set_dispatch<TestObject>([](TracedObject<TestObject> obj, IRDocsifier p) {
+ return IdDoc("x");
+ });
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
- .set_dispatch("tir", [](TracedObject<TestObject> obj, IRDocsifier p) { return IdDoc("tir"); });
+ .set_dispatch<TestObject>("tir", [](TracedObject<TestObject> obj, IRDocsifier p) {
+ return IdDoc("tir");
+ });
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
- .set_dispatch("relax",
- [](TracedObject<TestObject> obj, IRDocsifier p) { return IdDoc("relax"); });
+ .set_dispatch<TestObject>("relax", [](TracedObject<TestObject> obj, IRDocsifier p) {
+ return IdDoc("relax");
+ });
TEST(PrinterIRDocsifierTest, AsDoc) {
IRDocsifier p(Map<String, String>{});
diff --git a/tests/cpp/tvmscript_printer_traced_object_functor_test.cc b/tests/cpp/tvmscript_printer_traced_object_functor_test.cc
index 374eb609b6..d662ce1324 100644
--- a/tests/cpp/tvmscript_printer_traced_object_functor_test.cc
+++ b/tests/cpp/tvmscript_printer_traced_object_functor_test.cc
@@ -33,7 +33,7 @@ class FooObjectNode : public Object {
public:
void VisitAttrs(AttrVisitor* v) {}
- static constexpr const char* _type_key = "test.FooObject";
+ static constexpr const char* _type_key = "test.TracedObjectFunctor.FooObject";
TVM_DECLARE_FINAL_OBJECT_INFO(FooObjectNode, Object);
};
@@ -49,7 +49,7 @@ class BarObjectNode : public Object {
public:
void VisitAttrs(AttrVisitor* v) {}
- static constexpr const char* _type_key = "test.BarObject";
+ static constexpr const char* _type_key = "test.TracedObjectFunctor.BarObject";
TVM_DECLARE_FINAL_OBJECT_INFO(BarObjectNode, Object);
};
@@ -69,8 +69,8 @@ TEST(TracedObjectFunctorTest, NormalRegistration) {
TracedObjectFunctor<String> functor;
ObjectPath path = ObjectPath::Root();
- functor.set_dispatch([](TracedObject<FooObject> o) -> String { return "Foo"; });
- functor.set_dispatch([](TracedObject<BarObject> o) -> String { return "Bar"; });
+ functor.set_dispatch<FooObject>([](TracedObject<FooObject> o) -> String { return "Foo"; });
+ functor.set_dispatch<BarObject>([](TracedObject<BarObject> o) -> String { return "Bar"; });
ICHECK_EQ(functor("", MakeTraced(FooObject(), path)), "Foo");
ICHECK_EQ(functor("", MakeTraced(BarObject(), path)), "Bar");
@@ -80,8 +80,8 @@ TEST(TracedObjectFunctorTest, RegistrationWithFunction) {
TracedObjectFunctor<String> functor;
ObjectPath path = ObjectPath::Root();
- functor.set_dispatch([](TracedObject<FooObject> o) -> String { return "FooLambda"; });
- functor.set_dispatch("tir", ComputeFoo);
+ functor.set_dispatch<FooObject>([](TracedObject<FooObject> o) -> String { return "FooLambda"; });
+ functor.set_dispatch<FooObject>("tir", ComputeFoo);
ICHECK_EQ(functor("", MakeTraced(FooObject(), path)), "FooLambda");
ICHECK_EQ(functor("tir", MakeTraced(FooObject(), path)), "Foo");
@@ -91,9 +91,11 @@ TEST(TracedObjectFunctorTest, RegistrationWithDispatchToken) {
TracedObjectFunctor<String> functor;
ObjectPath path = ObjectPath::Root();
- functor.set_dispatch([](TracedObject<FooObject> o) -> String { return "Foo"; });
- functor.set_dispatch("tir", [](TracedObject<FooObject> o) -> String { return "Foo tir"; });
- functor.set_dispatch("relax", [](TracedObject<FooObject> o) -> String { return "Foo relax"; });
+ functor.set_dispatch<FooObject>([](TracedObject<FooObject> o) -> String { return "Foo"; });
+ functor.set_dispatch<FooObject>("tir",
+ [](TracedObject<FooObject> o) -> String { return "Foo tir"; });
+ functor.set_dispatch<FooObject>("relax",
+ [](TracedObject<FooObject> o) -> String { return "Foo relax"; });
ICHECK_EQ(functor("", MakeTraced(FooObject(), path)), "Foo");
ICHECK_EQ(functor("tir", MakeTraced(FooObject(), path)), "Foo tir");
@@ -119,8 +121,8 @@ TEST(TracedObjectFunctorTest, ExtraArg) {
TracedObjectFunctor<int, int> functor;
ObjectPath path = ObjectPath::Root();
- functor.set_dispatch([](TracedObject<FooObject> o, int x) { return x; });
- functor.set_dispatch([](TracedObject<BarObject> o, int x) { return x + 1; });
+ functor.set_dispatch<FooObject>([](TracedObject<FooObject> o, int x) { return x; });
+ functor.set_dispatch<BarObject>([](TracedObject<BarObject> o, int x) { return x + 1; });
ICHECK_EQ(functor("", MakeTraced(FooObject(), path), 2), 2);
ICHECK_EQ(functor("", MakeTraced(BarObject(), path), 2), 3);
@@ -131,8 +133,9 @@ TEST(TracedObjectFunctorTest, RemoveDispatchFunction) {
TracedObjectFunctor<String> functor;
ObjectPath path = ObjectPath::Root();
- functor.set_dispatch([](TracedObject<FooObject> o) -> String { return "Foo"; });
- functor.set_dispatch("tir", [](TracedObject<FooObject> o) -> String { return "Foo tir"; });
+ functor.set_dispatch<FooObject>([](TracedObject<FooObject> o) -> String { return "Foo"; });
+ functor.set_dispatch<FooObject>("tir",
+ [](TracedObject<FooObject> o) -> String { return "Foo tir"; });
ICHECK_EQ(functor("", MakeTraced(FooObject(), path)), "Foo");
ICHECK_EQ(functor("tir", MakeTraced(FooObject(), path)), "Foo tir");
@@ -158,11 +161,11 @@ TEST(TracedObjectFunctorTest, DuplicateRegistration_WithoutToken) {
TracedObjectFunctor<int, int> functor;
ObjectPath path = ObjectPath::Root();
- functor.set_dispatch([](TracedObject<FooObject> o, int x) { return x; });
+ functor.set_dispatch<FooObject>([](TracedObject<FooObject> o, int x) { return x; });
bool failed = false;
try {
- functor.set_dispatch([](TracedObject<FooObject> o, int x) { return x; });
+ functor.set_dispatch<FooObject>([](TracedObject<FooObject> o, int x) { return x; });
} catch (...) {
failed = true;
}
@@ -173,11 +176,11 @@ TEST(TracedObjectFunctorTest, DuplicateRegistration_WithToken) {
TracedObjectFunctor<int, int> functor;
ObjectPath path = ObjectPath::Root();
- functor.set_dispatch("tir", [](TracedObject<FooObject> o, int x) { return x; });
+ functor.set_dispatch<FooObject>("tir", [](TracedObject<FooObject> o, int x) { return x; });
bool failed = false;
try {
- functor.set_dispatch("tir", [](TracedObject<FooObject> o, int x) { return x; });
+ functor.set_dispatch<FooObject>("tir", [](TracedObject<FooObject> o, int x) { return x; });
} catch (...) {
failed = true;
}