You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/07/22 12:33:47 UTC

[GitHub] [tvm] echuraev commented on a diff in pull request #11878: [Adreno] Add markup pass of relay tensors for static texture planning

echuraev commented on code in PR #11878:
URL: https://github.com/apache/tvm/pull/11878#discussion_r927596550


##########
tests/python/relay/test_conv2d_nchw_texture.py:
##########
@@ -435,3 +435,641 @@ def test_conv2d_vgg16_winograd_4d():
     graph = build_run_compare(mod, params1, {"data": input_shape}, dtype, target)
     matches = re.findall("winograd", graph)
     assert len(matches) > 0
+
+
+@tvm.testing.requires_opencl
+def test_2conv2d():

Review Comment:
   Probably let's rename this test to something more meaningful?



##########
src/relay/transforms/annotate_texture_storage.cc:
##########
@@ -0,0 +1,528 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file annotate_texture_storage.cc
+ * \brief Collection of target specific relay passes which
+ * storage scope related information.
+ *
+ *  - CollectStorageInfo returns a mapping from relay expr
+ *    to a list of output storage scopes for each output.
+ *    These scopes are used during memory planning as well
+ *    as downstream when doing codegen and in the graph runtime when doing runtime dataspace
+ *    allocations.
+ *
+ *  - AnnotateMemoryScope calls *target.CollectStorageInfo for all target been represented
+ *    in the graph and rewrites graph modifying or inserting of VirtualDevice with required
+ *    memory_scop collected from the CollectStorageInfo
+ */
+
+#include <tvm/relay/attrs/nn.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/tir/expr.h>
+
+#include <memory>
+#include <unordered_map>
+
+#include "../transforms/device_aware_visitors.h"
+
+namespace tvm {
+namespace relay {
+namespace {
+
+/**
+ * @brief Analyzes the graph and returns mapping of expressions vs desired memory scope
+ */
+class StorageInfo : private transform::DeviceAwareExprVisitor {
+ public:
+  StorageInfo() : transform::DeviceAwareExprVisitor(Optional<IRModule>()) {}
+
+  static Map<Expr, Array<String>> GetStorageMap(const Expr& expr) {
+    StorageInfo storage_info;
+    storage_info.VisitExpr(expr);
+    storage_info.LegalizeProducerStorage();
+    Map<Expr, Array<String>> storage_map;
+    for (auto& kv : storage_info.storage_scope_) {
+      std::vector<String> storage_scopes;
+      std::copy(kv.second.begin(), kv.second.end(), std::back_inserter(storage_scopes));
+      storage_map.Set(GetRef<Expr>(kv.first), Array<String>{storage_scopes});
+    }
+
+    // Filling the input arguments by "global" scope to handle PlanDevice algo which propagates
+    // virtual devices from outputs to inputs. At the same time outputs must be unconstrained
+    // to avoid useless device_copy
+    for (const auto& cs : storage_info.consumer_storage_scopes_) {
+      // we have record in consumers that mean that potentially consumer
+      // dealt with textures anyhow, it's safe to mark this expr as global scope
+      // even without verification of the consumer's outputs scope
+      if (storage_info.CanConsumeTextures(cs.second) &&
+          storage_map.find(GetRef<Expr>(cs.first)) == storage_map.end()) {
+        storage_map.Set(GetRef<Expr>(cs.first), Array<String>{"global"});
+      }
+    }
+
+    // initial algo assumes mapping of outputs of the expr that is not enough, need to update
+    // VirtualDevice for function variables to get proper codegen. Adding vars to storage_map
+    for (const auto& a : storage_info.args_to_vars_) {
+      if (storage_map.count(a.first)) {
+        for (const auto& v : a.second) {
+          storage_map.Set(v, storage_map[a.first]);
+        }
+      }
+    }
+    return storage_map;
+  }
+
+ private:
+  void Visit(const Expr& expr) {
+    // Pre-order traversal to enable upward propagation
+    // of consumer storage scopes to producers when desirable.
+    if (const auto* fn = expr.as<FunctionNode>()) {
+      this->VisitExpr(fn->body);
+      for (const auto& param : fn->params) {
+        this->VisitExpr(param);
+      }
+    } else {
+      this->VisitExpr(expr);
+    }
+  }
+
+  void VisitExpr_(const VarNode* vn) final { ApplyConsumerScopeToInputs(vn); }
+
+  void VisitExpr_(const ConstantNode* cn) final { ApplyConsumerScopeToInputs(cn); }
+
+  void DeviceAwareVisitExpr_(const CallNode* call) final {
+    // Check the contents of this primitive function
+    if (DeviceSupportsTextureStorage(GetRef<Expr>(call))) {
+      if (const auto* fn = call->op.as<FunctionNode>()) {
+        if (fn->HasNonzeroAttr(attr::kPrimitive)) {
+          primitive_supports_texture_ = false;
+          Visit(call->op);
+          if (primitive_supports_texture_) {
+            if (call->checked_type().as<TensorTypeNode>()) {
+              std::string scope = "global.texture";
+              if (const auto* ttype = call->checked_type().as<TensorTypeNode>()) {
+                if (ttype->shape.size() == 5) {
+                  scope = Scope(ttype->shape, GetVirtualDevice(GetRef<Expr>(call)));
+                }
+              }
+              storage_scope_[call].push_back(scope);
+            } else {
+              const auto* tuple_type = call->type_as<TupleTypeNode>();
+              ICHECK(tuple_type);
+              // TODO(csullivan): Add support for mixed output storage scope.
+              // In current adreno storage planner all outputs of a
+              // primitive function are assumed to be of the same storage
+              // type. This should be easy to extend in the future.
+              for (size_t i = 0; i < tuple_type->fields.size(); i++) {
+                storage_scope_[call].push_back("global.texture");
+              }
+            }
+            for (size_t i = 0; i < fn->params.size(); i++) {
+              args_to_vars_[call->args[i]].push_back(fn->params[i]);
+            }
+          }
+          // Add consumer storage scope information for call arguments
+          for (auto& arg : call->args) {
+            if (storage_scope_.count(call)) {
+              ICHECK(!HasMixedStorageOutputs(call))
+                  << "Mixed output storage scopes are not currently supported";
+              consumer_storage_scopes_[arg.operator->()].push_back("global.texture");
+            } else {
+              consumer_storage_scopes_[arg.operator->()].push_back("global");
+            }
+          }
+        }
+      }
+    }
+
+    primitive_supports_texture_ = SupportsTextureStorage(call);
+
+    for (auto& arg : call->args) {
+      Visit(arg);
+    }
+    // We have all callees filled into storage_scope_ if they support textures
+    // We need to verify if this call expects texture and if it does not, remove from
+    // storage_scope_ since initially storage_scope_ is filled only based on knowledge
+    // that function able to work with textures, but not necessary that this texture is
+    // expected by function callee
+    for (auto& arg : call->args) {
+      if (consumer_storage_scopes_.count(arg.operator->()) &&
+          GetConsumerScope(consumer_storage_scopes_[arg.operator->()]) != "global.texture") {
+        storage_scope_.erase(arg.operator->());
+        if (const auto* cn = arg.as<CallNode>()) {
+          if (const auto* fn = cn->op.as<FunctionNode>()) {
+            storage_scope_.erase(fn->body.operator->());
+          }
+        }
+      }
+    }
+  }
+
+  std::string Scope(Array<PrimExpr> shape, const VirtualDevice& vd) {
+    if (vd != VirtualDevice::FullyUnconstrained()) {
+      std::map<int, std::string> diffs;
+      int limit =
+          vd->target->GetAttr<Integer>("texture_spatial_limit").value_or(Integer(16384))->value;
+      int a0 = shape[0].as<IntImmNode>()->value;
+      int a1 = shape[1].as<IntImmNode>()->value;
+      int a2 = shape[2].as<IntImmNode>()->value;
+      int a3 = shape[3].as<IntImmNode>()->value;
+
+      int d3l = a0 * a1 * a2;
+      int d3r = a3;
+      int diff3 = d3l > d3r ? d3l - d3r : d3r - d3l;
+      if (d3l < limit && d3r < limit) diffs[diff3] = "";
+
+      int d2l = a0 * a1;
+      int d2r = a2 * a3;
+      int diff2 = d2l > d2r ? d2l - d2r : d2r - d2l;
+      if (d2l < limit && d2r < limit) diffs[diff2] = "nhwc";
+
+      int d1l = a0;
+      int d1r = a1 * a2 * a3;
+      int diff1 = d1l > d1r ? d1l - d1r : d1r - d1l;
+      if (d1l < limit && d1r < limit) diffs[diff1] = "weight";
+      if (!diffs.empty()) {
+        std::string scope = "global.texture";
+        if (!diffs.begin()->second.empty()) {
+          scope += ("-" + diffs.begin()->second);
+        }
+        return scope;
+      }
+    }
+    return "global";
+  }
+
+  void ApplyConsumerScopeToInputs(const ExprNode* expr) {
+    std::string scope;
+    auto consumer_scopes_it = consumer_storage_scopes_.find(expr);
+    if (consumer_scopes_it != consumer_storage_scopes_.end()) {
+      std::string consumer_scope = GetConsumerScope(consumer_scopes_it->second);
+      ICHECK(!storage_scope_.count(expr))
+          << "Already propagated consumer scopes to input: " << GetRef<Expr>(expr);
+
+      bool expr_is_rgba_vectorizable = false;
+      if (const auto* ttype = expr->checked_type().as<TensorTypeNode>()) {
+        if (ttype->shape.size() == 5) {
+          scope = Scope(ttype->shape, GetVirtualDevice(GetRef<Expr>(expr)));
+          if (scope != "global") {
+            auto inner_dim = ttype->shape.back().as<IntImmNode>();
+            if (inner_dim && inner_dim->value == 4) {
+              expr_is_rgba_vectorizable = true;
+            }
+          }
+        }
+      }
+
+      // Only propagate texture scope from consumers to input expr if
+      // the input shape of the input expr is rgba vectorizable.
+      if (consumer_scope.find("global.texture") != std::string::npos) {
+        if (expr_is_rgba_vectorizable) {
+          storage_scope_[expr].push_back(scope);
+        }
+      } else {
+        storage_scope_[expr].push_back(consumer_scope);
+      }
+    }
+  }
+
+  void LegalizeProducerStorage() {
+    for (auto& kv : consumer_storage_scopes_) {
+      const ExprNode* producer = kv.first;
+      std::string legal_scope = GetConsumerScope(kv.second);
+      if (storage_scope_.count(producer)) {
+        ICHECK(!HasMixedStorageOutputs(producer))
+            << "Mixed output storage scopes are not currently supported";
+        if (storage_scope_[producer][0].find(legal_scope) == std::string::npos) {
+          for (size_t i = 0; i < storage_scope_[producer].size(); i++) {
+            // Only support uniform storage scope across all outputs for now
+            storage_scope_[producer][i] = legal_scope;
+          }
+        }
+      }
+    }
+  }
+
+  bool DeviceSupportsTextureStorage(const Expr& expr) {
+    auto vd = GetVirtualDevice(expr);
+    if (vd != VirtualDevice::FullyUnconstrained()) {
+      if (Optional<String> t_device = vd->target->GetAttr<String>("device")) {
+        if (vd->target->kind->device_type == kDLOpenCL && t_device.defined()) {
+          if (t_device.value() == "adreno") {

Review Comment:
   Just an idea. Probably we could add a method to `t_device` which will report if the device support textures or not?



##########
tests/python/relay/test_conv2d_nchw_texture.py:
##########
@@ -435,3 +435,641 @@ def test_conv2d_vgg16_winograd_4d():
     graph = build_run_compare(mod, params1, {"data": input_shape}, dtype, target)
     matches = re.findall("winograd", graph)
     assert len(matches) > 0
+
+
+@tvm.testing.requires_opencl
+def test_2conv2d():
+    target = "opencl --device=adreno"
+    dtype = "float16"
+
+    input_shape = (1, 32, 40, 40)
+    filter_shape1 = (96, 32, 2, 2)
+    filter_shape2 = (32, 96, 2, 2)
+    bias_shape1 = (1, 96, 1, 1)
+    bias_shape2 = (1, 32, 1, 1)
+    A = relay.var("data", shape=input_shape, dtype=dtype)
+    W1 = relay.var("weight1", shape=filter_shape1, dtype=dtype)
+    B1 = relay.var("bias1", shape=bias_shape1, dtype=dtype)
+    W2 = relay.var("weight2", shape=filter_shape2, dtype=dtype)
+    B2 = relay.var("bias2", shape=bias_shape2, dtype=dtype)
+
+    # C = relay.nn.relu(A)
+    conv1 = relay.nn.conv2d(
+        A,
+        W1,
+        data_layout="NCHW",
+        kernel_layout="OIHW",
+        padding=[0, 0, 0, 0],
+        strides=[2, 2],
+        out_dtype=dtype,
+        channels=96,
+        kernel_size=(2, 2),
+    )
+    D = relay.op.add(conv1, B1)
+    D = relay.op.nn.relu(D)
+
+    conv2 = relay.nn.conv2d(
+        D,
+        W2,
+        data_layout="NCHW",
+        kernel_layout="OIHW",
+        padding=[0, 0, 0, 0],
+        strides=[2, 2],
+        out_dtype=dtype,
+        channels=32,
+        kernel_size=(2, 2),
+    )
+    D = relay.op.add(conv2, B2)
+    D = relay.op.nn.relu(D)
+
+    mod = relay.Function([A, W1, B1, W2, B2], D)
+    np.random.seed(0)
+    initializer = relay.testing.init.Xavier()
+    filter_data1 = np.zeros(filter_shape1).astype(dtype)
+    bias_data1 = np.zeros(bias_shape1).astype(dtype)
+    initializer("weight", filter_data1)
+    initializer("bias", bias_data1)
+    filter_data2 = np.zeros(filter_shape2).astype(dtype)
+    bias_data2 = np.zeros(bias_shape2).astype(dtype)
+    initializer("weight", filter_data2)
+    initializer("bias", bias_data2)
+    params1 = {
+        "weight1": tvm.nd.array(filter_data1),
+        "bias1": tvm.nd.array(bias_data1),
+        "weight2": tvm.nd.array(filter_data2),
+        "bias2": tvm.nd.array(bias_data2),
+    }
+
+    static_memory_scope = [
+        "",
+        "global",
+        "global.texture-weight",
+        "global.texture-weight",
+        "global.texture-nhwc",
+        "global.texture-weight",
+        "global.texture-weight",
+        "",
+        "",
+    ]
+
+    build_run_compare(mod, params1, {"data": input_shape}, dtype, target, static_memory_scope)
+
+
+@tvm.testing.requires_opencl
+def test_residual_block():
+    target = "opencl --device=adreno"
+    dtype = "float16"
+
+    input_shape = (1, 32, 40, 40)
+    filter_shape1 = (32, 32, 2, 2)
+    filter_shape2 = (32, 32, 1, 1)
+    filter_shape3 = (32, 32, 2, 2)
+    bias_shape1 = (1, 32, 1, 1)
+    # bias_shape2 = (1, 32, 1, 1)

Review Comment:
   Please, remove commented code



##########
src/relay/transforms/annotate_texture_storage.cc:
##########
@@ -0,0 +1,528 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file annotate_texture_storage.cc
+ * \brief Collection of target specific relay passes which
+ * storage scope related information.
+ *
+ *  - CollectStorageInfo returns a mapping from relay expr
+ *    to a list of output storage scopes for each output.
+ *    These scopes are used during memory planning as well
+ *    as downstream when doing codegen and in the graph runtime when doing runtime dataspace
+ *    allocations.
+ *
+ *  - AnnotateMemoryScope calls *target.CollectStorageInfo for all target been represented
+ *    in the graph and rewrites graph modifying or inserting of VirtualDevice with required
+ *    memory_scop collected from the CollectStorageInfo
+ */
+
+#include <tvm/relay/attrs/nn.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/tir/expr.h>
+
+#include <memory>
+#include <unordered_map>
+
+#include "../transforms/device_aware_visitors.h"
+
+namespace tvm {
+namespace relay {
+namespace {
+
+/**
+ * @brief Analyzes the graph and returns mapping of expressions vs desired memory scope
+ */
+class StorageInfo : private transform::DeviceAwareExprVisitor {
+ public:
+  StorageInfo() : transform::DeviceAwareExprVisitor(Optional<IRModule>()) {}
+
+  static Map<Expr, Array<String>> GetStorageMap(const Expr& expr) {
+    StorageInfo storage_info;
+    storage_info.VisitExpr(expr);
+    storage_info.LegalizeProducerStorage();
+    Map<Expr, Array<String>> storage_map;
+    for (auto& kv : storage_info.storage_scope_) {
+      std::vector<String> storage_scopes;
+      std::copy(kv.second.begin(), kv.second.end(), std::back_inserter(storage_scopes));
+      storage_map.Set(GetRef<Expr>(kv.first), Array<String>{storage_scopes});
+    }
+
+    // Filling the input arguments by "global" scope to handle PlanDevice algo which propagates
+    // virtual devices from outputs to inputs. At the same time outputs must be unconstrained
+    // to avoid useless device_copy
+    for (const auto& cs : storage_info.consumer_storage_scopes_) {
+      // we have record in consumers that mean that potentially consumer
+      // dealt with textures anyhow, it's safe to mark this expr as global scope
+      // even without verification of the consumer's outputs scope
+      if (storage_info.CanConsumeTextures(cs.second) &&
+          storage_map.find(GetRef<Expr>(cs.first)) == storage_map.end()) {
+        storage_map.Set(GetRef<Expr>(cs.first), Array<String>{"global"});
+      }
+    }
+
+    // initial algo assumes mapping of outputs of the expr that is not enough, need to update
+    // VirtualDevice for function variables to get proper codegen. Adding vars to storage_map
+    for (const auto& a : storage_info.args_to_vars_) {
+      if (storage_map.count(a.first)) {
+        for (const auto& v : a.second) {
+          storage_map.Set(v, storage_map[a.first]);
+        }
+      }
+    }
+    return storage_map;
+  }
+
+ private:
+  void Visit(const Expr& expr) {
+    // Pre-order traversal to enable upward propagation
+    // of consumer storage scopes to producers when desirable.
+    if (const auto* fn = expr.as<FunctionNode>()) {
+      this->VisitExpr(fn->body);
+      for (const auto& param : fn->params) {
+        this->VisitExpr(param);
+      }
+    } else {
+      this->VisitExpr(expr);
+    }
+  }
+
+  void VisitExpr_(const VarNode* vn) final { ApplyConsumerScopeToInputs(vn); }
+
+  void VisitExpr_(const ConstantNode* cn) final { ApplyConsumerScopeToInputs(cn); }
+
+  void DeviceAwareVisitExpr_(const CallNode* call) final {
+    // Check the contents of this primitive function
+    if (DeviceSupportsTextureStorage(GetRef<Expr>(call))) {
+      if (const auto* fn = call->op.as<FunctionNode>()) {
+        if (fn->HasNonzeroAttr(attr::kPrimitive)) {
+          primitive_supports_texture_ = false;
+          Visit(call->op);
+          if (primitive_supports_texture_) {
+            if (call->checked_type().as<TensorTypeNode>()) {
+              std::string scope = "global.texture";
+              if (const auto* ttype = call->checked_type().as<TensorTypeNode>()) {
+                if (ttype->shape.size() == 5) {
+                  scope = Scope(ttype->shape, GetVirtualDevice(GetRef<Expr>(call)));
+                }
+              }
+              storage_scope_[call].push_back(scope);
+            } else {
+              const auto* tuple_type = call->type_as<TupleTypeNode>();
+              ICHECK(tuple_type);
+              // TODO(csullivan): Add support for mixed output storage scope.
+              // In current adreno storage planner all outputs of a
+              // primitive function are assumed to be of the same storage
+              // type. This should be easy to extend in the future.
+              for (size_t i = 0; i < tuple_type->fields.size(); i++) {
+                storage_scope_[call].push_back("global.texture");
+              }
+            }
+            for (size_t i = 0; i < fn->params.size(); i++) {
+              args_to_vars_[call->args[i]].push_back(fn->params[i]);
+            }
+          }
+          // Add consumer storage scope information for call arguments
+          for (auto& arg : call->args) {
+            if (storage_scope_.count(call)) {
+              ICHECK(!HasMixedStorageOutputs(call))
+                  << "Mixed output storage scopes are not currently supported";
+              consumer_storage_scopes_[arg.operator->()].push_back("global.texture");
+            } else {
+              consumer_storage_scopes_[arg.operator->()].push_back("global");
+            }
+          }
+        }
+      }
+    }
+
+    primitive_supports_texture_ = SupportsTextureStorage(call);
+
+    for (auto& arg : call->args) {
+      Visit(arg);
+    }
+    // We have all callees filled into storage_scope_ if they support textures
+    // We need to verify if this call expects texture and if it does not, remove from
+    // storage_scope_ since initially storage_scope_ is filled only based on knowledge
+    // that function able to work with textures, but not necessary that this texture is
+    // expected by function callee
+    for (auto& arg : call->args) {
+      if (consumer_storage_scopes_.count(arg.operator->()) &&
+          GetConsumerScope(consumer_storage_scopes_[arg.operator->()]) != "global.texture") {
+        storage_scope_.erase(arg.operator->());
+        if (const auto* cn = arg.as<CallNode>()) {
+          if (const auto* fn = cn->op.as<FunctionNode>()) {
+            storage_scope_.erase(fn->body.operator->());
+          }
+        }
+      }
+    }
+  }
+
+  std::string Scope(Array<PrimExpr> shape, const VirtualDevice& vd) {
+    if (vd != VirtualDevice::FullyUnconstrained()) {
+      std::map<int, std::string> diffs;
+      int limit =
+          vd->target->GetAttr<Integer>("texture_spatial_limit").value_or(Integer(16384))->value;
+      int a0 = shape[0].as<IntImmNode>()->value;
+      int a1 = shape[1].as<IntImmNode>()->value;
+      int a2 = shape[2].as<IntImmNode>()->value;
+      int a3 = shape[3].as<IntImmNode>()->value;
+
+      int d3l = a0 * a1 * a2;
+      int d3r = a3;
+      int diff3 = d3l > d3r ? d3l - d3r : d3r - d3l;
+      if (d3l < limit && d3r < limit) diffs[diff3] = "";
+
+      int d2l = a0 * a1;
+      int d2r = a2 * a3;
+      int diff2 = d2l > d2r ? d2l - d2r : d2r - d2l;
+      if (d2l < limit && d2r < limit) diffs[diff2] = "nhwc";
+
+      int d1l = a0;
+      int d1r = a1 * a2 * a3;
+      int diff1 = d1l > d1r ? d1l - d1r : d1r - d1l;
+      if (d1l < limit && d1r < limit) diffs[diff1] = "weight";
+      if (!diffs.empty()) {
+        std::string scope = "global.texture";
+        if (!diffs.begin()->second.empty()) {
+          scope += ("-" + diffs.begin()->second);
+        }
+        return scope;
+      }
+    }
+    return "global";
+  }
+
+  void ApplyConsumerScopeToInputs(const ExprNode* expr) {
+    std::string scope;
+    auto consumer_scopes_it = consumer_storage_scopes_.find(expr);
+    if (consumer_scopes_it != consumer_storage_scopes_.end()) {
+      std::string consumer_scope = GetConsumerScope(consumer_scopes_it->second);
+      ICHECK(!storage_scope_.count(expr))
+          << "Already propagated consumer scopes to input: " << GetRef<Expr>(expr);
+
+      bool expr_is_rgba_vectorizable = false;
+      if (const auto* ttype = expr->checked_type().as<TensorTypeNode>()) {
+        if (ttype->shape.size() == 5) {
+          scope = Scope(ttype->shape, GetVirtualDevice(GetRef<Expr>(expr)));
+          if (scope != "global") {
+            auto inner_dim = ttype->shape.back().as<IntImmNode>();
+            if (inner_dim && inner_dim->value == 4) {
+              expr_is_rgba_vectorizable = true;
+            }
+          }
+        }
+      }
+
+      // Only propagate texture scope from consumers to input expr if
+      // the input shape of the input expr is rgba vectorizable.
+      if (consumer_scope.find("global.texture") != std::string::npos) {
+        if (expr_is_rgba_vectorizable) {
+          storage_scope_[expr].push_back(scope);
+        }
+      } else {
+        storage_scope_[expr].push_back(consumer_scope);
+      }
+    }
+  }
+
+  void LegalizeProducerStorage() {
+    for (auto& kv : consumer_storage_scopes_) {
+      const ExprNode* producer = kv.first;
+      std::string legal_scope = GetConsumerScope(kv.second);
+      if (storage_scope_.count(producer)) {
+        ICHECK(!HasMixedStorageOutputs(producer))
+            << "Mixed output storage scopes are not currently supported";
+        if (storage_scope_[producer][0].find(legal_scope) == std::string::npos) {
+          for (size_t i = 0; i < storage_scope_[producer].size(); i++) {
+            // Only support uniform storage scope across all outputs for now
+            storage_scope_[producer][i] = legal_scope;
+          }
+        }
+      }
+    }
+  }
+
+  bool DeviceSupportsTextureStorage(const Expr& expr) {
+    auto vd = GetVirtualDevice(expr);
+    if (vd != VirtualDevice::FullyUnconstrained()) {
+      if (Optional<String> t_device = vd->target->GetAttr<String>("device")) {
+        if (vd->target->kind->device_type == kDLOpenCL && t_device.defined()) {
+          if (t_device.value() == "adreno") {
+            return true;
+          }
+        }
+      }
+    }
+    return false;
+  }
+
+  std::string GetConsumerScope(const std::vector<std::string>& consumer_scopes) const {
+    if (!consumer_scopes.size()) {
+      return "global";
+    }
+    std::string texture_tag = "global.texture";
+    for (auto& consumer_scope : consumer_scopes) {
+      if (consumer_scope.find(texture_tag) == std::string::npos) {
+        return "global";
+      }
+    }
+    return texture_tag;
+  }
+
+  bool CanConsumeTextures(const std::vector<std::string>& consumer_scopes) const {
+    if (!consumer_scopes.size()) {
+      return false;
+    }

Review Comment:
   We can remove it? Because anyway we won't go to the loop and the `false` will be returned. Same comment for `GetConsumerScope`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org