You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/09/17 20:54:09 UTC

[tvm] branch main updated: [TVMScript] Add more helper functions to the printer infra (#12829)

This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 1ecf084eec [TVMScript] Add more helper functions to the printer infra  (#12829)
1ecf084eec is described below

commit 1ecf084eecaff167967df1a8c998de72e1198c24
Author: Lite Ye <ye...@gmail.com>
AuthorDate: Sat Sep 17 16:54:01 2022 -0400

    [TVMScript] Add more helper functions to the printer infra  (#12829)
    
    This PR is split from https://github.com/apache/tvm/pull/12492, to make the necessary updates to the printer infra for future PRs of TIR printer.
    
    Tracking issue: https://github.com/apache/tvm/issues/11912
    
    Co-authored-by: Greg Bonik <gb...@octoml.ai>
---
 include/tvm/script/printer/doc.h                   | 64 +++++++++++++++
 include/tvm/script/printer/traced_object_functor.h | 37 +--------
 include/tvm/script/printer/var_table.h             | 11 +++
 src/script/printer/doc.cc                          | 30 +++++--
 src/script/printer/ir_docsifier.cc                 |  2 +-
 src/script/printer/utils.h                         | 93 ++++++++++++++++++++++
 src/script/printer/var_table.cc                    |  3 +-
 tests/cpp/tvmscript_printer_irdocsifier_test.cc    | 13 ++-
 ...tvmscript_printer_traced_object_functor_test.cc | 37 +++++----
 9 files changed, 228 insertions(+), 62 deletions(-)

diff --git a/include/tvm/script/printer/doc.h b/include/tvm/script/printer/doc.h
index 72f343354b..1ee7fd6a7f 100644
--- a/include/tvm/script/printer/doc.h
+++ b/include/tvm/script/printer/doc.h
@@ -22,6 +22,7 @@
 #include <tvm/ir/expr.h>
 #include <tvm/node/node.h>
 #include <tvm/runtime/data_type.h>
+#include <tvm/script/printer/traced_object.h>
 
 namespace tvm {
 namespace script {
@@ -87,6 +88,15 @@ class ExprDocNode : public DocNode {
    */
   ExprDoc Attr(String attr) const;
 
+  /*!
+   * \brief Create a doc representing attribute access on the current ExprDoc
+   * \param attr The attribute to access.
+   *
+   * The ObjectPath of attr will be pushed to the source_path of the returned
+   * doc.
+   */
+  ExprDoc Attr(TracedObject<String> attr) const;
+
   /*!
    * \brief Create a doc representing index access on the current ExprDoc
    * \param indices The indices to access.
@@ -242,6 +252,7 @@ class LiteralDocNode : public ExprDocNode {
 class LiteralDoc : public ExprDoc {
  protected:
   explicit LiteralDoc(ObjectRef value);
+  LiteralDoc(ObjectRef value, ObjectPath object_path);
 
  public:
   /*!
@@ -249,30 +260,83 @@ class LiteralDoc : public ExprDoc {
    */
   static LiteralDoc None() { return LiteralDoc(ObjectRef(nullptr)); }
 
+  /*!
+   * \brief Create a LiteralDoc to represent None/null/empty value.
+   * \param object_path The source path of the returned Doc.
+   */
+  static LiteralDoc None(ObjectPath object_path) {
+    return LiteralDoc(ObjectRef(nullptr), object_path);
+  }
+
   /*!
    * \brief Create a LiteralDoc to represent integer.
    * \param v The integer value.
    */
   static LiteralDoc Int(int v) { return LiteralDoc(IntImm(DataType::Int(64), v)); }
 
+  /*!
+   * \brief Create a LiteralDoc to represent integer.
+   * \param v The integer value.
+   *
+   * The ObjectPath of v will be pushed to the source_path of the returned doc.
+   */
+  static LiteralDoc Int(const TracedObject<IntImm>& v) { return LiteralDoc(v.Get(), v.GetPath()); }
+
+  /*!
+   * \brief Create a LiteralDoc to represent integer.
+   * \param v The integer value.
+   *
+   * The ObjectPath of v will be pushed to the source_path of the returned doc.
+   */
+  static LiteralDoc Int(const TracedBasicValue<int>& v) {
+    return LiteralDoc(IntImm(DataType::Int(64), v.Get()), v.GetPath());
+  }
   /*!
    * \brief Create a LiteralDoc to represent boolean.
    * \param v The boolean value.
    */
   static LiteralDoc Boolean(bool v) { return LiteralDoc(IntImm(DataType::Bool(), v)); }
 
+  /*!
+   * \brief Create a LiteralDoc to represent boolean.
+   * \param v The boolean value.
+   *
+   * The ObjectPath of v will be pushed to the source_path of the returned doc.
+   */
+  static LiteralDoc Boolean(const TracedBasicValue<bool>& v) {
+    return LiteralDoc(IntImm(DataType::Bool(), v.Get()), v.GetPath());
+  }
+
   /*!
    * \brief Create a LiteralDoc to represent float.
    * \param v The float value.
    */
   static LiteralDoc Float(double v) { return LiteralDoc(FloatImm(DataType::Float(64), v)); }
 
+  /*!
+   * \brief Create a LiteralDoc to represent float.
+   * \param v The float value.
+   *
+   * The ObjectPath of v will be pushed to the source_path of the returned doc.
+   */
+  static LiteralDoc Float(const TracedObject<FloatImm>& v) {
+    return LiteralDoc(v.Get(), v.GetPath());
+  }
+
   /*!
    * \brief Create a LiteralDoc to represent string.
    * \param v The string value.
    */
   static LiteralDoc Str(const String& v) { return LiteralDoc(v); }
 
+  /*!
+   * \brief Create a LiteralDoc to represent string.
+   * \param v The string value.
+   *
+   * The ObjectPath of v will be pushed to the source_path of the returned doc.
+   */
+  static LiteralDoc Str(const TracedObject<String>& v) { return LiteralDoc(v.Get(), v.GetPath()); }
+
   TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(LiteralDoc, ExprDoc, LiteralDocNode);
 };
 
diff --git a/include/tvm/script/printer/traced_object_functor.h b/include/tvm/script/printer/traced_object_functor.h
index 6caaf8a6e0..8f72d139a5 100644
--- a/include/tvm/script/printer/traced_object_functor.h
+++ b/include/tvm/script/printer/traced_object_functor.h
@@ -34,35 +34,6 @@ namespace tvm {
 namespace script {
 namespace printer {
 
-namespace {
-
-namespace detail {
-/*!
- * \brief Helper template class to extract the type of first argument of a function
- * \tparam FType The function type.
- */
-template <typename FType>
-struct FirstArgTypeGetter;
-
-template <typename R, typename ArgOne, typename... OtherArgs>
-struct FirstArgTypeGetter<R(ArgOne, OtherArgs...)> {
-  using T = ArgOne;
-};
-
-/*!
- * \brief Template alias for the type of first argument of a function
- * \tparam FType The function type.
- *
- * The name of public functions are in snake case to be consistent with
- * tvm/node/functor.h
- */
-template <typename FType>
-using FirstArgType = typename detail::FirstArgTypeGetter<
-    typename tvm::runtime::detail::function_signature<FType>::FType>::T;
-}  // namespace detail
-
-}  // namespace
-
 /*
  * This type alias and the following free functions are created to reduce the binary bloat
  * from template and also hide implementation details from this header
@@ -156,8 +127,7 @@ class TracedObjectFunctor {
    *
    * The diaptch function should have signature `R(TracedObject<TObjectRef>, Args...)`.
    */
-  template <typename TCallable,
-            typename TObjectRef = typename detail::FirstArgType<TCallable>::ObjectRefType,
+  template <typename TObjectRef, typename TCallable,
             typename = std::enable_if_t<IsDispatchFunction<TObjectRef, TCallable>::value>>
   TSelf& set_dispatch(String token, TCallable f) {
     return set_dispatch(
@@ -177,9 +147,10 @@ class TracedObjectFunctor {
    *
    * Default dispatch function has an empty string as dispatch token.
    */
-  template <typename TCallable>
+  template <typename TObjectRef, typename TCallable,
+            typename = std::enable_if_t<IsDispatchFunction<TObjectRef, TCallable>::value>>
   TSelf& set_dispatch(TCallable&& f) {
-    return set_dispatch(kDefaultDispatchToken, std::forward<TCallable>(f));
+    return set_dispatch<TObjectRef>(kDefaultDispatchToken, std::forward<TCallable>(f));
   }
 
   /*!
diff --git a/include/tvm/script/printer/var_table.h b/include/tvm/script/printer/var_table.h
index 9300a976c5..2cd9335213 100644
--- a/include/tvm/script/printer/var_table.h
+++ b/include/tvm/script/printer/var_table.h
@@ -103,6 +103,17 @@ class VarTableNode : public Object {
    */
   Optional<ExprDoc> GetVarDoc(const ObjectRef& obj, const ObjectPath& object_path) const;
 
+  /*!
+   * \brief Get the doc for variable.
+   * \param obj The traced variable object.
+   *
+   * \return The doc for variable, if it exists in the table. Otherwise it returns NullOpt.
+   */
+  template <typename TObjectRef>
+  Optional<ExprDoc> GetVarDoc(const TracedObject<TObjectRef> obj) const {
+    return GetVarDoc(obj.Get(), obj.GetPath());
+  }
+
   /*!
    * \brief Check if a variable exists in the table.
    * \param obj The variable object.
diff --git a/src/script/printer/doc.cc b/src/script/printer/doc.cc
index d6f5ff35ab..f3b431bd62 100644
--- a/src/script/printer/doc.cc
+++ b/src/script/printer/doc.cc
@@ -27,6 +27,12 @@ namespace printer {
 
 ExprDoc ExprDocNode::Attr(String attr) const { return AttrAccessDoc(GetRef<ExprDoc>(this), attr); }
 
+ExprDoc ExprDocNode::Attr(TracedObject<String> attr) const {
+  auto doc = AttrAccessDoc(GetRef<ExprDoc>(this), attr.Get());
+  doc->source_paths.push_back(attr.GetPath());
+  return doc;
+}
+
 ExprDoc ExprDocNode::operator[](Array<Doc> indices) const {
   return IndexDoc(GetRef<ExprDoc>(this), indices);
 }
@@ -54,6 +60,13 @@ LiteralDoc::LiteralDoc(ObjectRef value) {
   this->data_ = std::move(n);
 }
 
+LiteralDoc::LiteralDoc(ObjectRef value, ObjectPath object_path) {
+  ObjectPtr<LiteralDocNode> n = make_object<LiteralDocNode>();
+  n->value = value;
+  n->source_paths.push_back(object_path);
+  this->data_ = std::move(n);
+}
+
 IdDoc::IdDoc(String name) {
   ObjectPtr<IdDocNode> n = make_object<IdDocNode>();
   n->name = name;
@@ -225,7 +238,8 @@ TVM_REGISTER_GLOBAL("script.printer.DocSetSourcePaths")
     });
 
 TVM_REGISTER_NODE_TYPE(ExprDocNode);
-TVM_REGISTER_GLOBAL("script.printer.ExprDocAttr").set_body_method<ExprDoc>(&ExprDocNode::Attr);
+TVM_REGISTER_GLOBAL("script.printer.ExprDocAttr")
+    .set_body_method<ExprDoc, ExprDocNode, ExprDoc, String>(&ExprDocNode::Attr);
 TVM_REGISTER_GLOBAL("script.printer.ExprDocIndex")
     .set_body_method<ExprDoc>(&ExprDocNode::operator[]);
 TVM_REGISTER_GLOBAL("script.printer.ExprDocCall")
@@ -242,11 +256,15 @@ TVM_REGISTER_GLOBAL("script.printer.StmtBlockDoc").set_body_typed([](Array<StmtD
 });
 
 TVM_REGISTER_NODE_TYPE(LiteralDocNode);
-TVM_REGISTER_GLOBAL("script.printer.LiteralDocNone").set_body_typed(LiteralDoc::None);
-TVM_REGISTER_GLOBAL("script.printer.LiteralDocInt").set_body_typed(LiteralDoc::Int);
-TVM_REGISTER_GLOBAL("script.printer.LiteralDocBoolean").set_body_typed(LiteralDoc::Boolean);
-TVM_REGISTER_GLOBAL("script.printer.LiteralDocFloat").set_body_typed(LiteralDoc::Float);
-TVM_REGISTER_GLOBAL("script.printer.LiteralDocStr").set_body_typed(LiteralDoc::Str);
+TVM_REGISTER_GLOBAL("script.printer.LiteralDocNone").set_body_typed<LiteralDoc()>(LiteralDoc::None);
+TVM_REGISTER_GLOBAL("script.printer.LiteralDocInt")
+    .set_body_typed<LiteralDoc(int)>(LiteralDoc::Int);
+TVM_REGISTER_GLOBAL("script.printer.LiteralDocBoolean")
+    .set_body_typed<LiteralDoc(bool)>(LiteralDoc::Boolean);
+TVM_REGISTER_GLOBAL("script.printer.LiteralDocFloat")
+    .set_body_typed<LiteralDoc(double)>(LiteralDoc::Float);
+TVM_REGISTER_GLOBAL("script.printer.LiteralDocStr")
+    .set_body_typed<LiteralDoc(const String&)>(LiteralDoc::Str);
 
 TVM_REGISTER_NODE_TYPE(IdDocNode);
 TVM_REGISTER_GLOBAL("script.printer.IdDoc").set_body_typed([](String name) { return IdDoc(name); });
diff --git a/src/script/printer/ir_docsifier.cc b/src/script/printer/ir_docsifier.cc
index b72ed48db6..7f032ec502 100644
--- a/src/script/printer/ir_docsifier.cc
+++ b/src/script/printer/ir_docsifier.cc
@@ -61,7 +61,7 @@ RootNodeContainer::RootNodeContainer(ObjectRef root_node) {
 //     });
 // \endcode
 TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
-    .set_dispatch([](TracedObject<RootNodeContainer> obj, IRDocsifier p) -> Doc {
+    .set_dispatch<RootNodeContainer>([](TracedObject<RootNodeContainer> obj, IRDocsifier p) -> Doc {
       String top_dispatch_token = p->dispatch_tokens.back();
       ICHECK_NE(top_dispatch_token, "");
       ICHECK(false) << "Printing IR " << top_dispatch_token << " is not implemented.";
diff --git a/src/script/printer/utils.h b/src/script/printer/utils.h
new file mode 100644
index 0000000000..abe7ce5e9a
--- /dev/null
+++ b/src/script/printer/utils.h
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_SCRIPT_PRINTER_UTILS_H_
+#define TVM_SCRIPT_PRINTER_UTILS_H_
+
+#include <tvm/script/printer/doc.h>
+#include <tvm/script/printer/ir_docsifier.h>
+
+#include <utility>
+
+namespace tvm {
+namespace script {
+namespace printer {
+
+template <typename DocType, typename NodeType>
+Array<DocType> AsDocArray(const TracedArray<NodeType>& refs, const IRDocsifier& ir_docsifier) {
+  Array<DocType> result;
+  for (auto ref : refs) {
+    result.push_back(ir_docsifier->AsExprDoc(ref));
+  }
+  return result;
+}
+
+template <typename DocType, typename NodeType>
+Array<DocType> AsDocArray(std::initializer_list<NodeType>&& refs, const IRDocsifier& ir_docsifier) {
+  Array<DocType> result;
+  for (auto& ref : refs) {
+    result.push_back(ir_docsifier->AsExprDoc(ref));
+  }
+  return result;
+}
+
+template <typename RefType>
+Array<ExprDoc> AsExprDocArray(const TracedArray<RefType>& refs, const IRDocsifier& ir_docsifier) {
+  return AsDocArray<ExprDoc>(refs, ir_docsifier);
+}
+
+template <typename RefType>
+Array<ExprDoc> AsExprDocArray(std::initializer_list<RefType>&& refs,
+                              const IRDocsifier& ir_docsifier) {
+  return AsDocArray<ExprDoc>(std::move(refs), ir_docsifier);
+}
+
+inline DictDoc AsDictDoc(const TracedMap<String, ObjectRef>& dict,
+                         const IRDocsifier& ir_docsifier) {
+  Array<ExprDoc> keys;
+  Array<ExprDoc> values;
+
+  for (auto p : dict) {
+    keys.push_back(LiteralDoc::Str(p.first));
+    values.push_back(ir_docsifier->AsExprDoc(p.second));
+  }
+
+  auto doc = DictDoc(keys, values);
+  doc->source_paths.push_back(dict.GetPath());
+  return doc;
+}
+
+template <typename T>
+inline ListDoc AsListDoc(const TracedArray<T>& arr, const IRDocsifier& ir_docsifier) {
+  auto ret = ListDoc(AsExprDocArray(arr, ir_docsifier));
+  ret->source_paths.push_back(arr.GetPath());
+  return ret;
+}
+
+template <typename T>
+inline TupleDoc AsTupleDoc(const TracedArray<T>& arr, const IRDocsifier& ir_docsifier) {
+  auto ret = TupleDoc(AsExprDocArray(arr, ir_docsifier));
+  ret->source_paths.push_back(arr.GetPath());
+  return ret;
+}
+
+}  // namespace printer
+}  // namespace script
+}  // namespace tvm
+
+#endif  // TVM_SCRIPT_PRINTER_UTILS_H_
diff --git a/src/script/printer/var_table.cc b/src/script/printer/var_table.cc
index 49ba93f9bc..62d8b2f66c 100644
--- a/src/script/printer/var_table.cc
+++ b/src/script/printer/var_table.cc
@@ -99,7 +99,8 @@ TVM_REGISTER_GLOBAL("script.printer.VarTableDefineByDoc")
           obj, [f = std::move(factory)]() { return f(); }, frame);
     });
 TVM_REGISTER_GLOBAL("script.printer.VarTableGetVarDoc")
-    .set_body_method<VarTable>(&VarTableNode::GetVarDoc);
+    .set_body_method<VarTable, VarTableNode, Optional<ExprDoc>, const ObjectRef&,
+                     const ObjectPath&>(&VarTableNode::GetVarDoc);
 TVM_REGISTER_GLOBAL("script.printer.VarTableIsVarDefined")
     .set_body_method<VarTable>(&VarTableNode::IsVarDefined);
 
diff --git a/tests/cpp/tvmscript_printer_irdocsifier_test.cc b/tests/cpp/tvmscript_printer_irdocsifier_test.cc
index fcdb5ed04e..8c68399df2 100644
--- a/tests/cpp/tvmscript_printer_irdocsifier_test.cc
+++ b/tests/cpp/tvmscript_printer_irdocsifier_test.cc
@@ -45,14 +45,19 @@ class TestObject : public ObjectRef {
 TVM_REGISTER_NODE_TYPE(TestObjectNode);
 
 TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
-    .set_dispatch([](TracedObject<TestObject> obj, IRDocsifier p) { return IdDoc("x"); });
+    .set_dispatch<TestObject>([](TracedObject<TestObject> obj, IRDocsifier p) {
+      return IdDoc("x");
+    });
 
 TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
-    .set_dispatch("tir", [](TracedObject<TestObject> obj, IRDocsifier p) { return IdDoc("tir"); });
+    .set_dispatch<TestObject>("tir", [](TracedObject<TestObject> obj, IRDocsifier p) {
+      return IdDoc("tir");
+    });
 
 TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
-    .set_dispatch("relax",
-                  [](TracedObject<TestObject> obj, IRDocsifier p) { return IdDoc("relax"); });
+    .set_dispatch<TestObject>("relax", [](TracedObject<TestObject> obj, IRDocsifier p) {
+      return IdDoc("relax");
+    });
 
 TEST(PrinterIRDocsifierTest, AsDoc) {
   IRDocsifier p(Map<String, String>{});
diff --git a/tests/cpp/tvmscript_printer_traced_object_functor_test.cc b/tests/cpp/tvmscript_printer_traced_object_functor_test.cc
index 374eb609b6..d662ce1324 100644
--- a/tests/cpp/tvmscript_printer_traced_object_functor_test.cc
+++ b/tests/cpp/tvmscript_printer_traced_object_functor_test.cc
@@ -33,7 +33,7 @@ class FooObjectNode : public Object {
  public:
   void VisitAttrs(AttrVisitor* v) {}
 
-  static constexpr const char* _type_key = "test.FooObject";
+  static constexpr const char* _type_key = "test.TracedObjectFunctor.FooObject";
   TVM_DECLARE_FINAL_OBJECT_INFO(FooObjectNode, Object);
 };
 
@@ -49,7 +49,7 @@ class BarObjectNode : public Object {
  public:
   void VisitAttrs(AttrVisitor* v) {}
 
-  static constexpr const char* _type_key = "test.BarObject";
+  static constexpr const char* _type_key = "test.TracedObjectFunctor.BarObject";
   TVM_DECLARE_FINAL_OBJECT_INFO(BarObjectNode, Object);
 };
 
@@ -69,8 +69,8 @@ TEST(TracedObjectFunctorTest, NormalRegistration) {
   TracedObjectFunctor<String> functor;
   ObjectPath path = ObjectPath::Root();
 
-  functor.set_dispatch([](TracedObject<FooObject> o) -> String { return "Foo"; });
-  functor.set_dispatch([](TracedObject<BarObject> o) -> String { return "Bar"; });
+  functor.set_dispatch<FooObject>([](TracedObject<FooObject> o) -> String { return "Foo"; });
+  functor.set_dispatch<BarObject>([](TracedObject<BarObject> o) -> String { return "Bar"; });
 
   ICHECK_EQ(functor("", MakeTraced(FooObject(), path)), "Foo");
   ICHECK_EQ(functor("", MakeTraced(BarObject(), path)), "Bar");
@@ -80,8 +80,8 @@ TEST(TracedObjectFunctorTest, RegistrationWithFunction) {
   TracedObjectFunctor<String> functor;
   ObjectPath path = ObjectPath::Root();
 
-  functor.set_dispatch([](TracedObject<FooObject> o) -> String { return "FooLambda"; });
-  functor.set_dispatch("tir", ComputeFoo);
+  functor.set_dispatch<FooObject>([](TracedObject<FooObject> o) -> String { return "FooLambda"; });
+  functor.set_dispatch<FooObject>("tir", ComputeFoo);
 
   ICHECK_EQ(functor("", MakeTraced(FooObject(), path)), "FooLambda");
   ICHECK_EQ(functor("tir", MakeTraced(FooObject(), path)), "Foo");
@@ -91,9 +91,11 @@ TEST(TracedObjectFunctorTest, RegistrationWithDispatchToken) {
   TracedObjectFunctor<String> functor;
   ObjectPath path = ObjectPath::Root();
 
-  functor.set_dispatch([](TracedObject<FooObject> o) -> String { return "Foo"; });
-  functor.set_dispatch("tir", [](TracedObject<FooObject> o) -> String { return "Foo tir"; });
-  functor.set_dispatch("relax", [](TracedObject<FooObject> o) -> String { return "Foo relax"; });
+  functor.set_dispatch<FooObject>([](TracedObject<FooObject> o) -> String { return "Foo"; });
+  functor.set_dispatch<FooObject>("tir",
+                                  [](TracedObject<FooObject> o) -> String { return "Foo tir"; });
+  functor.set_dispatch<FooObject>("relax",
+                                  [](TracedObject<FooObject> o) -> String { return "Foo relax"; });
 
   ICHECK_EQ(functor("", MakeTraced(FooObject(), path)), "Foo");
   ICHECK_EQ(functor("tir", MakeTraced(FooObject(), path)), "Foo tir");
@@ -119,8 +121,8 @@ TEST(TracedObjectFunctorTest, ExtraArg) {
   TracedObjectFunctor<int, int> functor;
   ObjectPath path = ObjectPath::Root();
 
-  functor.set_dispatch([](TracedObject<FooObject> o, int x) { return x; });
-  functor.set_dispatch([](TracedObject<BarObject> o, int x) { return x + 1; });
+  functor.set_dispatch<FooObject>([](TracedObject<FooObject> o, int x) { return x; });
+  functor.set_dispatch<BarObject>([](TracedObject<BarObject> o, int x) { return x + 1; });
 
   ICHECK_EQ(functor("", MakeTraced(FooObject(), path), 2), 2);
   ICHECK_EQ(functor("", MakeTraced(BarObject(), path), 2), 3);
@@ -131,8 +133,9 @@ TEST(TracedObjectFunctorTest, RemoveDispatchFunction) {
   TracedObjectFunctor<String> functor;
   ObjectPath path = ObjectPath::Root();
 
-  functor.set_dispatch([](TracedObject<FooObject> o) -> String { return "Foo"; });
-  functor.set_dispatch("tir", [](TracedObject<FooObject> o) -> String { return "Foo tir"; });
+  functor.set_dispatch<FooObject>([](TracedObject<FooObject> o) -> String { return "Foo"; });
+  functor.set_dispatch<FooObject>("tir",
+                                  [](TracedObject<FooObject> o) -> String { return "Foo tir"; });
 
   ICHECK_EQ(functor("", MakeTraced(FooObject(), path)), "Foo");
   ICHECK_EQ(functor("tir", MakeTraced(FooObject(), path)), "Foo tir");
@@ -158,11 +161,11 @@ TEST(TracedObjectFunctorTest, DuplicateRegistration_WithoutToken) {
   TracedObjectFunctor<int, int> functor;
   ObjectPath path = ObjectPath::Root();
 
-  functor.set_dispatch([](TracedObject<FooObject> o, int x) { return x; });
+  functor.set_dispatch<FooObject>([](TracedObject<FooObject> o, int x) { return x; });
 
   bool failed = false;
   try {
-    functor.set_dispatch([](TracedObject<FooObject> o, int x) { return x; });
+    functor.set_dispatch<FooObject>([](TracedObject<FooObject> o, int x) { return x; });
   } catch (...) {
     failed = true;
   }
@@ -173,11 +176,11 @@ TEST(TracedObjectFunctorTest, DuplicateRegistration_WithToken) {
   TracedObjectFunctor<int, int> functor;
   ObjectPath path = ObjectPath::Root();
 
-  functor.set_dispatch("tir", [](TracedObject<FooObject> o, int x) { return x; });
+  functor.set_dispatch<FooObject>("tir", [](TracedObject<FooObject> o, int x) { return x; });
 
   bool failed = false;
   try {
-    functor.set_dispatch("tir", [](TracedObject<FooObject> o, int x) { return x; });
+    functor.set_dispatch<FooObject>("tir", [](TracedObject<FooObject> o, int x) { return x; });
   } catch (...) {
     failed = true;
   }