You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2021/05/01 11:27:26 UTC
[tvm] branch main updated: [Target][Legalization]Add Tir Level
Legalization Function Registration And Update Intrinsic Lowering Pass
(#7936)
This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 3d380fc [Target][Legalization]Add Tir Level Legalization Function Registration And Update Intrinsic Lowering Pass (#7936)
3d380fc is described below
commit 3d380fc5817ce335ea823ce7e6b7e35717e579cd
Author: Xiyou Zhou <xi...@octoml.ai>
AuthorDate: Sat May 1 04:27:06 2021 -0700
[Target][Legalization]Add Tir Level Legalization Function Registration And Update Intrinsic Lowering Pass (#7936)
---
src/target/intrin_rule.cc | 24 ++++---
src/target/llvm/intrin_rule_llvm.cc | 109 ++++++++++++++++---------------
src/target/spirv/intrin_rule_spirv.cc | 53 ++++++++-------
src/tir/transforms/lower_intrin.cc | 43 ++++++------
tests/python/unittest/test_tir_intrin.py | 69 ++++++++++++++++++-
5 files changed, 188 insertions(+), 110 deletions(-)
diff --git a/src/target/intrin_rule.cc b/src/target/intrin_rule.cc
index bfc3fe6..e697d9b 100644
--- a/src/target/intrin_rule.cc
+++ b/src/target/intrin_rule.cc
@@ -112,19 +112,25 @@ TVM_REGISTER_OP("tir.ceil")
TVM_REGISTER_OP("tir.round")
.set_attr<FLowerIntrinsic>("default.FLowerIntrinsic", DispatchPureExtern<FloatSuffix>);
+TVM_REGISTER_OP("tir.pow").set_attr<FLowerIntrinsic>("default.FLowerIntrinsic",
+ DispatchPureExtern<FloatSuffix>);
+
+} // namespace intrin
+
+namespace legalize {
+
+using namespace tir;
+
TVM_REGISTER_OP("tir.rsqrt")
- .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr {
+ .set_attr<FLegalize>("default.FLegalize", [](const PrimExpr& e) -> PrimExpr {
const CallNode* call = e.as<CallNode>();
ICHECK(call != nullptr);
auto one = make_const(call->args[0].dtype(), 1);
return one / sqrt(call->args[0]);
});
-TVM_REGISTER_OP("tir.pow").set_attr<FLowerIntrinsic>("default.FLowerIntrinsic",
- DispatchPureExtern<FloatSuffix>);
-
TVM_REGISTER_OP("tir.sigmoid")
- .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr {
+ .set_attr<FLegalize>("default.FLegalize", [](const PrimExpr& e) -> PrimExpr {
const CallNode* call = e.as<CallNode>();
ICHECK(call != nullptr);
auto one = make_const(call->args[0].dtype(), 1);
@@ -132,21 +138,21 @@ TVM_REGISTER_OP("tir.sigmoid")
});
TVM_REGISTER_OP("tir.isfinite")
- .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr {
+ .set_attr<FLegalize>("default.FLegalize", [](const PrimExpr& e) -> PrimExpr {
const CallNode* call = e.as<CallNode>();
ICHECK(call != nullptr);
return isfinite(call->args[0]);
});
TVM_REGISTER_OP("tir.isinf")
- .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr {
+ .set_attr<FLegalize>("default.FLegalize", [](const PrimExpr& e) -> PrimExpr {
const CallNode* call = e.as<CallNode>();
ICHECK(call != nullptr);
return isinf(call->args[0]);
});
TVM_REGISTER_OP("tir.q_multiply_shift")
- .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr {
+ .set_attr<FLegalize>("default.FLegalize", [](const PrimExpr& e) -> PrimExpr {
using tir::make_const;
const tir::CallNode* call = e.as<tir::CallNode>();
@@ -222,6 +228,6 @@ TVM_REGISTER_OP("tir.q_multiply_shift")
}
});
-} // namespace intrin
+} // namespace legalize
} // namespace codegen
} // namespace tvm
diff --git a/src/target/llvm/intrin_rule_llvm.cc b/src/target/llvm/intrin_rule_llvm.cc
index 2d30c20..adbd105 100644
--- a/src/target/llvm/intrin_rule_llvm.cc
+++ b/src/target/llvm/intrin_rule_llvm.cc
@@ -30,6 +30,7 @@
namespace tvm {
namespace codegen {
namespace llvm {
+namespace intrin {
using tir::FLowerIntrinsic;
TVM_REGISTER_OP("tir.prefetch")
@@ -43,20 +44,6 @@ TVM_REGISTER_OP("tir.exp2")
.set_attr<FLowerIntrinsic>("llvm.FLowerIntrinsic",
DispatchLLVMPureIntrin<::llvm::Intrinsic::exp2, 1>);
-// TODO(tvm-team): migrate the legalization transformations as a separate
-// set of rules in TIR that can be shared across backends.
-TVM_REGISTER_OP("tir.exp10")
- .set_attr<FLowerIntrinsic>("llvm.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr {
- using tir::make_const;
- using tir::make_zero;
- const tir::CallNode* call = e.as<tir::CallNode>();
- ICHECK(call != nullptr);
- const PrimExpr& x = call->args[0];
- PrimExpr ln10 = make_const(x.dtype(), 2.302585093);
- PrimExpr ret = exp(x * ln10);
- return ret;
- });
-
TVM_REGISTER_OP("tir.fma").set_attr<FLowerIntrinsic>(
"llvm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::fmuladd, 3>);
@@ -99,8 +86,37 @@ TVM_REGISTER_OP("tir.nearbyint")
.set_attr<FLowerIntrinsic>("llvm.FLowerIntrinsic",
DispatchLLVMPureIntrin<::llvm::Intrinsic::nearbyint, 1>);
+TVM_REGISTER_OP("tir.pow").set_attr<FLowerIntrinsic>(
+ "llvm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::pow, 2>);
+
+TVM_REGISTER_OP("tir.popcount")
+ .set_attr<FLowerIntrinsic>("llvm.FLowerIntrinsic",
+ DispatchLLVMPureIntrin<::llvm::Intrinsic::ctpop, 1>);
+
+TVM_REGISTER_OP("tir.cos").set_attr<FLowerIntrinsic>(
+ "llvm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::cos, 1>);
+
+TVM_REGISTER_OP("tir.sin").set_attr<FLowerIntrinsic>(
+ "llvm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::sin, 1>);
+} // namespace intrin
+
+namespace legalize {
+using tir::FLegalize;
+
+TVM_REGISTER_OP("tir.exp10")
+ .set_attr<FLegalize>("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr {
+ using tir::make_const;
+ using tir::make_zero;
+ const tir::CallNode* call = e.as<tir::CallNode>();
+ ICHECK(call != nullptr);
+ const PrimExpr& x = call->args[0];
+ PrimExpr ln10 = make_const(x.dtype(), 2.302585093);
+ PrimExpr ret = exp(x * ln10);
+ return ret;
+ });
+
TVM_REGISTER_OP("tir.tanh")
- .set_attr<FLowerIntrinsic>("llvm.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr {
+ .set_attr<FLegalize>("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr {
using tir::make_const;
using tir::make_zero;
const tir::CallNode* call = e.as<tir::CallNode>();
@@ -118,28 +134,16 @@ TVM_REGISTER_OP("tir.tanh")
return tir::Select(x >= make_zero(x.dtype()), tanh_pos, tanh_neg);
});
-TVM_REGISTER_OP("tir.pow").set_attr<FLowerIntrinsic>(
- "llvm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::pow, 2>);
-
-TVM_REGISTER_OP("tir.popcount")
- .set_attr<FLowerIntrinsic>("llvm.FLowerIntrinsic",
- DispatchLLVMPureIntrin<::llvm::Intrinsic::ctpop, 1>);
-
-TVM_REGISTER_OP("tir.tan").set_attr<FLowerIntrinsic>("llvm.FLowerIntrinsic",
- [](const PrimExpr& e) -> PrimExpr {
- const tir::CallNode* call =
- e.as<tir::CallNode>();
- ICHECK(call != nullptr);
- const PrimExpr& x = call->args[0];
- PrimExpr tan_x = sin(x) / cos(x);
- return tan_x;
- });
-
-TVM_REGISTER_OP("tir.cos").set_attr<FLowerIntrinsic>(
- "llvm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::cos, 1>);
+TVM_REGISTER_OP("tir.tan").set_attr<FLegalize>("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr {
+ const tir::CallNode* call = e.as<tir::CallNode>();
+ ICHECK(call != nullptr);
+ const PrimExpr& x = call->args[0];
+ PrimExpr tan_x = sin(x) / cos(x);
+ return tan_x;
+});
TVM_REGISTER_OP("tir.cosh")
- .set_attr<FLowerIntrinsic>("llvm.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr {
+ .set_attr<FLegalize>("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr {
using tir::make_const;
using tir::make_zero;
const tir::CallNode* call = e.as<tir::CallNode>();
@@ -153,11 +157,8 @@ TVM_REGISTER_OP("tir.cosh")
return ret;
});
-TVM_REGISTER_OP("tir.sin").set_attr<FLowerIntrinsic>(
- "llvm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::sin, 1>);
-
TVM_REGISTER_OP("tir.sinh")
- .set_attr<FLowerIntrinsic>("llvm.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr {
+ .set_attr<FLegalize>("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr {
using tir::make_const;
using tir::make_zero;
const tir::CallNode* call = e.as<tir::CallNode>();
@@ -171,21 +172,21 @@ TVM_REGISTER_OP("tir.sinh")
return ret;
});
-TVM_REGISTER_OP("tir.clz").set_attr<FLowerIntrinsic>(
- "llvm.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr {
- const tir::CallNode* call = e.as<tir::CallNode>();
- ICHECK(call != nullptr);
- ICHECK_EQ(call->args.size(), 1);
- Array<PrimExpr> cargs;
- cargs.push_back(IntImm(DataType::UInt(32), ::llvm::Intrinsic::ctlz));
- cargs.push_back(IntImm(DataType::UInt(32), 2));
- cargs.push_back(call->args[0]);
- cargs.push_back(IntImm(DataType::Int(1), 1)); // is_zero_undef
- // LLVM requires that the return type must match the first argument type
- auto clz = tir::Call(call->args[0]->dtype, tir::builtin::call_llvm_intrin(), cargs);
- return cast(call->dtype, clz);
- });
-
+TVM_REGISTER_OP("tir.clz").set_attr<FLegalize>("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr {
+ const tir::CallNode* call = e.as<tir::CallNode>();
+ ICHECK(call != nullptr);
+ ICHECK_EQ(call->args.size(), 1);
+ Array<PrimExpr> cargs;
+ cargs.push_back(IntImm(DataType::UInt(32), ::llvm::Intrinsic::ctlz));
+ cargs.push_back(IntImm(DataType::UInt(32), 2));
+ cargs.push_back(call->args[0]);
+ cargs.push_back(IntImm(DataType::Int(1), 1)); // is_zero_undef
+ // LLVM requires that the return type must match the first argument type
+ auto clz = tir::Call(call->args[0]->dtype, tir::builtin::call_llvm_intrin(), cargs);
+ return cast(call->dtype, clz);
+});
+
+} // namespace legalize
} // namespace llvm
} // namespace codegen
} // namespace tvm
diff --git a/src/target/spirv/intrin_rule_spirv.cc b/src/target/spirv/intrin_rule_spirv.cc
index fa38f8f..eca7c4c 100644
--- a/src/target/spirv/intrin_rule_spirv.cc
+++ b/src/target/spirv/intrin_rule_spirv.cc
@@ -30,8 +30,6 @@
namespace tvm {
namespace codegen {
namespace spirv {
-using tir::FLowerIntrinsic;
-
// num_signature means number of arguments used to query signature
template <unsigned id>
PrimExpr CallGLSLIntrin(PrimExpr e, const Array<PrimExpr>& args) {
@@ -59,6 +57,8 @@ inline PrimExpr DispatchGLSLPureIntrin(const PrimExpr& e) {
return CallGLSLIntrin<id>(e);
}
+namespace intrin {
+using tir::FLowerIntrinsic;
TVM_REGISTER_OP("tir.floor")
.set_attr<FLowerIntrinsic>("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin<GLSLstd450Floor>);
@@ -98,29 +98,6 @@ TVM_REGISTER_OP("tir.pow").set_attr<FLowerIntrinsic>("vulkan.FLowerIntrinsic",
TVM_REGISTER_OP("tir.tanh")
.set_attr<FLowerIntrinsic>("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin<GLSLstd450Tanh>);
-TVM_REGISTER_OP("tir.clz").set_attr<FLowerIntrinsic>(
- "vulkan.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr {
- const tir::CallNode* call = e.as<tir::CallNode>();
- ICHECK(call != nullptr);
- ICHECK_EQ(call->args.size(), 1);
- PrimExpr arg = call->args[0];
- PrimExpr msb;
- if (arg.dtype().bits() == 64) {
- // SPIR-V FindUMsb intrinsic only supports 32 bit input
- auto int32 = DataType::Int(32);
- PrimExpr arg_hi32 = tvm::tir::Cast(int32, arg >> 32);
- PrimExpr arg_lo32 = tvm::tir::Cast(int32, arg);
- PrimExpr msb_hi = CallGLSLIntrin<GLSLstd450FindUMsb>(e, {arg_hi32});
- PrimExpr msb_lo = CallGLSLIntrin<GLSLstd450FindUMsb>(e, {arg_lo32});
- msb = tvm::if_then_else(arg_hi32 == 0, msb_lo, msb_hi + 32);
- } else if (arg.dtype().bits() == 32) {
- msb = CallGLSLIntrin<GLSLstd450FindUMsb>(e);
- } else {
- LOG(FATAL) << "SPIR-V clz only supports a 32 bit or 64 bit integer.";
- }
- return PrimExpr(arg.dtype().bits() - 1) - msb;
- });
-
// WebGPU rules.
TVM_REGISTER_OP("tir.floor")
.set_attr<FLowerIntrinsic>("webgpu.FLowerIntrinsic", DispatchGLSLPureIntrin<GLSLstd450Floor>);
@@ -151,7 +128,33 @@ TVM_REGISTER_OP("tir.pow").set_attr<FLowerIntrinsic>("webgpu.FLowerIntrinsic",
TVM_REGISTER_OP("tir.tanh")
.set_attr<FLowerIntrinsic>("webgpu.FLowerIntrinsic", DispatchGLSLPureIntrin<GLSLstd450Tanh>);
+} // namespace intrin
+namespace legalize {
+using tir::FLegalize;
+TVM_REGISTER_OP("tir.clz").set_attr<FLegalize>(
+ "vulkan.FLegalize", [](const PrimExpr& e) -> PrimExpr {
+ const tir::CallNode* call = e.as<tir::CallNode>();
+ ICHECK(call != nullptr);
+ ICHECK_EQ(call->args.size(), 1);
+ PrimExpr arg = call->args[0];
+ PrimExpr msb;
+ if (arg.dtype().bits() == 64) {
+ // SPIR-V FindUMsb intrinsic only supports 32 bit input
+ auto int32 = DataType::Int(32);
+ PrimExpr arg_hi32 = tvm::tir::Cast(int32, arg >> 32);
+ PrimExpr arg_lo32 = tvm::tir::Cast(int32, arg);
+ PrimExpr msb_hi = CallGLSLIntrin<GLSLstd450FindUMsb>(e, {arg_hi32});
+ PrimExpr msb_lo = CallGLSLIntrin<GLSLstd450FindUMsb>(e, {arg_lo32});
+ msb = tvm::if_then_else(arg_hi32 == 0, msb_lo, msb_hi + 32);
+ } else if (arg.dtype().bits() == 32) {
+ msb = CallGLSLIntrin<GLSLstd450FindUMsb>(e);
+ } else {
+ LOG(FATAL) << "SPIR-V clz only supports a 32 bit or 64 bit integer.";
+ }
+ return PrimExpr(arg.dtype().bits() - 1) - msb;
+ });
+} // namespace legalize
} // namespace spirv
} // namespace codegen
} // namespace tvm
diff --git a/src/tir/transforms/lower_intrin.cc b/src/tir/transforms/lower_intrin.cc
index 4101891..2555002 100644
--- a/src/tir/transforms/lower_intrin.cc
+++ b/src/tir/transforms/lower_intrin.cc
@@ -39,33 +39,34 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
public:
using IRMutatorWithAnalyzer::VisitExpr_;
using IRMutatorWithAnalyzer::VisitStmt_;
+ using FLowerGeneral = runtime::TypedPackedFunc<PrimExpr(PrimExpr)>;
IntrinInjecter(arith::Analyzer* analyzer, std::string target, std::string mtriple = "")
: IRMutatorWithAnalyzer(analyzer) {
- std::vector<std::string> patterns_;
- patterns_.push_back(target + ".FLowerIntrinsic");
-
+ std::vector<std::string> patterns;
+ patterns.push_back(target + ".FLowerIntrinsic");
+ patterns.push_back(target + ".FLegalize");
bool is_llvm_aarch64 = (mtriple.find("aarch64") != std::string::npos);
if (is_llvm_aarch64) {
- patterns_.push_back(target + ".aarch64.FLowerIntrinsic");
- }
-
- patterns_.push_back("default.FLowerIntrinsic");
-
- fma_ = runtime::Registry::Get("tvm.intrin.rule." + target + ".fma");
- if (target == "stackvm") {
- support_bitwise_op_ = false;
+ patterns.push_back(target + ".aarch64.FLowerIntrinsic");
+ patterns.push_back(target + ".aarch64.FLegalize");
}
-
- for (const std::string& pattern : patterns_)
- if (Op::HasAttrMap(pattern))
- lower_intrin_maps_.push_back(Op::GetAttrMap<FLowerIntrinsic>(pattern));
+ patterns.push_back("default.FLowerIntrinsic");
+ patterns.push_back("default.FLegalize");
+
+ for (const std::string& pattern : patterns)
+ if (Op::HasAttrMap(pattern)) {
+ attr_maps_.push_back(Op::GetAttrMap<FLowerGeneral>(pattern));
+ if (fma_ == nullptr) {
+ fma_ = (*attr_maps_.rbegin()).get(Op::Get("tir.fma"), nullptr);
+ }
+ }
}
PrimExpr VisitExpr_(const CallNode* op) final {
if (auto* ptr_op = op->op.as<OpNode>()) {
- for (const auto& f_lower_intrin_map : lower_intrin_maps_) {
- FLowerIntrinsic f = f_lower_intrin_map.get(GetRef<Op>(ptr_op), nullptr);
+ for (const auto& f_attr_map : attr_maps_) {
+ FLowerGeneral f = f_attr_map.get(GetRef<Op>(ptr_op), nullptr);
if (f != nullptr) {
PrimExpr e = GetRef<PrimExpr>(op);
PrimExpr r = f(e);
@@ -269,7 +270,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
PrimExpr rhs = SwapBroadcastCast(b);
if (fma_ != nullptr && op->dtype.is_float()) {
- PrimExpr r = (*fma_)(Call(op->dtype, builtin::fma(), {lhs, rhs, c}));
+ PrimExpr r = fma_(Call(op->dtype, builtin::fma(), {lhs, rhs, c}));
if (r.defined()) return this->VisitExpr(r);
} else {
if (!lhs.same_as(a) || !rhs.same_as(b)) {
@@ -280,9 +281,9 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
return IRMutatorWithAnalyzer::VisitExpr_(op);
}
- // patterns
- std::vector<OpAttrMap<FLowerIntrinsic>> lower_intrin_maps_;
- const PackedFunc* fma_{nullptr};
+ // attribute maps, shared only when FLegalize == FLowerIntrinsic
+ std::vector<OpAttrMap<FLowerGeneral>> attr_maps_;
+ FLowerGeneral fma_{nullptr};
bool support_bitwise_op_{true};
};
diff --git a/tests/python/unittest/test_tir_intrin.py b/tests/python/unittest/test_tir_intrin.py
index 8512d1c..79b2819 100644
--- a/tests/python/unittest/test_tir_intrin.py
+++ b/tests/python/unittest/test_tir_intrin.py
@@ -16,9 +16,10 @@
# under the License.
import tvm
import tvm.testing
-from tvm import te
+from tvm import te, tir
from tvm import topi
from tvm.contrib import utils, clang
+from tvm.script import ty
import numpy as np
import ctypes
import math
@@ -184,6 +185,71 @@ def test_clz():
np.testing.assert_equal(b.asnumpy(), ref)
+@tvm.script.tir
+class Module:
+ def test_tir_fma(A: ty.handle, B: ty.handle, C: ty.handle, d: ty.handle) -> None:
+ # function attr dict
+ tir.func_attr({"global_symbol": "test_fma", "tir.noalias": True})
+ n = tir.var("int32")
+ stride = tir.var("int32")
+ stride_1 = tir.var("int32")
+ stride_2 = tir.var("int32")
+ stride_3 = tir.var("int32")
+ A_1 = tir.match_buffer(
+ A,
+ [n],
+ strides=[stride],
+ elem_offset=0,
+ align=128,
+ offset_factor=1,
+ type="auto",
+ )
+ B_1 = tir.match_buffer(
+ B,
+ [n],
+ strides=[stride_1],
+ elem_offset=0,
+ align=128,
+ offset_factor=1,
+ type="auto",
+ )
+ C_1 = tir.match_buffer(
+ C,
+ [n],
+ strides=[stride_2],
+ elem_offset=0,
+ align=128,
+ offset_factor=1,
+ type="auto",
+ )
+ d_1 = tir.match_buffer(
+ d,
+ [n],
+ strides=[stride_3],
+ elem_offset=0,
+ align=128,
+ offset_factor=1,
+ type="auto",
+ )
+ # body
+ for i in tir.serial(0, n):
+ d_1.data[(i * stride_3)] = (
+ tir.load("float32", A_1.data, (i * stride))
+ * tir.load("float32", B_1.data, (i * stride_1))
+ ) + tir.load("float32", C_1.data, (i * stride_2))
+
+
+def test_fma():
+ opt = tvm.transform.Sequential(
+ [
+ tvm.tir.transform.Apply(lambda f: f.with_attr("target", tvm.target.Target("llvm"))),
+ tvm.tir.transform.LowerIntrin(),
+ ]
+ )
+ mod = opt(Module())
+ assert mod["test_tir_fma"].body.body.value.op.name == "tir.call_llvm_pure_intrin"
+
+
if __name__ == "__main__":
test_nearbyint()
test_unary_intrin()
@@ -191,3 +257,4 @@ if __name__ == "__main__":
test_binary_intrin()
test_ldexp()
test_clz()
+ test_fma()