You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/04/05 20:57:30 UTC

[tvm] branch main updated: Handle float16 in ConstantNode visitor in LowerToTECompute (#10902)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 138fafffdb Handle float16 in ConstantNode visitor in LowerToTECompute (#10902)
138fafffdb is described below

commit 138fafffdb9e136d74dc0068102e22f27169d4eb
Author: Krzysztof Parzyszek <kp...@quicinc.com>
AuthorDate: Tue Apr 5 15:57:25 2022 -0500

    Handle float16 in ConstantNode visitor in LowerToTECompute (#10902)
    
    Load the bit representation of Float16 as uint16, and convert it to the
    corresponding float32 value.
---
 include/tvm/runtime/builtin_fp16.h     | 36 ++++++++++++++++++++++++++++++++++
 src/relay/backend/te_compiler_cache.cc |  3 +++
 2 files changed, 39 insertions(+)

diff --git a/include/tvm/runtime/builtin_fp16.h b/include/tvm/runtime/builtin_fp16.h
new file mode 100644
index 0000000000..e93aea228c
--- /dev/null
+++ b/include/tvm/runtime/builtin_fp16.h
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file builtin_fp16.h
+ * \brief Functions for conversion between fp32 and fp16
+ */
+#ifndef TVM_RUNTIME_BUILTIN_FP16_H_
+#define TVM_RUNTIME_BUILTIN_FP16_H_
+
+#include <tvm/runtime/c_runtime_api.h>
+
+#include <cstdint>
+
+extern "C" {
+TVM_DLL uint16_t __gnu_f2h_ieee(float);
+TVM_DLL float __gnu_h2f_ieee(uint16_t);
+}
+
+#endif  // TVM_RUNTIME_BUILTIN_FP16_H_
diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc
index 963732be54..e0e7277676 100644
--- a/src/relay/backend/te_compiler_cache.cc
+++ b/src/relay/backend/te_compiler_cache.cc
@@ -29,6 +29,7 @@
 #include <tvm/relay/op.h>
 #include <tvm/relay/op_attr_types.h>
 #include <tvm/relay/op_strategy.h>
+#include <tvm/runtime/builtin_fp16.h>
 #include <tvm/runtime/device_api.h>
 #include <tvm/runtime/registry.h>
 #include <tvm/te/operation.h>
@@ -175,6 +176,8 @@ class LowerToTECompute : public backend::MemoizedExprTranslator<Array<te::Tensor
               return make_const(dtype, static_cast<const int32_t*>(data)[0]);
             } else if (dtype == DataType::Int(64)) {
               return make_const(dtype, static_cast<const int64_t*>(data)[0]);
+            } else if (dtype == DataType::Float(16)) {
+              return make_const(dtype, __gnu_h2f_ieee(static_cast<const uint16_t*>(data)[0]));
             } else if (dtype == DataType::Float(32)) {
               return make_const(dtype, static_cast<const float*>(data)[0]);
             } else if (dtype == DataType::Float(64)) {