You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/09/04 17:04:15 UTC

[GitHub] [incubator-tvm] hypercubestart opened a new pull request #6400: [Relay] Add Defunctionalization Pass

hypercubestart opened a new pull request #6400:
URL: https://github.com/apache/incubator-tvm/pull/6400


   introduces a defunctionalization pass based upon Type-Driven Defunctionalization.
   
   currently it assumes a number of characteristics about the program, and will be extended upon in future work
   
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] MarisaKirisame commented on a change in pull request #6400: [Relay] Add Defunctionalization Pass

Posted by GitBox <gi...@apache.org>.
MarisaKirisame commented on a change in pull request #6400:
URL: https://github.com/apache/incubator-tvm/pull/6400#discussion_r486048835



##########
File path: src/relay/transforms/defunctionalization.cc
##########
@@ -0,0 +1,432 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file defunctionalization.cc
+ *
+ * \brief Defunctionalization for Relay IR
+ *
+ * This pass transforms a higher-order program into a first-order program with defunctionalization.
+ * This means that all higher order functions (i.e functions that take function arguments or return
+ * functions) should be transformed into a semantically equivalent first order one.
+ *
+ * This pass implements a basic typed defunctionalization method.
+ * All higher order functions are cloned and specialized (so that there are no type params).
+ * Function type arguments are encoded as datatypes and a helper `apply` function is used
+ * to "call" them.
+ *
+ * For example, take the following higher order program:
+ * fun map F y = case y of
+ *          Nil => Nil
+ *          | Cons(x, XS) => Cons(F z, map F XS)
+ * fun addone 1 = map (\x -> \x + 1) 1
+ *
+ * where `addone` is our program.
+ * When we call the `map` function, we see that it is a higher-order function,
+ * but we can clone `map ` function and specialize it with the type_params of the call.
+ * In addition, our function argument `(\x -> \x + 1)` will be encoded as a datatype constructor,
+ * which we will call `incr`, and all calls to `F` in our specialized map function will use the
+ * helper `apply` function.
+ *
+ * After defunctionalization, we get:
+ * fun apply encoding arg =  case encoding of
+ *     “incr” => incr arg
+ * fun map’ F y = case y of
+ *           Nil => Nil
+ *           | Cons(x, xs) => Cons(apply F x, map’ F xs)
+ * fun addone 1 = map’ “incr” 1
+ *
+ * Currently, defunctionalization makes the following assumptions:
+ * - functions cannot return function values
+ * - function arguments are in two forms: identifier or a lambda abstraction
+ * - no functions stored in datatype
+ * - functions are not let binded
+ */
+
+#include <tvm/ir/type_functor.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/feature.h>
+#include <tvm/relay/transform.h>
+#include <tvm/te/operation.h>
+
+#include "../analysis/type_solver.h"
+#include "../transforms/pass_util.h"
+namespace tvm {
+namespace relay {
+
+// determine if type contains a FuncType
+bool HasFuncType(const Type& t) {
+  struct FuncTypeVisitor : TypeVisitor {
+    bool has_func_type;
+    FuncTypeVisitor() : has_func_type(false) {}
+
+    void VisitType_(const FuncTypeNode* op) { this->has_func_type = true; }
+  };
+
+  auto visitor = FuncTypeVisitor();
+  visitor.VisitType(t);
+  return visitor.has_func_type;
+}
+// determine if FuncType is a higher order type
+bool IsHigherOrderFunc(const FuncType& t) {
+  bool higher_order = false;
+  for (auto arg : t->arg_types) {
+    higher_order |= HasFuncType(arg);
+  }
+  return higher_order |= HasFuncType(t->ret_type);
+}
+
+/*!
+ * \brief mutator for driving the Defunctionalization transformation
+ */
+class DefuncMutator : public ExprMutator {
+ public:
+  explicit DefuncMutator(const IRModule& mod) : mod(mod), constructor_counter(0) {}
+
+  Expr VisitExpr_(const CallNode* call) {
+    if (auto op = call->op.as<GlobalVarNode>()) {
+      CHECK_EQ(call->type_args.size(), op->checked_type().as<FuncTypeNode>()->type_params.size())
+          << "all type args must be explicit";
+
+      auto op_type = InstFuncType(op->checked_type().as<FuncTypeNode>(), call->type_args);
+      CHECK_EQ(FreeTypeVars(op_type, mod).size(), 0) << "free type vars in instantiated";
+      CHECK(!HasFuncType(op_type->ret_type)) << "returning functions not supported";
+
+      if (!IsHigherOrderFunc(op_type)) {
+        // not higher order function
+        return ExprMutator::VisitExpr_(call);
+      }
+
+      // first we encode function arguments
+      Array<Expr> args;
+      for (size_t i = 0; i < call->args.size(); i++) {
+        auto arg = call->args[i];
+        auto type = op_type->arg_types[i];
+        if (!HasFuncType(type)) {
+          args.push_back(arg);
+          continue;
+        }
+
+        args.push_back(EncodeArg(arg, type));

Review comment:
       if else is a better style then continue.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] MarisaKirisame commented on a change in pull request #6400: [Relay] Add Defunctionalization Pass

Posted by GitBox <gi...@apache.org>.
MarisaKirisame commented on a change in pull request #6400:
URL: https://github.com/apache/incubator-tvm/pull/6400#discussion_r485019376



##########
File path: src/relay/transforms/defunctionalization.cc
##########
@@ -0,0 +1,428 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file defunctionalization.cc
+ *
+ * \brief Defunctionalization for Relay IR
+ *
+ * This pass transforms a higher-order program into a first-order program with defunctionalization.
+ * This means that all higher order functions (i.e functions that take function arguments or return
+ * functions) should be transformed into a semantically equivalent first order one.
+ *
+ * This pass implements a basic typed defunctionalization method.
+ * All higher order functions are cloned and specialized (so that there are no type params).
+ * Function type arguments are encoded as datatypes and a helper `apply` function is used
+ * to "call" them.
+ *
+ * For example, take the following higher order program:
+ * fun map F y = case y of
+ *          Nil => Nil
+ *          | Cons(x, XS) => Cons(F z, map F XS)
+ * fun addone 1 = map (\x -> \x + 1) 1
+ *
+ * where `addone` is our program.
+ * When we call the `map` function, we see that it is a higher-order function,
+ * but we can clone `map ` function and specialize it with the type_params of the call.
+ * In addition, our function argument `(\x -> \x + 1)` will be encoded as a datatype constructor,
+ * which we will call `incr`, and all calls to `F` in our specialized map function will use the
+ * helper `apply` function.
+ *
+ * After defunctionalization, we get:
+ * fun apply encoding arg =  case encoding of
+ *     “incr” => incr arg
+ * fun map’ F y = case y of
+ *           Nil => Nil
+ *           | Cons(x, xs) => Cons(apply F x, map’ F xs)
+ * fun addone 1 = map’ “incr” 1
+ *
+ * Currently, defunctionalization makes the following assumptions:
+ * - functions cannot return function values
+ * - function arguments are in two forms: identifier or a lambda abstraction
+ * - no functions stored in datatype
+ * - functions are not let binded
+ */
+
+#include <tvm/ir/type_functor.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/feature.h>
+#include <tvm/relay/transform.h>
+#include <tvm/te/operation.h>
+
+#include "../analysis/type_solver.h"
+#include "../transforms/pass_util.h"
+namespace tvm {
+namespace relay {
+
+struct FuncTypeVisitor : TypeVisitor {
+  bool has_func_type;
+  FuncTypeVisitor() : has_func_type(false) {}
+
+  void VisitType_(const FuncTypeNode* op) { this->has_func_type = true; }
+};
+// determine if expr contains a FuncType
+bool HasFuncType(const Expr& e) {
+  auto visitor = FuncTypeVisitor();
+  visitor.VisitType(e->checked_type());

Review comment:
       can you call the type version of this instead?

##########
File path: src/relay/transforms/defunctionalization.cc
##########
@@ -0,0 +1,428 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file defunctionalization.cc
+ *
+ * \brief Defunctionalization for Relay IR
+ *
+ * This pass transforms a higher-order program into a first-order program with defunctionalization.
+ * This means that all higher order functions (i.e functions that take function arguments or return
+ * functions) should be transformed into a semantically equivalent first order one.
+ *
+ * This pass implements a basic typed defunctionalization method.
+ * All higher order functions are cloned and specialized (so that there are no type params).
+ * Function type arguments are encoded as datatypes and a helper `apply` function is used
+ * to "call" them.
+ *
+ * For example, take the following higher order program:
+ * fun map F y = case y of
+ *          Nil => Nil
+ *          | Cons(x, XS) => Cons(F z, map F XS)
+ * fun addone 1 = map (\x -> \x + 1) 1
+ *
+ * where `addone` is our program.
+ * When we call the `map` function, we see that it is a higher-order function,
+ * but we can clone `map ` function and specialize it with the type_params of the call.
+ * In addition, our function argument `(\x -> \x + 1)` will be encoded as a datatype constructor,
+ * which we will call `incr`, and all calls to `F` in our specialized map function will use the
+ * helper `apply` function.
+ *
+ * After defunctionalization, we get:
+ * fun apply encoding arg =  case encoding of
+ *     “incr” => incr arg
+ * fun map’ F y = case y of
+ *           Nil => Nil
+ *           | Cons(x, xs) => Cons(apply F x, map’ F xs)
+ * fun addone 1 = map’ “incr” 1
+ *
+ * Currently, defunctionalization makes the following assumptions:
+ * - functions cannot return function values
+ * - function arguments are in two forms: identifier or a lambda abstraction
+ * - no functions stored in datatype
+ * - functions are not let binded
+ */
+
+#include <tvm/ir/type_functor.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/feature.h>
+#include <tvm/relay/transform.h>
+#include <tvm/te/operation.h>
+
+#include "../analysis/type_solver.h"
+#include "../transforms/pass_util.h"
+namespace tvm {
+namespace relay {
+
+struct FuncTypeVisitor : TypeVisitor {
+  bool has_func_type;
+  FuncTypeVisitor() : has_func_type(false) {}
+
+  void VisitType_(const FuncTypeNode* op) { this->has_func_type = true; }
+};
+// determine if expr contains a FuncType
+bool HasFuncType(const Expr& e) {
+  auto visitor = FuncTypeVisitor();
+  visitor.VisitType(e->checked_type());
+  return visitor.has_func_type;
+}
+// determine if type contains a FuncType
+bool HasFuncType(const Type& t) {
+  auto visitor = FuncTypeVisitor();
+  visitor.VisitType(t);
+  return visitor.has_func_type;
+}
+// determine if FuncType is a higher order type
+bool IsHigherOrderFunc(const FuncType& t) {
+  bool higher_order = false;
+  for (auto arg : t->arg_types) {
+    higher_order |= HasFuncType(arg);
+  }
+  return higher_order |= HasFuncType(t->ret_type);
+}
+
+/*!
+ * \brief mutator for driving the Defunctionalization transformation
+ */
+class DefuncMutator : public ExprMutator {
+ public:
+  explicit DefuncMutator(const IRModule& mod) : mod(mod), constructor_counter(0) {}
+
+  Expr VisitExpr_(const CallNode* call) {
+    if (auto op = call->op.as<GlobalVarNode>()) {
+      CHECK(call->type_args.size() == op->checked_type().as<FuncTypeNode>()->type_params.size())

Review comment:
       CHECK_EQ

##########
File path: src/relay/transforms/defunctionalization.cc
##########
@@ -0,0 +1,428 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file defunctionalization.cc
+ *
+ * \brief Defunctionalization for Relay IR
+ *
+ * This pass transforms a higher-order program into a first-order program with defunctionalization.
+ * This means that all higher order functions (i.e functions that take function arguments or return
+ * functions) should be transformed into a semantically equivalent first order one.
+ *
+ * This pass implements a basic typed defunctionalization method.
+ * All higher order functions are cloned and specialized (so that there are no type params).
+ * Function type arguments are encoded as datatypes and a helper `apply` function is used
+ * to "call" them.
+ *
+ * For example, take the following higher order program:
+ * fun map F y = case y of
+ *          Nil => Nil
+ *          | Cons(x, XS) => Cons(F z, map F XS)
+ * fun addone 1 = map (\x -> \x + 1) 1
+ *
+ * where `addone` is our program.
+ * When we call the `map` function, we see that it is a higher-order function,
+ * but we can clone `map ` function and specialize it with the type_params of the call.
+ * In addition, our function argument `(\x -> \x + 1)` will be encoded as a datatype constructor,
+ * which we will call `incr`, and all calls to `F` in our specialized map function will use the
+ * helper `apply` function.
+ *
+ * After defunctionalization, we get:
+ * fun apply encoding arg =  case encoding of
+ *     “incr” => incr arg
+ * fun map’ F y = case y of
+ *           Nil => Nil
+ *           | Cons(x, xs) => Cons(apply F x, map’ F xs)
+ * fun addone 1 = map’ “incr” 1
+ *
+ * Currently, defunctionalization makes the following assumptions:
+ * - functions cannot return function values
+ * - function arguments are in two forms: identifier or a lambda abstraction
+ * - no functions stored in datatype
+ * - functions are not let binded
+ */
+
+#include <tvm/ir/type_functor.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/feature.h>
+#include <tvm/relay/transform.h>
+#include <tvm/te/operation.h>
+
+#include "../analysis/type_solver.h"
+#include "../transforms/pass_util.h"
+namespace tvm {
+namespace relay {
+
+struct FuncTypeVisitor : TypeVisitor {
+  bool has_func_type;
+  FuncTypeVisitor() : has_func_type(false) {}
+
+  void VisitType_(const FuncTypeNode* op) { this->has_func_type = true; }
+};
+// determine if expr contains a FuncType
+bool HasFuncType(const Expr& e) {
+  auto visitor = FuncTypeVisitor();
+  visitor.VisitType(e->checked_type());
+  return visitor.has_func_type;
+}
+// determine if type contains a FuncType
+bool HasFuncType(const Type& t) {
+  auto visitor = FuncTypeVisitor();
+  visitor.VisitType(t);
+  return visitor.has_func_type;
+}
+// determine if FuncType is a higher order type
+bool IsHigherOrderFunc(const FuncType& t) {
+  bool higher_order = false;
+  for (auto arg : t->arg_types) {
+    higher_order |= HasFuncType(arg);
+  }
+  return higher_order |= HasFuncType(t->ret_type);
+}
+
+/*!
+ * \brief mutator for driving the Defunctionalization transformation
+ */
+class DefuncMutator : public ExprMutator {
+ public:
+  explicit DefuncMutator(const IRModule& mod) : mod(mod), constructor_counter(0) {}
+
+  Expr VisitExpr_(const CallNode* call) {
+    if (auto op = call->op.as<GlobalVarNode>()) {
+      CHECK(call->type_args.size() == op->checked_type().as<FuncTypeNode>()->type_params.size())
+          << "all type args must be explicit";
+
+      auto op_type = InstFuncType(op->checked_type().as<FuncTypeNode>(), call->type_args);
+      CHECK(FreeTypeVars(op_type, mod).size() == 0) << "free type vars in instantiated";
+      CHECK(!HasFuncType(op_type->ret_type)) << "returning functions not supported";
+
+      if (!IsHigherOrderFunc(op_type)) {
+        // not higher order function
+        return ExprMutator::VisitExpr_(call);
+      }
+
+      // first we encode function arguments
+      Array<Expr> args;
+      for (size_t i = 0; i < call->args.size(); i++) {
+        auto arg = call->args[i];
+        auto type = op_type->arg_types[i];
+        if (!HasFuncType(type)) {
+          args.push_back(arg);
+          continue;
+        }
+
+        // we assume arg is either an identifier (var or globalvar) or a function
+        CHECK(type.as<FuncTypeNode>()) << "assume no nested functions";
+        CHECK(arg.as<VarNode>() || arg.as<GlobalVarNode>() || arg.as<FunctionNode>())
+            << "assume all first-order-parameters are identifiers or functions";
+
+        if (arg.as<VarNode>()) {
+          // variable with functype will be encoded as datatype in surrounding function
+          args.push_back(arg);
+        }
+        if (arg.as<GlobalVarNode>()) {

Review comment:
       else if

##########
File path: src/relay/transforms/defunctionalization.cc
##########
@@ -0,0 +1,428 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file defunctionalization.cc
+ *
+ * \brief Defunctionalization for Relay IR
+ *
+ * This pass transforms a higher-order program into a first-order program with defunctionalization.
+ * This means that all higher order functions (i.e functions that take function arguments or return
+ * functions) should be transformed into a semantically equivalent first order one.
+ *
+ * This pass implements a basic typed defunctionalization method.
+ * All higher order functions are cloned and specialized (so that there are no type params).
+ * Function type arguments are encoded as datatypes and a helper `apply` function is used
+ * to "call" them.
+ *
+ * For example, take the following higher order program:
+ * fun map F y = case y of
+ *          Nil => Nil
+ *          | Cons(x, XS) => Cons(F z, map F XS)
+ * fun addone 1 = map (\x -> \x + 1) 1
+ *
+ * where `addone` is our program.
+ * When we call the `map` function, we see that it is a higher-order function,
+ * but we can clone `map ` function and specialize it with the type_params of the call.
+ * In addition, our function argument `(\x -> \x + 1)` will be encoded as a datatype constructor,
+ * which we will call `incr`, and all calls to `F` in our specialized map function will use the
+ * helper `apply` function.
+ *
+ * After defunctionalization, we get:
+ * fun apply encoding arg =  case encoding of
+ *     “incr” => incr arg
+ * fun map’ F y = case y of
+ *           Nil => Nil
+ *           | Cons(x, xs) => Cons(apply F x, map’ F xs)
+ * fun addone 1 = map’ “incr” 1
+ *
+ * Currently, defunctionalization makes the following assumptions:
+ * - functions cannot return function values
+ * - function arguments are in two forms: identifier or a lambda abstraction
+ * - no functions stored in datatype
+ * - functions are not let binded
+ */
+
+#include <tvm/ir/type_functor.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/feature.h>
+#include <tvm/relay/transform.h>
+#include <tvm/te/operation.h>
+
+#include "../analysis/type_solver.h"
+#include "../transforms/pass_util.h"
+namespace tvm {
+namespace relay {
+
+struct FuncTypeVisitor : TypeVisitor {
+  bool has_func_type;
+  FuncTypeVisitor() : has_func_type(false) {}
+
+  void VisitType_(const FuncTypeNode* op) { this->has_func_type = true; }
+};
+// determine if expr contains a FuncType
+bool HasFuncType(const Expr& e) {
+  auto visitor = FuncTypeVisitor();
+  visitor.VisitType(e->checked_type());
+  return visitor.has_func_type;
+}
+// determine if type contains a FuncType
+bool HasFuncType(const Type& t) {
+  auto visitor = FuncTypeVisitor();
+  visitor.VisitType(t);
+  return visitor.has_func_type;
+}
+// determine if FuncType is a higher order type
+bool IsHigherOrderFunc(const FuncType& t) {
+  bool higher_order = false;
+  for (auto arg : t->arg_types) {
+    higher_order |= HasFuncType(arg);
+  }
+  return higher_order |= HasFuncType(t->ret_type);
+}
+
+/*!
+ * \brief mutator for driving the Defunctionalization transformation
+ */
+class DefuncMutator : public ExprMutator {
+ public:
+  explicit DefuncMutator(const IRModule& mod) : mod(mod), constructor_counter(0) {}
+
+  Expr VisitExpr_(const CallNode* call) {
+    if (auto op = call->op.as<GlobalVarNode>()) {
+      CHECK(call->type_args.size() == op->checked_type().as<FuncTypeNode>()->type_params.size())
+          << "all type args must be explicit";
+
+      auto op_type = InstFuncType(op->checked_type().as<FuncTypeNode>(), call->type_args);
+      CHECK(FreeTypeVars(op_type, mod).size() == 0) << "free type vars in instantiated";
+      CHECK(!HasFuncType(op_type->ret_type)) << "returning functions not supported";
+
+      if (!IsHigherOrderFunc(op_type)) {
+        // not higher order function
+        return ExprMutator::VisitExpr_(call);
+      }
+
+      // first we encode function arguments
+      Array<Expr> args;
+      for (size_t i = 0; i < call->args.size(); i++) {
+        auto arg = call->args[i];
+        auto type = op_type->arg_types[i];
+        if (!HasFuncType(type)) {
+          args.push_back(arg);
+          continue;
+        }
+
+        // we assume arg is either an identifier (var or globalvar) or a function
+        CHECK(type.as<FuncTypeNode>()) << "assume no nested functions";
+        CHECK(arg.as<VarNode>() || arg.as<GlobalVarNode>() || arg.as<FunctionNode>())
+            << "assume all first-order-parameters are identifiers or functions";
+
+        if (arg.as<VarNode>()) {
+          // variable with functype will be encoded as datatype in surrounding function
+          args.push_back(arg);
+        }
+        if (arg.as<GlobalVarNode>()) {
+          args.push_back(EncodeGlobalVar(Downcast<GlobalVar>(arg), Downcast<FuncType>(type)));
+        }
+        if (auto fn = arg.as<FunctionNode>()) {
+          // we handle free vars in anonymous functions by adding arguments to
+          // the constructor function
+          auto free_vars = FreeVars(arg);
+          auto ft = Downcast<FuncType>(type);
+
+          auto arg_types = Array<Type>();
+          auto pattern_vars = Array<Pattern>();
+          auto call_args = Array<Expr>();
+          Map<Var, Expr> free_var_bind_map;
+          for (auto free_var : free_vars) {
+            // free vars are already encoded, can only exist within
+            // specialized functions
+            if (free_var->type_annotation.defined()) {
+              arg_types.push_back(free_var->type_annotation);
+            } else {
+              arg_types.push_back(free_var->checked_type());
+            }
+            auto new_var = Var(free_var->name_hint(), free_var->type_annotation);
+            free_var_bind_map.Set(free_var, new_var);
+            pattern_vars.push_back(PatternVar(new_var));
+            call_args.push_back(free_var);
+          }
+          auto gtv = GetFuncEncode(ft);
+          auto c = Constructor(std::to_string(constructor_counter++), arg_types, gtv);
+          AddConstructor(gtv, c);
+
+          auto apply_gv = GetApplyFunction(ft);
+          auto body = this->VisitExpr(Bind(fn->body, free_var_bind_map));
+          AddApplyCase(apply_gv, ft, c, Function(fn->params, body, fn->ret_type, fn->type_params),
+                       pattern_vars);
+
+          args.push_back(Call(c, call_args));
+        }
+      }
+      auto name = op->name_hint + TypeToString(op_type);
+      auto gv = GlobalVar(name);
+      if (specialized_gv_map.count(name)) {
+        gv = specialized_gv_map[name];
+      } else {
+        specialized_gv_map[name] = gv;
+        // clone and specialize with specific type
+        auto clone = Downcast<Function>(DeDup(mod->Lookup(GetRef<GlobalVar>(op))));
+        auto specialized_function = Specialize(clone, call->type_args);
+        // change var types and change all applications to use `apply` method
+        auto f = Downcast<Function>(FirstifyVars(specialized_function));
+        mod->Add(gv, f);
+      }
+      return Call(gv, args);
+    } else if (auto op = call->op.as<FunctionNode>()) {
+      // reduction by applying vars
+      std::unordered_map<Var, Expr, ObjectHash, ObjectEqual> var_binding_map;
+      for (size_t i = 0; i < op->params.size(); i++) {
+        var_binding_map[op->params[i]] = call->args[i];
+      }
+      auto e = Bind(op->body, var_binding_map);
+      return this->VisitExpr(e);
+    } else if (auto op = call->op.as<VarNode>()) {
+      // var node will be encoded as datatype
+      // so we need to use the `apply` helper method
+      auto var_original_type = GetUnencodedType(op->type_annotation).as<FuncTypeNode>();
+      CHECK(var_original_type) << "var original type not saved in var_save_type map";
+      auto op_type = InstFuncType(var_original_type, call->type_args);
+
+      Array<Expr> args = {GetRef<Var>(op)};
+      for (auto arg : call->args) {
+        args.push_back(this->VisitExpr(arg));
+      }
+
+      return Call(GetApplyFunction(op_type), args);
+    }
+    return ExprMutator::VisitExpr_(call);
+  }
+
+ private:
+  // module
+  IRModule mod;
+  // gv + str(type) to specialized clone gv
+  std::unordered_map<std::string, GlobalVar> specialized_gv_map;
+  // str(func_type) to ADT
+  std::unordered_map<std::string, GlobalTypeVar> func_encoding;
+  // str(func_tyoe) to apply gv
+  std::unordered_map<std::string, GlobalVar> apply_map;
+  // encoded ADT handle to FuncType
+  std::unordered_map<GlobalTypeVar, Type, ObjectHash, StructuralEqual> original_func_type_map;
+  // gv to (str(func_type) to constructor encoding)
+  std::unordered_map<GlobalVar, std::unordered_map<std::string, Constructor>, ObjectHash,
+                     ObjectEqual>
+      gv_datatype_map;
+  // use monotonically increasing integer to represent new constructor_name
+  uint64_t constructor_counter;
+
+  /*!
+   * \brief add a constructor to the GlobalTypeVar, creating a new TypeDef if GlobalTypeVar does not
+   * exist
+   */
+  void AddConstructor(GlobalTypeVar gtv, Constructor c) {
+    if (!mod->ContainGlobalTypeVar(gtv->name_hint)) {
+      mod->AddTypeDef(gtv, TypeData(gtv, {}, {c}));
+    } else {
+      auto typedata = mod->LookupTypeDef(gtv);
+      auto constructors = typedata->constructors;
+      constructors.push_back(c);
+      mod->UpdateTypeDef(gtv, TypeData(typedata->header, typedata->type_vars, constructors));
+    }
+  }
+  /*!
+   * \brief add a case to the apply function, creating the function if it does not exist
+   *
+   * \param apply_gv GlobalVar of the apply function
+   * \param ft is the type functions the apply function handles
+   * \param c constructor to add a case for
+   * \param expr calls this expr with the args to the apply_gv
+   * \param patterns PatterVars to match with the constructor, used for handling free vars in
+   * functions
+   */
+  void AddApplyCase(GlobalVar apply_gv, FuncType ft, Constructor c, const Expr& expr,
+                    const Array<Pattern> patterns) {
+    CHECK(c->inputs.size() == patterns.size())
+        << "constructor function and pattern vars have different sizes";
+    if (!mod->ContainGlobalVar(apply_gv->name_hint)) {
+      auto x = Var("x", TypeCall(c->belong_to, {}));
+      auto vars = Array<Var>({x});
+      auto args = Array<Expr>();
+      for (auto t : ft->arg_types) {
+        auto y = Var("y", t);
+        vars.push_back(y);
+        args.push_back(y);
+      }
+
+      auto clauses = Array<Clause>({Clause(PatternConstructor(c, patterns), Call(expr, args))});
+      auto body = Match(x, clauses);
+      auto f = Function(vars, body, ft->ret_type, {});
+
+      mod->Add(apply_gv, f);
+    } else {
+      auto f = Downcast<Function>(mod->Lookup(apply_gv));
+      auto body = f->body.as<MatchNode>();
+      CHECK(body) << "internal invariant broken; apply function body should be a match node";
+
+      auto clauses = body->clauses;
+      auto x = f->params[0];
+      auto args = Array<Expr>();
+      for (size_t i = 1; i < f->params.size(); i++) {
+        args.push_back(f->params[i]);
+      }
+      clauses.push_back(Clause(PatternConstructor(c, patterns), Call(expr, args)));
+
+      mod->Add(apply_gv, Function(f->params, Match(x, clauses), f->ret_type, f->type_params), true);
+    }
+  }
+
+  /*!
+   * \brief encode a global var with a specialized type with a datatype
+   */
+  Expr EncodeGlobalVar(const GlobalVar& gv, const FuncType& ft) {
+    auto map = gv_datatype_map[gv];
+    auto type_key = TypeToString(ft);
+    if (map.count(type_key) == 0) {
+      auto gtv = GetFuncEncode(ft);
+      auto c = Constructor(std::to_string(constructor_counter++), {}, gtv);
+      map[type_key] = c;
+      AddConstructor(gtv, c);
+      AddApplyCase(GetApplyFunction(ft), ft, c, gv, {});
+    }
+    return Call(map[type_key], {});
+  }
+
+  /*!
+   * \brief type to string
+   */
+  std::string TypeToString(const Type& t) {
+    std::ostringstream s;
+    s << t;
+    return s.str();
+  }
+
+  /*!
+   * \brief get ADT handle for encoding type t
+   */
+  GlobalTypeVar GetFuncEncode(const Type& t) {
+    auto adt_name = "T" + TypeToString(t);
+    if (func_encoding.count(adt_name) == 0) {
+      func_encoding[adt_name] = GlobalTypeVar(adt_name, TypeKind::kAdtHandle);
+    }
+    original_func_type_map[func_encoding[adt_name]] = t;
+    return func_encoding[adt_name];
+  }
+
+  /*!
+   * \brief get original function type represented by type t
+   */
+  FuncType GetUnencodedType(const Type& t) {
+    auto tc = t.as<TypeCallNode>();
+    CHECK(tc) << "expected type call when getting original type from encoded type";
+    auto gv = tc->func.as<GlobalTypeVarNode>();
+    CHECK(gv) << "expected global type var in encoded type";
+    auto type = original_func_type_map[GetRef<GlobalTypeVar>(gv)];
+    CHECK(type.defined()) << "reverse mapping from encoded type to original type not found";
+    return Downcast<FuncType>(type);
+  }
+
+  /*!
+   * \brief get the apply function for calling datatypes encoding functions of type t
+   */
+  GlobalVar GetApplyFunction(const Type& t) {
+    auto f_name = "apply" + TypeToString(t);
+    if (apply_map.count(f_name) == 0) {
+      apply_map[f_name] = GlobalVar("apply" + TypeToString(t));
+    }
+    return apply_map[f_name];
+  }
+
+  /*!
+   * \brief specialize a function type
+   */
+  FuncType InstFuncType(const FuncTypeNode* fty, const Array<Type> type_args) {
+    CHECK(fty) << "InstFuncType functype is null";
+    auto map = tvm::Map<TypeVar, Type>();
+    for (size_t i = 0; i < type_args.size(); i++) {

Review comment:
       check they are equal size

##########
File path: src/relay/transforms/defunctionalization.cc
##########
@@ -0,0 +1,428 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file defunctionalization.cc
+ *
+ * \brief Defunctionalization for Relay IR
+ *
+ * This pass transforms a higher-order program into a first-order program with defunctionalization.
+ * This means that all higher order functions (i.e functions that take function arguments or return
+ * functions) should be transformed into a semantically equivalent first order one.
+ *
+ * This pass implements a basic typed defunctionalization method.
+ * All higher order functions are cloned and specialized (so that there are no type params).
+ * Function type arguments are encoded as datatypes and a helper `apply` function is used
+ * to "call" them.
+ *
+ * For example, take the following higher order program:
+ * fun map F y = case y of
+ *          Nil => Nil
+ *          | Cons(x, XS) => Cons(F z, map F XS)
+ * fun addone 1 = map (\x -> \x + 1) 1
+ *
+ * where `addone` is our program.
+ * When we call the `map` function, we see that it is a higher-order function,
+ * but we can clone `map ` function and specialize it with the type_params of the call.
+ * In addition, our function argument `(\x -> \x + 1)` will be encoded as a datatype constructor,
+ * which we will call `incr`, and all calls to `F` in our specialized map function will use the
+ * helper `apply` function.
+ *
+ * After defunctionalization, we get:
+ * fun apply encoding arg =  case encoding of
+ *     “incr” => incr arg
+ * fun map’ F y = case y of
+ *           Nil => Nil
+ *           | Cons(x, xs) => Cons(apply F x, map’ F xs)
+ * fun addone 1 = map’ “incr” 1
+ *
+ * Currently, defunctionalization makes the following assumptions:
+ * - functions cannot return function values
+ * - function arguments are in two forms: identifier or a lambda abstraction
+ * - no functions stored in datatype
+ * - functions are not let binded
+ */
+
+#include <tvm/ir/type_functor.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/feature.h>
+#include <tvm/relay/transform.h>
+#include <tvm/te/operation.h>
+
+#include "../analysis/type_solver.h"
+#include "../transforms/pass_util.h"
+namespace tvm {
+namespace relay {
+
+struct FuncTypeVisitor : TypeVisitor {
+  bool has_func_type;
+  FuncTypeVisitor() : has_func_type(false) {}
+
+  void VisitType_(const FuncTypeNode* op) { this->has_func_type = true; }
+};
+// determine if expr contains a FuncType
+bool HasFuncType(const Expr& e) {
+  auto visitor = FuncTypeVisitor();
+  visitor.VisitType(e->checked_type());
+  return visitor.has_func_type;
+}
+// determine if type contains a FuncType
+bool HasFuncType(const Type& t) {
+  auto visitor = FuncTypeVisitor();
+  visitor.VisitType(t);
+  return visitor.has_func_type;
+}
+// determine if FuncType is a higher order type
+bool IsHigherOrderFunc(const FuncType& t) {
+  bool higher_order = false;
+  for (auto arg : t->arg_types) {
+    higher_order |= HasFuncType(arg);
+  }
+  return higher_order |= HasFuncType(t->ret_type);
+}
+
+/*!
+ * \brief mutator for driving the Defunctionalization transformation
+ */
+class DefuncMutator : public ExprMutator {
+ public:
+  explicit DefuncMutator(const IRModule& mod) : mod(mod), constructor_counter(0) {}
+
+  Expr VisitExpr_(const CallNode* call) {
+    if (auto op = call->op.as<GlobalVarNode>()) {
+      CHECK(call->type_args.size() == op->checked_type().as<FuncTypeNode>()->type_params.size())
+          << "all type args must be explicit";
+
+      auto op_type = InstFuncType(op->checked_type().as<FuncTypeNode>(), call->type_args);
+      CHECK(FreeTypeVars(op_type, mod).size() == 0) << "free type vars in instantiated";
+      CHECK(!HasFuncType(op_type->ret_type)) << "returning functions not supported";
+
+      if (!IsHigherOrderFunc(op_type)) {
+        // not higher order function
+        return ExprMutator::VisitExpr_(call);
+      }
+
+      // first we encode function arguments
+      Array<Expr> args;
+      for (size_t i = 0; i < call->args.size(); i++) {
+        auto arg = call->args[i];
+        auto type = op_type->arg_types[i];
+        if (!HasFuncType(type)) {
+          args.push_back(arg);
+          continue;
+        }
+
+        // we assume arg is either an identifier (var or globalvar) or a function
+        CHECK(type.as<FuncTypeNode>()) << "assume no nested functions";
+        CHECK(arg.as<VarNode>() || arg.as<GlobalVarNode>() || arg.as<FunctionNode>())
+            << "assume all first-order-parameters are identifiers or functions";
+
+        if (arg.as<VarNode>()) {
+          // variable with functype will be encoded as datatype in surrounding function
+          args.push_back(arg);
+        }
+        if (arg.as<GlobalVarNode>()) {
+          args.push_back(EncodeGlobalVar(Downcast<GlobalVar>(arg), Downcast<FuncType>(type)));
+        }
+        if (auto fn = arg.as<FunctionNode>()) {
+          // we handle free vars in anonymous functions by adding arguments to
+          // the constructor function
+          auto free_vars = FreeVars(arg);
+          auto ft = Downcast<FuncType>(type);
+
+          auto arg_types = Array<Type>();
+          auto pattern_vars = Array<Pattern>();
+          auto call_args = Array<Expr>();
+          Map<Var, Expr> free_var_bind_map;
+          for (auto free_var : free_vars) {
+            // free vars are already encoded, can only exist within
+            // specialized functions
+            if (free_var->type_annotation.defined()) {
+              arg_types.push_back(free_var->type_annotation);
+            } else {
+              arg_types.push_back(free_var->checked_type());
+            }
+            auto new_var = Var(free_var->name_hint(), free_var->type_annotation);
+            free_var_bind_map.Set(free_var, new_var);
+            pattern_vars.push_back(PatternVar(new_var));
+            call_args.push_back(free_var);
+          }
+          auto gtv = GetFuncEncode(ft);
+          auto c = Constructor(std::to_string(constructor_counter++), arg_types, gtv);
+          AddConstructor(gtv, c);
+
+          auto apply_gv = GetApplyFunction(ft);
+          auto body = this->VisitExpr(Bind(fn->body, free_var_bind_map));
+          AddApplyCase(apply_gv, ft, c, Function(fn->params, body, fn->ret_type, fn->type_params),
+                       pattern_vars);
+
+          args.push_back(Call(c, call_args));
+        }
+      }
+      auto name = op->name_hint + TypeToString(op_type);
+      auto gv = GlobalVar(name);
+      if (specialized_gv_map.count(name)) {
+        gv = specialized_gv_map[name];
+      } else {
+        specialized_gv_map[name] = gv;
+        // clone and specialize with specific type
+        auto clone = Downcast<Function>(DeDup(mod->Lookup(GetRef<GlobalVar>(op))));
+        auto specialized_function = Specialize(clone, call->type_args);
+        // change var types and change all applications to use `apply` method
+        auto f = Downcast<Function>(FirstifyVars(specialized_function));
+        mod->Add(gv, f);
+      }
+      return Call(gv, args);
+    } else if (auto op = call->op.as<FunctionNode>()) {
+      // reduction by applying vars
+      std::unordered_map<Var, Expr, ObjectHash, ObjectEqual> var_binding_map;
+      for (size_t i = 0; i < op->params.size(); i++) {
+        var_binding_map[op->params[i]] = call->args[i];
+      }
+      auto e = Bind(op->body, var_binding_map);
+      return this->VisitExpr(e);
+    } else if (auto op = call->op.as<VarNode>()) {
+      // var node will be encoded as datatype
+      // so we need to use the `apply` helper method
+      auto var_original_type = GetUnencodedType(op->type_annotation).as<FuncTypeNode>();
+      CHECK(var_original_type) << "var original type not saved in var_save_type map";
+      auto op_type = InstFuncType(var_original_type, call->type_args);
+
+      Array<Expr> args = {GetRef<Var>(op)};
+      for (auto arg : call->args) {
+        args.push_back(this->VisitExpr(arg));
+      }
+
+      return Call(GetApplyFunction(op_type), args);
+    }
+    return ExprMutator::VisitExpr_(call);
+  }
+
+ private:
+  // module
+  IRModule mod;
+  // gv + str(type) to specialized clone gv
+  std::unordered_map<std::string, GlobalVar> specialized_gv_map;
+  // str(func_type) to ADT
+  std::unordered_map<std::string, GlobalTypeVar> func_encoding;
+  // str(func_tyoe) to apply gv
+  std::unordered_map<std::string, GlobalVar> apply_map;
+  // encoded ADT handle to FuncType
+  std::unordered_map<GlobalTypeVar, Type, ObjectHash, StructuralEqual> original_func_type_map;
+  // gv to (str(func_type) to constructor encoding)
+  std::unordered_map<GlobalVar, std::unordered_map<std::string, Constructor>, ObjectHash,
+                     ObjectEqual>
+      gv_datatype_map;
+  // use monotonically increasing integer to represent new constructor_name
+  uint64_t constructor_counter;
+
+  /*!
+   * \brief add a constructor to the GlobalTypeVar, creating a new TypeDef if GlobalTypeVar does not
+   * exist
+   */
+  void AddConstructor(GlobalTypeVar gtv, Constructor c) {
+    if (!mod->ContainGlobalTypeVar(gtv->name_hint)) {
+      mod->AddTypeDef(gtv, TypeData(gtv, {}, {c}));
+    } else {
+      auto typedata = mod->LookupTypeDef(gtv);
+      auto constructors = typedata->constructors;
+      constructors.push_back(c);
+      mod->UpdateTypeDef(gtv, TypeData(typedata->header, typedata->type_vars, constructors));
+    }
+  }
+  /*!
+   * \brief add a case to the apply function, creating the function if it does not exist
+   *
+   * \param apply_gv GlobalVar of the apply function
+   * \param ft is the type functions the apply function handles
+   * \param c constructor to add a case for
+   * \param expr calls this expr with the args to the apply_gv
+   * \param patterns PatterVars to match with the constructor, used for handling free vars in
+   * functions
+   */
+  void AddApplyCase(GlobalVar apply_gv, FuncType ft, Constructor c, const Expr& expr,
+                    const Array<Pattern> patterns) {
+    CHECK(c->inputs.size() == patterns.size())
+        << "constructor function and pattern vars have different sizes";
+    if (!mod->ContainGlobalVar(apply_gv->name_hint)) {
+      auto x = Var("x", TypeCall(c->belong_to, {}));
+      auto vars = Array<Var>({x});
+      auto args = Array<Expr>();
+      for (auto t : ft->arg_types) {
+        auto y = Var("y", t);
+        vars.push_back(y);
+        args.push_back(y);
+      }
+
+      auto clauses = Array<Clause>({Clause(PatternConstructor(c, patterns), Call(expr, args))});
+      auto body = Match(x, clauses);
+      auto f = Function(vars, body, ft->ret_type, {});
+
+      mod->Add(apply_gv, f);
+    } else {
+      auto f = Downcast<Function>(mod->Lookup(apply_gv));
+      auto body = f->body.as<MatchNode>();
+      CHECK(body) << "internal invariant broken; apply function body should be a match node";
+
+      auto clauses = body->clauses;
+      auto x = f->params[0];
+      auto args = Array<Expr>();
+      for (size_t i = 1; i < f->params.size(); i++) {
+        args.push_back(f->params[i]);
+      }
+      clauses.push_back(Clause(PatternConstructor(c, patterns), Call(expr, args)));
+
+      mod->Add(apply_gv, Function(f->params, Match(x, clauses), f->ret_type, f->type_params), true);
+    }
+  }
+
+  /*!
+   * \brief encode a global var with a specialized type with a datatype
+   */
+  Expr EncodeGlobalVar(const GlobalVar& gv, const FuncType& ft) {
+    auto map = gv_datatype_map[gv];
+    auto type_key = TypeToString(ft);
+    if (map.count(type_key) == 0) {
+      auto gtv = GetFuncEncode(ft);
+      auto c = Constructor(std::to_string(constructor_counter++), {}, gtv);
+      map[type_key] = c;
+      AddConstructor(gtv, c);
+      AddApplyCase(GetApplyFunction(ft), ft, c, gv, {});
+    }
+    return Call(map[type_key], {});
+  }
+
+  /*!
+   * \brief type to string
+   */
+  std::string TypeToString(const Type& t) {
+    std::ostringstream s;
+    s << t;
+    return s.str();
+  }
+
+  /*!
+   * \brief get ADT handle for encoding type t
+   */
+  GlobalTypeVar GetFuncEncode(const Type& t) {
+    auto adt_name = "T" + TypeToString(t);
+    if (func_encoding.count(adt_name) == 0) {
+      func_encoding[adt_name] = GlobalTypeVar(adt_name, TypeKind::kAdtHandle);
+    }
+    original_func_type_map[func_encoding[adt_name]] = t;
+    return func_encoding[adt_name];
+  }
+
+  /*!
+   * \brief get original function type represented by type t
+   */
+  FuncType GetUnencodedType(const Type& t) {
+    auto tc = t.as<TypeCallNode>();
+    CHECK(tc) << "expected type call when getting original type from encoded type";
+    auto gv = tc->func.as<GlobalTypeVarNode>();
+    CHECK(gv) << "expected global type var in encoded type";
+    auto type = original_func_type_map[GetRef<GlobalTypeVar>(gv)];
+    CHECK(type.defined()) << "reverse mapping from encoded type to original type not found";
+    return Downcast<FuncType>(type);
+  }
+
+  /*!
+   * \brief get the apply function for calling datatypes encoding functions of type t
+   */
+  GlobalVar GetApplyFunction(const Type& t) {
+    auto f_name = "apply" + TypeToString(t);
+    if (apply_map.count(f_name) == 0) {
+      apply_map[f_name] = GlobalVar("apply" + TypeToString(t));
+    }
+    return apply_map[f_name];
+  }
+
+  /*!
+   * \brief specialize a function type
+   */
+  FuncType InstFuncType(const FuncTypeNode* fty, const Array<Type> type_args) {
+    CHECK(fty) << "InstFuncType functype is null";
+    auto map = tvm::Map<TypeVar, Type>();
+    for (size_t i = 0; i < type_args.size(); i++) {
+      map.Set(fty->type_params[i], type_args[i]);
+    }
+    // copy with typevars removed
+    return Downcast<FuncType>(TypeSubst(FuncType(fty->arg_types, fty->ret_type, {}, {}), map));
+  }
+
+  /*!
+   * \brief specialize a function expression
+   */
+  Function Specialize(const Function& f, const Array<Type> type_args) {
+    auto map = tvm::Map<TypeVar, Type>();
+    for (size_t i = 0; i < type_args.size(); i++) {
+      map.Set(f->type_params[i], type_args[i]);

Review comment:
       check eq size

##########
File path: src/relay/transforms/defunctionalization.cc
##########
@@ -0,0 +1,428 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file defunctionalization.cc
+ *
+ * \brief Defunctionalization for Relay IR
+ *
+ * This pass transforms a higher-order program into a first-order program with defunctionalization.
+ * This means that all higher order functions (i.e functions that take function arguments or return
+ * functions) should be transformed into a semantically equivalent first order one.
+ *
+ * This pass implements a basic typed defunctionalization method.
+ * All higher order functions are cloned and specialized (so that there are no type params).
+ * Function type arguments are encoded as datatypes and a helper `apply` function is used
+ * to "call" them.
+ *
+ * For example, take the following higher order program:
+ * fun map F y = case y of
+ *          Nil => Nil
+ *          | Cons(x, XS) => Cons(F z, map F XS)
+ * fun addone 1 = map (\x -> \x + 1) 1
+ *
+ * where `addone` is our program.
+ * When we call the `map` function, we see that it is a higher-order function,
+ * but we can clone `map ` function and specialize it with the type_params of the call.
+ * In addition, our function argument `(\x -> \x + 1)` will be encoded as a datatype constructor,
+ * which we will call `incr`, and all calls to `F` in our specialized map function will use the
+ * helper `apply` function.
+ *
+ * After defunctionalization, we get:
+ * fun apply encoding arg =  case encoding of
+ *     “incr” => incr arg
+ * fun map’ F y = case y of
+ *           Nil => Nil
+ *           | Cons(x, xs) => Cons(apply F x, map’ F xs)
+ * fun addone 1 = map’ “incr” 1
+ *
+ * Currently, defunctionalization makes the following assumptions:
+ * - functions cannot return function values
+ * - function arguments are in two forms: identifier or a lambda abstraction
+ * - no functions stored in datatype
+ * - functions are not let binded
+ */
+
+#include <tvm/ir/type_functor.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/feature.h>
+#include <tvm/relay/transform.h>
+#include <tvm/te/operation.h>
+
+#include "../analysis/type_solver.h"
+#include "../transforms/pass_util.h"
+namespace tvm {
+namespace relay {
+
+struct FuncTypeVisitor : TypeVisitor {
+  bool has_func_type;

Review comment:
       move this into hasfunctype

##########
File path: tests/python/relay/test_pass_defunctionalization.py
##########
@@ -0,0 +1,226 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+import numpy as np
+
+import tvm
+from tvm import relay
+from tvm.relay.backend.interpreter import ConstructorValue
+from tvm.relay import transform, ExprVisitor, TypeVisitor
+from tvm.relay.testing import Prelude
+
+# determine if type t is a FuncType or has a nested FuncType
+def has_func_type(t):
+  class FuncTypeVisitor(TypeVisitor):
+    def __init__(self):
+      super().__init__()
+      self.has_func = False
+
+    def visit_func_type(self, ftt):
+      self.has_func = True
+
+  ftvisitor = FuncTypeVisitor()
+  ftvisitor.visit(t)
+  return ftvisitor.has_func
+
+# determine whether a program has any higher order functions
+# a higher order function is defined as one that:
+# - has function type arguments
+# - returns a function
+def assert_no_higher_order_functions(expr, mod):
+  class CheckFirstOrderVisitor(ExprVisitor):
+    def __init__(self, mod):
+      super().__init__()
+      self.mod = mod
+      self.hof = []
+      self.visited_gv = set()
+    
+    def visit_call(self, call):
+      is_higher_order = False
+      # check return type
+      if (has_func_type(call.checked_type)):
+        is_higher_order = True
+      # check argument types
+      for a in call.args:
+        if (has_func_type(a.checked_type)):
+          is_higher_order = True
+      # if it is higher order, save it or debugging later

Review comment:
       ```suggestion
         # if it is higher order, save it for debugging later
   ```

##########
File path: src/relay/transforms/defunctionalization.cc
##########
@@ -0,0 +1,428 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file defunctionalization.cc
+ *
+ * \brief Defunctionalization for Relay IR
+ *
+ * This pass transforms a higher-order program into a first-order program with defunctionalization.
+ * This means that all higher order functions (i.e functions that take function arguments or return
+ * functions) should be transformed into a semantically equivalent first order one.
+ *
+ * This pass implements a basic typed defunctionalization method.
+ * All higher order functions are cloned and specialized (so that there are no type params).
+ * Function type arguments are encoded as datatypes and a helper `apply` function is used
+ * to "call" them.
+ *
+ * For example, take the following higher order program:
+ * fun map F y = case y of
+ *          Nil => Nil
+ *          | Cons(x, XS) => Cons(F z, map F XS)
+ * fun addone 1 = map (\x -> \x + 1) 1
+ *
+ * where `addone` is our program.
+ * When we call the `map` function, we see that it is a higher-order function,
+ * but we can clone `map ` function and specialize it with the type_params of the call.
+ * In addition, our function argument `(\x -> \x + 1)` will be encoded as a datatype constructor,
+ * which we will call `incr`, and all calls to `F` in our specialized map function will use the
+ * helper `apply` function.
+ *
+ * After defunctionalization, we get:
+ * fun apply encoding arg =  case encoding of
+ *     “incr” => incr arg
+ * fun map’ F y = case y of
+ *           Nil => Nil
+ *           | Cons(x, xs) => Cons(apply F x, map’ F xs)
+ * fun addone 1 = map’ “incr” 1
+ *
+ * Currently, defunctionalization makes the following assumptions:
+ * - functions cannot return function values
+ * - function arguments are in two forms: identifier or a lambda abstraction
+ * - no functions stored in datatype
+ * - functions are not let binded
+ */
+
+#include <tvm/ir/type_functor.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/feature.h>
+#include <tvm/relay/transform.h>
+#include <tvm/te/operation.h>
+
+#include "../analysis/type_solver.h"
+#include "../transforms/pass_util.h"
+namespace tvm {
+namespace relay {
+
+struct FuncTypeVisitor : TypeVisitor {
+  bool has_func_type;
+  FuncTypeVisitor() : has_func_type(false) {}
+
+  void VisitType_(const FuncTypeNode* op) { this->has_func_type = true; }
+};
+// determine if expr contains a FuncType
+bool HasFuncType(const Expr& e) {
+  auto visitor = FuncTypeVisitor();
+  visitor.VisitType(e->checked_type());
+  return visitor.has_func_type;
+}
+// determine if type contains a FuncType
+bool HasFuncType(const Type& t) {
+  auto visitor = FuncTypeVisitor();
+  visitor.VisitType(t);
+  return visitor.has_func_type;
+}
+// determine if FuncType is a higher order type
+bool IsHigherOrderFunc(const FuncType& t) {
+  bool higher_order = false;
+  for (auto arg : t->arg_types) {
+    higher_order |= HasFuncType(arg);
+  }
+  return higher_order |= HasFuncType(t->ret_type);
+}
+
+/*!
+ * \brief mutator for driving the Defunctionalization transformation
+ */
+class DefuncMutator : public ExprMutator {
+ public:
+  explicit DefuncMutator(const IRModule& mod) : mod(mod), constructor_counter(0) {}
+
+  Expr VisitExpr_(const CallNode* call) {
+    if (auto op = call->op.as<GlobalVarNode>()) {
+      CHECK(call->type_args.size() == op->checked_type().as<FuncTypeNode>()->type_params.size())
+          << "all type args must be explicit";
+
+      auto op_type = InstFuncType(op->checked_type().as<FuncTypeNode>(), call->type_args);
+      CHECK(FreeTypeVars(op_type, mod).size() == 0) << "free type vars in instantiated";

Review comment:
       CHECK_EQ

##########
File path: tests/python/relay/test_pass_defunctionalization.py
##########
@@ -0,0 +1,226 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+import numpy as np
+
+import tvm
+from tvm import relay
+from tvm.relay.backend.interpreter import ConstructorValue
+from tvm.relay import transform, ExprVisitor, TypeVisitor
+from tvm.relay.testing import Prelude
+
+# determine if type t is a FuncType or has a nested FuncType
+def has_func_type(t):
+  class FuncTypeVisitor(TypeVisitor):
+    def __init__(self):
+      super().__init__()
+      self.has_func = False
+
+    def visit_func_type(self, ftt):
+      self.has_func = True
+
+  ftvisitor = FuncTypeVisitor()
+  ftvisitor.visit(t)
+  return ftvisitor.has_func
+
+# determine whether a program has any higher order functions
+# a higher order function is defined as one that:
+# - has function type arguments
+# - returns a function
+def assert_no_higher_order_functions(expr, mod):
+  class CheckFirstOrderVisitor(ExprVisitor):
+    def __init__(self, mod):
+      super().__init__()
+      self.mod = mod
+      self.hof = []
+      self.visited_gv = set()
+    
+    def visit_call(self, call):
+      is_higher_order = False
+      # check return type
+      if (has_func_type(call.checked_type)):
+        is_higher_order = True
+      # check argument types
+      for a in call.args:
+        if (has_func_type(a.checked_type)):
+          is_higher_order = True
+      # if it is higher order, save it or debugging later
+      if is_higher_order:
+        self.hof.append(call)
+      super().visit_call(call)
+
+    def visit_global_var(self, gv):
+      # visit global vars to visit entire program
+      if gv not in self.visited_gv:
+        self.visited_gv.add(gv)
+        self.visit(self.mod[gv])
+
+  mod = transform.InferType()(mod)
+  check_fo_visitor = CheckFirstOrderVisitor(mod)
+  check_fo_visitor.visit(expr)
+
+  nl = '\n--------\n'
+  errmsg = f"""found {len(check_fo_visitor.hof)} higher order functions:
+  {nl.join(expr.astext() for expr in check_fo_visitor.hof)}"""
+
+  assert len(check_fo_visitor.hof) == 0, errmsg
+
+# assert that a program is defunctionalized and returns
+# defunctionalized module
+# assumes program starts from mod['main']
+def defunctionalized(mod):
+  mod = transform.InferType()(mod)
+  mod['main'] = transform.Defunctionalization(mod['main'], mod)
+  mod = transform.InferType()(mod)
+  assert_no_higher_order_functions(mod['main'], mod)
+
+  return mod
+
+# adt list to python list
+def to_list(mod, l):
+  list = mod.get_global_type_var('List')
+  list_adt = mod[list]
+  cons = list_adt.constructors[0]
+  nil = list_adt.constructors[1]
+
+  assert isinstance(l, ConstructorValue)
+  val = l
+  ret = []
+  while True:
+      if val.tag == cons.tag:
+          ret.append(val.fields[0].asnumpy())
+          val = val.fields[1]
+      else:
+          assert val.tag == nil.tag
+          break
+  return ret
+
+# list to adt list
+def to_adt_list(mod, arr):
+  expr = mod['main']
+  l = mod.get_global_type_var('List')
+  list_adt = mod[l]
+  cons = list_adt.constructors[0]
+  nil = list_adt.constructors[1]
+
+  li = nil()
+  for a in arr:
+    li = cons(relay.const(a), li)
+  ex = relay.create_executor(mod=mod)
+  adt = ex.evaluate(li)
+  mod['main'] = expr
+  return adt
+
+def test_simple():
+  code = """
+#[version = "0.0.5"]
+def @simple[A, B](%f: fn(A) -> B, %xs: A) -> B {
+  %f(%xs)
+}
+def @main(%l: Tensor[(5, 5), float32]) -> Tensor[(5, 5), float32] {
+  %0 = fn[A](%x: A) -> A {
+    %x
+  };
+  @simple(%0, %l)
+}
+"""
+  mod = tvm.parser.fromtext(code)
+  defunc_mod = defunctionalized(mod)
+
+  input = np.random.rand(5,5).astype('float32')
+
+  ex = relay.create_executor('debug', mod=mod)
+  defunc_ex = relay.create_executor('debug', mod=defunc_mod)
+
+  out = ex.evaluate()(input)
+  defunc_out = defunc_ex.evaluate()(input)
+
+  np.testing.assert_equal(out.asnumpy(), defunc_out.asnumpy())
+  
+
+def test_global_recursion():
+  code = """
+#[version = "0.0.5"]
+type List[A] {
+  Cons(A, List[A]),
+  Nil,
+}
+def @id[A](%x: A) -> A {
+  %x
+}
+def @map[A, B](%f: fn(A) -> B, %xs: List[A]) -> List[B] {
+  match (%xs) {
+    Cons(%x, %rest) => Cons(%f(%x), @map(%f, %rest)),
+    Nil => Nil,
+  }
+}
+def @main(%l: List[float32]) -> List[float32] {
+  @map(@id, %l)
+}
+"""
+  mod = tvm.parser.fromtext(code)
+  defunc_mod = defunctionalized(mod)
+
+  input = np.random.rand(10).astype('float32')
+  
+  ex = relay.create_executor('debug', mod=mod)
+  defunc_ex = relay.create_executor('debug', mod=defunc_mod)
+
+  out = ex.evaluate(mod['main'])(to_adt_list(mod, input))
+  defunc_out = defunc_ex.evaluate()(to_adt_list(defunc_mod, input))
+
+  np.testing.assert_array_equal(to_list(mod, out), to_list(defunc_mod, defunc_out))
+
+def test_recursive_datatype():
+  # CPS will create recursive datatype
+  code = """
+#[version = "0.0.5"]
+type List[A] {
+  Cons(A, List[A]),
+  Nil,
+}
+def @sum(%f: fn(int32) -> int32, %xs: List[int32]) -> int32 {

Review comment:
       the continuation should be called k

##########
File path: src/relay/transforms/defunctionalization.cc
##########
@@ -0,0 +1,428 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file defunctionalization.cc
+ *
+ * \brief Defunctionalization for Relay IR
+ *
+ * This pass transforms a higher-order program into a first-order program with defunctionalization.
+ * This means that all higher order functions (i.e functions that take function arguments or return
+ * functions) should be transformed into a semantically equivalent first order one.
+ *
+ * This pass implements a basic typed defunctionalization method.
+ * All higher order functions are cloned and specialized (so that there are no type params).
+ * Function type arguments are encoded as datatypes and a helper `apply` function is used
+ * to "call" them.
+ *
+ * For example, take the following higher order program:
+ * fun map F y = case y of
+ *          Nil => Nil
+ *          | Cons(x, XS) => Cons(F z, map F XS)
+ * fun addone 1 = map (\x -> \x + 1) 1
+ *
+ * where `addone` is our program.
+ * When we call the `map` function, we see that it is a higher-order function,
+ * but we can clone `map ` function and specialize it with the type_params of the call.
+ * In addition, our function argument `(\x -> \x + 1)` will be encoded as a datatype constructor,
+ * which we will call `incr`, and all calls to `F` in our specialized map function will use the
+ * helper `apply` function.
+ *
+ * After defunctionalization, we get:
+ * fun apply encoding arg =  case encoding of
+ *     “incr” => incr arg
+ * fun map’ F y = case y of
+ *           Nil => Nil
+ *           | Cons(x, xs) => Cons(apply F x, map’ F xs)
+ * fun addone 1 = map’ “incr” 1
+ *
+ * Currently, defunctionalization makes the following assumptions:
+ * - functions cannot return function values
+ * - function arguments are in two forms: identifier or a lambda abstraction
+ * - no functions stored in datatype
+ * - functions are not let binded
+ */
+
+#include <tvm/ir/type_functor.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/feature.h>
+#include <tvm/relay/transform.h>
+#include <tvm/te/operation.h>
+
+#include "../analysis/type_solver.h"
+#include "../transforms/pass_util.h"
+namespace tvm {
+namespace relay {
+
+struct FuncTypeVisitor : TypeVisitor {
+  bool has_func_type;
+  FuncTypeVisitor() : has_func_type(false) {}
+
+  void VisitType_(const FuncTypeNode* op) { this->has_func_type = true; }
+};
+// determine if expr contains a FuncType
+bool HasFuncType(const Expr& e) {
+  auto visitor = FuncTypeVisitor();
+  visitor.VisitType(e->checked_type());
+  return visitor.has_func_type;
+}
+// determine if type contains a FuncType
+bool HasFuncType(const Type& t) {
+  auto visitor = FuncTypeVisitor();
+  visitor.VisitType(t);
+  return visitor.has_func_type;
+}
+// determine if FuncType is a higher order type
+bool IsHigherOrderFunc(const FuncType& t) {
+  bool higher_order = false;
+  for (auto arg : t->arg_types) {
+    higher_order |= HasFuncType(arg);
+  }
+  return higher_order |= HasFuncType(t->ret_type);
+}
+
+/*!
+ * \brief mutator for driving the Defunctionalization transformation
+ */
+class DefuncMutator : public ExprMutator {
+ public:
+  explicit DefuncMutator(const IRModule& mod) : mod(mod), constructor_counter(0) {}
+
+  Expr VisitExpr_(const CallNode* call) {
+    if (auto op = call->op.as<GlobalVarNode>()) {
+      CHECK(call->type_args.size() == op->checked_type().as<FuncTypeNode>()->type_params.size())
+          << "all type args must be explicit";
+
+      auto op_type = InstFuncType(op->checked_type().as<FuncTypeNode>(), call->type_args);
+      CHECK(FreeTypeVars(op_type, mod).size() == 0) << "free type vars in instantiated";
+      CHECK(!HasFuncType(op_type->ret_type)) << "returning functions not supported";
+
+      if (!IsHigherOrderFunc(op_type)) {
+        // not higher order function
+        return ExprMutator::VisitExpr_(call);
+      }
+
+      // first we encode function arguments
+      Array<Expr> args;
+      for (size_t i = 0; i < call->args.size(); i++) {
+        auto arg = call->args[i];
+        auto type = op_type->arg_types[i];
+        if (!HasFuncType(type)) {
+          args.push_back(arg);
+          continue;
+        }
+
+        // we assume arg is either an identifier (var or globalvar) or a function

Review comment:
       refactor the stuff after continue into a function

##########
File path: src/relay/transforms/defunctionalization.cc
##########
@@ -0,0 +1,428 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file defunctionalization.cc
+ *
+ * \brief Defunctionalization for Relay IR
+ *
+ * This pass transforms a higher-order program into a first-order program with defunctionalization.
+ * This means that all higher order functions (i.e functions that take function arguments or return
+ * functions) should be transformed into a semantically equivalent first order one.
+ *
+ * This pass implements a basic typed defunctionalization method.
+ * All higher order functions are cloned and specialized (so that there are no type params).
+ * Function type arguments are encoded as datatypes and a helper `apply` function is used
+ * to "call" them.
+ *
+ * For example, take the following higher order program:
+ * fun map F y = case y of
+ *          Nil => Nil
+ *          | Cons(x, XS) => Cons(F z, map F XS)
+ * fun addone 1 = map (\x -> \x + 1) 1
+ *
+ * where `addone` is our program.
+ * When we call the `map` function, we see that it is a higher-order function,
+ * but we can clone `map ` function and specialize it with the type_params of the call.
+ * In addition, our function argument `(\x -> \x + 1)` will be encoded as a datatype constructor,
+ * which we will call `incr`, and all calls to `F` in our specialized map function will use the
+ * helper `apply` function.
+ *
+ * After defunctionalization, we get:
+ * fun apply encoding arg =  case encoding of
+ *     “incr” => incr arg
+ * fun map’ F y = case y of
+ *           Nil => Nil
+ *           | Cons(x, xs) => Cons(apply F x, map’ F xs)
+ * fun addone 1 = map’ “incr” 1
+ *
+ * Currently, defunctionalization makes the following assumptions:
+ * - functions cannot return function values
+ * - function arguments are in two forms: identifier or a lambda abstraction
+ * - no functions stored in datatype
+ * - functions are not let binded
+ */
+
+#include <tvm/ir/type_functor.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/feature.h>
+#include <tvm/relay/transform.h>
+#include <tvm/te/operation.h>
+
+#include "../analysis/type_solver.h"
+#include "../transforms/pass_util.h"
+namespace tvm {
+namespace relay {
+
+struct FuncTypeVisitor : TypeVisitor {
+  bool has_func_type;
+  FuncTypeVisitor() : has_func_type(false) {}
+
+  void VisitType_(const FuncTypeNode* op) { this->has_func_type = true; }
+};
+// determine if expr contains a FuncType
+bool HasFuncType(const Expr& e) {
+  auto visitor = FuncTypeVisitor();
+  visitor.VisitType(e->checked_type());
+  return visitor.has_func_type;
+}
+// determine if type contains a FuncType
+bool HasFuncType(const Type& t) {
+  auto visitor = FuncTypeVisitor();
+  visitor.VisitType(t);
+  return visitor.has_func_type;
+}
+// determine if FuncType is a higher order type
+bool IsHigherOrderFunc(const FuncType& t) {
+  bool higher_order = false;
+  for (auto arg : t->arg_types) {
+    higher_order |= HasFuncType(arg);
+  }
+  return higher_order |= HasFuncType(t->ret_type);
+}
+
+/*!
+ * \brief mutator for driving the Defunctionalization transformation
+ */
+class DefuncMutator : public ExprMutator {
+ public:
+  explicit DefuncMutator(const IRModule& mod) : mod(mod), constructor_counter(0) {}
+
+  Expr VisitExpr_(const CallNode* call) {
+    if (auto op = call->op.as<GlobalVarNode>()) {
+      CHECK(call->type_args.size() == op->checked_type().as<FuncTypeNode>()->type_params.size())
+          << "all type args must be explicit";
+
+      auto op_type = InstFuncType(op->checked_type().as<FuncTypeNode>(), call->type_args);
+      CHECK(FreeTypeVars(op_type, mod).size() == 0) << "free type vars in instantiated";
+      CHECK(!HasFuncType(op_type->ret_type)) << "returning functions not supported";
+
+      if (!IsHigherOrderFunc(op_type)) {
+        // not higher order function
+        return ExprMutator::VisitExpr_(call);
+      }
+
+      // first we encode function arguments
+      Array<Expr> args;
+      for (size_t i = 0; i < call->args.size(); i++) {
+        auto arg = call->args[i];
+        auto type = op_type->arg_types[i];
+        if (!HasFuncType(type)) {
+          args.push_back(arg);
+          continue;
+        }
+
+        // we assume arg is either an identifier (var or globalvar) or a function
+        CHECK(type.as<FuncTypeNode>()) << "assume no nested functions";
+        CHECK(arg.as<VarNode>() || arg.as<GlobalVarNode>() || arg.as<FunctionNode>())
+            << "assume all first-order-parameters are identifiers or functions";
+
+        if (arg.as<VarNode>()) {
+          // variable with functype will be encoded as datatype in surrounding function
+          args.push_back(arg);
+        }
+        if (arg.as<GlobalVarNode>()) {
+          args.push_back(EncodeGlobalVar(Downcast<GlobalVar>(arg), Downcast<FuncType>(type)));
+        }
+        if (auto fn = arg.as<FunctionNode>()) {
+          // we handle free vars in anonymous functions by adding arguments to
+          // the constructor function
+          auto free_vars = FreeVars(arg);
+          auto ft = Downcast<FuncType>(type);
+
+          auto arg_types = Array<Type>();
+          auto pattern_vars = Array<Pattern>();
+          auto call_args = Array<Expr>();
+          Map<Var, Expr> free_var_bind_map;
+          for (auto free_var : free_vars) {
+            // free vars are already encoded, can only exist within
+            // specialized functions
+            if (free_var->type_annotation.defined()) {
+              arg_types.push_back(free_var->type_annotation);
+            } else {
+              arg_types.push_back(free_var->checked_type());
+            }
+            auto new_var = Var(free_var->name_hint(), free_var->type_annotation);
+            free_var_bind_map.Set(free_var, new_var);
+            pattern_vars.push_back(PatternVar(new_var));
+            call_args.push_back(free_var);
+          }
+          auto gtv = GetFuncEncode(ft);
+          auto c = Constructor(std::to_string(constructor_counter++), arg_types, gtv);

Review comment:
       ++constructor_counter is better style as it align with iterator usage.

##########
File path: src/relay/transforms/defunctionalization.cc
##########
@@ -0,0 +1,428 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file defunctionalization.cc
+ *
+ * \brief Defunctionalization for Relay IR
+ *
+ * This pass transforms a higher-order program into a first-order program with defunctionalization.
+ * This means that all higher order functions (i.e functions that take function arguments or return
+ * functions) should be transformed into a semantically equivalent first order one.
+ *
+ * This pass implements a basic typed defunctionalization method.
+ * All higher order functions are cloned and specialized (so that there are no type params).
+ * Function type arguments are encoded as datatypes and a helper `apply` function is used
+ * to "call" them.
+ *
+ * For example, take the following higher order program:
+ * fun map F y = case y of
+ *          Nil => Nil
+ *          | Cons(x, XS) => Cons(F z, map F XS)
+ * fun addone 1 = map (\x -> \x + 1) 1
+ *
+ * where `addone` is our program.
+ * When we call the `map` function, we see that it is a higher-order function,
+ * but we can clone `map ` function and specialize it with the type_params of the call.
+ * In addition, our function argument `(\x -> \x + 1)` will be encoded as a datatype constructor,
+ * which we will call `incr`, and all calls to `F` in our specialized map function will use the
+ * helper `apply` function.
+ *
+ * After defunctionalization, we get:
+ * fun apply encoding arg =  case encoding of
+ *     “incr” => incr arg
+ * fun map’ F y = case y of
+ *           Nil => Nil
+ *           | Cons(x, xs) => Cons(apply F x, map’ F xs)
+ * fun addone 1 = map’ “incr” 1
+ *
+ * Currently, defunctionalization makes the following assumptions:
+ * - functions cannot return function values
+ * - function arguments are in two forms: identifier or a lambda abstraction
+ * - no functions stored in datatype
+ * - functions are not let binded
+ */
+
+#include <tvm/ir/type_functor.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/feature.h>
+#include <tvm/relay/transform.h>
+#include <tvm/te/operation.h>
+
+#include "../analysis/type_solver.h"
+#include "../transforms/pass_util.h"
+namespace tvm {
+namespace relay {
+
+struct FuncTypeVisitor : TypeVisitor {
+  bool has_func_type;
+  FuncTypeVisitor() : has_func_type(false) {}
+
+  void VisitType_(const FuncTypeNode* op) { this->has_func_type = true; }
+};
+// determine if expr contains a FuncType
+bool HasFuncType(const Expr& e) {
+  auto visitor = FuncTypeVisitor();
+  visitor.VisitType(e->checked_type());
+  return visitor.has_func_type;
+}
+// determine if type contains a FuncType
+bool HasFuncType(const Type& t) {
+  auto visitor = FuncTypeVisitor();
+  visitor.VisitType(t);
+  return visitor.has_func_type;
+}
+// determine if FuncType is a higher order type
+bool IsHigherOrderFunc(const FuncType& t) {
+  bool higher_order = false;
+  for (auto arg : t->arg_types) {
+    higher_order |= HasFuncType(arg);
+  }
+  return higher_order |= HasFuncType(t->ret_type);
+}
+
+/*!
+ * \brief mutator for driving the Defunctionalization transformation
+ */
+class DefuncMutator : public ExprMutator {
+ public:
+  explicit DefuncMutator(const IRModule& mod) : mod(mod), constructor_counter(0) {}
+
+  Expr VisitExpr_(const CallNode* call) {
+    if (auto op = call->op.as<GlobalVarNode>()) {
+      CHECK(call->type_args.size() == op->checked_type().as<FuncTypeNode>()->type_params.size())
+          << "all type args must be explicit";
+
+      auto op_type = InstFuncType(op->checked_type().as<FuncTypeNode>(), call->type_args);
+      CHECK(FreeTypeVars(op_type, mod).size() == 0) << "free type vars in instantiated";
+      CHECK(!HasFuncType(op_type->ret_type)) << "returning functions not supported";
+
+      if (!IsHigherOrderFunc(op_type)) {
+        // not higher order function
+        return ExprMutator::VisitExpr_(call);
+      }
+
+      // first we encode function arguments
+      Array<Expr> args;
+      for (size_t i = 0; i < call->args.size(); i++) {
+        auto arg = call->args[i];
+        auto type = op_type->arg_types[i];
+        if (!HasFuncType(type)) {
+          args.push_back(arg);
+          continue;
+        }
+
+        // we assume arg is either an identifier (var or globalvar) or a function
+        CHECK(type.as<FuncTypeNode>()) << "assume no nested functions";
+        CHECK(arg.as<VarNode>() || arg.as<GlobalVarNode>() || arg.as<FunctionNode>())
+            << "assume all first-order-parameters are identifiers or functions";
+
+        if (arg.as<VarNode>()) {
+          // variable with functype will be encoded as datatype in surrounding function
+          args.push_back(arg);
+        }
+        if (arg.as<GlobalVarNode>()) {
+          args.push_back(EncodeGlobalVar(Downcast<GlobalVar>(arg), Downcast<FuncType>(type)));
+        }
+        if (auto fn = arg.as<FunctionNode>()) {
+          // we handle free vars in anonymous functions by adding arguments to
+          // the constructor function
+          auto free_vars = FreeVars(arg);
+          auto ft = Downcast<FuncType>(type);
+
+          auto arg_types = Array<Type>();
+          auto pattern_vars = Array<Pattern>();
+          auto call_args = Array<Expr>();
+          Map<Var, Expr> free_var_bind_map;
+          for (auto free_var : free_vars) {
+            // free vars are already encoded, can only exist within
+            // specialized functions
+            if (free_var->type_annotation.defined()) {
+              arg_types.push_back(free_var->type_annotation);
+            } else {
+              arg_types.push_back(free_var->checked_type());
+            }
+            auto new_var = Var(free_var->name_hint(), free_var->type_annotation);
+            free_var_bind_map.Set(free_var, new_var);
+            pattern_vars.push_back(PatternVar(new_var));
+            call_args.push_back(free_var);
+          }
+          auto gtv = GetFuncEncode(ft);
+          auto c = Constructor(std::to_string(constructor_counter++), arg_types, gtv);
+          AddConstructor(gtv, c);
+
+          auto apply_gv = GetApplyFunction(ft);
+          auto body = this->VisitExpr(Bind(fn->body, free_var_bind_map));
+          AddApplyCase(apply_gv, ft, c, Function(fn->params, body, fn->ret_type, fn->type_params),
+                       pattern_vars);
+
+          args.push_back(Call(c, call_args));
+        }
+      }
+      auto name = op->name_hint + TypeToString(op_type);
+      auto gv = GlobalVar(name);
+      if (specialized_gv_map.count(name)) {
+        gv = specialized_gv_map[name];
+      } else {
+        specialized_gv_map[name] = gv;
+        // clone and specialize with specific type
+        auto clone = Downcast<Function>(DeDup(mod->Lookup(GetRef<GlobalVar>(op))));
+        auto specialized_function = Specialize(clone, call->type_args);
+        // change var types and change all applications to use `apply` method
+        auto f = Downcast<Function>(FirstifyVars(specialized_function));
+        mod->Add(gv, f);
+      }
+      return Call(gv, args);
+    } else if (auto op = call->op.as<FunctionNode>()) {
+      // reduction by applying vars
+      std::unordered_map<Var, Expr, ObjectHash, ObjectEqual> var_binding_map;
+      for (size_t i = 0; i < op->params.size(); i++) {
+        var_binding_map[op->params[i]] = call->args[i];
+      }
+      auto e = Bind(op->body, var_binding_map);
+      return this->VisitExpr(e);
+    } else if (auto op = call->op.as<VarNode>()) {
+      // var node will be encoded as datatype
+      // so we need to use the `apply` helper method
+      auto var_original_type = GetUnencodedType(op->type_annotation).as<FuncTypeNode>();
+      CHECK(var_original_type) << "var original type not saved in var_save_type map";
+      auto op_type = InstFuncType(var_original_type, call->type_args);
+
+      Array<Expr> args = {GetRef<Var>(op)};
+      for (auto arg : call->args) {
+        args.push_back(this->VisitExpr(arg));
+      }
+
+      return Call(GetApplyFunction(op_type), args);
+    }
+    return ExprMutator::VisitExpr_(call);
+  }
+
+ private:
+  // module
+  IRModule mod;
+  // gv + str(type) to specialized clone gv
+  std::unordered_map<std::string, GlobalVar> specialized_gv_map;
+  // str(func_type) to ADT
+  std::unordered_map<std::string, GlobalTypeVar> func_encoding;
+  // str(func_tyoe) to apply gv
+  std::unordered_map<std::string, GlobalVar> apply_map;
+  // encoded ADT handle to FuncType
+  std::unordered_map<GlobalTypeVar, Type, ObjectHash, StructuralEqual> original_func_type_map;
+  // gv to (str(func_type) to constructor encoding)
+  std::unordered_map<GlobalVar, std::unordered_map<std::string, Constructor>, ObjectHash,
+                     ObjectEqual>
+      gv_datatype_map;
+  // use monotonically increasing integer to represent new constructor_name
+  uint64_t constructor_counter;
+
+  /*!
+   * \brief add a constructor to the GlobalTypeVar, creating a new TypeDef if GlobalTypeVar does not
+   * exist
+   */
+  void AddConstructor(GlobalTypeVar gtv, Constructor c) {
+    if (!mod->ContainGlobalTypeVar(gtv->name_hint)) {
+      mod->AddTypeDef(gtv, TypeData(gtv, {}, {c}));
+    } else {
+      auto typedata = mod->LookupTypeDef(gtv);
+      auto constructors = typedata->constructors;
+      constructors.push_back(c);
+      mod->UpdateTypeDef(gtv, TypeData(typedata->header, typedata->type_vars, constructors));
+    }
+  }
+  /*!
+   * \brief add a case to the apply function, creating the function if it does not exist
+   *
+   * \param apply_gv GlobalVar of the apply function
+   * \param ft is the type functions the apply function handles
+   * \param c constructor to add a case for
+   * \param expr calls this expr with the args to the apply_gv
+   * \param patterns PatterVars to match with the constructor, used for handling free vars in
+   * functions
+   */
+  void AddApplyCase(GlobalVar apply_gv, FuncType ft, Constructor c, const Expr& expr,
+                    const Array<Pattern> patterns) {
+    CHECK(c->inputs.size() == patterns.size())
+        << "constructor function and pattern vars have different sizes";
+    if (!mod->ContainGlobalVar(apply_gv->name_hint)) {
+      auto x = Var("x", TypeCall(c->belong_to, {}));
+      auto vars = Array<Var>({x});
+      auto args = Array<Expr>();
+      for (auto t : ft->arg_types) {
+        auto y = Var("y", t);
+        vars.push_back(y);
+        args.push_back(y);
+      }
+
+      auto clauses = Array<Clause>({Clause(PatternConstructor(c, patterns), Call(expr, args))});
+      auto body = Match(x, clauses);
+      auto f = Function(vars, body, ft->ret_type, {});
+
+      mod->Add(apply_gv, f);
+    } else {
+      auto f = Downcast<Function>(mod->Lookup(apply_gv));
+      auto body = f->body.as<MatchNode>();
+      CHECK(body) << "internal invariant broken; apply function body should be a match node";
+
+      auto clauses = body->clauses;
+      auto x = f->params[0];
+      auto args = Array<Expr>();
+      for (size_t i = 1; i < f->params.size(); i++) {
+        args.push_back(f->params[i]);
+      }
+      clauses.push_back(Clause(PatternConstructor(c, patterns), Call(expr, args)));
+
+      mod->Add(apply_gv, Function(f->params, Match(x, clauses), f->ret_type, f->type_params), true);
+    }
+  }
+
+  /*!
+   * \brief encode a global var with a specialized type with a datatype
+   */
+  Expr EncodeGlobalVar(const GlobalVar& gv, const FuncType& ft) {
+    auto map = gv_datatype_map[gv];
+    auto type_key = TypeToString(ft);
+    if (map.count(type_key) == 0) {
+      auto gtv = GetFuncEncode(ft);
+      auto c = Constructor(std::to_string(constructor_counter++), {}, gtv);
+      map[type_key] = c;
+      AddConstructor(gtv, c);
+      AddApplyCase(GetApplyFunction(ft), ft, c, gv, {});
+    }
+    return Call(map[type_key], {});
+  }
+
+  /*!
+   * \brief type to string
+   */
+  std::string TypeToString(const Type& t) {
+    std::ostringstream s;
+    s << t;
+    return s.str();
+  }
+
+  /*!
+   * \brief get ADT handle for encoding type t
+   */
+  GlobalTypeVar GetFuncEncode(const Type& t) {
+    auto adt_name = "T" + TypeToString(t);

Review comment:
       ```suggestion
       auto adt_name = "Defunc" + TypeToString(t);
   ```
   a longger name to avoid name clashing




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] MarisaKirisame commented on pull request #6400: [Relay] Add Defunctionalization Pass

Posted by GitBox <gi...@apache.org>.
MarisaKirisame commented on pull request #6400:
URL: https://github.com/apache/incubator-tvm/pull/6400#issuecomment-690407237


   Thx @hypercubestart @yzhliu .


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] hypercubestart edited a comment on pull request #6400: [Relay] Add Defunctionalization Pass

Posted by GitBox <gi...@apache.org>.
hypercubestart edited a comment on pull request #6400:
URL: https://github.com/apache/incubator-tvm/pull/6400#issuecomment-687312583


   cc: @wweic @MarisaKirisame @zhiics @icemelon9 @ZihengJiang @jroesch @tqchen 
   could you guys please review?


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] hypercubestart edited a comment on pull request #6400: [Relay] Add Defunctionalization Pass

Posted by GitBox <gi...@apache.org>.
hypercubestart edited a comment on pull request #6400:
URL: https://github.com/apache/incubator-tvm/pull/6400#issuecomment-687312583


   cc: @wweic @MarisaKirisame @zhiics @icemelon9 @ZihengJiang @jroesch @tqchen 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] hypercubestart commented on pull request #6400: [Relay] Add Defunctionalization Pass

Posted by GitBox <gi...@apache.org>.
hypercubestart commented on pull request #6400:
URL: https://github.com/apache/incubator-tvm/pull/6400#issuecomment-687312583


   cc: @wweic @MarisaKirisame @zhiics @icemelon9 @ZihengJiang @jroesch 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] yzhliu commented on a change in pull request #6400: [Relay] Add Defunctionalization Pass

Posted by GitBox <gi...@apache.org>.
yzhliu commented on a change in pull request #6400:
URL: https://github.com/apache/incubator-tvm/pull/6400#discussion_r483774102



##########
File path: python/tvm/relay/transform/transform.py
##########
@@ -736,6 +736,23 @@ def gradient(expr, mod=None, mode='higher_order'):
         return _ffi_api.gradient(expr, mod)
     raise Exception('unknown mode')
 
+def Defunctionalization(expr, mod):
+    """

Review comment:
       Please add some description.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] hypercubestart commented on a change in pull request #6400: [Relay] Add Defunctionalization Pass

Posted by GitBox <gi...@apache.org>.
hypercubestart commented on a change in pull request #6400:
URL: https://github.com/apache/incubator-tvm/pull/6400#discussion_r486053273



##########
File path: src/relay/transforms/defunctionalization.cc
##########
@@ -0,0 +1,432 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file defunctionalization.cc
+ *
+ * \brief Defunctionalization for Relay IR
+ *
+ * This pass transforms a higher-order program into a first-order program with defunctionalization.
+ * This means that all higher order functions (i.e functions that take function arguments or return
+ * functions) should be transformed into a semantically equivalent first order one.
+ *
+ * This pass implements a basic typed defunctionalization method.
+ * All higher order functions are cloned and specialized (so that there are no type params).
+ * Function type arguments are encoded as datatypes and a helper `apply` function is used
+ * to "call" them.
+ *
+ * For example, take the following higher order program:
+ * fun map F y = case y of
+ *          Nil => Nil
+ *          | Cons(x, XS) => Cons(F z, map F XS)
+ * fun addone 1 = map (\x -> \x + 1) 1
+ *
+ * where `addone` is our program.
+ * When we call the `map` function, we see that it is a higher-order function,
+ * but we can clone `map ` function and specialize it with the type_params of the call.
+ * In addition, our function argument `(\x -> \x + 1)` will be encoded as a datatype constructor,
+ * which we will call `incr`, and all calls to `F` in our specialized map function will use the
+ * helper `apply` function.
+ *
+ * After defunctionalization, we get:
+ * fun apply encoding arg =  case encoding of
+ *     “incr” => incr arg
+ * fun map’ F y = case y of
+ *           Nil => Nil
+ *           | Cons(x, xs) => Cons(apply F x, map’ F xs)
+ * fun addone 1 = map’ “incr” 1
+ *
+ * Currently, defunctionalization makes the following assumptions:
+ * - functions cannot return function values
+ * - function arguments are in two forms: identifier or a lambda abstraction
+ * - no functions stored in datatype
+ * - functions are not let binded
+ */
+
+#include <tvm/ir/type_functor.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/feature.h>
+#include <tvm/relay/transform.h>
+#include <tvm/te/operation.h>
+
+#include "../analysis/type_solver.h"
+#include "../transforms/pass_util.h"
+namespace tvm {
+namespace relay {
+
+// determine if type contains a FuncType
+bool HasFuncType(const Type& t) {
+  struct FuncTypeVisitor : TypeVisitor {
+    bool has_func_type;
+    FuncTypeVisitor() : has_func_type(false) {}
+
+    void VisitType_(const FuncTypeNode* op) { this->has_func_type = true; }
+  };
+
+  auto visitor = FuncTypeVisitor();
+  visitor.VisitType(t);
+  return visitor.has_func_type;
+}
+// determine if FuncType is a higher order type
+bool IsHigherOrderFunc(const FuncType& t) {
+  bool higher_order = false;
+  for (auto arg : t->arg_types) {
+    higher_order |= HasFuncType(arg);
+  }
+  return higher_order |= HasFuncType(t->ret_type);
+}
+
+/*!
+ * \brief mutator for driving the Defunctionalization transformation
+ */
+class DefuncMutator : public ExprMutator {
+ public:
+  explicit DefuncMutator(const IRModule& mod) : mod(mod), constructor_counter(0) {}
+
+  Expr VisitExpr_(const CallNode* call) {
+    if (auto op = call->op.as<GlobalVarNode>()) {
+      CHECK_EQ(call->type_args.size(), op->checked_type().as<FuncTypeNode>()->type_params.size())
+          << "all type args must be explicit";
+
+      auto op_type = InstFuncType(op->checked_type().as<FuncTypeNode>(), call->type_args);
+      CHECK_EQ(FreeTypeVars(op_type, mod).size(), 0) << "free type vars in instantiated";
+      CHECK(!HasFuncType(op_type->ret_type)) << "returning functions not supported";
+
+      if (!IsHigherOrderFunc(op_type)) {
+        // not higher order function
+        return ExprMutator::VisitExpr_(call);
+      }
+
+      // first we encode function arguments
+      Array<Expr> args;
+      for (size_t i = 0; i < call->args.size(); i++) {
+        auto arg = call->args[i];
+        auto type = op_type->arg_types[i];
+        if (!HasFuncType(type)) {
+          args.push_back(arg);
+          continue;
+        }
+
+        args.push_back(EncodeArg(arg, type));

Review comment:
       yep, looks like CI issue was caused by this: https://github.com/apache/incubator-tvm/pull/6434




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] MarisaKirisame merged pull request #6400: [Relay] Add Defunctionalization Pass

Posted by GitBox <gi...@apache.org>.
MarisaKirisame merged pull request #6400:
URL: https://github.com/apache/incubator-tvm/pull/6400


   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org