You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/09/06 03:33:14 UTC

[GitHub] [tvm] zhuwenxi commented on a diff in pull request #11326: [LIBXSMM][BYOC] Integrate Libxsmm with BYOC

zhuwenxi commented on code in PR #11326:
URL: https://github.com/apache/tvm/pull/11326#discussion_r963222218


##########
src/runtime/contrib/libxsmm/libxsmm_json_runtime.cc:
##########
@@ -0,0 +1,192 @@
+#include <libxsmm.h>
+#include <libxsmm_typedefs.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/registry.h>
+
+#include "../json/json_node.h"
+#include "../json/json_runtime.h"
+
+namespace tvm {
+namespace runtime {
+namespace contrib {
+
+class LibxsmmJSONRuntime : public json::JSONRuntimeBase {
+ public:
+  LibxsmmJSONRuntime(const std::string& symbol_name, const std::string& graph_json,
+                     const Array<String> const_names)
+      : JSONRuntimeBase(symbol_name, graph_json, const_names) {}
+
+  const char* type_key() const { return "libxsmm_json"; }
+
+  void Init(const Array<NDArray>& consts) override {
+
+    SetupConstants(consts);
+    for (size_t nid = 0; nid < nodes_.size(); ++nid) {
+      auto& node = nodes_[nid];
+      if (node.GetOpType() == "kernel") {
+        auto op_name = node.GetOpName();
+
+        // Check if has bias or relu fusion.
+        has_bias_ = op_name.find("_bias") != std::string::npos;
+        has_relu_ = op_name.find("_relu") != std::string::npos;
+
+        // Get M, N, K, lda, ldb, ldc.
+        auto data_entry = node.GetInputs()[0];
+        auto weight_entry = node.GetInputs()[1];
+        json::JSONGraphNodeEntry out_entry(nid, 0);
+
+        std::vector<int64_t> input_shape = nodes_[data_entry.id_].GetOpShape()[data_entry.index_];
+        std::vector<int64_t> weight_shape =
+            nodes_[weight_entry.id_].GetOpShape()[weight_entry.index_];
+        std::vector<int64_t> out_shape = nodes_[out_entry.id_].GetOpShape()[out_entry.index_];
+
+        M_ = input_shape[0];
+        N_ = weight_shape[0];
+        K_ = input_shape[1];
+
+        int lda = N_;
+        int ldb = K_;
+        int ldc = N_;
+
+        // Curently we support fp32 only.
+        libxsmm_datatype dtype = LIBXSMM_DATATYPE_F32;
+
+        // Configure GEMM related parameters
+        libxsmm_bitfield l_flags = LIBXSMM_GEMM_FLAG_NONE | LIBXSMM_GEMM_FLAG_BETA_0;
+        libxsmm_bitfield l_prefetch_flags = LIBXSMM_GEMM_PREFETCH_NONE;
+        libxsmm_gemm_shape l_shape =
+            libxsmm_create_gemm_shape(N_, M_, K_, lda, ldb, ldc, dtype, dtype, dtype, dtype);
+        libxsmm_blasint stride_a = N_ * K_ * sizeof(float);
+        libxsmm_blasint stride_b = K_ * M_ * sizeof(float);
+        libxsmm_gemm_batch_reduce_config l_brconfig = libxsmm_create_gemm_batch_reduce_config(
+            LIBXSMM_GEMM_BATCH_REDUCE_STRIDE, stride_a, stride_b, 0 /*br_unrool_hint*/);
+
+        libxsmm_gemm_ext_unary_argops l_argops;
+        libxsmm_gemm_ext_binary_postops l_postops;
+        memset(&l_argops, 0, sizeof(libxsmm_gemm_ext_unary_argops));
+        memset(&l_postops, 0, sizeof(libxsmm_gemm_ext_binary_postops));
+
+        if (has_bias_) {
+          l_postops.d_in_type = dtype;
+          l_postops.d_binary_flags = LIBXSMM_MELTW_FLAG_BINARY_BCAST_COL_IN_0;
+          l_postops.d_binary_type = LIBXSMM_MELTW_TYPE_BINARY_ADD;
+          l_postops.ldd = ldc;
+        }
+
+        if (has_relu_) {
+          l_argops.cp_unary_flags = LIBXSMM_MELTW_FLAG_UNARY_NONE;
+          l_argops.cp_unary_type = LIBXSMM_MELTW_TYPE_UNARY_RELU;
+          l_argops.ldcp = ldc;
+          // relu mask should have the same size as matrix C.
+          relu_mask_.resize(M_ * N_, 0);
+        }
+
+        // Use "libxsmm_gemmfunction" for GEMM kernel, and "libxsmm_gemmfunction_ext" for fused GEMM
+        // kernel.
+        if (has_bias_ || has_relu_) {
+          gemm_fusion_kernel_ = libxsmm_dispatch_brgemm_ext_v2(l_shape, l_flags, l_prefetch_flags,
+                                                               l_brconfig, l_argops, l_postops);
+        } else {
+          gemm_kernel_ = libxsmm_dispatch_brgemm_v2(l_shape, l_flags, l_prefetch_flags, l_brconfig);
+        }
+      }
+    }
+  }
+
+  void Run() override {
+    // Get input/output buffers.
+    auto data_eid = EntryID(input_nodes_[0], 0);
+    auto filter_eid = EntryID(input_nodes_[1], 0);
+    auto output_eid = EntryID(outputs_[0]);
+
+    void* data_handle = data_entry_[data_eid]->data;
+    void* filter_handle = data_entry_[filter_eid]->data;
+    void* output_handle = data_entry_[output_eid]->data;
+
+    // Transpose weight matrix since libxsmm only support GEMM rather than DENSE.
+    if (!transposed_filter_handle_) {
+      TVMDeviceAllocDataSpace(dev_, K_ * N_ * sizeof(float), kAllocAlignment, type_hint_,
+                              &transposed_filter_handle_);
+      for (int k = 0; k < K_; ++k) {
+        for (int n = 0; n < N_; ++n) {
+          static_cast<float*>(transposed_filter_handle_)[k * N_ + n] =
+              static_cast<float*>(filter_handle)[n * K_ + k];
+        }
+      }
+    }

Review Comment:
   Done.
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org