You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2021/05/01 11:27:26 UTC

[tvm] branch main updated: [Target][Legalization]Add Tir Level Legalization Function Registration And Update Intrinsic Lowering Pass (#7936)

This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 3d380fc  [Target][Legalization]Add Tir Level Legalization Function Registration And Update Intrinsic Lowering Pass (#7936)
3d380fc is described below

commit 3d380fc5817ce335ea823ce7e6b7e35717e579cd
Author: Xiyou Zhou <xi...@octoml.ai>
AuthorDate: Sat May 1 04:27:06 2021 -0700

    [Target][Legalization]Add Tir Level Legalization Function Registration And Update Intrinsic Lowering Pass (#7936)
---
 src/target/intrin_rule.cc                |  24 ++++---
 src/target/llvm/intrin_rule_llvm.cc      | 109 ++++++++++++++++---------------
 src/target/spirv/intrin_rule_spirv.cc    |  53 ++++++++-------
 src/tir/transforms/lower_intrin.cc       |  43 ++++++------
 tests/python/unittest/test_tir_intrin.py |  69 ++++++++++++++++++-
 5 files changed, 188 insertions(+), 110 deletions(-)

diff --git a/src/target/intrin_rule.cc b/src/target/intrin_rule.cc
index bfc3fe6..e697d9b 100644
--- a/src/target/intrin_rule.cc
+++ b/src/target/intrin_rule.cc
@@ -112,19 +112,25 @@ TVM_REGISTER_OP("tir.ceil")
 TVM_REGISTER_OP("tir.round")
     .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic", DispatchPureExtern<FloatSuffix>);
 
+TVM_REGISTER_OP("tir.pow").set_attr<FLowerIntrinsic>("default.FLowerIntrinsic",
+                                                     DispatchPureExtern<FloatSuffix>);
+
+}  // namespace intrin
+
+namespace legalize {
+
+using namespace tir;
+
 TVM_REGISTER_OP("tir.rsqrt")
-    .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr {
+    .set_attr<FLegalize>("default.FLegalize", [](const PrimExpr& e) -> PrimExpr {
       const CallNode* call = e.as<CallNode>();
       ICHECK(call != nullptr);
       auto one = make_const(call->args[0].dtype(), 1);
       return one / sqrt(call->args[0]);
     });
 
-TVM_REGISTER_OP("tir.pow").set_attr<FLowerIntrinsic>("default.FLowerIntrinsic",
-                                                     DispatchPureExtern<FloatSuffix>);
-
 TVM_REGISTER_OP("tir.sigmoid")
-    .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr {
+    .set_attr<FLegalize>("default.FLegalize", [](const PrimExpr& e) -> PrimExpr {
       const CallNode* call = e.as<CallNode>();
       ICHECK(call != nullptr);
       auto one = make_const(call->args[0].dtype(), 1);
@@ -132,21 +138,21 @@ TVM_REGISTER_OP("tir.sigmoid")
     });
 
 TVM_REGISTER_OP("tir.isfinite")
-    .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr {
+    .set_attr<FLegalize>("default.FLegalize", [](const PrimExpr& e) -> PrimExpr {
       const CallNode* call = e.as<CallNode>();
       ICHECK(call != nullptr);
       return isfinite(call->args[0]);
     });
 
 TVM_REGISTER_OP("tir.isinf")
-    .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr {
+    .set_attr<FLegalize>("default.FLegalize", [](const PrimExpr& e) -> PrimExpr {
       const CallNode* call = e.as<CallNode>();
       ICHECK(call != nullptr);
       return isinf(call->args[0]);
     });
 
 TVM_REGISTER_OP("tir.q_multiply_shift")
-    .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr {
+    .set_attr<FLegalize>("default.FLegalize", [](const PrimExpr& e) -> PrimExpr {
       using tir::make_const;
 
       const tir::CallNode* call = e.as<tir::CallNode>();
@@ -222,6 +228,6 @@ TVM_REGISTER_OP("tir.q_multiply_shift")
       }
     });
 
-}  // namespace intrin
+}  // namespace legalize
 }  // namespace codegen
 }  // namespace tvm
diff --git a/src/target/llvm/intrin_rule_llvm.cc b/src/target/llvm/intrin_rule_llvm.cc
index 2d30c20..adbd105 100644
--- a/src/target/llvm/intrin_rule_llvm.cc
+++ b/src/target/llvm/intrin_rule_llvm.cc
@@ -30,6 +30,7 @@
 namespace tvm {
 namespace codegen {
 namespace llvm {
+namespace intrin {
 using tir::FLowerIntrinsic;
 
 TVM_REGISTER_OP("tir.prefetch")
@@ -43,20 +44,6 @@ TVM_REGISTER_OP("tir.exp2")
     .set_attr<FLowerIntrinsic>("llvm.FLowerIntrinsic",
                                DispatchLLVMPureIntrin<::llvm::Intrinsic::exp2, 1>);
 
-// TODO(tvm-team): migrate the legalization transformations as a separate
-//                 set of rules in TIR that can be shared across backends.
-TVM_REGISTER_OP("tir.exp10")
-    .set_attr<FLowerIntrinsic>("llvm.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr {
-      using tir::make_const;
-      using tir::make_zero;
-      const tir::CallNode* call = e.as<tir::CallNode>();
-      ICHECK(call != nullptr);
-      const PrimExpr& x = call->args[0];
-      PrimExpr ln10 = make_const(x.dtype(), 2.302585093);
-      PrimExpr ret = exp(x * ln10);
-      return ret;
-    });
-
 TVM_REGISTER_OP("tir.fma").set_attr<FLowerIntrinsic>(
     "llvm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::fmuladd, 3>);
 
@@ -99,8 +86,37 @@ TVM_REGISTER_OP("tir.nearbyint")
     .set_attr<FLowerIntrinsic>("llvm.FLowerIntrinsic",
                                DispatchLLVMPureIntrin<::llvm::Intrinsic::nearbyint, 1>);
 
+TVM_REGISTER_OP("tir.pow").set_attr<FLowerIntrinsic>(
+    "llvm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::pow, 2>);
+
+TVM_REGISTER_OP("tir.popcount")
+    .set_attr<FLowerIntrinsic>("llvm.FLowerIntrinsic",
+                               DispatchLLVMPureIntrin<::llvm::Intrinsic::ctpop, 1>);
+
+TVM_REGISTER_OP("tir.cos").set_attr<FLowerIntrinsic>(
+    "llvm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::cos, 1>);
+
+TVM_REGISTER_OP("tir.sin").set_attr<FLowerIntrinsic>(
+    "llvm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::sin, 1>);
+}  // namespace intrin
+
+namespace legalize {
+using tir::FLegalize;
+
+TVM_REGISTER_OP("tir.exp10")
+    .set_attr<FLegalize>("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr {
+      using tir::make_const;
+      using tir::make_zero;
+      const tir::CallNode* call = e.as<tir::CallNode>();
+      ICHECK(call != nullptr);
+      const PrimExpr& x = call->args[0];
+      PrimExpr ln10 = make_const(x.dtype(), 2.302585093);
+      PrimExpr ret = exp(x * ln10);
+      return ret;
+    });
+
 TVM_REGISTER_OP("tir.tanh")
-    .set_attr<FLowerIntrinsic>("llvm.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr {
+    .set_attr<FLegalize>("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr {
       using tir::make_const;
       using tir::make_zero;
       const tir::CallNode* call = e.as<tir::CallNode>();
@@ -118,28 +134,16 @@ TVM_REGISTER_OP("tir.tanh")
       return tir::Select(x >= make_zero(x.dtype()), tanh_pos, tanh_neg);
     });
 
-TVM_REGISTER_OP("tir.pow").set_attr<FLowerIntrinsic>(
-    "llvm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::pow, 2>);
-
-TVM_REGISTER_OP("tir.popcount")
-    .set_attr<FLowerIntrinsic>("llvm.FLowerIntrinsic",
-                               DispatchLLVMPureIntrin<::llvm::Intrinsic::ctpop, 1>);
-
-TVM_REGISTER_OP("tir.tan").set_attr<FLowerIntrinsic>("llvm.FLowerIntrinsic",
-                                                     [](const PrimExpr& e) -> PrimExpr {
-                                                       const tir::CallNode* call =
-                                                           e.as<tir::CallNode>();
-                                                       ICHECK(call != nullptr);
-                                                       const PrimExpr& x = call->args[0];
-                                                       PrimExpr tan_x = sin(x) / cos(x);
-                                                       return tan_x;
-                                                     });
-
-TVM_REGISTER_OP("tir.cos").set_attr<FLowerIntrinsic>(
-    "llvm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::cos, 1>);
+TVM_REGISTER_OP("tir.tan").set_attr<FLegalize>("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr {
+  const tir::CallNode* call = e.as<tir::CallNode>();
+  ICHECK(call != nullptr);
+  const PrimExpr& x = call->args[0];
+  PrimExpr tan_x = sin(x) / cos(x);
+  return tan_x;
+});
 
 TVM_REGISTER_OP("tir.cosh")
-    .set_attr<FLowerIntrinsic>("llvm.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr {
+    .set_attr<FLegalize>("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr {
       using tir::make_const;
       using tir::make_zero;
       const tir::CallNode* call = e.as<tir::CallNode>();
@@ -153,11 +157,8 @@ TVM_REGISTER_OP("tir.cosh")
       return ret;
     });
 
-TVM_REGISTER_OP("tir.sin").set_attr<FLowerIntrinsic>(
-    "llvm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::sin, 1>);
-
 TVM_REGISTER_OP("tir.sinh")
-    .set_attr<FLowerIntrinsic>("llvm.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr {
+    .set_attr<FLegalize>("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr {
       using tir::make_const;
       using tir::make_zero;
       const tir::CallNode* call = e.as<tir::CallNode>();
@@ -171,21 +172,21 @@ TVM_REGISTER_OP("tir.sinh")
       return ret;
     });
 
-TVM_REGISTER_OP("tir.clz").set_attr<FLowerIntrinsic>(
-    "llvm.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr {
-      const tir::CallNode* call = e.as<tir::CallNode>();
-      ICHECK(call != nullptr);
-      ICHECK_EQ(call->args.size(), 1);
-      Array<PrimExpr> cargs;
-      cargs.push_back(IntImm(DataType::UInt(32), ::llvm::Intrinsic::ctlz));
-      cargs.push_back(IntImm(DataType::UInt(32), 2));
-      cargs.push_back(call->args[0]);
-      cargs.push_back(IntImm(DataType::Int(1), 1));  // is_zero_undef
-      // LLVM requires that the return type must match the first argument type
-      auto clz = tir::Call(call->args[0]->dtype, tir::builtin::call_llvm_intrin(), cargs);
-      return cast(call->dtype, clz);
-    });
-
+TVM_REGISTER_OP("tir.clz").set_attr<FLegalize>("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr {
+  const tir::CallNode* call = e.as<tir::CallNode>();
+  ICHECK(call != nullptr);
+  ICHECK_EQ(call->args.size(), 1);
+  Array<PrimExpr> cargs;
+  cargs.push_back(IntImm(DataType::UInt(32), ::llvm::Intrinsic::ctlz));
+  cargs.push_back(IntImm(DataType::UInt(32), 2));
+  cargs.push_back(call->args[0]);
+  cargs.push_back(IntImm(DataType::Int(1), 1));  // is_zero_undef
+  // LLVM requires that the return type must match the first argument type
+  auto clz = tir::Call(call->args[0]->dtype, tir::builtin::call_llvm_intrin(), cargs);
+  return cast(call->dtype, clz);
+});
+
+}  // namespace legalize
 }  // namespace llvm
 }  // namespace codegen
 }  // namespace tvm
diff --git a/src/target/spirv/intrin_rule_spirv.cc b/src/target/spirv/intrin_rule_spirv.cc
index fa38f8f..eca7c4c 100644
--- a/src/target/spirv/intrin_rule_spirv.cc
+++ b/src/target/spirv/intrin_rule_spirv.cc
@@ -30,8 +30,6 @@
 namespace tvm {
 namespace codegen {
 namespace spirv {
-using tir::FLowerIntrinsic;
-
 // num_signature means number of arguments used to query signature
 template <unsigned id>
 PrimExpr CallGLSLIntrin(PrimExpr e, const Array<PrimExpr>& args) {
@@ -59,6 +57,8 @@ inline PrimExpr DispatchGLSLPureIntrin(const PrimExpr& e) {
   return CallGLSLIntrin<id>(e);
 }
 
+namespace intrin {
+using tir::FLowerIntrinsic;
 TVM_REGISTER_OP("tir.floor")
     .set_attr<FLowerIntrinsic>("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin<GLSLstd450Floor>);
 
@@ -98,29 +98,6 @@ TVM_REGISTER_OP("tir.pow").set_attr<FLowerIntrinsic>("vulkan.FLowerIntrinsic",
 TVM_REGISTER_OP("tir.tanh")
     .set_attr<FLowerIntrinsic>("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin<GLSLstd450Tanh>);
 
-TVM_REGISTER_OP("tir.clz").set_attr<FLowerIntrinsic>(
-    "vulkan.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr {
-      const tir::CallNode* call = e.as<tir::CallNode>();
-      ICHECK(call != nullptr);
-      ICHECK_EQ(call->args.size(), 1);
-      PrimExpr arg = call->args[0];
-      PrimExpr msb;
-      if (arg.dtype().bits() == 64) {
-        // SPIR-V FindUMsb intrinsic only supports 32 bit input
-        auto int32 = DataType::Int(32);
-        PrimExpr arg_hi32 = tvm::tir::Cast(int32, arg >> 32);
-        PrimExpr arg_lo32 = tvm::tir::Cast(int32, arg);
-        PrimExpr msb_hi = CallGLSLIntrin<GLSLstd450FindUMsb>(e, {arg_hi32});
-        PrimExpr msb_lo = CallGLSLIntrin<GLSLstd450FindUMsb>(e, {arg_lo32});
-        msb = tvm::if_then_else(arg_hi32 == 0, msb_lo, msb_hi + 32);
-      } else if (arg.dtype().bits() == 32) {
-        msb = CallGLSLIntrin<GLSLstd450FindUMsb>(e);
-      } else {
-        LOG(FATAL) << "SPIR-V clz only supports a 32 bit or 64 bit integer.";
-      }
-      return PrimExpr(arg.dtype().bits() - 1) - msb;
-    });
-
 // WebGPU rules.
 TVM_REGISTER_OP("tir.floor")
     .set_attr<FLowerIntrinsic>("webgpu.FLowerIntrinsic", DispatchGLSLPureIntrin<GLSLstd450Floor>);
@@ -151,7 +128,33 @@ TVM_REGISTER_OP("tir.pow").set_attr<FLowerIntrinsic>("webgpu.FLowerIntrinsic",
 
 TVM_REGISTER_OP("tir.tanh")
     .set_attr<FLowerIntrinsic>("webgpu.FLowerIntrinsic", DispatchGLSLPureIntrin<GLSLstd450Tanh>);
+}  // namespace intrin
 
+namespace legalize {
+using tir::FLegalize;
+TVM_REGISTER_OP("tir.clz").set_attr<FLegalize>(
+    "vulkan.FLegalize", [](const PrimExpr& e) -> PrimExpr {
+      const tir::CallNode* call = e.as<tir::CallNode>();
+      ICHECK(call != nullptr);
+      ICHECK_EQ(call->args.size(), 1);
+      PrimExpr arg = call->args[0];
+      PrimExpr msb;
+      if (arg.dtype().bits() == 64) {
+        // SPIR-V FindUMsb intrinsic only supports 32 bit input
+        auto int32 = DataType::Int(32);
+        PrimExpr arg_hi32 = tvm::tir::Cast(int32, arg >> 32);
+        PrimExpr arg_lo32 = tvm::tir::Cast(int32, arg);
+        PrimExpr msb_hi = CallGLSLIntrin<GLSLstd450FindUMsb>(e, {arg_hi32});
+        PrimExpr msb_lo = CallGLSLIntrin<GLSLstd450FindUMsb>(e, {arg_lo32});
+        msb = tvm::if_then_else(arg_hi32 == 0, msb_lo, msb_hi + 32);
+      } else if (arg.dtype().bits() == 32) {
+        msb = CallGLSLIntrin<GLSLstd450FindUMsb>(e);
+      } else {
+        LOG(FATAL) << "SPIR-V clz only supports a 32 bit or 64 bit integer.";
+      }
+      return PrimExpr(arg.dtype().bits() - 1) - msb;
+    });
+}  // namespace legalize
 }  // namespace spirv
 }  // namespace codegen
 }  // namespace tvm
diff --git a/src/tir/transforms/lower_intrin.cc b/src/tir/transforms/lower_intrin.cc
index 4101891..2555002 100644
--- a/src/tir/transforms/lower_intrin.cc
+++ b/src/tir/transforms/lower_intrin.cc
@@ -39,33 +39,34 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
  public:
   using IRMutatorWithAnalyzer::VisitExpr_;
   using IRMutatorWithAnalyzer::VisitStmt_;
+  using FLowerGeneral = runtime::TypedPackedFunc<PrimExpr(PrimExpr)>;
 
   IntrinInjecter(arith::Analyzer* analyzer, std::string target, std::string mtriple = "")
       : IRMutatorWithAnalyzer(analyzer) {
-    std::vector<std::string> patterns_;
-    patterns_.push_back(target + ".FLowerIntrinsic");
-
+    std::vector<std::string> patterns;
+    patterns.push_back(target + ".FLowerIntrinsic");
+    patterns.push_back(target + ".FLegalize");
     bool is_llvm_aarch64 = (mtriple.find("aarch64") != std::string::npos);
     if (is_llvm_aarch64) {
-      patterns_.push_back(target + ".aarch64.FLowerIntrinsic");
-    }
-
-    patterns_.push_back("default.FLowerIntrinsic");
-
-    fma_ = runtime::Registry::Get("tvm.intrin.rule." + target + ".fma");
-    if (target == "stackvm") {
-      support_bitwise_op_ = false;
+      patterns.push_back(target + ".aarch64.FLowerIntrinsic");
+      patterns.push_back(target + ".aarch64.FLegalize");
     }
-
-    for (const std::string& pattern : patterns_)
-      if (Op::HasAttrMap(pattern))
-        lower_intrin_maps_.push_back(Op::GetAttrMap<FLowerIntrinsic>(pattern));
+    patterns.push_back("default.FLowerIntrinsic");
+    patterns.push_back("default.FLegalize");
+
+    for (const std::string& pattern : patterns)
+      if (Op::HasAttrMap(pattern)) {
+        attr_maps_.push_back(Op::GetAttrMap<FLowerGeneral>(pattern));
+        if (fma_ == nullptr) {
+          fma_ = (*attr_maps_.rbegin()).get(Op::Get("tir.fma"), nullptr);
+        }
+      }
   }
 
   PrimExpr VisitExpr_(const CallNode* op) final {
     if (auto* ptr_op = op->op.as<OpNode>()) {
-      for (const auto& f_lower_intrin_map : lower_intrin_maps_) {
-        FLowerIntrinsic f = f_lower_intrin_map.get(GetRef<Op>(ptr_op), nullptr);
+      for (const auto& f_attr_map : attr_maps_) {
+        FLowerGeneral f = f_attr_map.get(GetRef<Op>(ptr_op), nullptr);
         if (f != nullptr) {
           PrimExpr e = GetRef<PrimExpr>(op);
           PrimExpr r = f(e);
@@ -269,7 +270,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
     PrimExpr rhs = SwapBroadcastCast(b);
 
     if (fma_ != nullptr && op->dtype.is_float()) {
-      PrimExpr r = (*fma_)(Call(op->dtype, builtin::fma(), {lhs, rhs, c}));
+      PrimExpr r = fma_(Call(op->dtype, builtin::fma(), {lhs, rhs, c}));
       if (r.defined()) return this->VisitExpr(r);
     } else {
       if (!lhs.same_as(a) || !rhs.same_as(b)) {
@@ -280,9 +281,9 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
     return IRMutatorWithAnalyzer::VisitExpr_(op);
   }
 
-  // patterns
-  std::vector<OpAttrMap<FLowerIntrinsic>> lower_intrin_maps_;
-  const PackedFunc* fma_{nullptr};
+  // attribute maps, shared only when FLegalize == FLowerIntrinsic
+  std::vector<OpAttrMap<FLowerGeneral>> attr_maps_;
+  FLowerGeneral fma_{nullptr};
   bool support_bitwise_op_{true};
 };
 
diff --git a/tests/python/unittest/test_tir_intrin.py b/tests/python/unittest/test_tir_intrin.py
index 8512d1c..79b2819 100644
--- a/tests/python/unittest/test_tir_intrin.py
+++ b/tests/python/unittest/test_tir_intrin.py
@@ -16,9 +16,10 @@
 # under the License.
 import tvm
 import tvm.testing
-from tvm import te
+from tvm import te, tir
 from tvm import topi
 from tvm.contrib import utils, clang
+from tvm.script import ty
 import numpy as np
 import ctypes
 import math
@@ -184,6 +185,71 @@ def test_clz():
                 np.testing.assert_equal(b.asnumpy(), ref)
 
 
+@tvm.script.tir
+class Module:
+    def test_tir_fma(A: ty.handle, B: ty.handle, C: ty.handle, d: ty.handle) -> None:
+        # function attr dict
+        tir.func_attr({"global_symbol": "test_fma", "tir.noalias": True})
+        n = tir.var("int32")
+        stride = tir.var("int32")
+        stride_1 = tir.var("int32")
+        stride_2 = tir.var("int32")
+        stride_3 = tir.var("int32")
+        A_1 = tir.match_buffer(
+            A,
+            [n],
+            strides=[stride],
+            elem_offset=0,
+            align=128,
+            offset_factor=1,
+            type="auto",
+        )
+        B_1 = tir.match_buffer(
+            B,
+            [n],
+            strides=[stride_1],
+            elem_offset=0,
+            align=128,
+            offset_factor=1,
+            type="auto",
+        )
+        C_1 = tir.match_buffer(
+            C,
+            [n],
+            strides=[stride_2],
+            elem_offset=0,
+            align=128,
+            offset_factor=1,
+            type="auto",
+        )
+        d_1 = tir.match_buffer(
+            d,
+            [n],
+            strides=[stride_3],
+            elem_offset=0,
+            align=128,
+            offset_factor=1,
+            type="auto",
+        )
+        # body
+        for i in tir.serial(0, n):
+            d_1.data[(i * stride_3)] = (
+                tir.load("float32", A_1.data, (i * stride))
+                * tir.load("float32", B_1.data, (i * stride_1))
+            ) + tir.load("float32", C_1.data, (i * stride_2))
+
+
+def test_fma():
+    opt = tvm.transform.Sequential(
+        [
+            tvm.tir.transform.Apply(lambda f: f.with_attr("target", tvm.target.Target("llvm"))),
+            tvm.tir.transform.LowerIntrin(),
+        ]
+    )
+    mod = opt(Module())
+    assert mod["test_tir_fma"].body.body.value.op.name == "tir.call_llvm_pure_intrin"
+
+
 if __name__ == "__main__":
     test_nearbyint()
     test_unary_intrin()
@@ -191,3 +257,4 @@ if __name__ == "__main__":
     test_binary_intrin()
     test_ldexp()
     test_clz()
+    test_fma()