You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/08/26 21:01:28 UTC

[GitHub] [tvm] sunggg commented on a diff in pull request #12520: [MetaSchedule][UX] Make `Database` with-able

sunggg commented on code in PR #12520:
URL: https://github.com/apache/tvm/pull/12520#discussion_r956439486


##########
python/tvm/meta_schedule/relay_integration.py:
##########
@@ -38,7 +36,7 @@ def extract_task_from_relay(
     opt_level: int = 3,
     pass_config: Optional[Dict[str, Any]] = None,
     disabled_pass: Optional[List[str]] = None,
-    te_filter_func: Union[str, None, Callable[[List[Tensor]], PrimFunc]] = None,
+    tir_converter: str = "default",

Review Comment:
   Do we pass `ffi_key` here? If so, it might be good to clarify in the comment. 



##########
python/tvm/meta_schedule/database/database.py:
##########
@@ -234,6 +234,71 @@ def __len__(self) -> int:
         """
         return _ffi_api.DatabaseSize(self)  # type: ignore # pylint: disable=no-member
 
+    def query_tuning_record(self, mod: IRModule, target: Target) -> Optional[TuningRecord]:

Review Comment:
   Nit: do we also have the APIs to fetch data other than the best one? Since database stores every record (we keep the records even though it is not the best one), maybe good to offer API to access them. Might be useful for certain cases like cost model research



##########
src/relay/backend/utils.cc:
##########
@@ -368,6 +371,71 @@ void BindParamsInModule(IRModule mod, Map<String, runtime::NDArray> params) {
   BindParamsInModule(mod, params_tmp);
 }
 
+Optional<tir::PrimFunc> DefaultTIRConverterImpl(const Array<te::Tensor>& args,

Review Comment:
   Add some comments what this function does. 



##########
src/relay/backend/te_compiler_cache.cc:
##########
@@ -359,32 +350,43 @@ class ScheduleBuilder : public ExprVisitor {
           schedule = Downcast<te::Schedule>(obj);
         }
       }
-      if (meta_schedule_ctx_) {
+      if (database_) {
+        using tvm::meta_schedule::TuningRecord;
+        using tvm::tir::IndexMap;
+        using tvm::tir::Instruction;
+        using tvm::tir::InstructionKind;
+        using tvm::tir::PrimFunc;
+        using tvm::tir::Schedule;
+        backend::FTECompilerTIRConverter tir_converter = backend::GetTIRConverter();
         Array<te::Tensor> te_args = Concat(fn_inputs, tensor_outs);
         Array<runtime::NDArray> constants;
         for (auto [const_node, te_tensor] : lower_te_compute.constant_tensors_) {
           te_args.push_back(te_tensor);
           constants.push_back(const_node->data);
         }
-
-        if (Optional<tir::PrimFunc> tir_func =
-                meta_schedule_ctx_.value()->te_filter_func(te_args, constants)) {
-          IRModule relay_mod({{prim_fn_var, relay_func}});
-          IRModule tir_mod({{prim_fn_var, tir_func.value()}});
-          if (Optional<IRModule> opt_scheduled_mod = meta_schedule_ctx_.value()->Query(
-                  /*task_name=*/prim_fn_var->name_hint,     //
-                  /*mod=*/relay_mod,                        //
-                  /*target=*/target_,                       //
-                  /*dispatched=*/Array<IRModule>{tir_mod},  //
-                  /*f_take_tuning_record=*/ExtractTransformLayout)) {
-            IRModule scheduled_mod =
-                tir::transform::RemoveWeightLayoutRewriteBlock()(opt_scheduled_mod.value());
-            ICHECK_EQ(scheduled_mod->functions.count(prim_fn_var), 1);
-            prim_func = Downcast<tir::PrimFunc>(scheduled_mod->functions[prim_fn_var]);
+        if (Optional<PrimFunc> f = tir_converter(te_args, constants)) {
+          if (Optional<TuningRecord> opt_record = database_.value()->QueryTuningRecord(

Review Comment:
   What is the fallback strategy for meta_schedule if schedule is not found in the DB? 



##########
src/relay/backend/te_compiler_cache.cc:
##########
@@ -359,32 +350,43 @@ class ScheduleBuilder : public ExprVisitor {
           schedule = Downcast<te::Schedule>(obj);
         }
       }
-      if (meta_schedule_ctx_) {
+      if (database_) {
+        using tvm::meta_schedule::TuningRecord;
+        using tvm::tir::IndexMap;
+        using tvm::tir::Instruction;
+        using tvm::tir::InstructionKind;
+        using tvm::tir::PrimFunc;
+        using tvm::tir::Schedule;
+        backend::FTECompilerTIRConverter tir_converter = backend::GetTIRConverter();
         Array<te::Tensor> te_args = Concat(fn_inputs, tensor_outs);
         Array<runtime::NDArray> constants;
         for (auto [const_node, te_tensor] : lower_te_compute.constant_tensors_) {
           te_args.push_back(te_tensor);
           constants.push_back(const_node->data);
         }
-
-        if (Optional<tir::PrimFunc> tir_func =
-                meta_schedule_ctx_.value()->te_filter_func(te_args, constants)) {
-          IRModule relay_mod({{prim_fn_var, relay_func}});
-          IRModule tir_mod({{prim_fn_var, tir_func.value()}});
-          if (Optional<IRModule> opt_scheduled_mod = meta_schedule_ctx_.value()->Query(
-                  /*task_name=*/prim_fn_var->name_hint,     //
-                  /*mod=*/relay_mod,                        //
-                  /*target=*/target_,                       //
-                  /*dispatched=*/Array<IRModule>{tir_mod},  //
-                  /*f_take_tuning_record=*/ExtractTransformLayout)) {
-            IRModule scheduled_mod =
-                tir::transform::RemoveWeightLayoutRewriteBlock()(opt_scheduled_mod.value());
-            ICHECK_EQ(scheduled_mod->functions.count(prim_fn_var), 1);
-            prim_func = Downcast<tir::PrimFunc>(scheduled_mod->functions[prim_fn_var]);
+        if (Optional<PrimFunc> f = tir_converter(te_args, constants)) {
+          if (Optional<TuningRecord> opt_record = database_.value()->QueryTuningRecord(
+                  /*mod=*/backend::PrimFuncToIRModule(f.value()),
+                  /*target=*/target_)) {
+            static InstructionKind kind_transform_layout = InstructionKind::Get("TransformLayout");
+            TuningRecord record = opt_record.value();
+            for (const Instruction& inst : record->trace->insts) {
+              if (inst->kind.same_as(kind_transform_layout)) {
+                ICHECK_EQ(inst->attrs.size(), 3);
+                MetaScheduleLayoutRewriter::LayoutQueuePush(Downcast<IndexMap>(inst->attrs[2]));

Review Comment:
   Could you explain what this does?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org