You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/05/30 07:18:37 UTC
[tvm] branch main updated: [Pass] Add utility that asserts that IRModule is not mutated in a pass. (#11498)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 559f0c76a0 [Pass] Add utility that asserts that IRModule is not mutated in a pass. (#11498)
559f0c76a0 is described below
commit 559f0c76a0a8ee9c1620ee29ecd8ce1ced07093e
Author: Florin Blanaru <fl...@gmail.com>
AuthorDate: Mon May 30 08:18:32 2022 +0100
[Pass] Add utility that asserts that IRModule is not mutated in a pass. (#11498)
---
include/tvm/ir/transform.h | 4 ++
src/ir/transform.cc | 25 +++++++++-
tests/cpp/pass_immutable_module_test.cc | 86 +++++++++++++++++++++++++++++++++
3 files changed, 114 insertions(+), 1 deletion(-)
diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h
index d8f6632a66..febcca5c01 100644
--- a/include/tvm/ir/transform.h
+++ b/include/tvm/ir/transform.h
@@ -390,6 +390,10 @@ class Pass : public ObjectRef {
IRModule operator()(IRModule mod, const PassContext& pass_ctx) const;
TVM_DEFINE_OBJECT_REF_METHODS(Pass, ObjectRef, PassNode);
+
+ private:
+ IRModule static AssertImmutableModule(const IRModule& mod, const PassNode* node,
+ const PassContext& pass_ctx);
};
/*!
diff --git a/src/ir/transform.cc b/src/ir/transform.cc
index dfd307d715..d945278abc 100644
--- a/src/ir/transform.cc
+++ b/src/ir/transform.cc
@@ -24,6 +24,7 @@
#include <dmlc/thread_local.h>
#include <tvm/ir/transform.h>
#include <tvm/node/repr_printer.h>
+#include <tvm/node/structural_hash.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/registry.h>
@@ -41,6 +42,8 @@ using tvm::ReprPrinter;
using tvm::runtime::TVMArgs;
using tvm::runtime::TVMRetValue;
+TVM_REGISTER_PASS_CONFIG_OPTION("testing.immutable_module", Bool);
+
struct PassContextThreadLocalEntry {
/*! \brief The default pass context. */
PassContext default_context;
@@ -264,11 +267,31 @@ IRModule Pass::operator()(IRModule mod, const PassContext& pass_ctx) const {
<< " with opt level: " << pass_info->opt_level;
return mod;
}
- auto ret = node->operator()(std::move(mod), pass_ctx);
+ IRModule ret;
+ if (pass_ctx->GetConfig<Bool>("testing.immutable_module", Bool(false)).value()) {
+ ret = Pass::AssertImmutableModule(mod, node, pass_ctx);
+ } else {
+ ret = node->operator()(std::move(mod), pass_ctx);
+ }
pass_ctx.InstrumentAfterPass(ret, pass_info);
return std::move(ret);
}
+IRModule Pass::AssertImmutableModule(const IRModule& mod, const PassNode* node,
+ const PassContext& pass_ctx) {
+ size_t before_pass_hash = tvm::StructuralHash()(mod);
+ ObjectPtr<Object> module_ptr = ObjectRef::GetDataPtr<Object>(mod);
+ IRModule copy_mod = IRModule(module_ptr);
+ IRModule ret = node->operator()(mod, pass_ctx);
+ size_t after_pass_hash = tvm::StructuralHash()(copy_mod);
+ if (before_pass_hash != after_pass_hash) {
+ // The chance of getting a hash conflict between a module and the same module but mutated
+ // must be very low.
+ LOG_FATAL << "Immutable module has been modified in pass: " << node->Info()->name;
+ }
+ return std::move(ret);
+}
+
/*!
* \brief Module-level passes are designed to implement global
* analysis/optimizations, i.e. interprocedural optimizations (IPO), etc. Passes
diff --git a/tests/cpp/pass_immutable_module_test.cc b/tests/cpp/pass_immutable_module_test.cc
new file mode 100644
index 0000000000..b90f1deee7
--- /dev/null
+++ b/tests/cpp/pass_immutable_module_test.cc
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <gtest/gtest.h>
+#include <tvm/ir/module.h>
+#include <tvm/node/structural_equal.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/transform.h>
+#include <tvm/relay/type.h>
+#include <tvm/te/operation.h>
+
+using namespace tvm;
+using namespace transform;
+
+Pass MutateModulePass() {
+ auto pass_func = [=](IRModule mod, PassContext pc) -> IRModule {
+ GlobalVar var = mod->GetGlobalVar("dummyFunction");
+ mod->Remove(var);
+ return mod;
+ };
+ return tvm::transform::CreateModulePass(pass_func, 1, "ImmutableModulev1", {});
+}
+
+Pass DoNotMutateModulePass() {
+ auto pass_func = [=](IRModule mod, PassContext pc) -> IRModule {
+ IRModule result(mod->functions, mod->type_definitions, mod->Imports(), mod->source_map,
+ mod->attrs);
+ GlobalVar var = result->GetGlobalVar("dummyFunction");
+ result->Remove(var);
+ return result;
+ };
+ return tvm::transform::CreateModulePass(pass_func, 1, "ImmutableModulev2", {});
+}
+
+IRModule preamble() {
+ auto x = relay::Var("x", relay::Type());
+ auto f = relay::Function(tvm::Array<relay::Var>{x}, x, relay::Type(), {});
+ ICHECK(f->IsInstance<BaseFuncNode>());
+
+ auto global_var = GlobalVar("dummyFunction");
+ auto mod = IRModule::FromExpr(f, {{global_var, f}}, {});
+ return mod;
+}
+
+TEST(Relay, ModuleIsMutated) {
+ IRModule mod = preamble();
+
+ EXPECT_THROW(
+ {
+ auto pass_ctx = relay::transform::PassContext::Create();
+ pass_ctx->config.Set("testing.immutable_module", Bool(true));
+ {
+ tvm::With<relay::transform::PassContext> ctx_scope(pass_ctx);
+ mod = MutateModulePass()(mod);
+ }
+ },
+ runtime::InternalError);
+}
+
+TEST(Relay, ModuleIsNotMutated) {
+ IRModule mod = preamble();
+
+ auto pass_ctx = relay::transform::PassContext::Create();
+ pass_ctx->config.Set("testing.immutable_module", Bool(true));
+ {
+ tvm::With<relay::transform::PassContext> ctx_scope(pass_ctx);
+ mod = DoNotMutateModulePass()(mod);
+ }
+}