You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2019/11/07 03:27:03 UTC

[GitHub] [incubator-tvm] adityaatluri commented on a change in pull request #4234: Auto TensorCore CodeGen

adityaatluri commented on a change in pull request #4234: Auto TensorCore CodeGen
URL: https://github.com/apache/incubator-tvm/pull/4234#discussion_r343462536
 
 

 ##########
 File path: src/pass/tensor_core.cc
 ##########
 @@ -0,0 +1,1247 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *  Copyright (c) 2019 by Contributors
+ * \file tensor_core.cc
+ */
+// IR Passes for TensorCore CodeGen
+#include <tvm/ir.h>
+#include <tvm/expr.h>
+#include <tvm/operation.h>
+#include <tvm/ir_mutator.h>
+#include <tvm/ir_visitor.h>
+#include <tvm/expr_operator.h>
+#include <tvm/ir_pass.h>
+#include <tvm/buffer.h>
+#include <tvm/target_info.h>
+#include <tvm/build_module.h>
+#include <tvm/runtime/device_api.h>
+#include <unordered_map>
+#include "ir_util.h"
+#include "../arithmetic/compute_expr.h"
+#include "../runtime/thread_storage_scope.h"
+
+namespace tvm {
+namespace ir {
+
+using runtime::StorageRank;
+using runtime::StorageScope;
+using runtime::ThreadScope;
+using intrinsic::tvm_address_of;
+
+struct Tile {
+  int m{-1};
+  int n{-1};
+  int k{-1};
+};
+
+std::string simplify_name(std::string input) {
+  auto pos = input.find(".");
+  if (pos != std::string::npos) {
+    return input.substr(0, pos);
+  } else {
+    return input;
+  }
+}
+
+// MMAMatcher matches C = Cast(A)*Cast(B)+C,
+// where A & B are fp16/int8 local buffers,
+// and C is fp32/int32 local buffer.
+class MMAMatcher: public IRVisitor {
+ public:
+  explicit MMAMatcher(Map<Tensor, Buffer> extern_buffer,
+                      double cuda_compute_capability, double cuda_version) {
+    for (auto kv : extern_buffer) {
+      BufferInfo bi;
+      bi.name = kv.second->name;
+      bi.dtype = kv.second->dtype;
+      bi.external = true;
+      buf_map_[TensorKey{kv.first->op, kv.first->value_index}] = bi;
+    }
+    // Int wmma is supported when cuda version >= 10.0 && cuda arch >= 720
+    if (cuda_compute_capability >= 7.20 && cuda_version >= 10.0) {
+      support_int_wmma_ = true;
+    }
+  }
+  using IRVisitor::Visit_;
+
+  void Visit_(const AttrStmt* op) final {
+    if (op->attr_key == attr::pragma_tensor_core) {
+      tensor_core_on_ = true;
+      IRVisitor::Visit_(op);
+    } else if (op->attr_key == attr::realize_scope) {
+      storage_scope_[op->node.get()] = op->value.as<StringImm>()->value;
+      Visit(op->body);
+    } else {
+      IRVisitor::Visit_(op);
+    }
+  }
+
+  void Visit_(const Provide* op) final {
+    IRVisitor::Visit_(op);
+    auto it = buf_map_.find(TensorKey{op->func, op->value_index});
+    if (it == buf_map_.end()) {
+      return;
+    }
+    const BufferInfo& bi = it->second;
+    if (bi.released) {
+      return;
+    }
+    if (tensor_core_on_ && mma_sync_match_(op, bi)) {
+      matched_ = true;
+    }
+  }
+
+  void Visit_(const Realize* op) final {
+    TensorKey key{op->func, op->value_index};
+    if (buf_map_.count(key)) {
+      if (!buf_map_.at(key).external) {
+        return;
+      }
+      Visit(op->body);
+    } else {
+      BufferInfo bi;
+      bi.name = key.GetName();
+      bi.dtype = op->type;
+      buf_map_[key] = bi;
+      Visit(op->body);
+      buf_map_[key].released = true;
+    }
+  }
+
+  inline bool Matched() const {return matched_;}
+
+  friend class ScheduleAnalyser;
+  friend class BufferAnalyser;
+
+ private:
+  struct BufferInfo {
+    std::string name;
+    Type dtype;
+    bool external{false};
+    bool released{false};
+    bool same_as(const BufferInfo &bi) {
+      if (this->dtype != bi.dtype) return false;
+      if (this->name != bi.name) return false;
+      if (this->external != bi.external) return false;
+      if (this->released != bi.released) return false;
+      return true;
+    }
+  };
+
+  // Check whether the storage scope is local
+  bool check_local_buffer_(const Call* op, BufferInfo* bi) {
+    if (op->call_type == Call::Halide) {
+      auto it = storage_scope_.find(op->func.get());
+      if (it == storage_scope_.end()) {
+        return false;
+      }
+      const std::string& strkey = it->second;
+      if (strkey != "local") {
+        return false;
+      }
+      auto it1 = buf_map_.find(TensorKey{op->func, op->value_index});
+      if (it1 == buf_map_.end()) {
+        return false;
+      }
+      *bi = it1->second;
+      if (bi->released) {
+        return false;
+      }
+      return true;
+    }
+    return false;
+  }
+
+  // Do the pattern matching
+  bool mma_sync_match_(const Provide* op, BufferInfo store_buffer) {
 
 Review comment:
   Dumb question, does this function looks for wmma and replace with mma.sync?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services