You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2020/07/08 20:04:55 UTC
[incubator-tvm] branch master updated: [Frontend][Relay] Add Parser
2.0 (#5932)
This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new f9e905a [Frontend][Relay] Add Parser 2.0 (#5932)
f9e905a is described below
commit f9e905a3cc1f5497fa70051393b7e8cfac642fee
Author: Jared Roesch <jr...@octoml.ai>
AuthorDate: Wed Jul 8 13:04:42 2020 -0700
[Frontend][Relay] Add Parser 2.0 (#5932)
---
CMakeLists.txt | 1 +
include/tvm/ir/span.h | 18 +-
include/tvm/parser/parser.h | 40 +
python/tvm/__init__.py | 3 +
python/tvm/parser/__init__.py | 27 +
python/tvm/parser/_ffi_api.py | 21 +
src/ir/span.cc | 7 +-
src/node/structural_equal.cc | 2 +-
src/parser/diagnostic.h | 176 +++++
src/parser/op_table.h | 97 +++
src/parser/parser.cc | 1408 +++++++++++++++++++++++++++++++++
src/parser/token.h | 362 +++++++++
src/parser/tokenizer.h | 459 +++++++++++
tests/python/relay/test_ir_nodes.py | 8 +-
tests/python/relay/test_ir_parser2.py | 891 +++++++++++++++++++++
15 files changed, 3502 insertions(+), 18 deletions(-)
diff --git a/CMakeLists.txt b/CMakeLists.txt
index d7faa8a..aaddebd 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -192,6 +192,7 @@ file(GLOB_RECURSE COMPILER_SRCS
src/autotvm/*.cc
src/tir/*.cc
src/driver/*.cc
+ src/parser/*.cc
src/printer/*.cc
src/support/*.cc
)
diff --git a/include/tvm/ir/span.h b/include/tvm/ir/span.h
index 84d6a7b..40f854b 100644
--- a/include/tvm/ir/span.h
+++ b/include/tvm/ir/span.h
@@ -79,22 +79,22 @@ class Span;
*/
class SpanNode : public Object {
public:
- /*! \brief The source name */
+ /*! \brief The source name. */
SourceName source;
- /*! \brief Line number */
- int lineno;
- /*! \brief column offset */
- int col_offset;
+ /*! \brief The line number. */
+ int line;
+ /*! \brief The column offset. */
+ int column;
+
// override attr visitor
void VisitAttrs(AttrVisitor* v) {
v->Visit("source", &source);
- v->Visit("lineno", &lineno);
- v->Visit("col_offset", &col_offset);
+ v->Visit("line", &line);
+ v->Visit("column", &column);
}
bool SEqualReduce(const SpanNode* other, SEqualReducer equal) const {
- return equal(source, other->source) && equal(lineno, other->lineno) &&
- equal(col_offset, other->col_offset);
+ return equal(source, other->source) && equal(line, other->line) && equal(column, other->column);
}
static constexpr const char* _type_key = "Span";
diff --git a/include/tvm/parser/parser.h b/include/tvm/parser/parser.h
new file mode 100644
index 0000000..9380358
--- /dev/null
+++ b/include/tvm/parser/parser.h
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef TVM_PARSER_PARSER_H_
+#define TVM_PARSER_PARSER_H_
+/*!
+ * \file parser.h
+ * \brief A parser for TVM IR.
+ */
+#include <tvm/runtime/packed_func.h>
+#include <tvm/runtime/registry.h>
+
+#include <fstream>
+#include <string>
+
+namespace tvm {
+namespace parser {
+
+IRModule Parse(std::string file_name, std::string file_content);
+
+} // namespace parser
+} // namespace tvm
+
+#endif // TVM_PARSER_PARSER_H_
diff --git a/python/tvm/__init__.py b/python/tvm/__init__.py
index 6cbc6d2..cb1f4d2 100644
--- a/python/tvm/__init__.py
+++ b/python/tvm/__init__.py
@@ -57,6 +57,9 @@ from . import testing
# tvm.driver
from .driver import build, lower
+# tvm.parser
+from . import parser
+
# others
from . import arith
diff --git a/python/tvm/parser/__init__.py b/python/tvm/parser/__init__.py
new file mode 100644
index 0000000..071c464
--- /dev/null
+++ b/python/tvm/parser/__init__.py
@@ -0,0 +1,27 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""The under development unified IR parsing infrastructure."""
+from . import _ffi_api
+
+def parse(source, source_name="from_string"):
+ return _ffi_api.ParseModule(source_name, source)
+
+def parse_expr(source):
+ return _ffi_api.ParseExpr("string", source)
+
+def fromtext(source, source_name="from_string"):
+ return parse(str(source), str(source_name))
diff --git a/python/tvm/parser/_ffi_api.py b/python/tvm/parser/_ffi_api.py
new file mode 100644
index 0000000..7fa3b78
--- /dev/null
+++ b/python/tvm/parser/_ffi_api.py
@@ -0,0 +1,21 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""FFI APIs for tvm.ir"""
+import tvm._ffi
+
+
+tvm._ffi._init_api("parser", __name__)
diff --git a/src/ir/span.cc b/src/ir/span.cc
index 565439f..64b42ab 100644
--- a/src/ir/span.cc
+++ b/src/ir/span.cc
@@ -64,8 +64,8 @@ TVM_REGISTER_NODE_TYPE(SourceNameNode)
Span::Span(SourceName source, int lineno, int col_offset) {
auto n = make_object<SpanNode>();
n->source = std::move(source);
- n->lineno = lineno;
- n->col_offset = col_offset;
+ n->line = lineno;
+ n->column = col_offset;
data_ = std::move(n);
}
@@ -78,7 +78,6 @@ TVM_REGISTER_GLOBAL("ir.Span").set_body_typed([](SourceName source, int lineno,
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<SpanNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const SpanNode*>(ref.get());
- p->stream << "Span(" << node->source << ", " << node->lineno << ", " << node->col_offset
- << ")";
+ p->stream << "Span(" << node->source << ", " << node->line << ", " << node->column << ")";
});
} // namespace tvm
diff --git a/src/node/structural_equal.cc b/src/node/structural_equal.cc
index 9fcf510..e05cbbb 100644
--- a/src/node/structural_equal.cc
+++ b/src/node/structural_equal.cc
@@ -29,7 +29,7 @@
namespace tvm {
-// Define the dispatch functio here since primary user is in this file.
+// Define the dispatch function here since primary user is in this file.
bool ReflectionVTable::SEqualReduce(const Object* self, const Object* other,
SEqualReducer equal) const {
uint32_t tindex = self->type_index();
diff --git a/src/parser/diagnostic.h b/src/parser/diagnostic.h
new file mode 100644
index 0000000..19f5d20
--- /dev/null
+++ b/src/parser/diagnostic.h
@@ -0,0 +1,176 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file diagnostic.h
+ * \brief A new diagnostic interface for TVM error reporting.
+ *
+ * A prototype of the new diagnostic reporting interface for TVM.
+ *
+ * Eventually we hope to promote this file to the top-level and
+ * replace the existing errors.h.
+ */
+
+#ifndef TVM_PARSER_DIAGNOSTIC_H_
+#define TVM_PARSER_DIAGNOSTIC_H_
+
+#include <tvm/ir/span.h>
+#include <tvm/runtime/container.h>
+#include <tvm/runtime/object.h>
+
+#include <fstream>
+#include <string>
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace parser {
+
+/*! \brief A program source in any language.
+ *
+ * Could represent the source from an ML framework or the internal
+ * source of a TVM program.
+ */
+struct Source {
+ /*! \brief The raw source. */
+ std::string source;
+ /*! \brief A mapping of line breaks into the raw source. */
+ std::vector<std::pair<int, int>> line_map;
+
+ /*! \brief An empty source. */
+ Source() : source(), line_map() {}
+
+ /*! \brief Construct a source from a string. */
+ explicit Source(const std::string& source) : source(source) {
+ int index = 0;
+ int length = 0;
+ line_map.push_back({index, length});
+ for (auto c : source) {
+ if (c == '\n') {
+ // Record the length of the line.
+ line_map.back().second = length;
+ // Bump past the newline.
+ index += 1;
+ // Record the start of the next line, and put placeholder for length.
+ line_map.push_back({index, 0});
+ // Reset length to zero.
+ length = 0;
+ } else {
+ length += 1;
+ index += 1;
+ }
+ }
+ line_map.back().second = length;
+ }
+
+ Source(const Source& source) : source(source.source), line_map(source.line_map) {}
+
+ /*! \brief Generate an error message at a specific line and column with the
+ * annotated message.
+ *
+ * The error is written directly to the `out` std::ostream.
+ *
+ * \param out The output ostream.
+ * \param line The line at which to report a diagnostic.
+ * \param line The column at which to report a diagnostic.
+ * \param msg The message to attach.
+ */
+ void ReportAt(std::ostream& out, int line, int column, const std::string& msg) const {
+ CHECK(line - 1 <= static_cast<int64_t>(line_map.size()))
+ << "requested line: " << (line - 1) << "line_map size: " << line_map.size()
+ << "source: " << source;
+
+ // Adjust for zero indexing, now have (line_start, line_length);
+ auto range = line_map.at(line - 1);
+ int line_start = range.first;
+ int line_length = range.second;
+ out << "file:" << line << ":" << column << ": parse error: " << msg << std::endl;
+ out << " " << source.substr(line_start, line_length) << std::endl;
+ out << " ";
+ std::stringstream marker;
+ for (int i = 1; i <= line_length; i++) {
+ if (i == column) {
+ marker << "^";
+ } else if ((column - i) < 3) {
+ marker << "~";
+ } else if ((i - column) < 3) {
+ marker << "~";
+ } else {
+ marker << " ";
+ }
+ }
+ out << marker.str();
+ out << std::endl;
+ }
+};
+
+/*! \brief The diagnostic level, controls the printing of the message. */
+enum DiagnosticLevel {
+ Bug,
+ Error,
+ Warning,
+ Note,
+ Help,
+};
+
+/*! \brief A diagnostic message. */
+struct Diagnostic {
+ /*! \brief The level. */
+ DiagnosticLevel level;
+ /*! \brief The span at which to report an error. */
+ Span span;
+ /*! \brief The diagnostic message. */
+ std::string message;
+
+ Diagnostic(int line, int column, const std::string& message)
+ : level(DiagnosticLevel::Error), span(SourceName(), line, column), message(message) {}
+};
+
+/*! \brief A diagnostic context for recording errors against a source file.
+ * TODO(@jroesch): convert source map and improve in follow up PR, the parser
+ * assumes a single global file for now.
+ */
+struct DiagnosticContext {
+ /*! \brief The source to report against. */
+ Source source;
+
+ /*! \brief The set of diagnostics to report. */
+ std::vector<Diagnostic> diagnostics;
+
+ explicit DiagnosticContext(const Source& source) : source(source) {}
+
+ /*! \brief Emit a diagnostic. */
+ void Emit(const Diagnostic& diagnostic) { diagnostics.push_back(diagnostic); }
+
+ // TODO(@jroesch): eventually modularize the rendering interface to provide control of how to
+ // format errors.
+ void Render(std::ostream& ostream) {
+ for (auto diagnostic : diagnostics) {
+ source.ReportAt(ostream, diagnostic.span->line, diagnostic.span->column, diagnostic.message);
+ }
+
+ if (diagnostics.size()) {
+ LOG(FATAL) << "parse error occured";
+ }
+ }
+};
+
+} // namespace parser
+} // namespace tvm
+#endif // TVM_PARSER_DIAGNOSTIC_H_
diff --git a/src/parser/op_table.h b/src/parser/op_table.h
new file mode 100644
index 0000000..5af10a0
--- /dev/null
+++ b/src/parser/op_table.h
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file token.h
+ * \brief A operator table for parsing.
+ *
+ * Provides symbolic token sequences to map to TVM operators, with a given associativity and arity.
+ */
+
+#ifndef TVM_PARSER_OP_TABLE_H_
+#define TVM_PARSER_OP_TABLE_H_
+
+#include <tvm/ir/op.h>
+#include <tvm/runtime/container.h>
+#include <tvm/runtime/object.h>
+
+#include <fstream>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "./tokenizer.h"
+
+namespace tvm {
+namespace parser {
+
+struct Rule {
+ std::vector<TokenType> tokens;
+ int precedence;
+ int arity;
+ tvm::Op op;
+ bool left_assoc;
+
+ Rule() : tokens(), precedence(0), arity(0), op(tvm::Op()), left_assoc(false) {}
+
+ Rule(std::vector<TokenType> tokens, tvm::Op op, int precedence, int arity = 2,
+ bool left_assoc = false)
+ : tokens(tokens), precedence(precedence), arity(arity), op(op), left_assoc(left_assoc) {}
+
+ Rule(const Rule& rule) {
+ this->tokens = rule.tokens;
+ this->op = rule.op;
+ this->precedence = rule.precedence;
+ this->arity = rule.arity;
+ this->left_assoc = rule.left_assoc;
+ }
+};
+
+struct OperatorTable {
+ std::vector<Rule> rules;
+ std::unordered_map<std::string, Rule> this_is_a_hack;
+
+ explicit OperatorTable(std::vector<Rule> rules) : rules(rules), this_is_a_hack() {
+ for (auto rule : rules) {
+ std::stringstream key;
+ for (auto token : rule.tokens) {
+ key << ToString(token);
+ }
+ this->this_is_a_hack.insert({key.str(), rule});
+ }
+ }
+};
+
+OperatorTable DefaultOpTable() {
+ return OperatorTable(
+ {Rule({TokenType::Star}, Op::Get("multiply"), 12, 2, true),
+ Rule({TokenType::Division}, Op::Get("divide"), 12, 2, true),
+ Rule({TokenType::Plus}, Op::Get("add"), 10, 2, true),
+ Rule({TokenType::Minus}, Op::Get("subtract"), 10, 2, true),
+ Rule({TokenType::LAngle}, Op::Get("less"), 8, 2, true),
+ Rule({TokenType::LAngle, TokenType::Equal}, Op::Get("less_equal"), 8, 2, true),
+ Rule({TokenType::RAngle}, Op::Get("greater"), 8, 2, true),
+ Rule({TokenType::RAngle, TokenType::Equal}, Op::Get("greater_equal"), 8, 2, true),
+ Rule({TokenType::Equal, TokenType::Equal}, Op::Get("equal"), 7, 2, true),
+ Rule({TokenType::Bang, TokenType::Equal}, Op::Get("not_equal"), 7, 2, true)});
+}
+
+} // namespace parser
+} // namespace tvm
+#endif // TVM_PARSER_OP_TABLE_H_
diff --git a/src/parser/parser.cc b/src/parser/parser.cc
new file mode 100644
index 0000000..0aaa698
--- /dev/null
+++ b/src/parser/parser.cc
@@ -0,0 +1,1408 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file parser.cc
+ * \brief A parser for TVM IR.
+ */
+#include <tvm/ir/module.h>
+#include <tvm/node/reflection.h>
+#include <tvm/relay/adt.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/function.h>
+#include <tvm/runtime/object.h>
+#include <tvm/runtime/registry.h>
+
+#include <fstream>
+
+#include "./diagnostic.h"
+#include "./op_table.h"
+#include "./tokenizer.h"
+
+namespace tvm {
+namespace parser {
+
+using namespace relay;
+using Expr = relay::Expr;
+
+/*! \brief A wrapper structure for capturing the result of parsing
+ * a global definition *before* we add it to the IRModule.
+ *
+ * This enables the parser to parse everything in one pass before
+ * constructing the IRModule.
+ */
+struct GlobalFunc {
+ GlobalVar global;
+ Function function;
+ GlobalFunc() : global(), function() {}
+ GlobalFunc(GlobalVar global, Function function) : global(global), function(function) {}
+ GlobalFunc(const GlobalFunc& gfunc) {
+ this->global = gfunc.global;
+ this->function = gfunc.function;
+ }
+};
+
+/*! \brief A wrapper structure for capturing all top-level definitions
+ * when parsing a module.
+ */
+struct Definitions {
+ /*! \brief The set of global functions. */
+ std::vector<GlobalFunc> funcs;
+ /*! \brief The set of type definitions. */
+ std::vector<TypeData> types;
+ // TODO(@jroesch): contain meta-table below
+};
+
+/*! \brief A structure representing the semantic versioning information
+ * for a Relay program.
+ */
+class SemVer {
+ public:
+ int major_version;
+ int minor_version;
+ int patch_version;
+
+ SemVer() : major_version(0), minor_version(0), patch_version(0) {}
+ SemVer(int major_version, int minor_version, int patch_version)
+ : major_version(major_version), minor_version(minor_version), patch_version(patch_version) {}
+ SemVer(const SemVer& other)
+ : major_version(other.major_version),
+ minor_version(other.minor_version),
+ patch_version(other.patch_version) {}
+};
+
+/*! \brief A reference to a "meta-expression".
+ *
+ * In the text format we allow referencing metadata which
+ * uses a compact serialization that proceeds the main
+ * program body.
+ *
+ * We can reference this table using an expression of
+ * the form `meta[Type][index]`.
+ *
+ * We must later resolve these references to actual in-memory
+ * AST nodes but this requires first parsing the full program
+ * then expanding these temporary AST nodes into their corresponding
+ * nodes.
+ *
+ * For example the nth large constant will be pretty-printed as meta[relay.Constant][n]
+ * with its compact binary serialization residing in the metadata section at the end
+ * of the program.
+ */
+class MetaRefExprNode : public TempExprNode {
+ public:
+ /*! \brief The type key of the meta expression. */
+ std::string type_key;
+ /*! \brief The index into the type key's table. */
+ uint64_t node_index;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {}
+
+ // TODO(@jroesch): we probably will need to manually
+ // expand these with a pass.
+ Expr Realize() const final { return Expr(); }
+
+ static constexpr const char* _type_key = "relay.MetaRefExpr";
+ TVM_DECLARE_FINAL_OBJECT_INFO(MetaRefExprNode, TempExprNode);
+};
+
+class MetaRefExpr : public TempExpr {
+ public:
+ /*!
+ * \brief The constructor for MetaRefExpr
+ * \param type_key The type key of the object in the meta section.
+ * \param kind The index into that subfield.
+ */
+ TVM_DLL MetaRefExpr(std::string type_key, uint64_t node_index);
+
+ TVM_DEFINE_OBJECT_REF_METHODS(MetaRefExpr, TempExpr, MetaRefExprNode);
+};
+
+MetaRefExpr::MetaRefExpr(std::string type_key, uint64_t node_index) {
+ auto rnode = make_object<MetaRefExprNode>();
+ rnode->type_key = type_key;
+ rnode->node_index = node_index;
+ data_ = std::move(rnode);
+}
+
+/*! \brief A simple wrapper around a mapping from raw string names
+ * to a TVM variable, type variable or other binder type.
+ */
+template <typename T>
+struct Scope {
+ /*! \brief The internal map. */
+ std::unordered_map<std::string, T> name_map;
+};
+
+/*! \brief A stack of scopes.
+ *
+ * In order to properly handle scoping we must maintain a stack of scopes.
+ *
+ * A stack allows users to write programs which contain repeated variable
+ * names and to properly handle both nested scopes and removal of variables
+ * when they go out of scope.
+ *
+ * This is the classic approach to lexical scoping.
+ */
+template <typename T>
+class ScopeStack {
+ private:
+ std::vector<Scope<T>> scope_stack;
+
+ public:
+ /*! \brief Adds a variable binding to the current scope. */
+ void Add(const std::string& name, const T& value) {
+ if (!this->scope_stack.size()) {
+ LOG(FATAL) << "internal issue";
+ }
+ this->scope_stack.back().name_map.insert({name, value});
+ }
+
+ /*! \brief Looks up a variable name in the scope stack returning the matching variable
+ * in most recent scope. */
+ T Lookup(const std::string& name) {
+ for (auto scope = this->scope_stack.rbegin(); scope != this->scope_stack.rend(); ++scope) {
+ auto it = scope->name_map.find(name);
+ if (it != scope->name_map.end()) {
+ return it->second;
+ }
+ }
+ return T();
+ }
+
+ /*! \brief Adds a fresh scope. */
+ void PushStack() { this->scope_stack.push_back(Scope<T>()); }
+
+ /*! \brief Removes the most recent scope. */
+ void PopStack() { this->scope_stack.pop_back(); }
+};
+
+/*! \brief A table of interning strings as global function and type names. */
+template <typename T>
+struct InternTable {
+ /*! \brief The internal table mapping strings to a unique allocation. */
+ std::unordered_map<std::string, T> table;
+
+ /*! \brief Add the unique allocation. */
+ void Add(const std::string& name, const T& t) {
+ auto it = table.find(name);
+ if (it != table.end()) {
+ LOG(FATAL) << "duplicate name";
+ } else {
+ table.insert({name, t});
+ }
+ }
+
+ /*! \brief Return the unique allocation. */
+ Optional<T> Get(const std::string& name) {
+ auto it = table.find(name);
+ if (it != table.end()) {
+ return Optional<T>(it->second);
+ } else {
+ return Optional<T>();
+ }
+ }
+};
+
+/*! \brief The parser class is the main interface to the parser.
+ * the parser is not currently exposed beyond this .cc file.
+ *
+ * The parser is initialized with a diagnostic context, an
+ * operator table, and a token stream.
+ *
+ * The rest of the internal state is used to map the human readable
+ * form to in-memory IR representation.
+ *
+ * The main entry point to the parser are a set of parsing methods
+ * such as `ParseModule` and `ParseExpr`.
+ *
+ * As with traditional recursive descent parsers the parsing methods
+ * are factored recursively just as one would do with a formal language
+ * grammar.
+ *
+ * You can view a recursive descent parser as a human friendly way to specify
+ * a state machine, and thus this factoring is necessary as the 'state' of this
+ * machine is the combination of the current parsing method and the next token.
+ *
+ * Parsing proceeds by matching a token and then dispatching to the appropriate
+ * method to parse the next tokens in the stream.
+ *
+ * For example if we are parsing a type and encounter a "Tensor" token we switch
+ * into a mode for parsing `[`, a shape, a comma, a data type and then a `]`.
+ *
+ * Certain matches like this are unambiguous and proceed in a straight line fashion
+ * once the initial token is found. Other parsing is more complex and requires some
+ * tricks to correctly parse.
+ *
+ * For example when we find a '(' in an expression context, it may be part of
+ * a tuple, the arguments to a call, or a parenthesized expression. The below code
+ * disambiguate these cases by factoring expression parsing into a series of methods
+ * which encode the parsing context and thus how to interpret the parenthesis.
+ *
+ * For more information one should be able to read the code in order starting with
+ * `ParseModule` or `ParseExpr`.
+ */
+class Parser {
+ public:
+ /*! \brief The version that the parser is parsing. */
+ SemVer version;
+
+ /*! \brief The diagnostic context used for error reporting. */
+ DiagnosticContext diag_ctx;
+
+ /*! \brief The current position in the token stream. */
+ int pos;
+
+ /*! \brief The token stream for the parser. */
+ std::vector<Token> tokens;
+
+ /*! \brief The configured operator table. */
+ OperatorTable op_table;
+
+ /*! \brief Configure the whitespace mode, right now we ignore all whitespace. */
+ bool ignore_whitespace;
+
+ /*! \brief A global mapping for GlobalVar. */
+ InternTable<GlobalVar> global_names;
+
+ /*! \brief A global mapping for type definitions. */
+ InternTable<GlobalTypeVar> type_names;
+
+ /*! \brief A global mapping for constructor names. */
+ InternTable<Constructor> ctors;
+
+ /*! \brief A mapping from graph variable to expression, i.e., `%0 = expr`. */
+ std::unordered_map<int, Expr> graph_ctx;
+
+ /*! \brief The set of type scopes used for generics. */
+ ScopeStack<TypeVar> type_scopes;
+
+ /*! \brief The set of expression scopes used for lexical scope. */
+ ScopeStack<Var> expr_scopes;
+
+ Parser(std::vector<Token> tokens, OperatorTable op_table, Source source)
+ : diag_ctx(source), pos(0), tokens(tokens), op_table(op_table), ignore_whitespace(true) {}
+
+ /*! \brief Examine the next token in the stream, the current parser is configured to be
+ * whitespace insensitive so we will skip all whitespace or comment tokens. */
+ Token Peek() {
+ // For now we ignore all whitespace tokens and comments.
+ // We can tweak this behavior later to enable white space sensitivity in the parser.
+ while (pos < static_cast<int64_t>(tokens.size()) && ignore_whitespace &&
+ (tokens.at(pos)->token_type == TokenType::Whitespace ||
+ tokens.at(pos)->token_type == TokenType::Newline ||
+ tokens.at(pos)->token_type == TokenType::LineComment ||
+ tokens.at(pos)->token_type == TokenType::Comment)) {
+ pos++;
+ }
+
+ if (pos < static_cast<int64_t>(tokens.size())) {
+ return Token(this->tokens.at(pos));
+ } else {
+ return Token::Null();
+ }
+ }
+
+ /*! \brief Lookahead by N tokens.
+ * \param n The number of tokens to lookahead.
+ * \return The Nth token.
+ */
+ Token Lookahead(int n) {
+ CHECK_GE(n, 1) << "lookahead is only valid when n >= 1";
+
+ // We intend to skip n - 1 tokens, then return the nth.
+ auto old_pos = pos;
+ for (int i = 0; i < n - 1; i++) {
+ Peek();
+ pos++;
+ }
+
+ auto tok = Peek();
+ pos = old_pos;
+ return tok;
+ }
+
+ /*! \brief Consume a token, this method is the lowest level way to consume a token
+ * and will not ignore white space or look ahead in anyway.
+ *
+ * /param token_type The token type to match.
+ */
+ void Consume(const TokenType& token_type) {
+ if (tokens[pos]->token_type != token_type) {
+ std::string message =
+ "expected a " + Pretty(token_type) + " found " + Pretty(Peek()->token_type);
+ this->diag_ctx.Emit({tokens[pos]->line, tokens[pos]->column, message});
+ this->diag_ctx.Render(std::cout);
+ }
+ pos++;
+ }
+
+ /*! Match a token in the stream, this will first invoke Peek, ignoring tokens such
+ * as whitespace or comments returning the first meaningful token.
+ *
+ * We then try and consume the requested token, this will trigger an error if the
+ * current token does not match the token_type.
+ */
+ Token Match(const TokenType& token_type) {
+ auto tok = Peek();
+ Consume(token_type);
+ return tok;
+ }
+
+ /*! Conditionally consume a token when it matches, this will never trigger an error
+ * as we guard against consuming the token before we do.
+ *
+ * Useful for matching optional tokens, effectively looksahead by one.
+ */
+ bool WhenMatch(const TokenType& token_type) {
+ if (Peek()->token_type == token_type) {
+ Consume(token_type);
+ return true;
+ } else {
+ return false;
+ }
+ }
+
+ /* \brief Add a graph binding to the parsing context
+ *
+ * For example if we parse %0 = add(...), map 0 -> add(...), etc.
+ */
+ void AddGraphBinding(const Token& token, const Expr& expr) {
+ auto graph_no = token.ToNumber();
+ this->graph_ctx.insert({graph_no, expr});
+ }
+
+ /* \brief Lookup a previously bound graph variable.
+ *
+ * Note: we take tokens in all lookup methods so that we
+ * that we can do error reporting based on token location.
+ */
+ Expr LookupGraphBinding(const Token& token) {
+ auto graph_no = token.ToNumber();
+ return this->graph_ctx.at(graph_no);
+ }
+
+ /*! \brief Bind a local variable in the expression scope.
+ *
+ * "x" -> Var("x"), these are needed to map from the raw string names
+ * to unique variable nodes.
+ */
+ Var BindVar(const std::string& name, const relay::Type& type_annotation) {
+ auto var = Var(name, type_annotation);
+ this->expr_scopes.Add(name, var);
+ return var;
+ }
+
+ /*! \brief Bind a type variable in the type scope.
+ *
+ * "A" -> TypeVar("A", ...), these are needed to map from raw string names
+ * to unique type variable nodes.
+ */
+ TypeVar BindTypeVar(const std::string& name, const TypeKind type_kind) {
+ auto type_var = TypeVar(name, type_kind);
+ this->type_scopes.Add(name, type_var);
+ return type_var;
+ }
+
+ /*! \brief Lookup a variable in the expression scope.
+ *
+ * Note: all lookup methods take tokens intentionally for error reporting information.
+ */
+ Var LookupLocal(const Token& local) {
+ auto var = this->expr_scopes.Lookup(local.ToString());
+ if (!var.defined()) {
+ diag_ctx.Emit(
+ {local->line, local->column, "this local variable has not been previously declared"});
+ }
+ return var;
+ }
+
+ /*! \brief Lookup a variable in the type scope.
+ *
+ * Note: all lookup methods take tokens intentionally for error reporting information.
+ */
+ TypeVar LookupTypeVar(const Token& ident) {
+ auto var = this->type_scopes.Lookup(ident.ToString());
+ if (!var.defined()) {
+ diag_ctx.Emit(
+ {ident->line, ident->column,
+ "this type variable has not been previously declared anywhere, perhaps a typo?"});
+ }
+ return var;
+ }
+
+ /*! \brief Add an expression scope to the scope stack. */
+ void PushScope() { this->expr_scopes.PushStack(); }
+
+ /*! \brief Remove N expression scopes from the scope stack. */
+ void PopScopes(int n) {
+ for (int i = 0; i < n; i++) {
+ this->expr_scopes.PopStack();
+ }
+ }
+
+ /*! \brief Add an type scope to the scope stack. */
+ void PushTypeScope() { this->type_scopes.PushStack(); }
+
+ /*! \brief Remove N type scopes from the scope stack. */
+ void PopTypeScopes(int n) {
+ for (int i = 0; i < n; i++) {
+ this->type_scopes.PopStack();
+ }
+ }
+
+ /*! \brief Convert a numeric token to an NDArray for embedding into the Relay program. */
+ NDArray NumberToNDArray(const Token& token) {
+ if (token->token_type == TokenType::Integer) {
+ DLContext ctx = {DLDeviceType::kDLCPU, 0};
+ auto dtype = String2DLDataType("int32");
+ auto data = NDArray::Empty({}, dtype, ctx);
+ auto array = reinterpret_cast<int32_t*>(data->data);
+ // revisit this, literal node issue.
+ int64_t value = Downcast<tvm::Integer>(token->data);
+ array[0] = (int32_t)value;
+ return data;
+ } else if (token->token_type == TokenType::Float) {
+ DLContext ctx = {DLDeviceType::kDLCPU, 0};
+ auto dtype = String2DLDataType("float32");
+ auto data = NDArray::Empty({}, dtype, ctx);
+ auto array = reinterpret_cast<float*>(data->data);
+ // revisit this, literal node issue.
+ // TODO(@jroesch): bounds checking
+ float value = Downcast<tvm::FloatImm>(token->data)->value;
+ array[0] = value;
+ return data;
+ } else {
+ LOG(FATAL) << "internal error: should only call this function on numeric tokens";
+ return NDArray();
+ }
+ }
+
+ /*! \brief Convert a boolean value to an NDArray for embedding into the Relay program. */
+ NDArray BooleanToNDarray(bool value) {
+ DLContext ctx = {DLDeviceType::kDLCPU, 0};
+ auto dtype = String2DLDataType("bool");
+ auto data = NDArray::Empty({}, dtype, ctx);
+ auto array = reinterpret_cast<bool*>(data->data);
+ array[0] = value;
+ return data;
+ }
+
+ [[noreturn]] void ParseError(const Token& token, const std::string& msg) {
+ throw std::runtime_error(msg);
+ }
+
+ /*! \brief A parsing helper for a bracketed expression <start> <parser> <stop>. */
+ template <typename R>
+ R Bracket(TokenType open, TokenType close, std::function<R()> parser) {
+ Match(open);
+ R result = parser();
+ Match(close);
+ return result;
+ }
+
+ /*! \brief Parse `(` parser() `)`. */
+ template <typename R>
+ R Parens(std::function<R()> parser) {
+ return Bracket(TokenType::OpenParen, TokenType::CloseParen, parser);
+ }
+
+ /*! \brief Parse `{` parser() `}`. */
+ template <typename R>
+ R Block(std::function<R()> parser) {
+ return Bracket(TokenType::LCurly, TokenType::RCurly, parser);
+ }
+
+ /*! \brief Parses a sequence beginning with a start token, seperated by a seperator token, and
+ * ending with a stop token.
+ *
+ * The simple form being <start> (<parse()> <seperator>)* <stop>.
+ *
+ * This also provides a fourth argument which is allowed to run when the sequence which matches
+ * the inner sequence can not proceed.
+ *
+ * This is useful for parsing things like attributes which don't match the standard expression
+ * parsers but are contained within the stop token.
+ */
+ template <typename T>
+ Array<T> ParseSequence(TokenType start, TokenType sep, TokenType stop, std::function<T()> parse,
+ std::function<void()> before_stop = nullptr) {
+ Match(start);
+ if (WhenMatch(stop)) {
+ return Array<T>();
+ } else {
+ auto data = parse();
+ Array<T> elements = {data};
+
+ // parse '(' expr ')'
+ // if we are at the end invoke leftover parser
+ if (Peek()->token_type == stop && before_stop) {
+ before_stop();
+ }
+ if (WhenMatch(stop)) {
+ return elements;
+ // parse '( expr ',' * ')'
+ } else if (WhenMatch(sep)) {
+ // if we are at the end invoke leftover parser
+ if (Peek()->token_type == stop && before_stop) {
+ before_stop();
+ }
+ while (true) {
+ if (WhenMatch(stop)) {
+ break;
+ } else {
+ auto data = parse();
+ WhenMatch(sep);
+ elements.push_back(data);
+ }
+ }
+ return elements;
+ } else {
+ LOG(FATAL) << "issue";
+ return Array<T>(nullptr);
+ }
+ }
+ }
+
+ /*! \brief Parse a full IRModule. */
+ IRModule ParseModule() {
+ // Parse the semver header at the top of the module.
+ this->version = ParseSemVer();
+ // Parse the definitions.
+ auto defs = ParseDefinitions();
+ // Parse the metadata section at the end.
+ auto metadata = ParseMetadata();
+ Match(TokenType::EndOfFile);
+ Map<tvm::GlobalVar, BaseFunc> funcs;
+ Map<tvm::GlobalTypeVar, TypeData> types;
+
+ for (auto type_def : defs.types) {
+ types.Set(type_def->header, type_def);
+ }
+
+ auto mod = IRModule({}, types);
+
+ for (auto func : defs.funcs) {
+ mod->Add(func.global, func.function);
+ }
+
+ return mod;
+ }
+
+ /*! \brief Parse the semantic versioning header. */
+ SemVer ParseSemVer() {
+ // TODO(@jroesch): convert semver to module level attribute.
+ auto id = Peek();
+ if (id->token_type == TokenType::Identifier && id.ToString() == "v0") {
+ auto id = Match(TokenType::Identifier);
+ Consume(TokenType::Period);
+ Consume(TokenType::Float);
+ }
+ // TODO(@jroesch): the current lexing makes it hard to parse this
+ // in a way that doesnt feel like a hack.
+ //
+ // We should move to module level attributes instead
+ // so we can tag modules with top-level data.
+ //
+ // #[text_version = "0.0.4"]
+ //
+ // For now we only support current version.
+ return SemVer(0, 0, 4);
+ }
+
+ /*! \brief Parse zero or more Relay definitions. */
+ Definitions ParseDefinitions() {
+ Definitions defs;
+
+ while (true) {
+ auto next = Peek();
+ switch (next->token_type) {
+ case TokenType::Defn: {
+ Consume(TokenType::Defn);
+ auto global_name = Match(TokenType::Global).ToString();
+ auto global = GlobalVar(global_name);
+ global_names.Add(global_name, global);
+ auto func = ParseFunctionDef();
+ defs.funcs.push_back(GlobalFunc(global, func));
+ continue;
+ }
+ case TokenType::TypeDef: {
+ defs.types.push_back(ParseTypeDef());
+ continue;
+ }
+ case TokenType::Extern: {
+ Consume(TokenType::Extern);
+ auto type_def = ParseTypeDef();
+ if (type_def->constructors.size()) {
+ diag_ctx.Emit(
+ {next->line, next->column, "an external type may not have any constructors"});
+ }
+ defs.types.push_back(type_def);
+ }
+ default:
+ return defs;
+ }
+ }
+ }
+
+ /*! \brief Parse zero or more Relay type definitions. */
+ TypeData ParseTypeDef() {
+ // Match the `type` keyword.
+ Match(TokenType::TypeDef);
+ // Parse the type's identifier.
+ auto type_id = Match(TokenType::Identifier).ToString();
+ auto type_global = tvm::GlobalTypeVar(type_id, TypeKind::kAdtHandle);
+ type_names.Add(type_id, type_global);
+
+ Array<TypeVar> generics;
+
+ bool should_pop = false;
+ if (Peek()->token_type == TokenType::LSquare) {
+ // If we have generics we need to add a type scope.
+ PushTypeScope();
+ should_pop = true;
+ generics =
+ ParseSequence<TypeVar>(TokenType::LSquare, TokenType::Comma, TokenType::RSquare, [&]() {
+ auto type_var_name = Match(TokenType::Identifier).ToString();
+ return BindTypeVar(type_var_name, TypeKind::kType);
+ });
+ }
+
+ Array<tvm::Constructor> ctors;
+ if (Peek()->token_type == TokenType::LCurly) {
+ // Parse the list of constructors.
+ ctors = ParseSequence<tvm::Constructor>(
+ TokenType::LCurly, TokenType::Comma, TokenType::RCurly, [&]() {
+ // First match the name of the constructor.
+ auto ctor_name = Match(TokenType::Identifier).ToString();
+
+ Constructor ctor;
+ // Match the optional field list.
+ if (Peek()->token_type != TokenType::OpenParen) {
+ ctor = tvm::Constructor(ctor_name, {}, type_global);
+ } else {
+ auto arg_types =
+ ParseSequence<Type>(TokenType::OpenParen, TokenType::Comma, TokenType::CloseParen,
+ [&]() { return ParseType(); });
+ ctor = tvm::Constructor(ctor_name, arg_types, type_global);
+ }
+
+ CHECK(ctor.defined());
+
+ this->ctors.Add(ctor_name, ctor);
+
+ return ctor;
+ });
+ }
+
+ // Now pop the type scope.
+ if (should_pop) {
+ PopTypeScopes(1);
+ }
+
+ return TypeData(type_global, generics, ctors);
+ }
+
+ std::string HackTokensAsString(int n) {
+ std::stringstream key;
+ n = std::min(static_cast<int>(tokens.size() - pos), n);
+ for (int i = 0; i < n; i++) {
+ key << ToString(tokens.at(pos + i)->token_type);
+ }
+ return key.str();
+ }
+
+ std::vector<Rule> ParseOp() {
+ std::vector<Rule> matched;
+ Peek();
+ for (int i = 4; i > 0; i--) {
+ auto key = HackTokensAsString(i);
+ auto it = this->op_table.this_is_a_hack.find(key);
+ if (it != this->op_table.this_is_a_hack.end()) {
+ pos = pos + i;
+ matched.push_back(it->second);
+ }
+ }
+
+ return matched;
+ }
+
+ /*! \brief Parse a single Relay expression. */
+ Expr ParseExpr() {
+ return ConsumeWhitespace<Expr>([this] {
+ std::vector<Expr> exprs;
+
+ while (true) {
+ auto next = Peek();
+ switch (next->token_type) {
+ // For graph or let, match first rhs, then invoke ParseBindingExpr
+ // ParseBindingExpression then parse_lhs() parse_rhs() ';' continue
+ case TokenType::LCurly: {
+ // NB: Might need to optimize to remove deep recursion.
+ // Stack should only grow proportionally to the number of
+ // nested scopes.
+ return Bracket<Expr>(TokenType::LCurly, TokenType::RCurly, [&]() {
+ PushScope();
+ auto expr = ParseExpr();
+ PopScopes(1);
+ return expr;
+ });
+ }
+ case TokenType::Let:
+ exprs.push_back(ParseBindingExpr());
+ break;
+ case TokenType::Match:
+ case TokenType::PartialMatch: {
+ bool is_total = next->token_type == TokenType::Match;
+ Consume(next->token_type);
+ exprs.push_back(ParseMatch(is_total));
+ break;
+ }
+ case TokenType::If: {
+ exprs.push_back(ParseIf());
+ break;
+ }
+ case TokenType::Graph:
+ if (Lookahead(2)->token_type == TokenType::Equal) {
+ exprs.push_back(ParseBindingExpr());
+ break;
+ }
+ // intentional fall through here.
+ default: {
+ exprs.push_back(ParseExprBinOp());
+ break;
+ }
+ }
+
+ if (!WhenMatch(TokenType::Semicolon)) {
+ break;
+ }
+ }
+
+ CHECK_GE(exprs.size(), 1);
+
+ if (exprs.size() == 1) {
+ return exprs[0];
+ } else {
+ auto body = exprs.back();
+ exprs.pop_back();
+ while (exprs.size()) {
+ auto value = exprs.back();
+ exprs.pop_back();
+ body = relay::Let(Var("", IncompleteType()), value, body);
+ }
+ return body;
+ }
+ });
+ }
+
+ /*! \brief Parse a "binding expression"; an expression where
+ * a graph or let variable is bound.
+ *
+ * In order to avoid stack overflow this is implemented in a special
+ * iterative way to keep stack depth constant in a long chain of bindings.
+ */
+ Expr ParseBindingExpr() {
+ // We use a loop here so that the stack depth
+ // does not grow linearly with a sequence of
+ // graph or let bindings.
+ //
+ // Assuming we start at call depth k, we will
+ // enter k + c call frames to parse the RHS
+ // of the bindings where `c` is the depth
+ // of recursion needed by RHS.
+ //
+ // If RHS is a call expresssion the c=1.
+ //
+ // Once we have parsed the RHS we will be
+ // back at depth K, and will return to
+ // this loop header to parse another
+ // graph or let binding.
+ //
+ // This ensures for n sequential bindings
+ // the call depth will be the same before
+ // and after parsing the n bindings.
+ std::vector<std::pair<Var, Expr>> bindings;
+ int scopes = 0;
+
+ while (true) {
+ auto next = Peek();
+ if (next->token_type == TokenType::Graph && Lookahead(2)->token_type == TokenType::Equal) {
+ Match(TokenType::Graph);
+ Match(TokenType::Equal);
+ auto val = this->ParseExprBinOp();
+ Match(TokenType::Semicolon);
+ AddGraphBinding(next, val);
+ } else if (next->token_type == TokenType::Let) {
+ // Parse the 'let'.
+ Consume(TokenType::Let);
+
+ // Parse the local '%<id>'.
+ auto local_tok = Match(TokenType::Local);
+ auto string = local_tok.ToString();
+
+ // Parse the optional type annotation (':' <type>).
+ Type type;
+ if (WhenMatch(TokenType::Colon)) {
+ type = ParseType();
+ }
+
+ auto var = BindVar(string, type);
+
+ // Parse the '=';
+ Match(TokenType::Equal);
+
+ // Parse the body, and the ';'.
+ auto val = this->ParseExprBinOp();
+ Consume(TokenType::Semicolon);
+
+ // Add the bindings to the local data structure.
+ bindings.push_back({var, val});
+ scopes++;
+ PushScope();
+ } else {
+ // This is the only case we will increase the stack
+ // depth.
+ //
+ // If we parse a program which is a sequence of N bindings
+ // followed by a single body expression we will end up with
+ // a call depth of 3, the first call to ParseExpr, then
+ // ParseBindingExpr, then finally ParseExpr once more.
+
+ auto body = this->ParseExpr();
+
+ // Remove the same number of scopes we added.
+ PopScopes(scopes);
+
+ if (bindings.size() == 0) {
+ return body;
+ } else {
+ // We can now build the let binding up backwards.
+ for (auto binding = bindings.rbegin(); binding != bindings.rend(); binding++) {
+ body = relay::Let(binding->first, binding->second, body);
+ }
+ return body;
+ }
+ }
+ }
+ }
+
+ /*! Parse a function definition without a leading keyword or identifier.
+ *
+ * Handles things of the form [T1, ..., TN](arg1: U1, ..., argN, UN) -> Ret { body }.
+ */
+ Function ParseFunctionDef() {
+ PushScope();
+ PushTypeScope();
+
+ Array<TypeVar> generics;
+ if (Peek()->token_type == TokenType::LSquare) {
+ // If we have generics we need to add a type scope.
+ PushTypeScope();
+ generics =
+ ParseSequence<TypeVar>(TokenType::LSquare, TokenType::Comma, TokenType::RSquare, [&]() {
+ auto type_var_name = Match(TokenType::Identifier).ToString();
+ return BindTypeVar(type_var_name, TypeKind::kType);
+ });
+ }
+
+ auto params =
+ ParseSequence<Var>(TokenType::OpenParen, TokenType::Comma, TokenType::CloseParen, [&]() {
+ auto token = Match(TokenType::Local);
+ auto string = token.ToString();
+ Type type;
+ if (WhenMatch(TokenType::Colon)) {
+ type = ParseType();
+ }
+ return BindVar(string, type);
+ });
+
+ Type ret_type;
+ if (WhenMatch(TokenType::Minus)) {
+ Match(TokenType::RAngle);
+ ret_type = ParseType();
+ }
+
+ auto body = Block<Expr>([&]() { return ParseExpr(); });
+
+ PopTypeScopes(1);
+ PopScopes(1);
+
+ return relay::Function(params, body, ret_type, generics);
+ }
+
+ /*! \brief Parse an if-expression. */
+ Expr ParseIf() {
+ Consume(TokenType::If);
+ auto guard = Parens<Expr>([&] { return ParseExpr(); });
+
+ auto true_branch = Block<Expr>([&] { return ParseExpr(); });
+
+ Match(TokenType::Else);
+
+ auto false_branch = Block<Expr>([&] { return ParseExpr(); });
+
+ return relay::If(guard, true_branch, false_branch);
+ }
+
+ /* This factors parsing a list of patterns for both tuples, and constructors. */
+ Array<Pattern> ParsePatternList() {
+ return ParseSequence<Pattern>(TokenType::OpenParen, TokenType::Comma, TokenType::CloseParen,
+ [&] { return ParsePattern(); });
+ }
+
+ /*! \brief Parses a pattern for a match expression.
+ *
+ * A pattern is either a wildcard `_`, a local `%name`,
+ * a constructor `C(p1, ..., pn)` or tuple `(p1, ..., pn).
+ *
+ * This function recursively parses a pattern.
+ */
+ Pattern ParsePattern() {
+ auto next = Peek();
+ switch (next->token_type) {
+ case TokenType::Underscore: {
+ Match(TokenType::Underscore);
+ return PatternWildcard();
+ }
+ case TokenType::Local: {
+ auto id = Match(TokenType::Local);
+ Type type_annotation;
+ if (WhenMatch(TokenType::Colon)) {
+ type_annotation = ParseType();
+ }
+ auto var = BindVar(id.ToString(), type_annotation);
+ return PatternVar(var);
+ }
+ case TokenType::Identifier: {
+ auto id = Match(TokenType::Identifier);
+ auto ctor = ctors.Get(id.ToString());
+ CHECK(ctor) << "undefined identifier";
+ if (Peek()->token_type == TokenType::OpenParen) {
+ auto fields = ParsePatternList();
+ return PatternConstructor(ctor.value(), fields);
+ } else {
+ return PatternConstructor(ctor.value(), {});
+ }
+ }
+ default:
+ return PatternTuple(ParsePatternList());
+ }
+ }
+
+ Clause ParseMatchArm() {
+ PushScope();
+ auto pattern = ParsePattern();
+ Match(TokenType::Equal);
+ Consume(TokenType::RAngle);
+ auto expr = ParseExpr();
+ PopScopes(1);
+ return Clause(pattern, expr);
+ }
+
+ Expr ParseMatch(bool is_total) {
+ Expr scrutinee = ParseExpr();
+
+ Array<Clause> clauses = ParseSequence<Clause>(
+ TokenType::LCurly, TokenType::Comma, TokenType::RCurly, [&] { return ParseMatchArm(); });
+
+ return relay::Match(scrutinee, clauses, is_total);
+ }
+
+ Expr ParseExprBinOp() {
+ return ConsumeWhitespace<Expr>([this] {
+ // We must parse at least one expression, the default
+ // case is that there is no operator and we will fall
+ // through.
+ std::vector<Expr> exprs;
+ exprs.push_back(ParseCallExpr());
+
+ // Now we parse an optional op.
+ std::vector<Rule> ops;
+
+ // We will now parse 0 or more operator occurrences.
+ while (true) {
+ auto opt_op = ParseOp();
+
+ // If we didn't parse one we done.
+ if (opt_op.size() == 0) {
+ break;
+ }
+
+ // Read the operation we parsed;
+ auto op = opt_op[0];
+
+ Expr right = ParseCallExpr();
+
+ // If the operator stack is empty
+ // we parse an operator and expression
+ // and push them to stacks, then
+ // continue.
+ if (ops.size() == 0) {
+ ops.push_back(op);
+ exprs.push_back(right);
+ continue;
+ }
+
+ if (op.precedence > ops.back().precedence ||
+ (op.precedence == ops.back().precedence && op.left_assoc == false)) {
+ ops.push_back(op);
+ exprs.push_back(right);
+ continue;
+ }
+
+ while (ops.size() && (op.precedence < ops.back().precedence ||
+ (op.precedence == ops.back().precedence && op.left_assoc == true))) {
+ Rule new_op = ops.back();
+ ops.pop_back();
+ Expr right = exprs.back();
+ exprs.pop_back();
+ Expr left = exprs.back();
+ exprs.pop_back();
+ exprs.push_back(relay::Call(new_op.op, {left, right}));
+ }
+
+ exprs.push_back(right);
+ ops.push_back(op);
+ }
+
+ while (ops.size()) {
+ Rule new_op = ops.back();
+ ops.pop_back();
+ Expr right = exprs.back();
+ exprs.pop_back();
+ Expr left = exprs.back();
+ exprs.pop_back();
+ exprs.push_back(relay::Call(new_op.op, {left, right}));
+ }
+
+ CHECK_EQ(ops.size(), 0);
+ CHECK_EQ(exprs.size(), 1);
+ return exprs[0];
+ });
+ }
+
+ Attrs ParseAttrs(const std::string& type_key) {
+ Map<String, ObjectRef> kwargs;
+ auto attrs = tvm::ReflectionVTable::Global()->CreateObject(type_key, kwargs);
+ LOG(FATAL) << Attrs();
+ return Attrs();
+ }
+
+ Expr ParseCallArgs(Expr op) {
+ Attrs call_attrs;
+ if (Peek()->token_type == TokenType::OpenParen) {
+ Array<Expr> args = ParseSequence<Expr>(
+ TokenType::OpenParen, TokenType::Comma, TokenType::CloseParen,
+ [&] { return ParseExpr(); },
+ [&] {
+ auto is_ident = Lookahead(1)->token_type == TokenType::Identifier;
+ auto next_is_equal = Lookahead(2)->token_type == TokenType::Equal;
+
+ if (is_ident && next_is_equal) {
+ if (auto op_node = op.as<OpNode>()) {
+ call_attrs = ParseAttrs(op_node->attrs_type_key);
+ }
+ }
+ });
+ return Expr(Call(op, args, call_attrs, {}));
+ } else {
+ return Expr();
+ }
+ }
+
+ Expr ParseCallExpr() {
+ return ConsumeWhitespace<Expr>([this] {
+ Expr expr = ParseAtomicExpr();
+ // Parse as many call args as possible, building up expression
+ //
+ // NB(@jroesch): this seems like a hack but in order to parse curried functions
+ // and avoid complex grammar we will parse multiple call lists in a row.
+ while (true) {
+ auto new_expr = ParseCallArgs(expr);
+ if (new_expr.defined()) {
+ expr = new_expr;
+ } else {
+ break;
+ }
+ }
+
+ // We need a zero-arity case for constructors.
+ if (expr.as<ConstructorNode>()) {
+ return Expr(Call(expr, {}));
+ } else {
+ return expr;
+ }
+ });
+ }
+
+ Expr ParseAtomicExpr() {
+ return ConsumeWhitespace<Expr>([this] {
+ auto next = Peek();
+ switch (next->token_type) {
+ case TokenType::Integer:
+ case TokenType::Float: {
+ Consume(next->token_type);
+ auto number = NumberToNDArray(next);
+ Expr e = Constant(number);
+ return e;
+ }
+ case TokenType::Boolean: {
+ Consume(TokenType::Boolean);
+ int value = Downcast<tvm::Integer>(next->data);
+ auto boolean = BooleanToNDarray(value);
+ Expr e = Constant(boolean);
+ return e;
+ }
+ case TokenType::Local: {
+ Consume(TokenType::Local);
+ return Expr(LookupLocal(next));
+ }
+ case TokenType::Global: {
+ auto string = next.ToString();
+ Consume(TokenType::Global);
+ auto global = global_names.Get(string);
+ if (!global) {
+ auto global_var = GlobalVar(string);
+ global_names.Add(string, global_var);
+ return Expr(global_var);
+ } else {
+ return Expr(global.value());
+ }
+ }
+ case TokenType::Identifier: {
+ auto string = next.ToString();
+ Consume(TokenType::Identifier);
+ auto ctor = ctors.Get(string);
+ if (ctor) {
+ return Expr(ctor.value());
+ } else {
+ return Expr(Op::Get(string));
+ }
+ }
+ case TokenType::Graph: {
+ Consume(TokenType::Graph);
+ return LookupGraphBinding(next);
+ }
+ case TokenType::Fn: {
+ Consume(TokenType::Fn);
+ return Expr(ParseFunctionDef());
+ }
+ case TokenType::OpenParen: {
+ Consume(TokenType::OpenParen);
+ // parse '(' ')'
+ if (WhenMatch(TokenType::CloseParen)) {
+ return Expr(Tuple(Array<Expr>()));
+ } else {
+ auto expr = ParseExpr();
+ // parse '(' expr ')'
+ if (WhenMatch(TokenType::CloseParen)) {
+ return expr;
+ // parse '( expr ',' * ')'
+ } else if (WhenMatch(TokenType::Comma)) {
+ Array<Expr> exprs = {expr};
+ while (true) {
+ if (WhenMatch(TokenType::CloseParen)) {
+ break;
+ } else {
+ auto expr = ParseExpr();
+ WhenMatch(TokenType::Comma);
+ exprs.push_back(expr);
+ }
+ }
+ return static_cast<Expr>(Tuple(exprs));
+ }
+ }
+ }
+ default: {
+ std::stringstream msg;
+ msg << "expected an expression found " << Pretty(next->token_type);
+ diag_ctx.Emit({next->line, next->column, msg.str()});
+ diag_ctx.Render(std::cout);
+ return Expr();
+ }
+ }
+ });
+ }
+
+ /*! \brief Parse a shape. */
+ Array<tvm::PrimExpr> ParseShape() {
+ auto dims = ParseSequence<tvm::PrimExpr>(TokenType::OpenParen, TokenType::Comma,
+ TokenType::CloseParen, [&]() {
+ auto tok = Match(TokenType::Integer);
+ return Downcast<tvm::PrimExpr>(tok->data);
+ });
+ return dims;
+ }
+
+ /*! \brief Parse a function type. */
+ Type ParseFunctionType() {
+ auto ty_params = ParseSequence<Type>(TokenType::OpenParen, TokenType::Comma,
+ TokenType::CloseParen, [&]() { return ParseType(); });
+
+ Match(TokenType::Minus);
+ Match(TokenType::RAngle);
+ auto ret_type = ParseType();
+
+ return relay::FuncType(ty_params, ret_type, {}, {});
+ }
+
+ // Parses a user defined ADT or type variable.
+ Type ParseNonPrimitiveType(const Token& tok) {
+ auto name = tok.ToString();
+ Type head_type;
+ auto global_type = type_names.Get(name);
+
+ if (!global_type) {
+ head_type = LookupTypeVar(tok);
+ } else {
+ head_type = global_type.value();
+ }
+
+ CHECK(head_type.defined()) << "internal error: head type must be defined";
+
+ Array<Type> arg_types;
+ if (Peek()->token_type == TokenType::LSquare) {
+ arg_types = ParseSequence<Type>(TokenType::LSquare, TokenType::Comma, TokenType::RSquare,
+ [&]() { return ParseType(); });
+ }
+
+ if (arg_types.size()) {
+ return TypeCall(head_type, arg_types);
+ } else {
+ return head_type;
+ }
+ }
+
+ /*! \brief Parses a TVM type.
+ *
+ * This matches either a `Tensor[shape, dtype]`, a user defined ADT, a tuple type,
+ * a scalar type or an incomplete type `_`.
+ */
+ Type ParseType() {
+ auto tok = Peek();
+
+ if (tok->token_type == TokenType::OpenParen) {
+ auto tys = ParseSequence<relay::Type>(TokenType::OpenParen, TokenType::Comma,
+ TokenType::CloseParen, [&]() { return ParseType(); });
+ return relay::TupleType(tys);
+ } else if (WhenMatch(TokenType::Fn)) {
+ return ParseFunctionType();
+ } else if (WhenMatch(TokenType::Identifier)) {
+ auto id = tok.ToString();
+ if (id == "Tensor") {
+ Match(TokenType::LSquare);
+ auto shape = ParseShape();
+ Match(TokenType::Comma);
+ auto dtype_tok = Match(TokenType::Identifier);
+ auto dtype = DataType(String2DLDataType(dtype_tok.ToString()));
+ Match(TokenType::RSquare);
+ return TensorType(shape, dtype);
+ } else {
+ auto ty = tok.ToString();
+ if (ty.rfind("int", 0) == 0 || ty.find("float", 0) == 0 || ty.find("uint", 0) == 0 ||
+ ty.find("bool", 0) == 0) {
+ // Need to do better error handling here.
+ auto dtype = DataType(String2DLDataType(tok.ToString()));
+ return TensorType({}, dtype);
+ } else {
+ return ParseNonPrimitiveType(tok);
+ }
+ }
+ }
+ if (WhenMatch(TokenType::Underscore)) {
+ return IncompleteType();
+ } else {
+ std::stringstream msg;
+ msg << "failed to parse type found ";
+ msg << tok;
+ diag_ctx.Emit({tok->line, tok->column, msg.str()});
+ diag_ctx.Render(std::cout);
+ return Type();
+ }
+ }
+
+ template <typename R>
+ R ConsumeWhitespace(std::function<R()> func) {
+ auto old = this->ignore_whitespace;
+ this->ignore_whitespace = true;
+ while (tokens[pos]->token_type == TokenType::Whitespace) {
+ pos++;
+ }
+ auto res = func();
+ this->ignore_whitespace = old;
+ return res;
+ }
+
+ // TODO(@jroesch): this is the final remaining feature.
+ ObjectRef ParseMetadata() { return ObjectRef(); }
+
+ /*! \brief A helper for debugging the parser, displays the next N tokens in the token stream. */
+ void DisplayNextN(int n) {
+ std::cout << "remaining tokens: " << std::endl;
+ auto bound = std::min(pos + n, static_cast<int>(tokens.size()));
+ for (int i = 0; i < bound - pos; i++) {
+ std::cout << tokens[pos + i] << std::endl;
+ }
+ }
+
+ // A function for debugging the operator parser.
+ void DebugStack(const std::vector<Expr>& exprs, const std::vector<Rule>& rules) {
+ std::cout << "Expr Stack: ";
+ for (auto expr : exprs) {
+ std::cout << expr << ", ";
+ }
+
+ std::cout << std::endl;
+ std::cout << "Op Stack: ";
+ for (auto rule : rules) {
+ std::cout << rule.op << ", ";
+ }
+
+ std::cout << std::endl;
+ }
+};
+
+IRModule ParseModule(std::string file_name, std::string file_content) {
+ auto tokens = Tokenize(file_content);
+ Parser parser(tokens, DefaultOpTable(), Source(file_content));
+ return parser.ParseModule();
+}
+
+Expr ParseExpr(std::string file_name, std::string file_content) {
+ auto tokens = Tokenize(file_content);
+ Parser parser(tokens, DefaultOpTable(), Source(file_content));
+ parser.PushScope();
+ auto expr = parser.ParseExpr();
+ parser.Match(TokenType::EndOfFile);
+ return expr;
+}
+
+TVM_REGISTER_GLOBAL("parser.ParseModule")
+ .set_body_typed([](std::string file_name, std::string file_content) {
+ return ParseModule(file_name, file_content);
+ });
+
+TVM_REGISTER_GLOBAL("parser.ParseExpr")
+ .set_body_typed([](std::string file_name, std::string file_content) {
+ return ParseExpr(file_name, file_content);
+ });
+
+} // namespace parser
+} // namespace tvm
diff --git a/src/parser/token.h b/src/parser/token.h
new file mode 100644
index 0000000..d7aac23
--- /dev/null
+++ b/src/parser/token.h
@@ -0,0 +1,362 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file token.h
+ * \brief The definition of tokens for the TVM parser.
+ */
+
+#ifndef TVM_PARSER_TOKEN_H_
+#define TVM_PARSER_TOKEN_H_
+
+#include <tvm/runtime/container.h>
+#include <tvm/runtime/object.h>
+
+#include <fstream>
+#include <string>
+#include <utility>
+
+namespace tvm {
+namespace parser {
+
+using namespace runtime;
+
+enum TokenType {
+ CommentStart,
+ CommentEnd,
+ LineComment,
+ Comment,
+ Whitespace,
+ Newline,
+ StringLiteral,
+ Identifier,
+ Local,
+ Global,
+ Op,
+ Graph,
+ OpenParen,
+ CloseParen,
+ AtSymbol,
+ Percent,
+ Comma,
+ Period,
+ Equal,
+ Semicolon,
+ Colon,
+ Integer,
+ Float,
+ Division,
+ Boolean,
+ Plus,
+ Star,
+ Minus,
+ RAngle,
+ LAngle,
+ RCurly,
+ LCurly,
+ RSquare,
+ LSquare,
+ Bang,
+ At,
+ Question,
+ If,
+ Else,
+ Underscore,
+ Let,
+ Fn,
+ Defn,
+ TypeDef,
+ Extern,
+ Match,
+ PartialMatch,
+ Unknown,
+ EndOfFile,
+ Null,
+};
+
+std::string ToString(const TokenType& token_type) {
+ switch (token_type) {
+ case TokenType::CommentStart:
+ return "CommentStart";
+ case TokenType::CommentEnd:
+ return "CommentEnd";
+ case TokenType::LineComment:
+ return "LineComment";
+ case TokenType::Comment:
+ return "Comment";
+ case TokenType::Whitespace:
+ return "WhiteSpace";
+ case TokenType::Newline:
+ return "Newline";
+ case TokenType::StringLiteral:
+ return "StringLiteral";
+ case TokenType::Identifier:
+ return "Identifier";
+ case TokenType::Local:
+ return "Local";
+ case TokenType::Global:
+ return "Global";
+ case TokenType::Graph:
+ return "Graph";
+ case TokenType::Op:
+ return "Op";
+ case TokenType::OpenParen:
+ return "OpenParen";
+ case TokenType::CloseParen:
+ return "CloseParen";
+ case TokenType::AtSymbol:
+ return "AtSymbol";
+ case TokenType::Percent:
+ return "Percent";
+ case TokenType::Comma:
+ return "Comma";
+ case TokenType::Colon:
+ return "Colon";
+ case TokenType::Semicolon:
+ return "Semicolon";
+ case TokenType::Period:
+ return "Period";
+ case TokenType::Equal:
+ return "Equal";
+ case TokenType::Integer:
+ return "Integer";
+ case TokenType::Float:
+ return "Float";
+ case TokenType::Plus:
+ return "Plus";
+ case TokenType::Star:
+ return "Star";
+ case TokenType::Minus:
+ return "Minus";
+ case TokenType::Division:
+ return "Division";
+ case TokenType::RAngle:
+ return "RAngle";
+ case TokenType::LAngle:
+ return "LAngle";
+ case TokenType::RCurly:
+ return "RCurly";
+ case TokenType::LCurly:
+ return "LCurly";
+ case TokenType::RSquare:
+ return "RSquare";
+ case TokenType::LSquare:
+ return "LSquare";
+ case TokenType::Bang:
+ return "Bang";
+ case TokenType::Underscore:
+ return "Underscore";
+ case TokenType::At:
+ return "At";
+ case TokenType::Let:
+ return "Let";
+ case TokenType::If:
+ return "If";
+ case TokenType::Else:
+ return "Else";
+ case TokenType::Fn:
+ return "Fn";
+ case TokenType::Defn:
+ return "Defn";
+ case TokenType::TypeDef:
+ return "TypeDef";
+ case TokenType::Extern:
+ return "Extern";
+ case TokenType::Match:
+ return "Match";
+ case TokenType::PartialMatch:
+ return "PartialMatch";
+ case TokenType::Question:
+ return "Question";
+ case TokenType::Boolean:
+ return "Boolean";
+ case TokenType::Unknown:
+ return "Unknown";
+ case TokenType::EndOfFile:
+ return "EndOfFile";
+ case TokenType::Null:
+ return "Null";
+ // Older compilers warn even though the above code is exhaustive.
+ default:
+ LOG(FATAL) << "unreachable code";
+ return "";
+ }
+}
+
+std::string Pretty(const TokenType& token_type) {
+ switch (token_type) {
+ case TokenType::CommentStart:
+ return "`/*`";
+ case TokenType::CommentEnd:
+ return "`*/`";
+ case TokenType::LineComment:
+ return "`//`";
+ case TokenType::Comment:
+ return "comment";
+ case TokenType::Whitespace:
+ return "whitespace";
+ case TokenType::Newline:
+ return "newline";
+ case TokenType::StringLiteral:
+ return "string literal";
+ case TokenType::Identifier:
+ return "identifier";
+ case TokenType::Local:
+ return "local variable";
+ case TokenType::Global:
+ return "global variable";
+ case TokenType::Graph:
+ return "graph variable";
+ case TokenType::Op:
+ return "operator";
+ case TokenType::OpenParen:
+ return "`(`";
+ case TokenType::CloseParen:
+ return "`)`";
+ case TokenType::AtSymbol:
+ return "`@`";
+ case TokenType::Percent:
+ return "`%`";
+ case TokenType::Comma:
+ return "`,`";
+ case TokenType::Colon:
+ return "`:`";
+ case TokenType::Semicolon:
+ return "`;`";
+ case TokenType::Period:
+ return "`.`";
+ case TokenType::Equal:
+ return "`=`";
+ case TokenType::Integer:
+ return "integer";
+ case TokenType::Float:
+ return "float";
+ case TokenType::Plus:
+ return "`+`";
+ case TokenType::Star:
+ return "`*`";
+ case TokenType::Minus:
+ return "`-`";
+ case TokenType::Division:
+ return "`/`";
+ case TokenType::RAngle:
+ return "`<`";
+ case TokenType::LAngle:
+ return "`>`";
+ case TokenType::RCurly:
+ return "`}`";
+ case TokenType::LCurly:
+ return "`{`";
+ case TokenType::RSquare:
+ return "`]`";
+ case TokenType::LSquare:
+ return "`[`";
+ case TokenType::Bang:
+ return "`!`";
+ case TokenType::Underscore:
+ return "`_`";
+ case TokenType::At:
+ return "`@`";
+ case TokenType::Let:
+ return "`let`";
+ case TokenType::If:
+ return "`if`";
+ case TokenType::Else:
+ return "`else`";
+ case TokenType::Fn:
+ return "`fn`";
+ case TokenType::Defn:
+ return "`def`";
+ case TokenType::TypeDef:
+ return "`type`";
+ case TokenType::Extern:
+ return "`extern`";
+ case TokenType::Boolean:
+ return "boolean";
+ case TokenType::Match:
+ return "`match`";
+ case TokenType::PartialMatch:
+ return "`match?`";
+ case TokenType::Question:
+ return "`?`";
+ case TokenType::Unknown:
+ return "unknown";
+ case TokenType::EndOfFile:
+ return "end of file";
+ case TokenType::Null:
+ return "null";
+ // Older compilers warn even though the above code is exhaustive.
+ default:
+ LOG(FATAL) << "unreachable code";
+ return "";
+ }
+}
+
+class Token;
+
+class TokenNode : public Object {
+ public:
+ int line;
+ int column;
+ TokenType token_type;
+ mutable runtime::ObjectRef data;
+
+ void VisitAttrs(AttrVisitor* v) {}
+
+ static constexpr const char* _type_key = "parser.Token";
+ TVM_DECLARE_FINAL_OBJECT_INFO(TokenNode, Object);
+};
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+ .set_dispatch<TokenNode>([](const ObjectRef& ref, ReprPrinter* p) {
+ auto* node = static_cast<const TokenNode*>(ref.get());
+ p->stream << "Token(line=" << node->line << ", column=" << node->column
+ << ", token_type=" << ToString(node->token_type) << ", data=" << node->data << ")";
+ });
+
+TVM_REGISTER_NODE_TYPE(TokenNode);
+
+class Token : public ObjectRef {
+ public:
+ TVM_DLL explicit Token(int line, int column, TokenType token_type, ObjectRef data = ObjectRef());
+
+ static Token Null();
+ int64_t ToNumber() const;
+ std::string ToString() const;
+ TVM_DEFINE_OBJECT_REF_METHODS(Token, ObjectRef, TokenNode);
+};
+
+Token::Token(int line, int column, TokenType token_type, ObjectRef data) {
+ ObjectPtr<TokenNode> n = make_object<TokenNode>();
+ n->line = line;
+ n->column = column;
+ n->token_type = token_type;
+ n->data = data;
+ data_ = std::move(n);
+}
+
+Token Token::Null() { return Token(0, 0, TokenType::Null); }
+
+int64_t Token::ToNumber() const { return Downcast<tvm::Integer>(this->operator->()->data); }
+
+std::string Token::ToString() const { return Downcast<tvm::String>(this->operator->()->data); }
+
+} // namespace parser
+} // namespace tvm
+#endif // TVM_PARSER_TOKEN_H_
diff --git a/src/parser/tokenizer.h b/src/parser/tokenizer.h
new file mode 100644
index 0000000..f6c2734
--- /dev/null
+++ b/src/parser/tokenizer.h
@@ -0,0 +1,459 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file parser.h
+ * \brief A parser for TVM IR.
+ */
+#ifndef TVM_PARSER_TOKENIZER_H_
+#define TVM_PARSER_TOKENIZER_H_
+
+#include <tvm/runtime/container.h>
+#include <tvm/runtime/object.h>
+
+#include <fstream>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "./token.h"
+
+namespace tvm {
+namespace parser {
+
+using namespace runtime;
+
+bool IsDigit(char c) { return '0' <= c && c <= '9'; }
+
+bool IsWhitespace(char c) { return ' ' == c || c == '\t' || c == '\n'; }
+
+bool IsNumeric(char c) {
+ return (IsDigit(c) || c == '.' || c == 'e' || c == '-' || c == '+' || c == 'E') &&
+ !IsWhitespace(c);
+}
+
+bool IsIdentLetter(char c) { return '_' == c || ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z'); }
+
+bool IsIdent(char c) { return IsIdentLetter(c) || IsDigit(c); }
+
+static std::unordered_map<std::string, TokenType> KEYWORD_TABLE = {
+ {"let", TokenType::Let}, {"fn", TokenType::Fn}, {"def", TokenType::Defn},
+ {"if", TokenType::If}, {"else", TokenType::Else}, {"type", TokenType::TypeDef},
+ {"match", TokenType::Match}, {"extern", TokenType::Extern}};
+
+struct Tokenizer {
+ size_t pos;
+ int col;
+ int line;
+ char next_char;
+ const std::string& source;
+ std::vector<Token> tokens;
+
+ char Next() {
+ char c = this->source.at(this->pos);
+ if (c == '\n') {
+ this->line += 1;
+ this->col = 1;
+ } else {
+ this->col += 1;
+ }
+ pos += 1;
+ return c;
+ }
+
+ bool More() { return this->pos < this->source.size(); }
+
+ char Peek() {
+ CHECK(pos < this->source.size());
+ return this->source.at(this->pos);
+ }
+
+ Token NewToken(TokenType token_type, ObjectRef data = ObjectRef()) {
+ return Token(this->line, this->col, token_type, data);
+ }
+
+ enum CommentParserState {
+ Proceed,
+ Forward,
+ Backward,
+ };
+
+ void MatchComment(std::string* buffer) {
+ // We only invoke this after we have matched the first start
+ // token assume, we are proceeding the parse forward with
+ // nesting = 1.
+ //
+ // When we are done we should be at nesting zero and be
+ // in the stop state.
+ CommentParserState state = CommentParserState::Proceed;
+ int nesting = 1;
+
+ while (true) {
+ switch (state) {
+ case CommentParserState::Proceed: {
+ if (Peek() == '/') {
+ state = CommentParserState::Forward;
+ } else if (Peek() == '*') {
+ state = CommentParserState::Backward;
+ }
+ buffer->operator+=(Next());
+ continue;
+ }
+ case CommentParserState::Forward: {
+ if (Peek() == '*') {
+ nesting += 1;
+ buffer->operator+=(Next());
+ }
+ state = CommentParserState::Proceed;
+ continue;
+ }
+ case CommentParserState::Backward: {
+ if (Peek() == '/') {
+ nesting -= 1;
+ if (nesting == 0) {
+ Next();
+ buffer->pop_back();
+ return;
+ } else {
+ buffer->operator+=(Next());
+ state = CommentParserState::Proceed;
+ }
+ }
+ continue;
+ }
+ }
+ }
+ }
+
+ Token ParseNumber(bool is_pos, bool is_float, std::string number) {
+ CHECK(number.size() > 0) << "an empty string is an invalid number";
+
+ try {
+ if (is_float) {
+ throw std::invalid_argument("is_float");
+ }
+ auto token = NewToken(TokenType::Integer);
+ size_t index = 0;
+ int value = std::stoi(number, &index);
+ if (number.size() > index) {
+ throw std::invalid_argument("floating point");
+ }
+ value = is_pos ? value : -value;
+ token->data = tvm::Integer(value);
+ return token;
+ } catch (const std::invalid_argument& ia) {
+ auto token = NewToken(TokenType::Float);
+
+ if (number.back() == 'f') {
+ number.pop_back();
+ }
+
+ double value = stod(number);
+ value = is_pos ? value : -value;
+ token->data = tvm::FloatImm(DataType::Float(64), value);
+ return token;
+ }
+ }
+
+ inline Token TokenizeOnce() {
+ auto next = Peek();
+ if (next == '\n') {
+ auto token = NewToken(TokenType::Newline);
+ Next();
+ return token;
+ } else if (next == '\r') {
+ Next();
+ if (More() && Peek() == '\n') {
+ auto token = NewToken(TokenType::Newline);
+ return token;
+ } else {
+ // TODO(@jroesch): have lexer use diagnostic context too.
+ LOG(FATAL) << "lexer error";
+ return Token();
+ }
+ } else if (next == '"') {
+ LOG(FATAL) << "string not working yet";
+ return NewToken(TokenType::Unknown);
+ } else if (IsWhitespace(next)) {
+ auto token = NewToken(TokenType::Whitespace);
+ Next();
+ return token;
+ } else if (IsDigit(next) || next == '-') {
+ int negs = 0;
+ while (More() && Peek() == '-') {
+ Next();
+ negs++;
+ }
+ // If there isn't a number right after either,
+ // this is really slow for lexing, should replace
+ // with multi-token return or something.
+ if (negs && !IsDigit(Peek())) {
+ pos = pos - (negs - 1);
+ return NewToken(TokenType::Minus);
+ }
+
+ bool is_neg = negs % 2 == 1;
+ std::stringstream ss;
+ while (More() && IsNumeric(Peek())) {
+ ss << Next();
+ }
+
+ bool is_float = false;
+ // Remove trailing floating point prefix.
+ if (More() && Peek() == 'f') {
+ Next();
+ is_float = true;
+ }
+
+ return ParseNumber(!is_neg, is_float, ss.str());
+ } else if (next == '.') {
+ auto token = NewToken(TokenType::Period);
+ Next();
+ return token;
+ } else if (next == ',') {
+ auto token = NewToken(TokenType::Comma);
+ Next();
+ return token;
+ } else if (next == '=') {
+ auto token = NewToken(TokenType::Equal);
+ Next();
+ return token;
+ } else if (next == ';') {
+ auto token = NewToken(TokenType::Semicolon);
+ Next();
+ return token;
+ } else if (next == ':') {
+ auto token = NewToken(TokenType::Colon);
+ Next();
+ return token;
+ } else if (next == '(') {
+ auto token = NewToken(TokenType::OpenParen);
+ Next();
+ return token;
+ } else if (next == ')') {
+ auto token = NewToken(TokenType::CloseParen);
+ Next();
+ return token;
+ } else if (next == '+') {
+ auto token = NewToken(TokenType::Plus);
+ Next();
+ return token;
+ } else if (next == '-') {
+ auto token = NewToken(TokenType::Minus);
+ Next();
+ return token;
+ } else if (next == '*') {
+ auto token = NewToken(TokenType::Star);
+ Next();
+ return token;
+ } else if (next == '<') {
+ auto token = NewToken(TokenType::LAngle);
+ Next();
+ return token;
+ } else if (next == '>') {
+ auto token = NewToken(TokenType::RAngle);
+ Next();
+ return token;
+ } else if (next == '{') {
+ auto token = NewToken(TokenType::LCurly);
+ Next();
+ return token;
+ } else if (next == '}') {
+ auto token = NewToken(TokenType::RCurly);
+ Next();
+ return token;
+ } else if (next == '[') {
+ auto token = NewToken(TokenType::LSquare);
+ Next();
+ return token;
+ } else if (next == ']') {
+ auto token = NewToken(TokenType::RSquare);
+ Next();
+ return token;
+ } else if (next == '!') {
+ auto token = NewToken(TokenType::Bang);
+ Next();
+ return token;
+ } else if (next == '@') {
+ auto token = NewToken(TokenType::At);
+ Next();
+ return token;
+ } else if (next == '?') {
+ auto token = NewToken(TokenType::Question);
+ Next();
+ return token;
+ } else if (next == '%') {
+ auto token = NewToken(TokenType::Percent);
+ Next();
+ return token;
+ } else if (next == '/') {
+ Next();
+ if (Peek() == '/') {
+ auto token = NewToken(TokenType::LineComment);
+ // Consume the /
+ Next();
+ std::stringstream comment;
+ while (More() && Peek() != '\n') {
+ comment << Next();
+ }
+ token->data = tvm::String(comment.str());
+ return token;
+ } else if (Peek() == '*') {
+ // Eat the first /* pair before entering the state machine.
+ Next();
+ std::string comment;
+ MatchComment(&comment);
+ auto token = NewToken(TokenType::Comment, tvm::String(comment));
+ return token;
+ } else {
+ return NewToken(TokenType::Division);
+ }
+ } else if (IsIdentLetter(next)) {
+ std::stringstream ss;
+ // Due the below code we need to patch
+ // the line/col info to the start of
+ // token.
+ int line = this->line;
+ int col = this->col;
+
+ while (More() && IsIdent(Peek())) {
+ ss << Next();
+ }
+
+ std::string keyword = ss.str();
+ auto it = KEYWORD_TABLE.find(keyword);
+
+ TokenType token_type;
+ if (it != KEYWORD_TABLE.end()) {
+ token_type = it->second;
+
+ if (token_type == TokenType::Match) {
+ if (More() && Peek() == '?') {
+ Next();
+ token_type = TokenType::PartialMatch;
+ }
+ }
+ } else {
+ token_type = TokenType::Identifier;
+ }
+
+ return Token(line, col, token_type, tvm::String(ss.str()));
+ } else {
+ std::stringstream ss;
+ while (More() && !IsWhitespace(Peek())) {
+ ss << Next();
+ }
+ auto token = NewToken(TokenType::Unknown);
+ token->data = tvm::String(ss.str());
+ return token;
+ }
+ }
+
+ void Tokenize() {
+ while (this->More()) {
+ auto token = TokenizeOnce();
+ CHECK(token.defined());
+ this->tokens.push_back(token);
+ }
+ this->tokens.push_back(NewToken(TokenType::EndOfFile));
+ }
+
+ explicit Tokenizer(std::string& source) : pos(0), col(1), line(1), source(source), tokens() {}
+};
+
+std::vector<Token> Condense(const std::vector<Token>& tokens) {
+ std::vector<Token> out;
+
+ for (size_t i = 0; i < tokens.size(); i++) {
+ auto current = tokens.at(i);
+ switch (current->token_type) {
+ case TokenType::Percent: {
+ auto next = tokens.at(i + 1);
+ if (next->token_type == TokenType::Identifier) {
+ // Match this token.
+ i += 1;
+ auto tok = Token(current->line, current->column, TokenType::Local, next->data);
+ CHECK(tok.defined());
+ out.push_back(tok);
+ } else if (next->token_type == TokenType::Integer) {
+ i += 1;
+ auto tok = Token(current->line, current->column, TokenType::Graph, next->data);
+ CHECK(tok.defined());
+ out.push_back(tok);
+ } else {
+ CHECK(current.defined());
+ out.push_back(current);
+ }
+ continue;
+ }
+ case TokenType::At: {
+ auto next = tokens.at(i + 1);
+ if (next->token_type == TokenType::Identifier) {
+ // Match this token.
+ i += 1;
+ auto tok = Token(current->line, current->column, TokenType::Global, next->data);
+ CHECK(tok.defined());
+ out.push_back(tok);
+ } else {
+ CHECK(current.defined());
+ out.push_back(current);
+ }
+ continue;
+ }
+ case TokenType::Identifier: {
+ std::string str = Downcast<tvm::String>(current->data);
+ Token tok;
+ if (str == "True") {
+ auto data = tvm::Integer(1);
+ tok = Token(current->line, current->column, TokenType::Boolean, data);
+ } else if (str == "False") {
+ auto data = tvm::Integer(0);
+ tok = Token(current->line, current->column, TokenType::Boolean, data);
+ } else if (str == "_") {
+ tok = Token(current->line, current->column, TokenType::Underscore);
+ } else {
+ tok = current;
+ }
+ out.push_back(tok);
+ continue;
+ }
+ default: {
+ out.push_back(current);
+ continue;
+ }
+ }
+ }
+
+ return out;
+}
+
+std::vector<Token> Tokenize(std::string source) {
+ auto tokenizer = Tokenizer(source);
+ tokenizer.Tokenize();
+ auto tokens = Condense(tokenizer.tokens);
+ for (auto token : tokens) {
+ CHECK(token.defined());
+ }
+ return tokens;
+}
+
+} // namespace parser
+} // namespace tvm
+
+#endif // TVM_PARSER_TOKENIZER_H_
diff --git a/tests/python/relay/test_ir_nodes.py b/tests/python/relay/test_ir_nodes.py
index 5a71023..fed257f 100644
--- a/tests/python/relay/test_ir_nodes.py
+++ b/tests/python/relay/test_ir_nodes.py
@@ -33,8 +33,8 @@ def check_json_roundtrip(node):
def test_span():
span = relay.Span(None, 1, 1)
assert span.source == None
- assert span.lineno == 1
- assert span.col_offset == 1
+ assert span.line == 1
+ assert span.column == 1
assert span.same_as(span)
assert span == span
assert isinstance(span, relay.base.Span)
@@ -44,8 +44,8 @@ def test_span():
# to test the round trip
back = tvm.ir.load_json(tvm.ir.save_json(span))
assert back.source == span.source
- assert back.lineno == span.lineno
- assert back.col_offset == span.col_offset
+ assert back.line == span.line
+ assert back.column == span.column
def test_constant():
diff --git a/tests/python/relay/test_ir_parser2.py b/tests/python/relay/test_ir_parser2.py
new file mode 100644
index 0000000..23ba1fa
--- /dev/null
+++ b/tests/python/relay/test_ir_parser2.py
@@ -0,0 +1,891 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+from tvm import te
+from tvm import relay
+import pytest
+from numpy import isclose
+from typing import Union
+from functools import wraps
+raises_parse_error = pytest.mark.xfail(raises=tvm._ffi.base.TVMError)
+
+SEMVER = "v0.0.4"
+
+BINARY_OPS = {
+ "*": relay.multiply,
+ "/": relay.divide,
+ "+": relay.add,
+ "-": relay.subtract,
+ "<": relay.less,
+ ">": relay.greater,
+ "<=": relay.less_equal,
+ ">=": relay.greater_equal,
+ "==": relay.equal,
+ "!=": relay.not_equal,
+}
+
+TYPES = {
+ "int8",
+ "int16",
+ "int32",
+ "int64",
+
+ "uint8",
+ "uint16",
+ "uint32",
+ "uint64",
+
+ "float16",
+ "float32",
+ "float64",
+
+ "bool",
+
+ "int8x4",
+ "uint1x4",
+ "float16x4",
+}
+
+LIST_DEFN = """
+type List[A] {
+ Cons(A, List[A]),
+ Nil,
+}
+"""
+
+def assert_graph_equal(lhs, rhs):
+ tvm.ir.assert_structural_equal(lhs, rhs, map_free_vars=True)
+
+def graph_equal(lhs, rhs):
+ return tvm.ir.structural_equal(lhs, rhs, map_free_vars=True)
+
+
+def roundtrip_expr(expr):
+ x = tvm.parser.parse_expr(str(str(expr)))
+ assert_graph_equal(x, expr)
+
+def roundtrip(expr):
+ x = tvm.parser.fromtext(expr.astext())
+ assert_graph_equal(x, expr)
+
+def parse_text(code):
+ expr = tvm.parser.parse_expr(code)
+ roundtrip_expr(expr)
+ return expr
+
+
+def parses_as(code, expr):
+ # type: (str, relay.Expr) -> bool
+ parsed = parse_text(code)
+ result = graph_equal(parsed, expr)
+ return result
+
+def parse_module(code):
+ mod = tvm.parser.parse(code)
+ roundtrip(mod)
+ return mod
+
+
+def assert_parses_as(code, expr):
+ parsed = parse_text(code)
+ assert_graph_equal(parsed, expr)
+
+def assert_parse_module_as(code, mod):
+ parsed = parse_module(code)
+ assert_graph_equal(parsed, mod)
+
+def get_scalar(x):
+ # type: (relay.Constant) -> (Union[float, int, bool])
+ return x.data.asnumpy().item()
+
+int32 = relay.scalar_type("int32")
+
+_ = relay.Var("_")
+X = relay.Var("x")
+Y = relay.Var("y")
+X_ANNO = relay.Var("x", int32)
+Y_ANNO = relay.Var("y", int32)
+
+UNIT = relay.Tuple([])
+
+
+def test_comments():
+ assert_parses_as(
+ """
+ // This is a line comment!
+ ()
+ """,
+ UNIT
+ )
+
+ assert_parses_as(
+ """
+ /* This is a block comment!
+ This is still a block comment!
+ */
+ ()
+ """,
+ UNIT
+ )
+
+ assert_parses_as(
+ """
+ /* This is a block comment!
+ /*Block comment is recursive!*/
+ */
+ ()
+ """,
+ UNIT
+ )
+
+
+def test_int_literal():
+ assert isinstance(parse_text("1"), relay.Constant)
+ assert isinstance(parse_text("1").data, tvm.nd.NDArray)
+
+ assert get_scalar(parse_text("1")) == 1
+ assert get_scalar(parse_text("10")) == 10
+ assert get_scalar(parse_text("0")) == 0
+ assert get_scalar(parse_text("-100")) == -100
+ assert get_scalar(parse_text("-05")) == -5
+
+
+def test_float_literal():
+ assert get_scalar(parse_text("1.0f")) == 1.0
+ assert isclose(get_scalar(parse_text("1.56667f")), 1.56667)
+ assert get_scalar(parse_text("0.0f")) == 0.0
+ assert get_scalar(parse_text("-10.0f")) == -10.0
+
+ # scientific notation
+ assert isclose(get_scalar(parse_text("1e-1f")), 1e-1)
+ assert get_scalar(parse_text("1e+1f")) == 1e+1
+ assert isclose(get_scalar(parse_text("1E-1f")), 1E-1)
+ assert get_scalar(parse_text("1E+1f")) == 1E+1
+ assert isclose(get_scalar(parse_text("1.0e-1f")), 1.0e-1)
+ assert get_scalar(parse_text("1.0e+1f")) == 1.0e+1
+ assert isclose(get_scalar(parse_text("1.0E-1f")), 1.0E-1)
+ assert get_scalar(parse_text("1.0E+1f")) == 1.0E+1
+
+
+def test_bool_literal():
+ assert get_scalar(parse_text("True")) == True
+ assert get_scalar(parse_text("False")) == False
+
+
+def test_negative():
+ # need to handle parsing non-literal operations
+ # assert isinstance(parse_text("let %x = 1; -%x").body, relay.Call)
+ assert get_scalar(parse_text("--10")) == 10
+ assert get_scalar(parse_text("---10")) == -10
+
+
+def test_bin_op():
+ for bin_op in BINARY_OPS.keys():
+ assert_parses_as(
+ "1 {} 1".format(bin_op),
+ BINARY_OPS.get(bin_op)(relay.const(1), relay.const(1))
+ )
+
+
+def test_parens():
+ assert graph_equal(parse_text("1 * 1 + 1"), parse_text("(1 * 1) + 1"))
+ assert not graph_equal(parse_text("1 * 1 + 1"), parse_text("1 * (1 + 1)"))
+
+
+def test_op_assoc():
+ assert graph_equal(parse_text("1 * 1 + 1 < 1 == 1"), parse_text("(((1 * 1) + 1) < 1) == 1"))
+ assert graph_equal(parse_text("1 == 1 < 1 + 1 * 1"), parse_text("1 == (1 < (1 + (1 * 1)))"))
+
+
+def test_vars():
+ # var
+ var = parse_text("let %foo = (); %foo")
+ assert isinstance(var.body, relay.Var)
+ assert var.body.name_hint == "foo"
+
+ # global var
+ global_var = parse_text("@foo")
+ assert isinstance(global_var, relay.GlobalVar)
+ assert global_var.name_hint == "foo"
+
+ # operator id
+ op = parse_text("add")
+ assert isinstance(op, tvm.ir.Op)
+ assert op.name == "add"
+
+
+def test_let():
+ assert_parses_as(
+ "let %x = 1; ()",
+ relay.Let(
+ X,
+ relay.const(1),
+ UNIT
+ )
+ )
+
+ assert_parses_as(
+ """
+ let %x = 1;
+ let %y = 2;
+ ()
+ """,
+ relay.Let(
+ X,
+ relay.const(1),
+ relay.Let(
+ Y,
+ relay.const(2),
+ UNIT
+ )
+ )
+ )
+
+
+def test_seq():
+ assert_parses_as(
+ "(); ()",
+ relay.Let(
+ _,
+ UNIT,
+ UNIT)
+ )
+
+ assert_parses_as(
+ "let %_ = 1; ()",
+ relay.Let(
+ X,
+ relay.const(1),
+ UNIT
+ )
+ )
+
+
+def test_graph():
+ code = "%0 = (); %1 = 1; (%0, %0, %1)"
+ assert_parses_as(
+ code,
+ relay.Tuple([UNIT, UNIT, relay.const(1)])
+ )
+
+
+@raises_parse_error
+def test_graph_wrong_order():
+ parse_text("%1 = (); %1")
+
+
+@raises_parse_error
+def test_let_global_var():
+ parse_text("let @x = 1; ()")
+
+
+@raises_parse_error
+def test_let_op():
+ parse_text("let x = 1; ()")
+
+
+def test_tuple():
+ assert_parses_as("()", relay.Tuple([]))
+
+ assert_parses_as("(0,)", relay.Tuple([relay.const(0)]))
+
+ assert_parses_as("(0, 1)", relay.Tuple([relay.const(0), relay.const(1)]))
+
+ assert_parses_as("(0, 1, 2)", relay.Tuple([relay.const(0), relay.const(1), relay.const(2)]))
+
+
+def test_func():
+ # 0 args
+ assert_parses_as(
+ "fn () { 0 }",
+ relay.Function(
+ [],
+ relay.const(0),
+ None,
+ []
+ )
+ )
+
+ # 1 arg
+ assert_parses_as(
+ "fn (%x) { %x }",
+ relay.Function(
+ [X],
+ X,
+ None,
+ []
+ )
+ )
+
+ # 2 args
+ assert_parses_as(
+ "fn (%x, %y) { %x + %y }",
+ relay.Function(
+ [X, Y],
+ relay.add(X, Y),
+ None,
+ []
+ )
+ )
+
+ # annotations
+ assert_parses_as(
+ "fn (%x: int32) -> int32 { %x }",
+ relay.Function(
+ [X_ANNO],
+ X_ANNO,
+ int32,
+ []
+ )
+ )
+
+ # Refactor the attribute syntax and printing.
+ #
+ # # attributes
+ # assert_parses_as(
+ # "fn (n=5) { () }",
+ # relay.Function([], UNIT, None, None, tvm.ir.make_node("DictAttrs", n=relay.const(5)))
+ # )
+
+
+# TODO(@jmp): Crashes if %x isn't annnotated.
+def test_defn():
+ id_defn = parse_module(
+ """
+ def @id(%x: int32) -> int32 {
+ %x
+ }
+ """)
+ assert isinstance(id_defn, tvm.IRModule)
+
+
+def test_recursive_call():
+ id_defn = parse_module(
+ """
+ def @id(%x: int32) -> int32 {
+ @id(%x)
+ }
+ """)
+ assert isinstance(id_defn, tvm.IRModule)
+
+
+def test_ifelse():
+ assert_parses_as(
+ """
+ if (True) {
+ 0
+ } else {
+ 1
+ }
+ """,
+ relay.If(
+ relay.const(True),
+ relay.const(0),
+ relay.const(1)
+ )
+ )
+
+
+@raises_parse_error
+def test_ifelse_scope():
+ parse_text(
+ """
+ if (True) {
+ let %x = ();
+ ()
+ } else {
+ %x
+ }
+ """
+ )
+
+
+def test_call():
+ # select right function to call: simple ident case
+ id_func = relay.Var("id")
+ assert_parses_as(
+ """
+ let %id = fn (%x) { %x };
+ 10 * %id(10)
+ """,
+ relay.Let(
+ id_func,
+ relay.Function([X], X, None, []),
+ relay.multiply(relay.const(10), relay.Call(id_func, [relay.const(10)]))
+ )
+ )
+
+ # 0 args
+ constant = relay.Var("constant")
+ assert_parses_as(
+ """
+ let %constant = fn () { 0 };
+ %constant()
+ """,
+ relay.Let(
+ constant,
+ relay.Function([], relay.const(0), None, []),
+ relay.Call(constant, [], None, None)
+ )
+ )
+
+ # 1 arg
+ id_var = relay.Var("id")
+ assert_parses_as(
+ """
+ let %id = fn (%x) { %x };
+ %id(1)
+ """,
+ relay.Let(
+ id_var,
+ relay.Function([X], X, None, []),
+ relay.Call(id_var, [relay.const(1)], None, None)
+ )
+ )
+
+ # 2 args
+ multiply = relay.Var("multiply")
+ assert_parses_as(
+ """
+ let %multiply = fn (%x, %y) { %x * %y };
+ %multiply(0, 0)
+ """,
+ relay.Let(
+ multiply,
+ relay.Function(
+ [X, Y],
+ relay.multiply(X, Y),
+ None,
+ []
+ ),
+ relay.Call(multiply, [relay.const(0), relay.const(0)], None, None)
+ )
+ )
+
+ # anonymous function
+ assert_parses_as(
+ """
+ (fn (%x) { %x })(0)
+ """,
+ relay.Call(
+ relay.Function(
+ [X],
+ X,
+ None,
+ []
+ ),
+ [relay.const(0)],
+ None,
+ None
+ )
+ )
+
+ # curried function
+ curried_mult = relay.Var("curried_mult")
+ assert_parses_as(
+ """
+ let %curried_mult =
+ fn (%x) {
+ fn (%y) {
+ %x * %y
+ }
+ };
+ %curried_mult(0);
+ %curried_mult(0)(0)
+ """,
+ relay.Let(
+ curried_mult,
+ relay.Function(
+ [X],
+ relay.Function(
+ [Y],
+ relay.multiply(X, Y),
+ None,
+ []
+ ),
+ None,
+ []
+ ),
+ relay.Let(
+ _,
+ relay.Call(curried_mult, [relay.const(0)], None, None),
+ relay.Call(relay.Call(curried_mult, [relay.const(0)], None, None), [relay.const(0)], None, None)
+ )
+ )
+ )
+
+ # op
+ assert_parses_as(
+ "abs(1)",
+ relay.Call(relay.op.get("abs"), [relay.const(1)], None, None)
+ )
+
+# Types
+
+
+def test_incomplete_type():
+ assert_parses_as(
+ "let %_ : _ = (); ()",
+ relay.Let(
+ _,
+ UNIT,
+ UNIT
+ )
+ )
+
+
+def test_builtin_types():
+ for builtin_type in TYPES:
+ parse_text("let %_ : {} = (); ()".format(builtin_type))
+
+
+def test_tensor_type():
+ assert_parses_as(
+ "let %_ : Tensor[(), float32] = (); ()",
+ relay.Let(
+ relay.Var("_", relay.TensorType((), "float32")),
+ UNIT,
+ UNIT
+ )
+ )
+
+ assert_parses_as(
+ "let %_ : Tensor[(1), float32] = (); ()",
+ relay.Let(
+ relay.Var("_", relay.TensorType((1,), "float32")),
+ UNIT,
+ UNIT
+ )
+ )
+
+ assert_parses_as(
+ "let %_ : Tensor[(1, 1), float32] = (); ()",
+ relay.Let(
+ relay.Var("_", relay.TensorType((1, 1), "float32")),
+ UNIT,
+ UNIT
+ )
+ )
+
+
+def test_function_type():
+ assert_parses_as(
+ """
+ let %_: fn () -> int32 = fn () -> int32 { 0 }; ()
+ """,
+ relay.Let(
+ relay.Var("_", relay.FuncType([], int32, [], [])),
+ relay.Function([], relay.const(0), int32, []),
+ UNIT
+ )
+ )
+
+ assert_parses_as(
+ """
+ let %_: fn (int32) -> int32 = fn (%x: int32) -> int32 { 0 }; ()
+ """,
+ relay.Let(
+ relay.Var("_", relay.FuncType([int32], int32, [], [])),
+ relay.Function([relay.Var("x", int32)], relay.const(0), int32, []),
+ UNIT
+ )
+ )
+
+ assert_parses_as(
+ """
+ let %_: fn (int32, int32) -> int32 = fn (%x: int32, %y: int32) -> int32 { 0 }; ()
+ """,
+ relay.Let(
+ relay.Var("_", relay.FuncType([int32, int32], int32, [], [])),
+ relay.Function([relay.Var("x", int32), relay.Var("y", int32)], relay.const(0), int32, []),
+ UNIT
+ )
+ )
+
+
+def test_tuple_type():
+ assert_parses_as(
+ """
+ let %_: () = (); ()
+ """,
+ relay.Let(
+ relay.Var("_", relay.TupleType([])),
+ UNIT,
+ UNIT
+ )
+ )
+
+ assert_parses_as(
+ """
+ let %_: (int32,) = (0,); ()
+ """,
+ relay.Let(
+ relay.Var("_", relay.TupleType([int32])),
+ relay.Tuple([relay.const(0)]),
+ UNIT
+ )
+ )
+
+ assert_parses_as(
+ """
+ let %_: (int32, int32) = (0, 1); ()
+ """,
+ relay.Let(
+ relay.Var("_", relay.TupleType([int32, int32])),
+ relay.Tuple([relay.const(0), relay.const(1)]),
+ UNIT
+ )
+ )
+
+
+def test_adt_defn():
+ mod = tvm.IRModule()
+
+ glob_typ_var = relay.GlobalTypeVar("Ayy")
+ prog = relay.TypeData(
+ glob_typ_var,
+ [],
+ [relay.Constructor("Nil", [], glob_typ_var)])
+ mod[glob_typ_var] = prog
+ assert_parse_module_as(
+ """
+ type Ayy { Nil }
+ """,
+ mod
+ )
+
+
+def test_empty_adt_defn():
+ mod = tvm.IRModule()
+
+ glob_typ_var = relay.GlobalTypeVar("Ayy")
+ prog = relay.TypeData(glob_typ_var, [], [])
+ mod[glob_typ_var] = prog
+ assert_parse_module_as(
+ """
+ type Ayy { }
+ """,
+ mod
+ )
+
+
+def test_multiple_cons_defn():
+ mod = tvm.IRModule()
+
+ list_var = relay.GlobalTypeVar("List")
+ typ_var = relay.TypeVar("A")
+ prog = relay.TypeData(
+ list_var,
+ [typ_var],
+ [
+ relay.Constructor("Cons", [typ_var, list_var(typ_var)], list_var),
+ relay.Constructor("Nil", [], list_var),
+ ])
+ mod[list_var] = prog
+ assert_parse_module_as(LIST_DEFN, mod)
+
+
+def test_multiple_type_param_defn():
+ glob_typ_var = relay.GlobalTypeVar("Either")
+ typ_var_a = relay.TypeVar("A")
+ typ_var_b = relay.TypeVar("B")
+ prog = relay.TypeData(
+ glob_typ_var,
+ [typ_var_a, typ_var_b],
+ [
+ relay.Constructor("Left", [typ_var_a], glob_typ_var),
+ relay.Constructor("Right", [typ_var_b], glob_typ_var),
+ ])
+ mod = tvm.IRModule()
+ mod[glob_typ_var] = prog
+ assert_parse_module_as(
+ """
+ type Either[A, B] {
+ Left(A),
+ Right(B),
+ }
+ """,
+ mod
+ )
+
+
+def test_match():
+ # pair each match keyword with whether it specifies a complete match or not
+ match_keywords = [("match", True), ("match?", False)]
+ for (match_keyword, is_complete) in match_keywords:
+ mod = tvm.IRModule()
+
+ list_var = relay.GlobalTypeVar("List")
+ typ_var = relay.TypeVar("A")
+ cons_constructor = relay.Constructor(
+ "Cons", [typ_var, list_var(typ_var)], list_var)
+ nil_constructor = relay.Constructor("Nil", [], list_var)
+ list_def = relay.TypeData(
+ list_var,
+ [typ_var],
+ [cons_constructor, nil_constructor])
+ mod[list_var] = list_def
+
+ length_var = relay.GlobalVar("length")
+ typ_var = relay.TypeVar("A")
+ input_type = list_var(typ_var)
+ input_var = relay.Var("xs", input_type)
+ rest_var = relay.Var("rest")
+ cons_case = relay.Let(
+ relay.var("", type_annotation=None),
+ UNIT,
+ relay.add(relay.const(1), relay.Call(length_var, [rest_var])))
+ body = relay.Match(input_var,
+ [relay.Clause(
+ relay.PatternConstructor(
+ cons_constructor,
+ [relay.PatternWildcard(), relay.PatternVar(rest_var)]),
+ cons_case),
+ relay.Clause(
+ relay.PatternConstructor(nil_constructor, []),
+ relay.const(0))],
+ complete=is_complete
+ )
+ length_func = relay.Function(
+ [input_var],
+ body,
+ int32,
+ [typ_var]
+ )
+ mod[length_var] = length_func
+
+ assert_parse_module_as(
+ """
+ %s
+
+ def @length[A](%%xs: List[A]) -> int32 {
+ %s (%%xs) {
+ Cons(_, %%rest : List[A]) => {
+ ();
+ 1 + @length(%%rest)
+ },
+ Nil => 0,
+ }
+ }
+ """ % (LIST_DEFN, match_keyword),
+ mod
+ )
+
+
+def test_adt_cons_expr():
+ mod = tvm.IRModule()
+
+ list_var = relay.GlobalTypeVar("List")
+ typ_var = relay.TypeVar("A")
+ cons_constructor = relay.Constructor(
+ "Cons", [typ_var, list_var(typ_var)], list_var)
+ nil_constructor = relay.Constructor("Nil", [], list_var)
+ list_def = relay.TypeData(
+ list_var,
+ [typ_var],
+ [cons_constructor, nil_constructor])
+ mod[list_var] = list_def
+
+ make_singleton_var = relay.GlobalVar("make_singleton")
+ input_var = relay.Var("x", int32)
+ make_singleton_func = relay.Function(
+ [input_var],
+ cons_constructor(input_var, nil_constructor()),
+ list_var(int32)
+ )
+ mod[make_singleton_var] = make_singleton_func
+
+ assert_parse_module_as(
+ """
+ %s
+
+ def @make_singleton(%%x: int32) -> List[int32] {
+ Cons(%%x, Nil)
+ }
+ """ % LIST_DEFN,
+ mod
+ )
+
+
+@raises_parse_error
+def test_duplicate_adt_defn():
+ parse_module(
+ """
+ %s
+
+ type List[A] {
+ Cons(A, List[A]),
+ Nil,
+ }
+ """ % LIST_DEFN
+ )
+
+
+@raises_parse_error
+def test_duplicate_adt_cons():
+ parse_text(
+ """
+ type Ayy { Lmao }
+ type Haha { Lmao }
+ """
+ )
+
+
+@raises_parse_error
+def test_duplicate_adt_cons_defn():
+ parse_text(
+ """
+ type Ayy { Lmao }
+ type Lmao { Ayy }
+ """
+ )
+
+
+@raises_parse_error
+def test_duplicate_global_var():
+ parse_text(
+ """
+ def @id[A](%x: A) -> A { x }
+ def @id[A](%x: A) -> A { x }
+ """
+ )
+
+
+def test_extern_adt_defn():
+ # TODO(weberlo): update this test once extern is implemented
+ mod = tvm.IRModule()
+
+ extern_var = relay.GlobalTypeVar("T")
+ typ_var = relay.TypeVar("A")
+ extern_def = relay.TypeData(extern_var, [typ_var], [])
+ mod[extern_var] = extern_def
+
+ assert_parse_module_as(
+ """
+ extern type T[A]
+ """,
+ mod
+ )
+
+@pytest.mark.skip("not yet tested on parser 2.0")
+def test_import_grad():
+ mod = tvm.IRModule()
+ mod.import_from_std("gradient.rly")
+
+if __name__ == "__main__":
+ import sys
+ pytest.main(sys.argv)