You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/01/31 10:21:04 UTC

[GitHub] [tvm] yangulei opened a new pull request #10112: [TIR, Relay] improve bfloat16 support

yangulei opened a new pull request #10112:
URL: https://github.com/apache/tvm/pull/10112


   ### Motivation:
   We are enabling [bfloat16](https://discuss.tvm.apache.org/t/rfc-add-bfloat16-data-type/6778) in [BYOC-oneDNN](https://discuss.tvm.apache.org/t/rfc-byoc-intel-r-onednn-integration/11582) following the path: [float32 graph] --> \<[AMP](https://discuss.tvm.apache.org/t/rfc-relay-fp32-fp16-model-support/9994)\> --> [bfloat16 graph] --> \<BYOC\> --> [TVM + oneDNN module]. While some of the Passes like `FoldConstant` can not work for bfloat16 before the improvements below.
   
   ### Changes:
   - Add runtime datatype dispatch and skip asserts for uint16 for bfloat16 compatibility.
   - Add bfloat16 casting for unary intrinsic operators to enable the graph optimization.
   - Improve the bf16_legalize module to enable bfloat16 lowering.
   
   With those improvements, a float32 graph could be converted to bfloat16 through AMP, and then be lowered to inference in bfloat16 mode now.
   
   ### Tested Models (gluoncv):
   - ResNet<18/34/50/101/152>_v1b
   - VGG<11/13/16/19>
   - VGG<11/13/16/19>_bn
   - DenseNet121
   - InceptionV3
   
   > By tested I mean I confirm it did some transformation on the graph and a forward pass could be run on CPU and matches the fp32 output somewhat. I have nothing on performance metrics or other devices yet.
   
   As @AndrewZhaoLuo said at https://github.com/apache/tvm/pull/8069
   
   ### Pending:
   The support for bfloat16 in BYOC-oneDNN is based on [multi-blocking layout transform](https://github.com/apache/tvm/pull/9996) and the [extensions on BYOC-oneDNN](https://github.com/apache/tvm/pull/9995) and pending.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] yangulei commented on a change in pull request #10112: [TIR, Relay] improve bfloat16 support

Posted by GitBox <gi...@apache.org>.
yangulei commented on a change in pull request #10112:
URL: https://github.com/apache/tvm/pull/10112#discussion_r803256748



##########
File path: tests/python/relay/test_cpp_build_module.py
##########
@@ -93,6 +93,35 @@ def test_fp16_build():
     np.testing.assert_allclose(out.numpy(), X.numpy() + Y.numpy(), atol=1e-5, rtol=1e-5)
 
 
+@tvm.testing.requires_llvm
+def test_bf16_build():
+    data = relay.var("data", shape=(1, 3, 224, 224), dtype='float32')
+    weight = relay.var("weight", shape=(64, 3, 7, 7), dtype='float32')
+    bn_gamma = relay.var("gamma", shape=(64,), dtype='float32')
+    bn_beta = relay.var("beta", shape=(64,), dtype='float32')
+    bn_mean = relay.var("mean", shape=(64,), dtype='float32')
+    bn_var = relay.var("var", shape=(64,), dtype='float32')
+    params = {
+        "weight": np.random.uniform(-1, 1, size=(64, 3, 7, 7)).astype('float32'),
+        "gamma": np.random.uniform(-1, 1, size=(64, )).astype('float32'),
+        "beta": np.random.uniform(-1, 1, size=(64, )).astype('float32'),
+        "mean": np.random.uniform(-1, 1, size=(64, )).astype('float32'),
+        "var": np.random.uniform(-1, 1, size=(64, )).astype('float32'),
+    }
+    conv_bf16 = relay.nn.conv2d(relay.cast(data, 'bfloat16'), relay.cast(weight, 'bfloat16'),
+                                strides=(2, 2), padding=(3, 3, 3, 3), channels=64, kernel_size=(7, 7), out_dtype='bfloat16')
+    bn_bf16 = relay.nn.batch_norm(conv_bf16, relay.cast(bn_gamma, 'bfloat16'),
+                                  relay.cast(bn_beta, 'bfloat16'), relay.cast(bn_mean, 'bfloat16'), relay.cast(bn_var, 'bfloat16'))
+    relu_bf16 = relay.nn.relu(bn_bf16[0])
+    maxpool_bf16 = relay.nn.max_pool2d(
+        relu_bf16, pool_size=(2, 2), strides=(2, 2))
+    avgpool_bf16 = relay.nn.avg_pool2d(
+        maxpool_bf16, pool_size=(2, 2), strides=(2, 2))
+    mod_bf16 = tvm.IRModule.from_expr(avgpool_bf16)
+    with tvm.transform.PassContext(opt_level=3):
+        relay.build(mod_bf16, target="llvm", params=params)

Review comment:
       Good point.
   The correctness of simple bfloat16 `adding` has been checked at https://github.com/apache/tvm/blob/14d0187ce9cefc41e33aa30b55c08a75a6711732/tests/python/unittest/test_target_codegen_llvm.py#L739
   While, verifying the correctness of bfloat16 inference is much more complex than fist thought.
   We usually use relative error and absolute error to evaluate the accuracy, but both of them could be large for a simple `multiply` with random inputs.
   On the other hand, the MSE of the outputs (array with len=1000) of native bfloat16 inference for ResNet\<18/34/50/101/152\>_v1b and VGG\<11/13/16/19\> are about 0.5% to 1%, and less than 0.2% for BYOC-oneDNN inference. Also the **_envelope curve_** of the bfloat16 and float32 outputs are close.
   I couldn't find any metrics good enough to estimate the accuracy of bfloat16 inference. Any suggestions for this?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] comaniac commented on a change in pull request #10112: [TIR, Relay] improve bfloat16 support

Posted by GitBox <gi...@apache.org>.
comaniac commented on a change in pull request #10112:
URL: https://github.com/apache/tvm/pull/10112#discussion_r803238389



##########
File path: tests/python/relay/test_cpp_build_module.py
##########
@@ -93,6 +93,35 @@ def test_fp16_build():
     np.testing.assert_allclose(out.numpy(), X.numpy() + Y.numpy(), atol=1e-5, rtol=1e-5)
 
 
+@tvm.testing.requires_llvm
+def test_bf16_build():
+    data = relay.var("data", shape=(1, 3, 224, 224), dtype='float32')
+    weight = relay.var("weight", shape=(64, 3, 7, 7), dtype='float32')
+    bn_gamma = relay.var("gamma", shape=(64,), dtype='float32')
+    bn_beta = relay.var("beta", shape=(64,), dtype='float32')
+    bn_mean = relay.var("mean", shape=(64,), dtype='float32')
+    bn_var = relay.var("var", shape=(64,), dtype='float32')
+    params = {
+        "weight": np.random.uniform(-1, 1, size=(64, 3, 7, 7)).astype('float32'),
+        "gamma": np.random.uniform(-1, 1, size=(64, )).astype('float32'),
+        "beta": np.random.uniform(-1, 1, size=(64, )).astype('float32'),
+        "mean": np.random.uniform(-1, 1, size=(64, )).astype('float32'),
+        "var": np.random.uniform(-1, 1, size=(64, )).astype('float32'),
+    }
+    conv_bf16 = relay.nn.conv2d(relay.cast(data, 'bfloat16'), relay.cast(weight, 'bfloat16'),
+                                strides=(2, 2), padding=(3, 3, 3, 3), channels=64, kernel_size=(7, 7), out_dtype='bfloat16')
+    bn_bf16 = relay.nn.batch_norm(conv_bf16, relay.cast(bn_gamma, 'bfloat16'),
+                                  relay.cast(bn_beta, 'bfloat16'), relay.cast(bn_mean, 'bfloat16'), relay.cast(bn_var, 'bfloat16'))
+    relu_bf16 = relay.nn.relu(bn_bf16[0])
+    maxpool_bf16 = relay.nn.max_pool2d(
+        relu_bf16, pool_size=(2, 2), strides=(2, 2))
+    avgpool_bf16 = relay.nn.avg_pool2d(
+        maxpool_bf16, pool_size=(2, 2), strides=(2, 2))
+    mod_bf16 = tvm.IRModule.from_expr(avgpool_bf16)
+    with tvm.transform.PassContext(opt_level=3):
+        relay.build(mod_bf16, target="llvm", params=params)

Review comment:
       Could we also verify the output correctness?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] comaniac commented on a change in pull request #10112: [TIR, Relay] improve bfloat16 support

Posted by GitBox <gi...@apache.org>.
comaniac commented on a change in pull request #10112:
URL: https://github.com/apache/tvm/pull/10112#discussion_r803259544



##########
File path: tests/python/relay/test_cpp_build_module.py
##########
@@ -93,6 +93,35 @@ def test_fp16_build():
     np.testing.assert_allclose(out.numpy(), X.numpy() + Y.numpy(), atol=1e-5, rtol=1e-5)
 
 
+@tvm.testing.requires_llvm
+def test_bf16_build():
+    data = relay.var("data", shape=(1, 3, 224, 224), dtype='float32')
+    weight = relay.var("weight", shape=(64, 3, 7, 7), dtype='float32')
+    bn_gamma = relay.var("gamma", shape=(64,), dtype='float32')
+    bn_beta = relay.var("beta", shape=(64,), dtype='float32')
+    bn_mean = relay.var("mean", shape=(64,), dtype='float32')
+    bn_var = relay.var("var", shape=(64,), dtype='float32')
+    params = {
+        "weight": np.random.uniform(-1, 1, size=(64, 3, 7, 7)).astype('float32'),
+        "gamma": np.random.uniform(-1, 1, size=(64, )).astype('float32'),
+        "beta": np.random.uniform(-1, 1, size=(64, )).astype('float32'),
+        "mean": np.random.uniform(-1, 1, size=(64, )).astype('float32'),
+        "var": np.random.uniform(-1, 1, size=(64, )).astype('float32'),
+    }
+    conv_bf16 = relay.nn.conv2d(relay.cast(data, 'bfloat16'), relay.cast(weight, 'bfloat16'),
+                                strides=(2, 2), padding=(3, 3, 3, 3), channels=64, kernel_size=(7, 7), out_dtype='bfloat16')
+    bn_bf16 = relay.nn.batch_norm(conv_bf16, relay.cast(bn_gamma, 'bfloat16'),
+                                  relay.cast(bn_beta, 'bfloat16'), relay.cast(bn_mean, 'bfloat16'), relay.cast(bn_var, 'bfloat16'))
+    relu_bf16 = relay.nn.relu(bn_bf16[0])
+    maxpool_bf16 = relay.nn.max_pool2d(
+        relu_bf16, pool_size=(2, 2), strides=(2, 2))
+    avgpool_bf16 = relay.nn.avg_pool2d(
+        maxpool_bf16, pool_size=(2, 2), strides=(2, 2))
+    mod_bf16 = tvm.IRModule.from_expr(avgpool_bf16)
+    with tvm.transform.PassContext(opt_level=3):
+        relay.build(mod_bf16, target="llvm", params=params)

Review comment:
       I see. Does that make sense if we calculate the a reference result using FP32 and cast it to bfloat16 for comparison? This is the only way I could think of, so I have no clue if that doesn't make sense.
   
   @masahi @AndrewZhaoLuo do you have comments?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] Hzfengsy commented on a change in pull request #10112: [TIR, Relay] improve bfloat16 support

Posted by GitBox <gi...@apache.org>.
Hzfengsy commented on a change in pull request #10112:
URL: https://github.com/apache/tvm/pull/10112#discussion_r803273135



##########
File path: src/relay/op/nn/nn.cc
##########
@@ -1177,7 +1177,8 @@ bool NLLLossRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
                                      << ", weights shape = " << weights->shape);
     return false;
   }
-  if (!(predictions->dtype == weights->dtype && predictions->dtype.is_float())) {
+  if (!(predictions->dtype == weights->dtype &&
+        (predictions->dtype.is_float() || predictions->dtype.is_bfloat16()))) {

Review comment:
       Shall we let `is_float()` be true for `bfloat16` exprs?

##########
File path: src/relay/backend/utils.h
##########
@@ -302,6 +302,8 @@ inline std::string DType2String(const tvm::DataType dtype) {
     os << "int";
   } else if (dtype.is_uint()) {
     os << "uint";
+  } else if (dtype.is_bfloat16()) {
+    os << "bfloat";

Review comment:
       ```suggestion
       os << "bfloat16";
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] yangulei commented on a change in pull request #10112: [TIR, Relay] improve bfloat16 support

Posted by GitBox <gi...@apache.org>.
yangulei commented on a change in pull request #10112:
URL: https://github.com/apache/tvm/pull/10112#discussion_r803333719



##########
File path: src/relay/op/nn/nn.cc
##########
@@ -1177,7 +1177,8 @@ bool NLLLossRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
                                      << ", weights shape = " << weights->shape);
     return false;
   }
-  if (!(predictions->dtype == weights->dtype && predictions->dtype.is_float())) {
+  if (!(predictions->dtype == weights->dtype &&
+        (predictions->dtype.is_float() || predictions->dtype.is_bfloat16()))) {

Review comment:
       I prefer this way too, since they **are** all floating-point datatypes.
   While there are some practice inconsistences so far, for example, if we let `is_float() == true` for `bfloat16`, then we cannot distinguish `bfloat16` and `float16` anymore as they both satisfy the condition `is_float() == true && bits() == 16`.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] yangulei commented on pull request #10112: [TIR, Relay] improve bfloat16 support

Posted by GitBox <gi...@apache.org>.
yangulei commented on pull request #10112:
URL: https://github.com/apache/tvm/pull/10112#issuecomment-1042474807


   Hi @billishyahao ,
   You can checkout [my fork](https://github.com/yangulei/tvm/tree/dev_byoc). The [readme file](https://github.com/yangulei/tvm/blob/dev_byoc/bench_oneDNN_BYOC/README.md) include a brief introduction about how to setup the environment and do the benchmarking. Be aware that this is not the final branch for upstreaming. You need a CPU with AVX-512 enabled to support bfloat16 functionally, and a CPU with AMX enabled to support bfloat16 natively, take a look at [oneDNN document](https://oneapi-src.github.io/oneDNN/dev_guide_data_types.html) for more details.
   FYI: we plan to upstream a tutorial about BYOC-oneDNN after the PRs about this are merged.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] yangulei commented on a change in pull request #10112: [TIR, Relay] improve bfloat16 support

Posted by GitBox <gi...@apache.org>.
yangulei commented on a change in pull request #10112:
URL: https://github.com/apache/tvm/pull/10112#discussion_r803333719



##########
File path: src/relay/op/nn/nn.cc
##########
@@ -1177,7 +1177,8 @@ bool NLLLossRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
                                      << ", weights shape = " << weights->shape);
     return false;
   }
-  if (!(predictions->dtype == weights->dtype && predictions->dtype.is_float())) {
+  if (!(predictions->dtype == weights->dtype &&
+        (predictions->dtype.is_float() || predictions->dtype.is_bfloat16()))) {

Review comment:
       I prefer this way too, since they **are** all floating-point datatypes.
   While there are some practical inconsistences so far, for example, if we let `is_float() == true` for `bfloat16`, then we cannot distinguish `bfloat16` and `float16` anymore as they both satisfy the condition `is_float() == true && bits() == 16`.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] yangulei commented on a change in pull request #10112: [TIR, Relay] improve bfloat16 support

Posted by GitBox <gi...@apache.org>.
yangulei commented on a change in pull request #10112:
URL: https://github.com/apache/tvm/pull/10112#discussion_r814422617



##########
File path: include/tvm/tir/op.h
##########
@@ -835,10 +835,18 @@ TVM_DLL PrimExpr q_multiply_shift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr s
                                   Span span = Span());
 
 // Intrinsic operators
-#define TVM_DECLARE_INTRIN_UNARY(OpName)                   \
-  inline PrimExpr OpName(PrimExpr x, Span span = Span()) { \
-    static const Op& op = Op::Get("tir." #OpName);         \
-    return tir::Call(x.dtype(), op, {x}, span);            \
+#define TVM_DECLARE_INTRIN_UNARY(OpName)                       \
+  inline PrimExpr OpName(PrimExpr x, Span span = Span()) {     \
+    static const Op& op = Op::Get("tir." #OpName);             \
+    if (x.dtype().is_bfloat16()) {                             \
+      DataType srcType = x.dtype();                            \
+      DataType dstType(kDLFloat, 32, srcType.lanes());         \
+      PrimExpr castX = tir::Cast(dstType, {x}, span);          \
+      PrimExpr result = tir::Call(dstType, op, {castX}, span); \
+      return tir::Cast(srcType, {result}, span);               \

Review comment:
       OK, Thanks. I will refine the code style in next PR.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on a change in pull request #10112: [TIR, Relay] improve bfloat16 support

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #10112:
URL: https://github.com/apache/tvm/pull/10112#discussion_r814241095



##########
File path: tests/python/relay/test_cpp_build_module.py
##########
@@ -93,6 +93,35 @@ def test_fp16_build():
     np.testing.assert_allclose(out.numpy(), X.numpy() + Y.numpy(), atol=1e-5, rtol=1e-5)
 
 
+@tvm.testing.requires_llvm
+def test_bf16_build():
+    data = relay.var("data", shape=(1, 3, 224, 224), dtype='float32')
+    weight = relay.var("weight", shape=(64, 3, 7, 7), dtype='float32')
+    bn_gamma = relay.var("gamma", shape=(64,), dtype='float32')
+    bn_beta = relay.var("beta", shape=(64,), dtype='float32')
+    bn_mean = relay.var("mean", shape=(64,), dtype='float32')
+    bn_var = relay.var("var", shape=(64,), dtype='float32')
+    params = {
+        "weight": np.random.uniform(-1, 1, size=(64, 3, 7, 7)).astype('float32'),
+        "gamma": np.random.uniform(-1, 1, size=(64, )).astype('float32'),
+        "beta": np.random.uniform(-1, 1, size=(64, )).astype('float32'),
+        "mean": np.random.uniform(-1, 1, size=(64, )).astype('float32'),
+        "var": np.random.uniform(-1, 1, size=(64, )).astype('float32'),
+    }
+    conv_bf16 = relay.nn.conv2d(relay.cast(data, 'bfloat16'), relay.cast(weight, 'bfloat16'),
+                                strides=(2, 2), padding=(3, 3, 3, 3), channels=64, kernel_size=(7, 7), out_dtype='bfloat16')
+    bn_bf16 = relay.nn.batch_norm(conv_bf16, relay.cast(bn_gamma, 'bfloat16'),
+                                  relay.cast(bn_beta, 'bfloat16'), relay.cast(bn_mean, 'bfloat16'), relay.cast(bn_var, 'bfloat16'))
+    relu_bf16 = relay.nn.relu(bn_bf16[0])
+    maxpool_bf16 = relay.nn.max_pool2d(
+        relu_bf16, pool_size=(2, 2), strides=(2, 2))
+    avgpool_bf16 = relay.nn.avg_pool2d(
+        maxpool_bf16, pool_size=(2, 2), strides=(2, 2))
+    mod_bf16 = tvm.IRModule.from_expr(avgpool_bf16)
+    with tvm.transform.PassContext(opt_level=3):
+        relay.build(mod_bf16, target="llvm", params=params)

Review comment:
       Yeah, even fp16 is not trivial for accuracy checking. I can imagine how hard bfloat is too.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on a change in pull request #10112: [TIR, Relay] improve bfloat16 support

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #10112:
URL: https://github.com/apache/tvm/pull/10112#discussion_r814240506



##########
File path: include/tvm/tir/op.h
##########
@@ -835,10 +835,18 @@ TVM_DLL PrimExpr q_multiply_shift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr s
                                   Span span = Span());
 
 // Intrinsic operators
-#define TVM_DECLARE_INTRIN_UNARY(OpName)                   \
-  inline PrimExpr OpName(PrimExpr x, Span span = Span()) { \
-    static const Op& op = Op::Get("tir." #OpName);         \
-    return tir::Call(x.dtype(), op, {x}, span);            \
+#define TVM_DECLARE_INTRIN_UNARY(OpName)                       \
+  inline PrimExpr OpName(PrimExpr x, Span span = Span()) {     \
+    static const Op& op = Op::Get("tir." #OpName);             \
+    if (x.dtype().is_bfloat16()) {                             \
+      DataType srcType = x.dtype();                            \
+      DataType dstType(kDLFloat, 32, srcType.lanes());         \
+      PrimExpr castX = tir::Cast(dstType, {x}, span);          \
+      PrimExpr result = tir::Call(dstType, op, {castX}, span); \
+      return tir::Cast(srcType, {result}, span);               \

Review comment:
       Just do `tir::Cast("bfloat16", {result}, span)`. We use `camel_case`.
   
   Can be fixed in a follow up.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] yangulei commented on a change in pull request #10112: [TIR, Relay] improve bfloat16 support

Posted by GitBox <gi...@apache.org>.
yangulei commented on a change in pull request #10112:
URL: https://github.com/apache/tvm/pull/10112#discussion_r803262987



##########
File path: tests/python/relay/test_cpp_build_module.py
##########
@@ -93,6 +93,35 @@ def test_fp16_build():
     np.testing.assert_allclose(out.numpy(), X.numpy() + Y.numpy(), atol=1e-5, rtol=1e-5)
 
 
+@tvm.testing.requires_llvm
+def test_bf16_build():
+    data = relay.var("data", shape=(1, 3, 224, 224), dtype='float32')
+    weight = relay.var("weight", shape=(64, 3, 7, 7), dtype='float32')
+    bn_gamma = relay.var("gamma", shape=(64,), dtype='float32')
+    bn_beta = relay.var("beta", shape=(64,), dtype='float32')
+    bn_mean = relay.var("mean", shape=(64,), dtype='float32')
+    bn_var = relay.var("var", shape=(64,), dtype='float32')
+    params = {
+        "weight": np.random.uniform(-1, 1, size=(64, 3, 7, 7)).astype('float32'),
+        "gamma": np.random.uniform(-1, 1, size=(64, )).astype('float32'),
+        "beta": np.random.uniform(-1, 1, size=(64, )).astype('float32'),
+        "mean": np.random.uniform(-1, 1, size=(64, )).astype('float32'),
+        "var": np.random.uniform(-1, 1, size=(64, )).astype('float32'),
+    }
+    conv_bf16 = relay.nn.conv2d(relay.cast(data, 'bfloat16'), relay.cast(weight, 'bfloat16'),
+                                strides=(2, 2), padding=(3, 3, 3, 3), channels=64, kernel_size=(7, 7), out_dtype='bfloat16')
+    bn_bf16 = relay.nn.batch_norm(conv_bf16, relay.cast(bn_gamma, 'bfloat16'),
+                                  relay.cast(bn_beta, 'bfloat16'), relay.cast(bn_mean, 'bfloat16'), relay.cast(bn_var, 'bfloat16'))
+    relu_bf16 = relay.nn.relu(bn_bf16[0])
+    maxpool_bf16 = relay.nn.max_pool2d(
+        relu_bf16, pool_size=(2, 2), strides=(2, 2))
+    avgpool_bf16 = relay.nn.avg_pool2d(
+        maxpool_bf16, pool_size=(2, 2), strides=(2, 2))
+    mod_bf16 = tvm.IRModule.from_expr(avgpool_bf16)
+    with tvm.transform.PassContext(opt_level=3):
+        relay.build(mod_bf16, target="llvm", params=params)

Review comment:
       The errors I mentioned above are using FP32 results as reference, and the bfloat16 results are casted to FP32 for comparisons.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] billishyahao edited a comment on pull request #10112: [TIR, Relay] improve bfloat16 support

Posted by GitBox <gi...@apache.org>.
billishyahao edited a comment on pull request #10112:
URL: https://github.com/apache/tvm/pull/10112#issuecomment-1041258684


   Tested Models (gluoncv):
   ResNet<18/34/50/101/152>_v1b
   VGG<11/13/16/19>
   VGG<11/13/16/19>_bn
   DenseNet121
   InceptionV3
   >By tested I mean I confirm it did some transformation on the graph and a forward pass could be run on CPU and matches the fp32 output somewhat. I have nothing on performance metrics or other devices yet.
   
   Hi @yangulei , Thanks for the patch. Could you share some details about the transformation work you had done on these graphs? Could be some op modifications?


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] billishyahao commented on pull request #10112: [TIR, Relay] improve bfloat16 support

Posted by GitBox <gi...@apache.org>.
billishyahao commented on pull request #10112:
URL: https://github.com/apache/tvm/pull/10112#issuecomment-1041258684


   Tested Models (gluoncv):
   ResNet<18/34/50/101/152>_v1b
   VGG<11/13/16/19>
   VGG<11/13/16/19>_bn
   DenseNet121
   InceptionV3
   >By tested I mean I confirm it did some transformation on the graph and a forward pass could be run on CPU and matches the fp32 output somewhat. I have nothing on performance metrics or other devices yet.
   
   Hi @yangulei , Thanks for the patch. Could you share some details about the work you had done on these graphs? 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi merged pull request #10112: [TIR, Relay] improve bfloat16 support

Posted by GitBox <gi...@apache.org>.
masahi merged pull request #10112:
URL: https://github.com/apache/tvm/pull/10112


   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org