You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by cs...@apache.org on 2022/08/02 16:10:57 UTC

[tvm] branch main updated: [Adreno] Add markup pass of relay tensors for static texture planning (#11878)

This is an automated email from the ASF dual-hosted git repository.

csullivan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new b05dca1f19 [Adreno] Add markup pass of relay tensors for static texture planning (#11878)
b05dca1f19 is described below

commit b05dca1f194614fb3f0310230274a3a19be23a86
Author: Andrey Malyshev <el...@gmail.com>
AuthorDate: Tue Aug 2 19:10:48 2022 +0300

    [Adreno] Add markup pass of relay tensors for static texture planning (#11878)
    
    * [Adreno] Add static texture markup relay pass
    
    Co-authored-by: Chris Sullivan <cs...@octoml.ai>
    
    * lint check
    
    * Remove hardcoded texture limit, check through target options
    
    * fix cpplint
    
    * Add winograd into annotation pass
    
    * fix clang
    
    * Remove extra call of PlanDevice in OptimizeImpl
    
    * Remove one more extra call of PlanDevice in OptimizeImpl
    
    * Fix/add scopes for static texture planning tests
    
    * Remove test_2conv2d as duplication of test_plan_device_issue
    
    * remove comments in test_residual_block
    
    * address review comments
    
    * fix black hits
    
    * Add textures test descriptions
    
    * Address PR comments
    
    Co-authored-by: Chris Sullivan <cs...@octoml.ai>
---
 include/tvm/relay/transform.h                      |   5 +
 python/tvm/topi/adreno/conv2d_nchw.py              |  19 +-
 python/tvm/topi/adreno/conv2d_nhwc.py              |  16 +-
 src/relay/backend/build_module.cc                  |   1 +
 src/relay/transforms/annotate_texture_storage.cc   | 523 +++++++++++++++++++
 tests/python/relay/test_conv2d_nchw_texture.py     | 569 ++++++++++++++++++++-
 tests/python/relay/test_conv2d_nhwc_texture.py     |   7 +-
 .../relay/test_depthwise_conv2d_nchw_texture.py    |   4 +-
 .../relay/test_depthwise_conv2d_nhwc_texture.py    |   2 +-
 tests/python/relay/utils/adreno_utils.py           |  15 +
 10 files changed, 1142 insertions(+), 19 deletions(-)

diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h
index 042ad1ef02..f60912fb01 100644
--- a/include/tvm/relay/transform.h
+++ b/include/tvm/relay/transform.h
@@ -580,6 +580,11 @@ TVM_DLL Pass AnnotateUsedMemory();
  */
 TVM_DLL Pass CapturePostDfsIndexInSpans();
 
+/*!
+ * \brief Calls device dependent memory scope analysis pass, collects mapping of desirable
+ * expr->memory_scope and annotates expressions by VirtualDevice with required memory_scope
+ */
+TVM_DLL Pass AnnotateMemoryScope(CompilationConfig config);
 }  // namespace transform
 
 /*!
diff --git a/python/tvm/topi/adreno/conv2d_nchw.py b/python/tvm/topi/adreno/conv2d_nchw.py
index 2a8f6028b7..16ecaa84d0 100644
--- a/python/tvm/topi/adreno/conv2d_nchw.py
+++ b/python/tvm/topi/adreno/conv2d_nchw.py
@@ -29,6 +29,7 @@ from .utils import (
     add_pad,
     bind_data_copy,
     get_default_conv2d_config,
+    get_texture_storage,
 )
 
 
@@ -214,8 +215,11 @@ def schedule_conv2d_NCHWc_KCRSk(cfg, s, output):
       5d tensors
     4. pad should be scheduled separately to create independent opencl kernel. If pad is
        inlined into convolution, this gives 1.5x performance drop
-    5. We are using cache_read to produce texture and guarantee the best performance
-       on the next stage.
+    5. We are using cache_read for intermediate tensors to produce texture and guarantee
+       the best performance on the next stage.
+       The weights are managed through static texture planning mechanism and guarantied come
+       in texture memory scope.
+       Thus way we are calling cache_read only for data tensor
     6. For 5d convolution we schedule the latest op with binding 5d axis and vectorize
        for textures
        For 4d tensor we are doing the same for the latest blocked stage, i.e. conversion
@@ -288,10 +292,15 @@ def schedule_conv2d_NCHWc_KCRSk(cfg, s, output):
         s[output].compute_inline()
 
     # create cache stage
-    AT = s.cache_read(pad_data, "global.texture", [conv])
+    AT = s.cache_read(pad_data, get_texture_storage(pad_data.shape), [conv])
     bind_data_copy(s[AT])
-    WT = s.cache_read(kernel, "global.texture-weight", [conv])
-    bind_data_copy(s[WT])
+    if (
+        autotvm.GLOBAL_SCOPE.in_tuning
+        or isinstance(kernel.op, tvm.te.ComputeOp)
+        and "filter_pack" in kernel.op.tag
+    ):
+        WT = s.cache_read(kernel, get_texture_storage(kernel.shape), [conv])
+        bind_data_copy(s[WT])
 
     # tile and bind spatial axes
     n, fc, y, x, fb = s[latest_blocked].op.axis
diff --git a/python/tvm/topi/adreno/conv2d_nhwc.py b/python/tvm/topi/adreno/conv2d_nhwc.py
index 388f606ecb..ce7bf0ccc9 100644
--- a/python/tvm/topi/adreno/conv2d_nhwc.py
+++ b/python/tvm/topi/adreno/conv2d_nhwc.py
@@ -210,8 +210,11 @@ def schedule_conv2d_NHWC(cfg, s, output):
       5d tensors
     4. pad should be scheduled separately to create independent opencl kernel. If pad is
        inlined into convolution, this gives 1.5x performance drop
-    5. We are using cache_read to produce texture and guarantee the best performance
-       on the next stage.
+    5. We are using cache_read for intermediate tensors to produce texture and guarantee
+       the best performance on the next stage.
+       The weights are managed through static texture planning mechanism and guarantied come
+       in texture memory scope.
+       Thus way we are calling cache_read only for data tensor
     6. For 5d convolution we schedule the latest op with binding 5d axis and vectorize
        for textures
        For 4d tensor we are doing the same for the latest blocked stage, i.e. conversion
@@ -287,8 +290,13 @@ def schedule_conv2d_NHWC(cfg, s, output):
     # create cache stage
     AT = s.cache_read(pad_data, get_texture_storage(pad_data.shape), [conv])
     bind_data_copy(s[AT])
-    WT = s.cache_read(kernel, get_texture_storage(kernel.shape), [conv])
-    bind_data_copy(s[WT])
+    if (
+        autotvm.GLOBAL_SCOPE.in_tuning
+        or isinstance(kernel.op, tvm.te.ComputeOp)
+        and "filter_pack" in kernel.op.tag
+    ):
+        WT = s.cache_read(kernel, get_texture_storage(kernel.shape), [conv])
+        bind_data_copy(s[WT])
 
     # tile and bind spatial axes
     n, y, x, fc, fb = s[latest_blocked].op.axis
diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc
index 39f2e7761a..7b39cb4443 100644
--- a/src/relay/backend/build_module.cc
+++ b/src/relay/backend/build_module.cc
@@ -396,6 +396,7 @@ class RelayBuildModule : public runtime::ModuleNode {
     relay_module = transform::Inline()(relay_module);
     relay_module = transform::InferType()(relay_module);
     relay_module = transform::LabelOps()(relay_module);
+    relay_module = transform::AnnotateMemoryScope(config_)(relay_module);
 
     ICHECK(relay_module.defined());
 
diff --git a/src/relay/transforms/annotate_texture_storage.cc b/src/relay/transforms/annotate_texture_storage.cc
new file mode 100644
index 0000000000..3dd918d962
--- /dev/null
+++ b/src/relay/transforms/annotate_texture_storage.cc
@@ -0,0 +1,523 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file annotate_texture_storage.cc
+ * \brief Collection of target specific relay passes which
+ * storage scope related information.
+ *
+ *  - CollectStorageInfo returns a mapping from relay expr
+ *    to a list of output storage scopes for each output.
+ *    These scopes are used during memory planning as well
+ *    as downstream when doing codegen and in the graph runtime when doing runtime dataspace
+ *    allocations.
+ *
+ *  - AnnotateMemoryScope calls *target.CollectStorageInfo for all target been represented
+ *    in the graph and rewrites graph modifying or inserting of VirtualDevice with required
+ *    memory_scope collected from the CollectStorageInfo
+ */
+
+#include <tvm/relay/attrs/nn.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/tir/expr.h>
+
+#include <memory>
+#include <unordered_map>
+
+#include "../transforms/device_aware_visitors.h"
+
+namespace tvm {
+namespace relay {
+namespace {
+
+/**
+ * @brief Analyzes the graph and returns mapping of expressions vs desired memory scope
+ */
+class StorageInfo : private transform::DeviceAwareExprVisitor {
+ public:
+  StorageInfo() : transform::DeviceAwareExprVisitor(Optional<IRModule>()) {}
+
+  static Map<Expr, Array<String>> GetStorageMap(const Expr& expr) {
+    StorageInfo storage_info;
+    storage_info.VisitExpr(expr);
+    storage_info.LegalizeProducerStorage();
+    Map<Expr, Array<String>> storage_map;
+    for (auto& kv : storage_info.storage_scope_) {
+      std::vector<String> storage_scopes;
+      std::copy(kv.second.begin(), kv.second.end(), std::back_inserter(storage_scopes));
+      storage_map.Set(GetRef<Expr>(kv.first), Array<String>{storage_scopes});
+    }
+
+    // Filling the input arguments by "global" scope to handle PlanDevice algo which propagates
+    // virtual devices from outputs to inputs. At the same time outputs must be unconstrained
+    // to avoid useless device_copy
+    for (const auto& cs : storage_info.consumer_storage_scopes_) {
+      // we have record in consumers that mean that potentially consumer
+      // dealt with textures anyhow, it's safe to mark this expr as global scope
+      // even without verification of the consumer's outputs scope
+      if (storage_info.CanConsumeTextures(cs.second) &&
+          storage_map.find(GetRef<Expr>(cs.first)) == storage_map.end()) {
+        storage_map.Set(GetRef<Expr>(cs.first), Array<String>{"global"});
+      }
+    }
+
+    // initial algo assumes mapping of outputs of the expr that is not enough, need to update
+    // VirtualDevice for function variables to get proper codegen. Adding vars to storage_map
+    for (const auto& a : storage_info.args_to_vars_) {
+      if (storage_map.count(a.first)) {
+        for (const auto& v : a.second) {
+          storage_map.Set(v, storage_map[a.first]);
+        }
+      }
+    }
+    return storage_map;
+  }
+
+ private:
+  void Visit(const Expr& expr) {
+    // Pre-order traversal to enable upward propagation
+    // of consumer storage scopes to producers when desirable.
+    if (const auto* fn = expr.as<FunctionNode>()) {
+      this->VisitExpr(fn->body);
+      for (const auto& param : fn->params) {
+        this->VisitExpr(param);
+      }
+    } else {
+      this->VisitExpr(expr);
+    }
+  }
+
+  void VisitExpr_(const VarNode* vn) final { ApplyConsumerScopeToInputs(vn); }
+
+  void VisitExpr_(const ConstantNode* cn) final { ApplyConsumerScopeToInputs(cn); }
+
+  void DeviceAwareVisitExpr_(const CallNode* call) final {
+    // Check the contents of this primitive function
+    if (const auto* fn = call->op.as<FunctionNode>()) {
+      if (fn->HasNonzeroAttr(attr::kPrimitive)) {
+        primitive_supports_texture_ = false;
+        Visit(call->op);
+        if (primitive_supports_texture_) {
+          if (call->checked_type().as<TensorTypeNode>()) {
+            std::string scope = "global.texture";
+            if (const auto* ttype = call->checked_type().as<TensorTypeNode>()) {
+              scope = Scope(ttype->shape, GetVirtualDevice(GetRef<Expr>(call)));
+            }
+            storage_scope_[call].push_back(scope);
+          } else {
+            const auto* tuple_type = call->type_as<TupleTypeNode>();
+            ICHECK(tuple_type);
+            // TODO(csullivan): Add support for mixed output storage scope.
+            // In current adreno storage planner all outputs of a
+            // primitive function are assumed to be of the same storage
+            // type. This should be easy to extend in the future.
+            for (size_t i = 0; i < tuple_type->fields.size(); i++) {
+              storage_scope_[call].push_back("global.texture");
+            }
+          }
+          for (size_t i = 0; i < fn->params.size(); i++) {
+            args_to_vars_[call->args[i]].push_back(fn->params[i]);
+          }
+        }
+        // Add consumer storage scope information for call arguments
+        for (auto& arg : call->args) {
+          if (storage_scope_.count(call)) {
+            ICHECK(!HasMixedStorageOutputs(call))
+                << "Mixed output storage scopes are not currently supported";
+            consumer_storage_scopes_[arg.operator->()].push_back("global.texture");
+          } else {
+            consumer_storage_scopes_[arg.operator->()].push_back("global");
+          }
+        }
+      }
+    }
+
+    primitive_supports_texture_ = SupportsTextureStorage(call);
+
+    for (auto& arg : call->args) {
+      Visit(arg);
+    }
+    // We have all callees filled into storage_scope_ if they support textures
+    // We need to verify if this call expects texture and if it does not, remove from
+    // storage_scope_ since initially storage_scope_ is filled only based on knowledge
+    // that function able to work with textures, but not necessary that this texture is
+    // expected by function callee
+    for (auto& arg : call->args) {
+      if (consumer_storage_scopes_.count(arg.operator->()) &&
+          GetConsumerScope(consumer_storage_scopes_[arg.operator->()]) != "global.texture") {
+        storage_scope_.erase(arg.operator->());
+        if (const auto* cn = arg.as<CallNode>()) {
+          if (const auto* fn = cn->op.as<FunctionNode>()) {
+            storage_scope_.erase(fn->body.operator->());
+          }
+        }
+      }
+    }
+  }
+
+  /**
+   * Defines the name of the memory scope which can fit the tensor of required shape
+   *
+   * The scope stands for "global" if tensor does not satisfy current flattening rules for textures
+   * (texture currently has to be 5d tensors with value eq 4 in the last dimension)
+   *
+   * The packing layout inside the texture scope (the part after the dash) is defined
+   * during the shape itself. Hardware can have limitations on the texture spatial dimensions
+   * we must not exceed these sizes. In addition to the fitting of h/w limitation we want to
+   * get balanced packing where final spatial sizes of textures will not be too different
+   * @param shape shape to be analyzed
+   * @param vd VirtualDevice for the tensors determined of memory scope
+   * @return string representing memory scope either "global" or "global.texture-layout"
+   */
+  std::string Scope(Array<PrimExpr> shape, const VirtualDevice& vd) {
+    // currently we support only textures been made from 5d tensors
+    // 5d requirement is not limitation of textures in general, it is limitation how
+    // we are representing memory scopes/layout and flattening of textures in tir
+    if (vd != VirtualDevice::FullyUnconstrained() && shape.size() == 5 &&
+        shape[4].as<IntImmNode>()->value == 4) {
+      std::map<int, std::string> diffs;
+      int limit =
+          vd->target->GetAttr<Integer>("texture_spatial_limit").value_or(Integer(16384))->value;
+      int a0 = shape[0].as<IntImmNode>()->value;
+      int a1 = shape[1].as<IntImmNode>()->value;
+      int a2 = shape[2].as<IntImmNode>()->value;
+      int a3 = shape[3].as<IntImmNode>()->value;
+
+      int d3l = a0 * a1 * a2;
+      int d3r = a3;
+      int diff3 = d3l > d3r ? d3l - d3r : d3r - d3l;
+      if (d3l < limit && d3r < limit) diffs[diff3] = "";
+
+      int d2l = a0 * a1;
+      int d2r = a2 * a3;
+      int diff2 = d2l > d2r ? d2l - d2r : d2r - d2l;
+      if (d2l < limit && d2r < limit) diffs[diff2] = "nhwc";
+
+      int d1l = a0;
+      int d1r = a1 * a2 * a3;
+      int diff1 = d1l > d1r ? d1l - d1r : d1r - d1l;
+      if (d1l < limit && d1r < limit) diffs[diff1] = "weight";
+      if (!diffs.empty()) {
+        std::string scope = "global.texture";
+        if (!diffs.begin()->second.empty()) {
+          scope += ("-" + diffs.begin()->second);
+        }
+        return scope;
+      }
+    }
+    return "global";
+  }
+
+  void ApplyConsumerScopeToInputs(const ExprNode* expr) {
+    std::string scope;
+    auto consumer_scopes_it = consumer_storage_scopes_.find(expr);
+    if (consumer_scopes_it != consumer_storage_scopes_.end()) {
+      std::string consumer_scope = GetConsumerScope(consumer_scopes_it->second);
+      ICHECK(!storage_scope_.count(expr))
+          << "Already propagated consumer scopes to input: " << GetRef<Expr>(expr);
+
+      bool expr_is_rgba_vectorizable = false;
+      if (const auto* ttype = expr->checked_type().as<TensorTypeNode>()) {
+        scope = Scope(ttype->shape, GetVirtualDevice(GetRef<Expr>(expr)));
+        if (scope != "global") {
+          auto inner_dim = ttype->shape.back().as<IntImmNode>();
+          if (inner_dim && inner_dim->value == 4) {
+            expr_is_rgba_vectorizable = true;
+          }
+        }
+      }
+
+      // Only propagate texture scope from consumers to input expr if
+      // the input shape of the input expr is rgba vectorizable.
+      if (consumer_scope.find("global.texture") != std::string::npos) {
+        if (expr_is_rgba_vectorizable) {
+          storage_scope_[expr].push_back(scope);
+        }
+      } else {
+        storage_scope_[expr].push_back(consumer_scope);
+      }
+    }
+  }
+
+  void LegalizeProducerStorage() {
+    for (auto& kv : consumer_storage_scopes_) {
+      const ExprNode* producer = kv.first;
+      std::string legal_scope = GetConsumerScope(kv.second);
+      if (storage_scope_.count(producer)) {
+        ICHECK(!HasMixedStorageOutputs(producer))
+            << "Mixed output storage scopes are not currently supported";
+        if (storage_scope_[producer][0].find(legal_scope) == std::string::npos) {
+          for (size_t i = 0; i < storage_scope_[producer].size(); i++) {
+            // Only support uniform storage scope across all outputs for now
+            storage_scope_[producer][i] = legal_scope;
+          }
+        }
+      }
+    }
+  }
+
+  std::string GetConsumerScope(const std::vector<std::string>& consumer_scopes) const {
+    if (!consumer_scopes.size()) {
+      return "global";
+    }
+    std::string texture_tag = "global.texture";
+    for (auto& consumer_scope : consumer_scopes) {
+      if (consumer_scope.find(texture_tag) == std::string::npos) {
+        return "global";
+      }
+    }
+    return texture_tag;
+  }
+
+  bool CanConsumeTextures(const std::vector<std::string>& consumer_scopes) const {
+    std::string texture_tag = "global.texture";
+    for (auto& consumer_scope : consumer_scopes) {
+      if (consumer_scope.find(texture_tag) == 0) {
+        return true;
+      }
+    }
+    return false;
+  }
+
+  bool HasMixedStorageOutputs(const ExprNode* expr) {
+    if (storage_scope_.count(expr)) {
+      std::string ref_scope = storage_scope_[expr][0];
+      for (std::string& scope : storage_scope_[expr]) {
+        if (scope != ref_scope) {
+          return true;
+        }
+      }
+    }
+    return false;
+  }
+
+  bool SupportsTextureStorage(const CallNode* call) const {
+    bool supports_texture_storage = false;
+    if (auto attrs = call->attrs.as<Conv2DAttrs>()) {
+      if (attrs->data_layout == "NCHW4c" && attrs->kernel_layout == "OIHW4o") {
+        supports_texture_storage = true;
+      } else if (attrs->data_layout == "NHWC4c" &&
+                 (attrs->kernel_layout == "HWOI4o" || attrs->kernel_layout == "HWIO4o" ||
+                  attrs->kernel_layout == "OIHW4o")) {
+        supports_texture_storage = true;
+      }
+    } else if (auto attrs = call->attrs.as<Conv2DWinogradAttrs>()) {
+      if ((attrs->data_layout == "NCHW4c" || attrs->data_layout == "NHWC4c") &&
+          (attrs->kernel_layout == "OIHW4o" || attrs->kernel_layout == "HWIO4o")) {
+        supports_texture_storage = true;
+      }
+    } else if (auto attrs = call->attrs.as<GlobalPool2DAttrs>()) {
+      if (attrs->layout == "NCHW4c") {
+        supports_texture_storage = true;
+      }
+    } else if (auto attrs = call->attrs.as<MaxPool2DAttrs>()) {
+      if (attrs->layout == "NCHW4c") {
+        supports_texture_storage = true;
+      }
+    } else if (auto attrs = call->attrs.as<AvgPool2DAttrs>()) {
+      if (attrs->layout == "NCHW4c") {
+        supports_texture_storage = true;
+      }
+    }
+
+    return supports_texture_storage;
+  }
+
+  /*! \brief Temporary state for marking whether a visited function
+   *         primitive supports texture storage scope */
+  bool primitive_supports_texture_ = false;
+  /*! \brief expr storage scope mapping for each output  */
+  std::unordered_map<const ExprNode*, std::vector<std::string>> storage_scope_;
+  /*! \brief output storage scopes used by consumers of expr key  */
+  std::unordered_map<const ExprNode*, std::vector<std::string>> consumer_storage_scopes_;
+  /*! \brief mapping of arguments to call to function variables*/
+  std::unordered_map<Expr, std::vector<Var>, ObjectPtrHash, ObjectPtrEqual> args_to_vars_;
+};
+
+}  // namespace
+
+/**
+ * @brief rewrite of virtual devices, memory_scope part for expressions defined
+ * by the StorageInfo analysis pass
+ *
+ * Currently this workflow supports analysis and rewriting of VirtualDevice for
+ * Constants and function Variables
+ */
+class RewriteVDStorageScopes : public transform::DeviceAwareExprMutator {
+  using VarMap = std::unordered_map<Expr, Var, ObjectPtrHash, ObjectPtrEqual>;
+
+ public:
+  explicit RewriteVDStorageScopes(const Map<Expr, Array<String>>& storage_scope)
+      : transform::DeviceAwareExprMutator(Optional<IRModule>()), storage_scope_(storage_scope) {}
+
+  Function Rewrite(const Expr& expr) { return Downcast<Function>(Mutate(expr)); }
+
+  Expr VisitExpr_(const VarNode* vn) final {
+    if (storage_scope_.find(GetRef<Expr>(vn)) != storage_scope_.end() &&
+        storage_scope_[GetRef<Expr>(vn)][0] != "global") {
+      Var c = Var(vn->vid, vn->type_annotation, vn->span);
+      auto virtual_device = GetVirtualDevice(GetRef<Expr>(vn));
+      c->virtual_device_ =
+          VirtualDevice(virtual_device->device_type(), virtual_device->virtual_device_id,
+                        virtual_device->target, storage_scope_[GetRef<Expr>(vn)][0]);
+      return c;
+    }
+    return GetRef<Var>(vn);
+  }
+
+  Expr VisitExpr_(const ConstantNode* vn) final {
+    if (storage_scope_.find(GetRef<Expr>(vn)) != storage_scope_.end()) {
+      Expr c = Constant(vn->data, vn->span);
+      auto virtual_device = GetVirtualDevice(GetRef<Expr>(vn));
+      c = OnDevice(c,
+                   VirtualDevice(virtual_device->device_type(), virtual_device->virtual_device_id,
+                                 virtual_device->target, storage_scope_[GetRef<Expr>(vn)][0]),
+                   true);
+      return c;
+    }
+    return GetRef<Constant>(vn);
+  }
+
+  Expr DeviceAwareVisitExpr_(const CallNode* call_node) final {
+    auto new_call = transform::DeviceAwareExprMutator::DeviceAwareVisitExpr_(call_node);
+    auto virtual_device = GetVirtualDevice(GetRef<Expr>(call_node));
+    std::string memory_scope = "";
+    if (storage_scope_.find(GetRef<Expr>(call_node)) != storage_scope_.end()) {
+      memory_scope = storage_scope_[GetRef<Expr>(call_node)][0];
+    } else if (virtual_device->memory_scope != "") {
+      memory_scope = virtual_device->memory_scope;
+    } else if (!call_node->op.as<FunctionNode>()) {
+      memory_scope = "";
+    }
+    if (!memory_scope.empty()) {
+      new_call =
+          OnDevice(new_call,
+                   VirtualDevice(virtual_device->device_type(), virtual_device->virtual_device_id,
+                                 virtual_device->target, memory_scope),
+                   true);
+    }
+    return new_call;
+  }
+
+ private:
+  Map<Expr, Array<String>> storage_scope_;
+  VarMap new_vars_;
+  Array<String> current_function_scope_;
+};
+
+Map<Expr, Array<String>> CollectTextureStorage(const Expr& expr) {
+  return StorageInfo::GetStorageMap(expr);
+}
+
+/**
+ * @brief Collects all target devices participated in graph
+ */
+class CollectVirtualDevices : public transform::DeviceAwareExprVisitor {
+ public:
+  CollectVirtualDevices() : transform::DeviceAwareExprVisitor(Optional<IRModule>()) {}
+  /**
+   * @brief Get all unique device elements from target of each VirtualDevice
+   *
+   * @param expr - IR
+   * @return set of devices
+   */
+  std::set<std::string> GetDevices(const Expr& expr) {
+    this->Run(expr);
+    return std::move(devices_);
+  }
+
+  void Visit(const Expr& expr) {
+    // Pre-order traversal to enable upward propagation
+    // of consumer storage scopes to producers when desirable.
+    if (const auto* fn = expr.as<FunctionNode>()) {
+      this->VisitExpr(fn->body);
+      for (const auto& param : fn->params) {
+        this->VisitExpr(param);
+      }
+    } else {
+      this->VisitExpr(expr);
+    }
+  }
+
+  void DeviceAwareVisitExpr_(const CallNode* call) final {
+    auto vd = GetVirtualDevice(GetRef<Expr>(call));
+    if (vd != VirtualDevice::FullyUnconstrained()) {
+      if (Optional<String> t_device = vd->target->GetAttr<String>("device")) {
+        devices_.insert(vd->target->kind->name + "." + t_device.value());
+      }
+    }
+    for (auto& arg : call->args) {
+      Visit(arg);
+    }
+  }
+
+  void Run(const Expr& expr) { VisitExpr(expr); }
+  using transform::DeviceAwareExprVisitor::VisitExpr_;
+  std::set<std::string> devices_;
+};
+
+/*!
+ * \brief Collect the target specific tensor storage info for each expression's output.
+ * \param expr The expression.
+ * \return The device based storage mapping.
+ */
+Map<Expr, Array<String>> CollectStorageInfo(const Expr& expr) {
+  std::set<std::string> device_types = CollectVirtualDevices().GetDevices(expr);
+  // TODO(amalyshe): current approach collects all targets withing graph and call the only
+  // function corresponding to all these targets in alphabetic order
+  // this will work reliable only for case of only one device and should be redesigned
+  // to handle common case
+  std::string ftarget_prefix = "relay.backend";
+  for (auto& dev_id : device_types) {
+    ftarget_prefix += (std::string(".") + dev_id);
+  }
+
+  Map<Expr, Array<String>> storage_info = {};
+  if (const auto* f = runtime::Registry::Get(ftarget_prefix + "._CollectStorageInfo")) {
+    storage_info = (*f)(expr);
+  }
+  return storage_info;
+}
+
+Expr AnnotateMemoryScopeExpr(const Expr& expr, const IRModule& mod, CompilationConfig config) {
+  auto storage_scope = CollectStorageInfo(expr);
+  if (storage_scope.size()) {
+    return RewriteVDStorageScopes(storage_scope).Rewrite(expr);
+  } else {
+    return expr;
+  }
+}
+
+namespace transform {
+tvm::transform::Pass AnnotateMemoryScope(CompilationConfig config) {
+  runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
+      [config = std::move(config)](Function f, IRModule m, PassContext pc) {
+        return Downcast<Function>(AnnotateMemoryScopeExpr(f, m, config));
+      };
+  return CreateFunctionPass(pass_func, 2, "AnnotateMemoryScope", {});
+}
+}  // namespace transform
+
+TVM_REGISTER_GLOBAL("relay.backend.opencl.adreno._CollectStorageInfo")
+    .set_body_typed(CollectTextureStorage);
+
+}  // namespace relay
+}  // namespace tvm
diff --git a/tests/python/relay/test_conv2d_nchw_texture.py b/tests/python/relay/test_conv2d_nchw_texture.py
index 2dd88f6118..58590998fd 100644
--- a/tests/python/relay/test_conv2d_nchw_texture.py
+++ b/tests/python/relay/test_conv2d_nchw_texture.py
@@ -63,7 +63,7 @@ def test_conv2d_inceptionv3_64x35x35_96x64x3x3_nopad():
         "bias": tvm.nd.array(bias_data),
     }
 
-    build_run_compare(mod, params1, {"data": input_shape}, dtype, target, gpu_preprocess)
+    build_run_compare(mod, params1, {"data": input_shape}, dtype, target, [], gpu_preprocess)
 
 
 @tvm.testing.requires_opencl
@@ -105,7 +105,7 @@ def test_conv2d_inceptionv3_64x35x35_96x64x3x3_nopad_pass():
         "bias": tvm.nd.array(bias_data),
     }
 
-    build_run_compare(mod, params1, {"data": input_shape}, dtype, target, gpu_preprocess)
+    build_run_compare(mod, params1, {"data": input_shape}, dtype, target, [], gpu_preprocess)
 
 
 @tvm.testing.requires_opencl
@@ -147,7 +147,7 @@ def test_conv2d_inceptionv3_35_35_strides():
         "bias": tvm.nd.array(bias_data),
     }
 
-    build_run_compare(mod, params1, {"data": input_shape}, dtype, target, gpu_preprocess)
+    build_run_compare(mod, params1, {"data": input_shape}, dtype, target, [], gpu_preprocess)
 
 
 @tvm.testing.requires_opencl
@@ -493,3 +493,566 @@ def test_conv2d_winograd_conv():
     )
     matches = re.findall("winograd", graph)
     assert len(matches) > 0
+
+
+@tvm.testing.requires_opencl
+def test_residual_block():
+    """
+    - some kind of residual block followed by convolution to have texture after residual block
+    - scalar data type verification which should be mapped to global memory scope
+        layout_transform (NCHW->NCHW4c)
+                  |                      <- buffer
+                conv2d (1)                  <- to get textures as output
+               /         \
+            conv2d (2)    |
+                 \       /
+                    add                     <- add should be fused into conv2d (2)
+                multiply to scalar          <- buffer to the input of multiply scalar value
+                    relu
+                     |                      <- texture in intermediate tensor
+                  conv2d (3)
+                   relu
+                     |                      <- buffer
+               layout_transform (NCHW4c->NCHW)
+    """
+    target = "opencl --device=adreno"
+    dtype = "float16"
+
+    input_shape = (1, 32, 40, 40)
+    filter_shape1 = (32, 32, 2, 2)
+    filter_shape2 = (32, 32, 1, 1)
+    filter_shape3 = (32, 32, 2, 2)
+    bias_shape1 = (1, 32, 1, 1)
+    A = relay.var("data", shape=input_shape, dtype=dtype)
+    W1 = relay.var("weight1", shape=filter_shape1, dtype=dtype)
+    B1 = relay.var("bias1", shape=bias_shape1, dtype=dtype)
+    W2 = relay.var("weight2", shape=filter_shape2, dtype=dtype)
+    W3 = relay.var("weight3", shape=filter_shape3, dtype=dtype)
+
+    conv1 = relay.nn.conv2d(
+        A,
+        W1,
+        data_layout="NCHW",
+        kernel_layout="OIHW",
+        padding=[0, 0, 0, 0],
+        strides=[2, 2],
+        out_dtype=dtype,
+        channels=32,
+        kernel_size=(2, 2),
+    )
+    D = relay.op.add(conv1, B1)
+    D = relay.op.nn.relu(D)
+
+    conv2 = relay.nn.conv2d(
+        D,
+        W2,
+        data_layout="NCHW",
+        kernel_layout="OIHW",
+        padding=[0, 0, 0, 0],
+        strides=[1, 1],
+        out_dtype=dtype,
+        channels=32,
+        kernel_size=(1, 1),
+    )
+    D = relay.op.add(conv2, D)
+    D = D * relay.const(0.15, "float16")
+    D = relay.op.nn.relu(D)
+
+    conv3 = relay.nn.conv2d(
+        D,
+        W3,
+        data_layout="NCHW",
+        kernel_layout="OIHW",
+        padding=[0, 0, 0, 0],
+        strides=[2, 2],
+        out_dtype=dtype,
+        channels=32,
+        kernel_size=(2, 2),
+    )
+    D = relay.op.nn.relu(conv3)
+
+    mod = relay.Function([A, W1, B1, W2, W3], D)
+    np.random.seed(0)
+    initializer = relay.testing.init.Xavier()
+    filter_data1 = np.zeros(filter_shape1).astype(dtype)
+    bias_data1 = np.zeros(bias_shape1).astype(dtype)
+    initializer("weight", filter_data1)
+    initializer("bias", bias_data1)
+    filter_data2 = np.zeros(filter_shape2).astype(dtype)
+    initializer("weight", filter_data2)
+    filter_data3 = np.zeros(filter_shape3).astype(dtype)
+    initializer("weight", filter_data3)
+    params1 = {
+        "weight1": tvm.nd.array(filter_data1),
+        "bias1": tvm.nd.array(bias_data1),
+        "weight2": tvm.nd.array(filter_data2),
+        "weight3": tvm.nd.array(filter_data3),
+    }
+
+    static_memory_scope = [
+        "",
+        "global",
+        "global.texture-weight",
+        "global.texture-weight",
+        "global.texture",
+        "global.texture-weight",
+        "global",
+        "global.texture",
+        "global.texture-weight",
+        "",
+        "",
+    ]
+
+    build_run_compare(mod, params1, {"data": input_shape}, dtype, target, static_memory_scope)
+
+
+@tvm.testing.requires_opencl
+def test_concat():
+    """
+        layout_transform (NCHW->NCHW4c)
+                  |                      <- buffer
+                conv2d (1)               <- to get textures as output
+               /         \
+            conv2d (2)    conv2d (3)
+                 \       /               <- concat does not support textures, there we should have buffers
+                concatenation
+                     |                   <- buffer
+               layout_transform (NCHW4c->NCHW)
+    """
+    target = "opencl --device=adreno"
+    dtype = "float16"
+
+    input_shape = (1, 32, 40, 40)
+    filter_shape1 = (96, 32, 2, 2)
+    filter_shape2 = (32, 96, 2, 2)
+    filter_shape3 = (5, 96, 2, 2)
+    bias_shape1 = (1, 96, 1, 1)
+    bias_shape2 = (1, 32, 1, 1)
+    A = relay.var("data", shape=input_shape, dtype=dtype)
+    W1 = relay.var("weight1", shape=filter_shape1, dtype=dtype)
+    B1 = relay.var("bias1", shape=bias_shape1, dtype=dtype)
+    W2 = relay.var("weight2", shape=filter_shape2, dtype=dtype)
+    W3 = relay.var("weight3", shape=filter_shape3, dtype=dtype)
+    B2 = relay.var("bias2", shape=bias_shape2, dtype=dtype)
+
+    # C = relay.nn.relu(A)
+    conv1 = relay.nn.conv2d(
+        A,
+        W1,
+        data_layout="NCHW",
+        kernel_layout="OIHW",
+        padding=[0, 0, 0, 0],
+        strides=[2, 2],
+        out_dtype=dtype,
+        channels=96,
+        kernel_size=(2, 2),
+    )
+    D = relay.op.add(conv1, B1)
+    D = relay.op.nn.relu(D)
+
+    conv2 = relay.nn.conv2d(
+        D,
+        W2,
+        data_layout="NCHW",
+        kernel_layout="OIHW",
+        padding=[0, 0, 0, 0],
+        strides=[2, 2],
+        out_dtype=dtype,
+        channels=32,
+        kernel_size=(2, 2),
+    )
+    conv2 = relay.op.add(conv2, B2)
+    conv2 = relay.op.nn.relu(conv2)
+
+    conv3 = relay.nn.conv2d(
+        D,
+        W3,
+        data_layout="NCHW",
+        kernel_layout="OIHW",
+        padding=[0, 0, 0, 0],
+        strides=[2, 2],
+        out_dtype=dtype,
+        channels=5,
+        kernel_size=(2, 2),
+    )
+
+    t = relay.Tuple([conv2, conv3])
+    c = relay.op.concatenate(t, axis=1)
+
+    mod = relay.Function([A, W1, B1, W2, B2, W3], c)
+    np.random.seed(0)
+    initializer = relay.testing.init.Xavier()
+    filter_data1 = np.zeros(filter_shape1).astype(dtype)
+    bias_data1 = np.zeros(bias_shape1).astype(dtype)
+    initializer("weight", filter_data1)
+    initializer("bias", bias_data1)
+    filter_data2 = np.zeros(filter_shape2).astype(dtype)
+    bias_data2 = np.zeros(bias_shape2).astype(dtype)
+    initializer("weight", filter_data2)
+    initializer("bias", bias_data2)
+    filter_data3 = np.zeros(filter_shape3).astype(dtype)
+    initializer("weight", filter_data3)
+    params1 = {
+        "weight1": tvm.nd.array(filter_data1),
+        "bias1": tvm.nd.array(bias_data1),
+        "weight2": tvm.nd.array(filter_data2),
+        "bias2": tvm.nd.array(bias_data2),
+        "weight3": tvm.nd.array(filter_data3),
+    }
+
+    static_memory_scope = [
+        "",
+        "global",
+        "global.texture-weight",
+        "global.texture-weight",
+        "global",
+        "global.texture-weight",
+        "global.texture-weight",
+        "",
+        "",
+        "",
+        "",
+        "",
+    ]
+
+    static_memory_scope = []
+
+    build_run_compare(mod, params1, {"data": input_shape}, dtype, target, static_memory_scope)
+
+
+@tvm.testing.requires_opencl
+def test_pooling_branching_texture_params():
+    """
+    Verification of the pooling and many branches having textures
+                layout_transform (NCHW->NCHW4c)
+                         |                        <- buffer
+                      conv2d (0)                  <- to get textures
+                         |                        <- textures
+                     pooling
+               /           \           \          <- textures
+            conv2d (1)    conv2d (2)    conv2d (3)
+                \             /           |
+                     add                  |       <- to have  the only one output, will be fused
+                      \                  /
+                            add                  <- to have  the only one output, will be fused
+                             |                   <- buffer
+                    layout_transform (NCHW4c->NCHW)
+    """
+    target = "opencl --device=adreno"
+    dtype = "float16"
+
+    input_shape = (1, 32, 40, 40)
+    filter_shape0 = (32, 32, 1, 1)
+    filter_shape1 = (32, 32, 2, 2)
+    filter_shape2 = (32, 32, 1, 1)
+    filter_shape3 = (32, 32, 2, 2)
+    bias_shape1 = (1, 32, 1, 1)
+    # bias_shape2 = (1, 32, 1, 1)
+    A = relay.var("data", shape=input_shape, dtype=dtype)
+    W0 = relay.var("weight0", shape=filter_shape0, dtype=dtype)
+    W1 = relay.var("weight1", shape=filter_shape1, dtype=dtype)
+    B1 = relay.var("bias1", shape=bias_shape1, dtype=dtype)
+    W2 = relay.var("weight2", shape=filter_shape2, dtype=dtype)
+    W3 = relay.var("weight3", shape=filter_shape3, dtype=dtype)
+
+    conv0 = relay.nn.conv2d(
+        A,
+        W0,
+        data_layout="NCHW",
+        kernel_layout="OIHW",
+        padding=[0, 0, 0, 0],
+        strides=[1, 1],
+        out_dtype=dtype,
+        channels=32,
+        kernel_size=(1, 1),
+    )
+
+    pool = relay.nn.avg_pool2d(conv0, pool_size=(2, 2), strides=(2, 2))
+    conv1 = relay.nn.conv2d(
+        pool,
+        W1,
+        data_layout="NCHW",
+        kernel_layout="OIHW",
+        padding=[0, 0, 1, 1],
+        strides=[1, 1],
+        out_dtype=dtype,
+        channels=32,
+        kernel_size=(2, 2),
+    )
+    conv1 = relay.op.add(conv1, B1)
+    conv1 = relay.op.nn.relu(conv1)
+
+    conv2 = relay.nn.conv2d(
+        pool,
+        W2,
+        data_layout="NCHW",
+        kernel_layout="OIHW",
+        padding=[0, 0, 0, 0],
+        strides=[1, 1],
+        out_dtype=dtype,
+        channels=32,
+        kernel_size=(1, 1),
+    )
+
+    conv3 = relay.nn.conv2d(
+        pool,
+        W3,
+        data_layout="NCHW",
+        kernel_layout="OIHW",
+        padding=[0, 1, 1, 0],
+        strides=[1, 1],
+        out_dtype=dtype,
+        channels=32,
+        kernel_size=(2, 2),
+    )
+    conv3 = relay.op.nn.relu(conv3)
+    res = relay.op.add(conv1, conv2)
+    res = relay.op.add(res, conv3)
+
+    mod = relay.Function([A, W0, W1, B1, W2, W3], res)
+    np.random.seed(0)
+    initializer = relay.testing.init.Xavier()
+    filter_data0 = np.zeros(filter_shape0).astype(dtype)
+    filter_data1 = np.zeros(filter_shape1).astype(dtype)
+    bias_data1 = np.zeros(bias_shape1).astype(dtype)
+    initializer("weight", filter_data1)
+    initializer("bias", bias_data1)
+    filter_data2 = np.zeros(filter_shape2).astype(dtype)
+    initializer("weight", filter_data2)
+    filter_data3 = np.zeros(filter_shape3).astype(dtype)
+    initializer("weight", filter_data3)
+    params1 = {
+        "weight0": tvm.nd.array(filter_data0),
+        "weight1": tvm.nd.array(filter_data1),
+        "bias1": tvm.nd.array(bias_data1),
+        "weight2": tvm.nd.array(filter_data2),
+        "weight3": tvm.nd.array(filter_data3),
+    }
+
+    static_memory_scope = [
+        "",
+        "global",
+        "global.texture-weight",
+        "global.texture",
+        "global.texture",
+        "global.texture-weight",
+        "global.texture-weight",
+        "global.texture-weight",
+        "global.texture",
+        "global.texture-weight",
+        "global.texture",
+        "",
+        "",
+    ]
+
+    build_run_compare(mod, params1, {"data": input_shape}, dtype, target, static_memory_scope)
+
+
+@tvm.testing.requires_opencl
+def test_branching_texture_params():
+    """
+    Verification of passing texture to several consumers markup of relay variables in
+    primary functions + on_device
+
+                layout_transform (NCHW->NCHW4c)
+                         |                      <- buffer
+                      conv2d (0)                <- to get textures
+             /           \           \          <- here should be textures and textures in params
+          conv2d (1)    conv2d (2)    conv2d (3)
+            \             /           |
+                  add                 |         <- to have  the only one output
+                    \                /
+                           add                  <- to have  the only one output
+                            |                   <- buffer
+                    layout_transform (NCHW4c->NCHW)
+    """
+    target = "opencl --device=adreno"
+    dtype = "float16"
+
+    input_shape = (1, 32, 40, 40)
+    filter_shape0 = (32, 32, 1, 1)
+    filter_shape1 = (32, 32, 2, 2)
+    filter_shape2 = (32, 32, 1, 1)
+    filter_shape3 = (32, 32, 2, 2)
+    bias_shape1 = (1, 32, 1, 1)
+    # bias_shape2 = (1, 32, 1, 1)
+    A = relay.var("data", shape=input_shape, dtype=dtype)
+    W0 = relay.var("weight0", shape=filter_shape0, dtype=dtype)
+    W1 = relay.var("weight1", shape=filter_shape1, dtype=dtype)
+    B1 = relay.var("bias1", shape=bias_shape1, dtype=dtype)
+    W2 = relay.var("weight2", shape=filter_shape2, dtype=dtype)
+    W3 = relay.var("weight3", shape=filter_shape3, dtype=dtype)
+
+    conv0 = relay.nn.conv2d(
+        A,
+        W0,
+        data_layout="NCHW",
+        kernel_layout="OIHW",
+        padding=[0, 0, 0, 0],
+        strides=[1, 1],
+        out_dtype=dtype,
+        channels=32,
+        kernel_size=(1, 1),
+    )
+
+    conv1 = relay.nn.conv2d(
+        conv0,
+        W1,
+        data_layout="NCHW",
+        kernel_layout="OIHW",
+        padding=[0, 0, 1, 1],
+        strides=[1, 1],
+        out_dtype=dtype,
+        channels=32,
+        kernel_size=(2, 2),
+    )
+    conv1 = relay.op.add(conv1, B1)
+    conv1 = relay.op.nn.relu(conv1)
+
+    conv2 = relay.nn.conv2d(
+        conv0,
+        W2,
+        data_layout="NCHW",
+        kernel_layout="OIHW",
+        padding=[0, 0, 0, 0],
+        strides=[1, 1],
+        out_dtype=dtype,
+        channels=32,
+        kernel_size=(1, 1),
+    )
+
+    conv3 = relay.nn.conv2d(
+        conv0,
+        W3,
+        data_layout="NCHW",
+        kernel_layout="OIHW",
+        padding=[0, 1, 1, 0],
+        strides=[1, 1],
+        out_dtype=dtype,
+        channels=32,
+        kernel_size=(2, 2),
+    )
+    conv3 = relay.op.nn.relu(conv3)
+    res = relay.op.add(conv1, conv2)
+    res = relay.op.add(res, conv3)
+
+    mod = relay.Function([A, W0, W1, B1, W2, W3], res)
+    np.random.seed(0)
+    initializer = relay.testing.init.Xavier()
+    filter_data0 = np.zeros(filter_shape0).astype(dtype)
+    filter_data1 = np.zeros(filter_shape1).astype(dtype)
+    bias_data1 = np.zeros(bias_shape1).astype(dtype)
+    initializer("weight", filter_data1)
+    initializer("bias", bias_data1)
+    filter_data2 = np.zeros(filter_shape2).astype(dtype)
+    initializer("weight", filter_data2)
+    filter_data3 = np.zeros(filter_shape3).astype(dtype)
+    initializer("weight", filter_data3)
+    params1 = {
+        "weight0": tvm.nd.array(filter_data0),
+        "weight1": tvm.nd.array(filter_data1),
+        "bias1": tvm.nd.array(bias_data1),
+        "weight2": tvm.nd.array(filter_data2),
+        "weight3": tvm.nd.array(filter_data3),
+    }
+
+    static_memory_scope = [
+        "",
+        "global",
+        "global.texture-weight",
+        "global.texture",
+        "global.texture-weight",
+        "global.texture-weight",
+        "global.texture-weight",
+        "global.texture",
+        "global.texture-weight",
+        "global.texture",
+        "",
+        "",
+    ]
+
+    build_run_compare(mod, params1, {"data": input_shape}, dtype, target, static_memory_scope)
+
+
+# function repeat, params scope are different in reused functions
+@tvm.testing.requires_opencl
+def test_conv2d_different_lowering_same_op():
+    """
+    Use case for verification of caching compiled functions
+    Three convolutions following by each other in this case should be
+    compiled in three different entities and lowered differently because
+    they are differ in input param memory scopes and in output memory scope
+
+                layout_transform (NCHW->NCHW4c)
+                         |                      <- buffer
+                      conv2d (1)                <- buffer as input tensor and texture as output
+                         |                      <- texture
+                      conv2d (2)                <- texture as input and texture as output
+                         |                      <- texture
+                      conv2d (3)                <- texture as input and buffer as output
+                         |                      <- buffer
+                    layout_transform (NCHW4c->NCHW)
+    """
+    target = "opencl --device=adreno"
+    dtype = "float16"
+
+    input_shape = (1, 32, 40, 40)
+    filter_shape1 = (32, 32, 1, 1)
+    A = relay.var("data", shape=input_shape, dtype=dtype)
+    W1 = relay.var("weight1", shape=filter_shape1, dtype=dtype)
+
+    conv1 = relay.nn.conv2d(
+        A,
+        W1,
+        data_layout="NCHW",
+        kernel_layout="OIHW",
+        padding=[0, 0, 0, 0],
+        strides=[1, 1],
+        out_dtype=dtype,
+        channels=32,
+        kernel_size=(1, 1),
+    )
+
+    conv2 = relay.nn.conv2d(
+        conv1,
+        W1,
+        data_layout="NCHW",
+        kernel_layout="OIHW",
+        padding=[0, 0, 0, 0],
+        strides=[1, 1],
+        out_dtype=dtype,
+        channels=32,
+        kernel_size=(1, 1),
+    )
+
+    conv3 = relay.nn.conv2d(
+        conv2,
+        W1,
+        data_layout="NCHW",
+        kernel_layout="OIHW",
+        padding=[0, 0, 0, 0],
+        strides=[1, 1],
+        out_dtype=dtype,
+        channels=32,
+        kernel_size=(1, 1),
+    )
+
+    mod = relay.Function([A, W1], conv3)
+    np.random.seed(0)
+    initializer = relay.testing.init.Xavier()
+    filter_data1 = np.zeros(filter_shape1).astype(dtype)
+    params1 = {
+        "weight1": tvm.nd.array(filter_data1),
+    }
+
+    static_memory_scope = [
+        "",
+        "global",
+        "global.texture-weight",
+        "global.texture",
+        "global.texture",
+        "",
+        "",
+    ]
+
+    build_run_compare(mod, params1, {"data": input_shape}, dtype, target, static_memory_scope)
diff --git a/tests/python/relay/test_conv2d_nhwc_texture.py b/tests/python/relay/test_conv2d_nhwc_texture.py
index b6c54e8daa..be5cefd460 100644
--- a/tests/python/relay/test_conv2d_nhwc_texture.py
+++ b/tests/python/relay/test_conv2d_nhwc_texture.py
@@ -225,7 +225,7 @@ def test_conv2d_inceptionv3_64x35x35_96x64x3x3_nopad():
         "bias": tvm.nd.array(bias_data),
     }
 
-    build_run_compare(mod, params1, {"data": input_shape}, dtype, target, gpu_preprocess)
+    build_run_compare(mod, params1, {"data": input_shape}, dtype, target, [], gpu_preprocess)
 
 
 @tvm.testing.requires_opencl
@@ -267,7 +267,7 @@ def test_conv2d_inceptionv3_64x35x35_96x64x3x3_nopad_pass():
         "bias": tvm.nd.array(bias_data),
     }
 
-    build_run_compare(mod, params1, {"data": input_shape}, dtype, target, gpu_preprocess)
+    build_run_compare(mod, params1, {"data": input_shape}, dtype, target, [], gpu_preprocess)
 
 
 @tvm.testing.requires_opencl
@@ -309,7 +309,7 @@ def test_conv2d_inceptionv3_35_35_strides():
         "bias": tvm.nd.array(bias_data),
     }
 
-    build_run_compare(mod, params1, {"data": input_shape}, dtype, target, gpu_preprocess)
+    build_run_compare(mod, params1, {"data": input_shape}, dtype, target, [], gpu_preprocess)
 
 
 @tvm.testing.requires_opencl
@@ -493,7 +493,6 @@ def test_conv2d_4x4x4_16c16pad():
     B = relay.var("weight", shape=filter_shape, dtype=dtype)
     bias = relay.var("bias", shape=bias_shape, dtype=dtype)
 
-    # C = relay.nn.relu(A)
     conv = relay.nn.conv2d(
         A,
         B,
diff --git a/tests/python/relay/test_depthwise_conv2d_nchw_texture.py b/tests/python/relay/test_depthwise_conv2d_nchw_texture.py
index 71cf62c5d8..c94d085b51 100644
--- a/tests/python/relay/test_depthwise_conv2d_nchw_texture.py
+++ b/tests/python/relay/test_depthwise_conv2d_nchw_texture.py
@@ -64,7 +64,7 @@ def test_depthwise_conv2d_bias_nchwc():
         "bias": tvm.nd.array(bias_data),
     }
 
-    build_run_compare(mod, params1, {"data": input_shape}, dtype, target, gpu_preprocess)
+    build_run_compare(mod, params1, {"data": input_shape}, dtype, target, [], gpu_preprocess)
 
 
 @tvm.testing.requires_opencl
@@ -103,7 +103,7 @@ def test_depthwise_conv2d_nchwc():
         "weight": tvm.nd.array(filter_data),
     }
 
-    build_run_compare(mod, params1, {"data": input_shape}, dtype, target, gpu_preprocess)
+    build_run_compare(mod, params1, {"data": input_shape}, dtype, target, [], gpu_preprocess)
 
 
 @tvm.testing.requires_opencl
diff --git a/tests/python/relay/test_depthwise_conv2d_nhwc_texture.py b/tests/python/relay/test_depthwise_conv2d_nhwc_texture.py
index 16d26c77ca..16f9b87499 100644
--- a/tests/python/relay/test_depthwise_conv2d_nhwc_texture.py
+++ b/tests/python/relay/test_depthwise_conv2d_nhwc_texture.py
@@ -20,7 +20,7 @@ import tvm
 import numpy as np
 from tvm import relay
 from tvm.relay import testing
-from utils.adreno_utils import gpu_preprocess, build_run_compare
+from utils.adreno_utils import build_run_compare
 
 
 @tvm.testing.requires_opencl
diff --git a/tests/python/relay/utils/adreno_utils.py b/tests/python/relay/utils/adreno_utils.py
index 6e353b22cd..27768c3d0c 100644
--- a/tests/python/relay/utils/adreno_utils.py
+++ b/tests/python/relay/utils/adreno_utils.py
@@ -24,6 +24,7 @@ from tvm import autotvm
 from tvm.relay import testing
 from tvm.relay.transform import recast
 from tvm.contrib import graph_runtime
+import json
 
 
 def get_cpu_reference(mod, params1, input_shape, inputs):
@@ -51,6 +52,7 @@ def build_run_compare(
     input_shape,
     dtype="float32",
     target="llvm",
+    static_mem_scopes=[],
     gpu_preprocess=None,
     stat_file=None,
 ):
@@ -82,6 +84,19 @@ def build_run_compare(
                 tvm_mod_nchwc, target_host=target_host, target=target, params=params1
             )
 
+    # verification that storage_scope has expected textures scopes
+    graph_json = json.loads(graph)
+    if "storage_scope" in graph_json["attrs"]:
+        assert (
+            len(static_mem_scopes) == len(graph_json["attrs"]["storage_scope"][1])
+            or len(static_mem_scopes) == 0
+        )
+    else:
+        assert len(static_mem_scopes) == 0
+
+    for i in range(0, len(static_mem_scopes)):
+        assert static_mem_scopes[i] == graph_json["attrs"]["storage_scope"][1][i]
+
     if run_on_host:
         ctx = tvm.opencl()
         m = graph_runtime.create(graph, lib, ctx)