You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2019/11/11 20:43:51 UTC

[GitHub] [incubator-tvm] icemelon9 commented on a change in pull request #4218: [Relay][VM][Interpreter] Enable first-class constructors in VM and interpreter via eta expansion

icemelon9 commented on a change in pull request #4218: [Relay][VM][Interpreter] Enable first-class constructors in VM and interpreter via eta expansion
URL: https://github.com/apache/incubator-tvm/pull/4218#discussion_r344899691
 
 

 ##########
 File path: src/relay/pass/eta_expand.cc
 ##########
 @@ -20,57 +20,147 @@
 /*!
  * \file eta_expand.cc
  *
- * \brief Add abstraction over a function. For example, abs will become (fun x -> abs x).
+ * \brief Add abstraction over a constructor or global variable bound to a function.
  *
  */
-#include <tvm/relay/type.h>
 #include <tvm/relay/transform.h>
+#include <tvm/relay/type.h>
+#include <tvm/relay/expr_functor.h>
+#include "../ir/type_functor.h"
 
 namespace tvm {
 namespace relay {
+namespace eta_expand {
+
+/*!
+ * \brief mutator to replace type variables with fresh ones, while maintaining alpha equality
+ */
+class TypeVarReplacer : public TypeMutator {
+ public:
+  TypeVarReplacer() : replace_map_({}) {}
 
-Expr EtaExpand(const Expr& e, const Module& mod) {
-  tvm::Array<Var> original_params;
-  tvm::Array<Expr> params;
-  tvm::Array<Var> args;
-  tvm::Array<TypeVar> original_type_params;
-  Type ret_type;
-
-  if (e->IsInstance<GlobalVarNode>()) {
-    auto gvar_node = e.as<GlobalVarNode>();
-    auto func = mod->Lookup(GetRef<GlobalVar>(gvar_node));
-    original_params = func->params;
-    original_type_params = func->type_params;
-    ret_type = func->ret_type;
-  } else {
-    CHECK(e->IsInstance<FunctionNode>());
-    auto func = GetRef<Function>(e.as<FunctionNode>());
-    original_params = func->params;
-    original_type_params = func->type_params;
-    ret_type = func->ret_type;
+  Type VisitType_(const TypeVarNode* type_var_node) final {
+    const auto type_var = GetRef<TypeVar>(type_var_node);
+    if (replace_map_.find(type_var) == replace_map_.end()) {
+      replace_map_[type_var] = TypeVarNode::make("A", Kind::kType);
+    }
+    return replace_map_[type_var];
   }
 
-  for (size_t i = 0; i < original_params.size(); ++i) {
-    auto var = VarNode::make("a", original_params[i]->type_annotation);
-    params.push_back(var);
-    args.push_back(var);
+ private:
+  /*! \brief variable replacement map to remap old type vars to fresh ones */
+  std::unordered_map<TypeVar, TypeVar, NodeHash, NodeEqual> replace_map_;
+};
+
+/*!
+ * \brief mutator to perform eta expansion on all functions in a module
+ */
+class EtaExpander : public ExprMutator {
+ public:
+  explicit EtaExpander(
+    const Module& mod,
+    bool expand_constructor,
+    bool expand_global_var)
+      : mod_(mod)
+      , type_var_replacer_(TypeVarReplacer())
 
 Review comment:
   move the comma to previous line?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services