You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/07/28 10:36:33 UTC

[GitHub] [tvm] manupa-arm commented on a diff in pull request #12215: Pass that removes reshapes post LowerTE

manupa-arm commented on code in PR #12215:
URL: https://github.com/apache/tvm/pull/12215#discussion_r932054959


##########
src/relay/transforms/remove_reshapes.cc:
##########
@@ -0,0 +1,116 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file remove_reshapes.cc
+ * \brief Relay pass for removing reshapes from lowered graph.
+ */
+
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+
+#include "../op/call/call.h"
+#include "../op/memory/on_device.h"
+
+namespace tvm {
+namespace relay {
+
+TVM_REGISTER_PASS_CONFIG_OPTION("relay.RemoveReshapes", Bool);
+/*! Removes reshapes right after LowerTE. Removes preceding on_device calls
+ * while removing reshapes.
+ */
+class RemoveReshapesMutator : public MixedModeMutator {
+ public:
+  explicit RemoveReshapesMutator(IRModule& mod) : ir_module_(mod) {}
+
+  using MixedModeMutator::VisitExpr_;
+
+  Expr VisitExpr_(const LetNode* let) final {
+    Let ret_let;
+    Var var = Downcast<Var>(this->Mutate(let->var));
+    auto value = this->Mutate(let->value);
+    if (auto* on_device_call = value.as<CallNode>()) {
+      OnDeviceProps on_device_props = GetOnDeviceProps(on_device_call);
+      if (on_device_props.body.defined() && on_device_props.body->IsInstance<CallNode>()) {
+        const Call call_lowered = Downcast<Call>(on_device_props.body);
+        if (call_lowered.defined() && call_lowered->op.same_as(CallLoweredOp())) {
+          let_var_to_call_lowered_.Set(var, call_lowered);
+        }
+      }
+    }
+    auto body = this->Mutate(let->body);
+    return WithFields(GetRef<Let>(let), var, value, body);
+  }
+
+  Expr Rewrite_(const CallNode* call, const Expr& post) final {
+    /*

Review Comment:
   nit : might worth explaining what the reader should get out of this block



##########
src/relay/transforms/remove_reshapes.cc:
##########
@@ -0,0 +1,116 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file remove_reshapes.cc
+ * \brief Relay pass for removing reshapes from lowered graph.
+ */
+
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+
+#include "../op/call/call.h"
+#include "../op/memory/on_device.h"
+
+namespace tvm {
+namespace relay {
+
+TVM_REGISTER_PASS_CONFIG_OPTION("relay.RemoveReshapes", Bool);

Review Comment:
   Please use something like relay.RemoveStandaloneReshapes.enable or relay.use_remove_standalone_reshapes.
   I personally prefer the former as it creates a namespace of any future options of the pass.



##########
include/tvm/relay/transform.h:
##########
@@ -580,6 +580,14 @@ TVM_DLL Pass AnnotateUsedMemory();
  */
 TVM_DLL Pass CapturePostDfsIndexInSpans();
 
+/*!
+ * \brief Remove reshapes after lowering the graph.
+ *
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass RemoveReshapes();

Review Comment:
   Lets rename this to remove StandaloneReshapes



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org