You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by jr...@apache.org on 2020/11/06 02:11:29 UTC
[incubator-tvm] 05/21: Borrow code from Egg
This is an automated email from the ASF dual-hosted git repository.
jroesch pushed a commit to branch cargo-build
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
commit b2b59c229e9b8c2002d8c8cd520748df6b38e074
Author: Jared Roesch <jr...@octoml.ai>
AuthorDate: Tue Oct 13 15:26:54 2020 -0700
Borrow code from Egg
---
rust/compiler-ext/src/lib.rs | 344 ++++++++++++++++++++++++++++++++++++++++++-
1 file changed, 337 insertions(+), 7 deletions(-)
diff --git a/rust/compiler-ext/src/lib.rs b/rust/compiler-ext/src/lib.rs
index 31e1bb2..58bdd0c 100644
--- a/rust/compiler-ext/src/lib.rs
+++ b/rust/compiler-ext/src/lib.rs
@@ -1,7 +1,337 @@
-#[cfg(test)]
-mod tests {
- #[test]
- fn it_works() {
- assert_eq!(2 + 2, 4);
- }
-}
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+ use std::os::raw::c_int;
+ use tvm::initialize;
+ use tvm::ir::{tir, PrimExpr};
+ use tvm::runtime::function::register_override;
+ use tvm::runtime::map::Map;
+ use tvm::runtime::object::{IsObject, IsObjectRef};
+
+ use ordered_float::NotNan;
+
+ mod interval;
+ mod math;
+
+ use math::{BoundsMap, Expr, RecExpr};
+ use tvm::ir::arith::ConstIntBound;
+ use tvm_rt::{ObjectRef, array::Array};
+
+ macro_rules! downcast_match {
+ ($id:ident; { $($t:ty => $arm:expr $(,)? )+ , else => $default:expr }) => {
+ $( if let Ok($id) = $id.downcast_clone::<$t>() { $arm } else )+
+ { $default }
+ }
+ }
+
+ #[derive(Default)]
+ struct VarMap {
+ vars: Vec<(tvm::ir::tir::Var, egg::Symbol)>,
+ objs: Vec<ObjectRef>,
+ }
+
+ impl VarMap {
+ // FIXME this should eventually do the right thing for TVM variables
+ // right now it depends on them having unique names
+ fn make_symbol(&mut self, var: tvm::ir::tir::Var) -> egg::Symbol {
+ let sym = egg::Symbol::from(var.name_hint.as_str().unwrap());
+ for (_, sym2) in &self.vars {
+ if sym == *sym2 {
+ return sym;
+ }
+ }
+
+ self.vars.push((var, sym));
+ sym
+ }
+
+ fn get_symbol(&self, sym: egg::Symbol) -> tvm::ir::tir::Var {
+ for (v, sym2) in &self.vars {
+ if sym == *sym2 {
+ return v.clone();
+ }
+ }
+ panic!("Should have found a var")
+ }
+
+ fn push_obj(&mut self, obj: impl IsObjectRef) -> usize {
+ let i = self.objs.len();
+ self.objs.push(obj.upcast());
+ i
+ }
+
+ fn get_obj<T: IsObjectRef>(&self, i: usize) -> T {
+ self.objs[i].clone().downcast().expect("bad downcast")
+ }
+ }
+
+ fn to_egg(vars: &mut VarMap, prim: &PrimExpr) -> RecExpr {
+ fn build(vars: &mut VarMap, p: &PrimExpr, recexpr: &mut RecExpr) -> egg::Id {
+ macro_rules! r {
+ ($e:expr) => {
+ build(vars, &$e, recexpr)
+ };
+ }
+
+ let dt = recexpr.add(Expr::DataType(p.datatype));
+ let e = downcast_match!(p; {
+ tir::Add => Expr::Add([dt, r!(p.a), r!(p.b)]),
+ tir::Sub => Expr::Sub([dt, r!(p.a), r!(p.b)]),
+ tir::Mul => Expr::Mul([dt, r!(p.a), r!(p.b)]),
+
+ tir::Div => Expr::Div([dt, r!(p.a), r!(p.b)]),
+ tir::Mod => Expr::Mod([dt, r!(p.a), r!(p.b)]),
+ tir::FloorDiv => Expr::FloorDiv([dt, r!(p.a), r!(p.b)]),
+ tir::FloorMod => Expr::FloorMod([dt, r!(p.a), r!(p.b)]),
+
+ tir::Min => Expr::Min([dt, r!(p.a), r!(p.b)]),
+ tir::Max => Expr::Max([dt, r!(p.a), r!(p.b)]),
+
+ tir::Ramp => Expr::Ramp([dt, r!(p.start), r!(p.stride), recexpr.add(Expr::Int(p.lanes.into()))]),
+ tir::Select => Expr::Select([dt, r!(p.condition), r!(p.true_value), r!(p.false_value)]),
+
+ tir::Eq => Expr::Equal([dt, r!(p.a), r!(p.b)]),
+ tir::Ne => Expr::NotEqual([dt, r!(p.a), r!(p.b)]),
+ tir::Lt => Expr::Less([dt, r!(p.a), r!(p.b)]),
+ tir::Le => Expr::LessEqual([dt, r!(p.a), r!(p.b)]),
+ tir::Gt => Expr::Greater([dt, r!(p.a), r!(p.b)]),
+ tir::Ge => Expr::GreaterEqual([dt, r!(p.a), r!(p.b)]),
+
+ tir::And => Expr::And([dt, r!(p.a), r!(p.b)]),
+ tir::Or => Expr::Or([dt, r!(p.a), r!(p.b)]),
+ tir::Not => Expr::Not([dt, r!(p.value)]),
+
+ tir::Broadcast => Expr::Broadcast([dt, r!(p.value), recexpr.add(Expr::Int(p.lanes.into()))]),
+
+ tir::Let => {
+ let sym = recexpr.add(Expr::Symbol(vars.make_symbol(p.var.clone())));
+ Expr::Let([dt, sym, r!(p.value), r!(p.body)])
+ }
+ tir::Var => {
+ let sym = recexpr.add(Expr::Symbol(vars.make_symbol(p)));
+ Expr::Var([dt, sym])
+ }
+ tir::IntImm => {
+ let int = recexpr.add(Expr::Int(p.value));
+ Expr::IntImm([dt, int])
+ }
+ tir::FloatImm => {
+ let float = recexpr.add(Expr::Float(NotNan::new(p.value).unwrap()));
+ Expr::FloatImm([dt, float])
+ }
+ tir::Cast => Expr::Cast([dt, r!(p.value)]),
+
+ tir::Call => {
+ let op = vars.push_obj(p.op.clone());
+ let mut arg_ids = vec![dt];
+ for i in 0..p.args.len() {
+ let arg: PrimExpr = p.args.get(i as isize).expect("array get fail");
+ arg_ids.push(r!(arg));
+ }
+ Expr::Call(op, arg_ids)
+ },
+ tir::Load => {
+ let sym = recexpr.add(Expr::Symbol(vars.make_symbol(p.var.clone())));
+ Expr::Load([dt, sym, r!(p.index), r!(p.predicate)])
+ },
+ else => {
+ println!("Failed to downcast type '{}': {}", p.type_key(), tvm::runtime::debug_print(p.clone().upcast()).unwrap().to_str().unwrap());
+ Expr::Object(vars.push_obj(p.clone()))
+ }
+ });
+
+ recexpr.add(e)
+ }
+
+ let mut recexpr = Default::default();
+ build(vars, prim, &mut recexpr);
+ recexpr
+ }
+
+ fn from_egg(vars: &VarMap, recexpr: &RecExpr) -> PrimExpr {
+ fn build(vars: &VarMap, nodes: &[Expr]) -> PrimExpr {
+ let go = |i: &egg::Id| build(vars, &nodes[..usize::from(*i) + 1]);
+ let get_dt = |i: &egg::Id| nodes[usize::from(*i)].to_dtype().unwrap();
+ let prim: PrimExpr = match nodes.last().expect("cannot be empty") {
+ Expr::Var([_dt, s]) => match &nodes[usize::from(*s)] {
+ Expr::Symbol(sym) => vars.get_symbol(*sym).upcast(),
+ n => panic!("Expected a symbol, got {:?}", n),
+ },
+ Expr::IntImm([dt, v]) => {
+ let value = nodes[usize::from(*v)].to_int().unwrap();
+ tir::IntImm::new(get_dt(dt), value).upcast()
+ }
+ Expr::FloatImm([dt, v]) => {
+ let value = nodes[usize::from(*v)].to_float().unwrap();
+ tir::FloatImm::new(get_dt(dt), value).upcast()
+ }
+ Expr::Let([dt, s, value, body]) => {
+ let var = match &nodes[usize::from(*s)] {
+ Expr::Symbol(sym) => vars.get_symbol(*sym).upcast(),
+ n => panic!("Expected a symbol, got {:?}", n),
+ };
+ tir::Let::new(get_dt(dt), var, go(value), go(body)).upcast()
+ }
+ Expr::Load([dt, s, value, body]) => {
+ let var = match &nodes[usize::from(*s)] {
+ Expr::Symbol(sym) => vars.get_symbol(*sym).upcast(),
+ n => panic!("Expected a symbol, got {:?}", n),
+ };
+ tir::Load::new(get_dt(dt), var, go(value), go(body)).upcast()
+ }
+
+ Expr::Add([dt, a, b]) => tir::Add::new(get_dt(dt), go(a), go(b)).upcast(),
+ Expr::Sub([dt, a, b]) => tir::Sub::new(get_dt(dt), go(a), go(b)).upcast(),
+ Expr::Mul([dt, a, b]) => tir::Mul::new(get_dt(dt), go(a), go(b)).upcast(),
+
+ Expr::Div([dt, a, b]) => tir::Div::new(get_dt(dt), go(a), go(b)).upcast(),
+ Expr::Mod([dt, a, b]) => tir::Mod::new(get_dt(dt), go(a), go(b)).upcast(),
+ Expr::FloorDiv([dt, a, b]) => tir::FloorDiv::new(get_dt(dt), go(a), go(b)).upcast(),
+ Expr::FloorMod([dt, a, b]) => tir::FloorMod::new(get_dt(dt), go(a), go(b)).upcast(),
+
+ Expr::Min([dt, a, b]) => tir::Min::new(get_dt(dt), go(a), go(b)).upcast(),
+ Expr::Max([dt, a, b]) => tir::Max::new(get_dt(dt), go(a), go(b)).upcast(),
+
+ Expr::Equal([dt, a, b]) => tir::Eq::new(get_dt(dt), go(a), go(b)).upcast(),
+ Expr::NotEqual([dt, a, b]) => tir::Ne::new(get_dt(dt), go(a), go(b)).upcast(),
+ Expr::Less([dt, a, b]) => tir::Lt::new(get_dt(dt), go(a), go(b)).upcast(),
+ Expr::LessEqual([dt, a, b]) => tir::Le::new(get_dt(dt), go(a), go(b)).upcast(),
+ Expr::Greater([dt, a, b]) => tir::Gt::new(get_dt(dt), go(a), go(b)).upcast(),
+ Expr::GreaterEqual([dt, a, b]) => tir::Ge::new(get_dt(dt), go(a), go(b)).upcast(),
+
+ Expr::And([dt, a, b]) => tir::And::new(get_dt(dt), go(a), go(b)).upcast(),
+ Expr::Or([dt, a, b]) => tir::Or::new(get_dt(dt), go(a), go(b)).upcast(),
+ Expr::Not([dt, a]) => tir::Not::new(get_dt(dt), go(a)).upcast(),
+
+ Expr::Ramp([dt, a, b, c]) => {
+ let len = &nodes[usize::from(*c)];
+ let i = len
+ .to_int()
+ .unwrap_or_else(|| panic!("Ramp lanes must be an int, got {:?}", len));
+ tir::Ramp::new(get_dt(dt), go(a), go(b), i as i32).upcast()
+ }
+ Expr::Broadcast([dt, val, lanes]) => {
+ let lanes = &nodes[usize::from(*lanes)];
+ let lanes = lanes
+ .to_int()
+ .unwrap_or_else(|| panic!("Ramp lanes must be an int, got {:?}", lanes));
+ println!("dt: {}", get_dt(dt));
+ tir::Broadcast::new(get_dt(dt), go(val), lanes as i32).upcast()
+ }
+
+ Expr::Select([dt, a, b, c]) => tir::Select::new(get_dt(dt), go(a), go(b), go(c)).upcast(),
+ Expr::Cast([dt, a]) => tir::Cast::new(get_dt(dt), go(a)).upcast(),
+ Expr::Call(expr, args) => {
+ let arg_exprs: Vec<PrimExpr> = args[1..].iter().map(go).collect();
+ let arg_exprs = Array::from_vec(arg_exprs).expect("failed to convert args");
+ tir::Call::new(get_dt(&args[0]), vars.get_obj(*expr), arg_exprs).upcast()
+ }
+
+ Expr::Object(i) => vars.get_obj(*i),
+ node => panic!("I don't know how to extract {:?}", node),
+ };
+ assert_ne!(prim.datatype.bits(), 0);
+ assert_ne!(prim.datatype.lanes(), 0);
+ prim
+ }
+ build(vars, recexpr.as_ref())
+ }
+
+ fn run(
+ input: PrimExpr,
+ expected: Option<PrimExpr>,
+ map: Map<PrimExpr, ConstIntBound>,
+ ) -> Result<PrimExpr, String> {
+ use egg::{CostFunction, Extractor};
+
+ let mut bounds = BoundsMap::default();
+ for (k, v) in map {
+ if let Ok(var) = k.downcast_clone::<tir::Var>() {
+ let sym: egg::Symbol = var.name_hint.as_str().unwrap().into();
+ bounds.insert(sym, (v.min_value, v.max_value));
+ } else {
+ println!("Non var in bounds map: {}", tvm::ir::as_text(k));
+ }
+ }
+
+ let mut vars = VarMap::default();
+ let expr = to_egg(&mut vars, &input);
+ let mut runner = math::default_runner();
+ runner.egraph.analysis.bounds = bounds;
+
+ let mut runner = runner.with_expr(&expr).run(&math::rules());
+ // runner.print_report();
+ let mut extractor = Extractor::new(&runner.egraph, math::CostFn);
+ let root = runner.egraph.find(runner.roots[0]);
+ let (cost, best) = extractor.find_best(root);
+ if let Some(expected) = expected {
+ let mut expected_vars = VarMap::default();
+ let expected_expr = to_egg(&mut expected_vars, &expected);
+ let expected_root = runner.egraph.add_expr(&expected_expr);
+ if expected_root != root {
+ return Err(format!(
+ "\n\nFailed to prove them equal!\nExpected:\n{}\nFound:\n{}\n",
+ expected_expr.pretty(40),
+ best.pretty(40)
+ ));
+ }
+ let expected_cost = math::CostFn.cost_rec(&expected_expr);
+ if expected_cost != cost {
+ let msg = format!(
+ "\n\nCosts not equal: Expected {}:\n{}\nFound {}:\n{}\n",
+ expected_cost,
+ expected_expr.pretty(40),
+ cost,
+ best.pretty(40)
+ );
+ if cost < expected_cost {
+ println!("egg wins: {}", msg)
+ } else {
+ return Err(msg);
+ }
+ }
+ }
+ log::info!(" returning... {}", best.pretty(60));
+ Ok(from_egg(&vars, &best))
+ }
+
+ fn simplify(prim: PrimExpr, map: Map<PrimExpr, ConstIntBound>) -> Result<PrimExpr, tvm::Error> {
+ log::debug!("map: {:?}", map);
+ run(prim, None, map).map_err(tvm::Error::CallFailed)
+ }
+
+ fn simplify_and_check(
+ prim: PrimExpr,
+ check: PrimExpr,
+ map: Map<PrimExpr, ConstIntBound>,
+ ) -> Result<PrimExpr, tvm::Error> {
+ log::debug!("check map: {:?}", map);
+ run(prim, Some(check), map).map_err(tvm::Error::CallFailed)
+ }
+
+ initialize!({
+ let _ = env_logger::try_init();
+ // NOTE this print prevents a segfault (on Linux) for now...
+ println!("Initializing simplifier... ");
+ register_override(simplify, "egg.simplify", true).expect("failed to initialize simplifier");
+ register_override(simplify_and_check, "egg.simplify_and_check", true)
+ .expect("failed to initialize simplifier");
+ log::debug!("done!");
+ });
+
\ No newline at end of file