You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/04/06 09:37:15 UTC

[GitHub] [tvm] manupa-arm commented on a diff in pull request #10189: [USMP] Adding support for U1 usecase for constant pools

manupa-arm commented on code in PR #10189:
URL: https://github.com/apache/tvm/pull/10189#discussion_r843703446


##########
src/tir/usmp/analysis/extract_buffer_info.cc:
##########
@@ -491,6 +559,55 @@ BufferInfoAnalysis BufferInfoExtractor::operator()(const PrimFunc& main_func) {
       open_set.erase(le_event.buffer_info);
     }
   }
+
+  // All items in RO pool should have conflicts with each other in this RO pool

Review Comment:
   nit : lets stick to ConstantPools / ConstantPoolInfo



##########
src/tir/usmp/analysis/extract_buffer_info.cc:
##########
@@ -491,6 +559,55 @@ BufferInfoAnalysis BufferInfoExtractor::operator()(const PrimFunc& main_func) {
       open_set.erase(le_event.buffer_info);
     }
   }
+
+  // All items in RO pool should have conflicts with each other in this RO pool
+  // as they will be placed in RO segment and pre-initialized
+  // Array<BufferInfo> buffer_info_arr =
+  //    ConvertToArrayOfBufferInfo(this->buffer_info_map_);
+
+  // split buffers to vars (RW) and constants (RO)
+  Array<BufferInfo> buffer_info_vars;
+  Array<BufferInfo> buffer_info_cons;
+  for (const auto& kv : this->buffer_info_map_) {
+    const auto& buf = kv.first;
+    for (const auto& pool : buf->pool_candidates) {
+      if (pool->IsInstance<ConstantPoolInfoNode>()) {
+        buffer_info_cons.push_back(buf);
+      } else {
+        buffer_info_vars.push_back(buf);
+      }
+
+      break;
+    }
+  }
+  ICHECK(buffer_info_map_.size() == buffer_info_vars.size() + buffer_info_cons.size())
+      << "missing value";
+
+  // intersect with each other, as all constants should exist at the same time
+  for (const auto& buf1 : buffer_info_cons) {
+    for (const auto& buf2 : buffer_info_cons) {
+      if (buf1->conflicts.end() ==
+          std::find(buf1->conflicts.begin(), buf1->conflicts.end(), buf2)) {
+        buf1->conflicts.push_back(buf2);
+        if (buf2->conflicts.end() ==
+            std::find(buf2->conflicts.begin(), buf2->conflicts.end(), buf1)) {
+          buf2->conflicts.push_back(buf1);
+        }
+      }
+    }
+    // remove conflicts with vars part as non-relevant anymore
+    for (const auto& buf2 : buffer_info_vars) {
+      const auto& it = std::find(buf1->conflicts.begin(), buf1->conflicts.end(), buf2);

Review Comment:
   This is expensive -- we need to come up with a way to remove this find(..)



##########
src/tir/usmp/analysis/extract_buffer_info.cc:
##########
@@ -491,6 +559,55 @@ BufferInfoAnalysis BufferInfoExtractor::operator()(const PrimFunc& main_func) {
       open_set.erase(le_event.buffer_info);
     }
   }
+
+  // All items in RO pool should have conflicts with each other in this RO pool
+  // as they will be placed in RO segment and pre-initialized
+  // Array<BufferInfo> buffer_info_arr =
+  //    ConvertToArrayOfBufferInfo(this->buffer_info_map_);
+
+  // split buffers to vars (RW) and constants (RO)
+  Array<BufferInfo> buffer_info_vars;
+  Array<BufferInfo> buffer_info_cons;
+  for (const auto& kv : this->buffer_info_map_) {
+    const auto& buf = kv.first;
+    for (const auto& pool : buf->pool_candidates) {
+      if (pool->IsInstance<ConstantPoolInfoNode>()) {
+        buffer_info_cons.push_back(buf);
+      } else {
+        buffer_info_vars.push_back(buf);
+      }
+
+      break;
+    }
+  }
+  ICHECK(buffer_info_map_.size() == buffer_info_vars.size() + buffer_info_cons.size())
+      << "missing value";
+
+  // intersect with each other, as all constants should exist at the same time
+  for (const auto& buf1 : buffer_info_cons) {
+    for (const auto& buf2 : buffer_info_cons) {
+      if (buf1->conflicts.end() ==
+          std::find(buf1->conflicts.begin(), buf1->conflicts.end(), buf2)) {
+        buf1->conflicts.push_back(buf2);
+        if (buf2->conflicts.end() ==
+            std::find(buf2->conflicts.begin(), buf2->conflicts.end(), buf1)) {
+          buf2->conflicts.push_back(buf1);
+        }
+      }
+    }
+    // remove conflicts with vars part as non-relevant anymore
+    for (const auto& buf2 : buffer_info_vars) {
+      const auto& it = std::find(buf1->conflicts.begin(), buf1->conflicts.end(), buf2);
+      if (buf1->conflicts.end() != it) {
+        buf1->conflicts.erase(it);
+        const auto& it = std::find(buf2->conflicts.begin(), buf2->conflicts.end(), buf1);

Review Comment:
   Same here.



##########
src/tir/usmp/analysis/extract_buffer_info.cc:
##########
@@ -491,6 +559,55 @@ BufferInfoAnalysis BufferInfoExtractor::operator()(const PrimFunc& main_func) {
       open_set.erase(le_event.buffer_info);
     }
   }
+
+  // All items in RO pool should have conflicts with each other in this RO pool
+  // as they will be placed in RO segment and pre-initialized
+  // Array<BufferInfo> buffer_info_arr =
+  //    ConvertToArrayOfBufferInfo(this->buffer_info_map_);
+
+  // split buffers to vars (RW) and constants (RO)
+  Array<BufferInfo> buffer_info_vars;
+  Array<BufferInfo> buffer_info_cons;
+  for (const auto& kv : this->buffer_info_map_) {
+    const auto& buf = kv.first;
+    for (const auto& pool : buf->pool_candidates) {
+      if (pool->IsInstance<ConstantPoolInfoNode>()) {
+        buffer_info_cons.push_back(buf);
+      } else {
+        buffer_info_vars.push_back(buf);
+      }
+
+      break;
+    }
+  }
+  ICHECK(buffer_info_map_.size() == buffer_info_vars.size() + buffer_info_cons.size())
+      << "missing value";
+
+  // intersect with each other, as all constants should exist at the same time
+  for (const auto& buf1 : buffer_info_cons) {
+    for (const auto& buf2 : buffer_info_cons) {
+      if (buf1->conflicts.end() ==
+          std::find(buf1->conflicts.begin(), buf1->conflicts.end(), buf2)) {
+        buf1->conflicts.push_back(buf2);
+        if (buf2->conflicts.end() ==
+            std::find(buf2->conflicts.begin(), buf2->conflicts.end(), buf1)) {
+          buf2->conflicts.push_back(buf1);
+        }
+      }
+    }
+    // remove conflicts with vars part as non-relevant anymore

Review Comment:
   Is this necessary ? I do understand it is not relevant by the virtue of now we will only place constants in a different pool but therefore liveness across pools does not matter. In the same time, will it break if we have every buffer to be a conflicting with constant nodes ?



##########
src/tir/usmp/analysis/extract_buffer_info.cc:
##########
@@ -491,6 +559,55 @@ BufferInfoAnalysis BufferInfoExtractor::operator()(const PrimFunc& main_func) {
       open_set.erase(le_event.buffer_info);
     }
   }
+
+  // All items in RO pool should have conflicts with each other in this RO pool
+  // as they will be placed in RO segment and pre-initialized
+  // Array<BufferInfo> buffer_info_arr =
+  //    ConvertToArrayOfBufferInfo(this->buffer_info_map_);
+
+  // split buffers to vars (RW) and constants (RO)
+  Array<BufferInfo> buffer_info_vars;
+  Array<BufferInfo> buffer_info_cons;
+  for (const auto& kv : this->buffer_info_map_) {
+    const auto& buf = kv.first;
+    for (const auto& pool : buf->pool_candidates) {

Review Comment:
   Can we not just check kv.second for the node type instead to determine this ? 
   Which would simplify the code here.



##########
src/tir/usmp/analysis/extract_buffer_info.cc:
##########
@@ -491,6 +559,55 @@ BufferInfoAnalysis BufferInfoExtractor::operator()(const PrimFunc& main_func) {
       open_set.erase(le_event.buffer_info);
     }
   }
+
+  // All items in RO pool should have conflicts with each other in this RO pool
+  // as they will be placed in RO segment and pre-initialized
+  // Array<BufferInfo> buffer_info_arr =
+  //    ConvertToArrayOfBufferInfo(this->buffer_info_map_);
+
+  // split buffers to vars (RW) and constants (RO)
+  Array<BufferInfo> buffer_info_vars;
+  Array<BufferInfo> buffer_info_cons;
+  for (const auto& kv : this->buffer_info_map_) {
+    const auto& buf = kv.first;
+    for (const auto& pool : buf->pool_candidates) {
+      if (pool->IsInstance<ConstantPoolInfoNode>()) {
+        buffer_info_cons.push_back(buf);
+      } else {
+        buffer_info_vars.push_back(buf);
+      }
+
+      break;
+    }
+  }
+  ICHECK(buffer_info_map_.size() == buffer_info_vars.size() + buffer_info_cons.size())
+      << "missing value";
+
+  // intersect with each other, as all constants should exist at the same time
+  for (const auto& buf1 : buffer_info_cons) {

Review Comment:
   We can avoid these expensive (with .find() on arrays) loop nests by simply replacing the conflicts array of each buffer node with a entire list of buffer nodes. Consider my other comment on line 598, I think it does not break to have all the buffer info objects to be conflicts of constant buffer info nodes (and that does make sense in theory as well).



##########
tests/python/unittest/test_tir_usmp_analysis_extract_bufferinfo.py:
##########
@@ -1365,9 +1362,9 @@ def run_model(data: T.handle, output: T.handle) -> None:
 
 def test_multiple_calls_to_same_primfunc():

Review Comment:
   We need a unit test here to check the newly added feature to extract buffer info pass.



##########
src/tir/usmp/analysis/extract_buffer_info.cc:
##########
@@ -491,6 +559,55 @@ BufferInfoAnalysis BufferInfoExtractor::operator()(const PrimFunc& main_func) {
       open_set.erase(le_event.buffer_info);
     }
   }
+
+  // All items in RO pool should have conflicts with each other in this RO pool
+  // as they will be placed in RO segment and pre-initialized
+  // Array<BufferInfo> buffer_info_arr =
+  //    ConvertToArrayOfBufferInfo(this->buffer_info_map_);
+
+  // split buffers to vars (RW) and constants (RO)
+  Array<BufferInfo> buffer_info_vars;
+  Array<BufferInfo> buffer_info_cons;

Review Comment:
   nit : may be using buffer_info_constants is a better name ?



##########
src/target/source/source_module.cc:
##########
@@ -236,20 +241,72 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode {
     return reference_arg + "_tvm_value";
   }
 
-  void GenerateInternalWorkspaceBuffers() {
+  void GenerateInternalBuffers() {
     if (metadata_->pool_inputs.defined()) {
       for (const auto& kv : metadata_->pool_inputs.value()) {
         tir::usmp::AllocatedPoolInfo allocated_pool_info = kv.second;
         if (allocated_pool_info->pool_info->is_internal) {
-          code_ << "__attribute__((section(\".data.tvm\"), ";
-          code_ << "aligned(" << 16 << ")))\n";
-          code_ << "static uint8_t " << allocated_pool_info->pool_info->pool_name << "["
-                << allocated_pool_info->allocated_size->value << "];\n";
+          if (const auto* pool_info = allocated_pool_info->pool_info.as<ConstantPoolInfoNode>()) {
+            GenerateConstantBuffer(pool_info, allocated_pool_info->allocated_size->value);
+          } else {
+            GenerateWorkspaceBuffer(allocated_pool_info->pool_info.as<WorkspacePoolInfoNode>(),
+                                    allocated_pool_info->allocated_size->value);
+          }
         }
       }
     }
   }
 
+  void GenerateConstantBuffer(const ConstantPoolInfoNode* pool_info, size_t allocated_size) {
+    size_t offset = 0;
+    if (pool_info->constant_info_array.size() > 0) {
+      // Pool is RO, form an initialized struct
+      code_ << "__attribute__((section(\".rodata.tvm\"), ";
+      code_ << "))\n";
+      code_ << "static struct " << pool_info->pool_name << " {\n";
+      // emit struct field names
+      std::vector<ConstantInfo> const_info_vec(pool_info->constant_info_array.begin(),
+                                               pool_info->constant_info_array.end());
+      std::sort(const_info_vec.begin(), const_info_vec.end(),
+                [](const ConstantInfo& a, const ConstantInfo& b) {
+                  return a->byte_offset->value < b->byte_offset->value;
+                });
+      for (const auto& const_info : const_info_vec) {
+        const auto& data = const_info->data;
+        const auto& offs = const_info->byte_offset;
+        int64_t num_elements = std::accumulate(data.Shape().begin(), data.Shape().end(), 1,
+                                               std::multiplies<int64_t>());
+        code_ << "  ";
+        codegen_c_base_.PrintType(data.DataType(), code_);
+        code_ << " " << const_info->name_hint << "[" << num_elements
+              << "] __attribute__((packed, aligned(" << metadata_->workspace_byte_alignment

Review Comment:
   This needs to be addressed.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org