You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/05/30 07:18:37 UTC

[tvm] branch main updated: [Pass] Add utility that asserts that IRModule is not mutated in a pass. (#11498)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 559f0c76a0 [Pass] Add utility that asserts that IRModule is not mutated in a pass. (#11498)
559f0c76a0 is described below

commit 559f0c76a0a8ee9c1620ee29ecd8ce1ced07093e
Author: Florin Blanaru <fl...@gmail.com>
AuthorDate: Mon May 30 08:18:32 2022 +0100

    [Pass] Add utility that asserts that IRModule is not mutated in a pass. (#11498)
---
 include/tvm/ir/transform.h              |  4 ++
 src/ir/transform.cc                     | 25 +++++++++-
 tests/cpp/pass_immutable_module_test.cc | 86 +++++++++++++++++++++++++++++++++
 3 files changed, 114 insertions(+), 1 deletion(-)

diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h
index d8f6632a66..febcca5c01 100644
--- a/include/tvm/ir/transform.h
+++ b/include/tvm/ir/transform.h
@@ -390,6 +390,10 @@ class Pass : public ObjectRef {
   IRModule operator()(IRModule mod, const PassContext& pass_ctx) const;
 
   TVM_DEFINE_OBJECT_REF_METHODS(Pass, ObjectRef, PassNode);
+
+ private:
+  IRModule static AssertImmutableModule(const IRModule& mod, const PassNode* node,
+                                        const PassContext& pass_ctx);
 };
 
 /*!
diff --git a/src/ir/transform.cc b/src/ir/transform.cc
index dfd307d715..d945278abc 100644
--- a/src/ir/transform.cc
+++ b/src/ir/transform.cc
@@ -24,6 +24,7 @@
 #include <dmlc/thread_local.h>
 #include <tvm/ir/transform.h>
 #include <tvm/node/repr_printer.h>
+#include <tvm/node/structural_hash.h>
 #include <tvm/runtime/device_api.h>
 #include <tvm/runtime/registry.h>
 
@@ -41,6 +42,8 @@ using tvm::ReprPrinter;
 using tvm::runtime::TVMArgs;
 using tvm::runtime::TVMRetValue;
 
+TVM_REGISTER_PASS_CONFIG_OPTION("testing.immutable_module", Bool);
+
 struct PassContextThreadLocalEntry {
   /*! \brief The default pass context. */
   PassContext default_context;
@@ -264,11 +267,31 @@ IRModule Pass::operator()(IRModule mod, const PassContext& pass_ctx) const {
                << " with opt level: " << pass_info->opt_level;
     return mod;
   }
-  auto ret = node->operator()(std::move(mod), pass_ctx);
+  IRModule ret;
+  if (pass_ctx->GetConfig<Bool>("testing.immutable_module", Bool(false)).value()) {
+    ret = Pass::AssertImmutableModule(mod, node, pass_ctx);
+  } else {
+    ret = node->operator()(std::move(mod), pass_ctx);
+  }
   pass_ctx.InstrumentAfterPass(ret, pass_info);
   return std::move(ret);
 }
 
+IRModule Pass::AssertImmutableModule(const IRModule& mod, const PassNode* node,
+                                     const PassContext& pass_ctx) {
+  size_t before_pass_hash = tvm::StructuralHash()(mod);
+  ObjectPtr<Object> module_ptr = ObjectRef::GetDataPtr<Object>(mod);
+  IRModule copy_mod = IRModule(module_ptr);
+  IRModule ret = node->operator()(mod, pass_ctx);
+  size_t after_pass_hash = tvm::StructuralHash()(copy_mod);
+  if (before_pass_hash != after_pass_hash) {
+    // The chance of getting a hash conflict between a module and the same module but mutated
+    // must be very low.
+    LOG_FATAL << "Immutable module has been modified in pass: " << node->Info()->name;
+  }
+  return std::move(ret);
+}
+
 /*!
  * \brief Module-level passes are designed to implement global
  * analysis/optimizations, i.e. interprocedural optimizations (IPO), etc. Passes
diff --git a/tests/cpp/pass_immutable_module_test.cc b/tests/cpp/pass_immutable_module_test.cc
new file mode 100644
index 0000000000..b90f1deee7
--- /dev/null
+++ b/tests/cpp/pass_immutable_module_test.cc
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <gtest/gtest.h>
+#include <tvm/ir/module.h>
+#include <tvm/node/structural_equal.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/transform.h>
+#include <tvm/relay/type.h>
+#include <tvm/te/operation.h>
+
+using namespace tvm;
+using namespace transform;
+
+Pass MutateModulePass() {
+  auto pass_func = [=](IRModule mod, PassContext pc) -> IRModule {
+    GlobalVar var = mod->GetGlobalVar("dummyFunction");
+    mod->Remove(var);
+    return mod;
+  };
+  return tvm::transform::CreateModulePass(pass_func, 1, "ImmutableModulev1", {});
+}
+
+Pass DoNotMutateModulePass() {
+  auto pass_func = [=](IRModule mod, PassContext pc) -> IRModule {
+    IRModule result(mod->functions, mod->type_definitions, mod->Imports(), mod->source_map,
+                    mod->attrs);
+    GlobalVar var = result->GetGlobalVar("dummyFunction");
+    result->Remove(var);
+    return result;
+  };
+  return tvm::transform::CreateModulePass(pass_func, 1, "ImmutableModulev2", {});
+}
+
+IRModule preamble() {
+  auto x = relay::Var("x", relay::Type());
+  auto f = relay::Function(tvm::Array<relay::Var>{x}, x, relay::Type(), {});
+  ICHECK(f->IsInstance<BaseFuncNode>());
+
+  auto global_var = GlobalVar("dummyFunction");
+  auto mod = IRModule::FromExpr(f, {{global_var, f}}, {});
+  return mod;
+}
+
+TEST(Relay, ModuleIsMutated) {
+  IRModule mod = preamble();
+
+  EXPECT_THROW(
+      {
+        auto pass_ctx = relay::transform::PassContext::Create();
+        pass_ctx->config.Set("testing.immutable_module", Bool(true));
+        {
+          tvm::With<relay::transform::PassContext> ctx_scope(pass_ctx);
+          mod = MutateModulePass()(mod);
+        }
+      },
+      runtime::InternalError);
+}
+
+TEST(Relay, ModuleIsNotMutated) {
+  IRModule mod = preamble();
+
+  auto pass_ctx = relay::transform::PassContext::Create();
+  pass_ctx->config.Set("testing.immutable_module", Bool(true));
+  {
+    tvm::With<relay::transform::PassContext> ctx_scope(pass_ctx);
+    mod = DoNotMutateModulePass()(mod);
+  }
+}