You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/09/07 16:15:28 UTC

[GitHub] [tvm] jwfromm commented on a change in pull request #8885: [PROFILING] Profiling over RPC

jwfromm commented on a change in pull request #8885:
URL: https://github.com/apache/tvm/pull/8885#discussion_r703648969



##########
File path: python/tvm/runtime/profiler_vm.py
##########
@@ -35,10 +37,17 @@ class VirtualMachineProfiler(vm.VirtualMachine):
 
     def __init__(self, exe, device, memory_cfg=None):
         super(VirtualMachineProfiler, self).__init__(exe, device, memory_cfg)
-        self.module = _ffi_api._VirtualMachineDebug(exe.module)
+
+        # Make sure the constructor of the VM module is on the proper device
+        if device.device_type >= rpc_base.RPC_SESS_MASK:

Review comment:
       Checking device type using a `>=` is kind of confusing. If there isn't a better way to check can you explain why the `>=` works in a comment.

##########
File path: src/runtime/profiling.cc
##########
@@ -540,6 +543,72 @@ Report::Report(Array<Map<String, ObjectRef>> calls,
   data_ = std::move(node);
 }
 
+Map<String, ObjectRef> parse_metrics(dmlc::JSONReader* reader) {
+  reader->BeginObject();
+  std::string metric_name, metric_value_name;
+  Map<String, ObjectRef> metrics;
+  while (reader->NextObjectItem(&metric_name)) {
+    ObjectRef o;
+    reader->BeginObject();
+    reader->NextObjectItem(&metric_value_name);
+    if (metric_value_name == "microseconds") {
+      double microseconds;
+      reader->Read(&microseconds);
+      o = ObjectRef(make_object<DurationNode>(microseconds));
+    } else if (metric_value_name == "percent") {
+      double percent;
+      reader->Read(&percent);
+      o = ObjectRef(make_object<PercentNode>(percent));
+    } else if (metric_value_name == "count") {
+      int64_t count;
+      reader->Read(&count);
+      o = ObjectRef(make_object<CountNode>(count));
+    } else if (metric_value_name == "string") {
+      std::string s;
+      reader->Read(&s);
+      o = String(s);
+    } else {
+      LOG(FATAL) << "Cannot parse metric of type " << metric_value_name
+                 << " valid types are microseconds, percent, count.";
+    }
+    metrics.Set(metric_name, o);
+    // Necessary to make sure that the parser hits the end of the object.
+    ICHECK(!reader->NextObjectItem(&metric_value_name));
+    // EndObject does not exist, leaving this here for clarity
+    // reader.EndObject();
+  }
+  // reader.EndObject();
+  return metrics;
+}
+
+Report Report::FromJSON(String json) {
+  std::stringstream input(json.operator std::string());
+  dmlc::JSONReader reader(&input);
+  std::string key;
+  Array<Map<String, ObjectRef>> calls;
+  Map<String, Map<String, ObjectRef>> device_metrics;
+
+  reader.BeginObject();
+  while (reader.NextObjectItem(&key)) {
+    if (key == "calls") {
+      reader.BeginArray();
+      while (reader.NextArrayItem()) {
+        calls.push_back(parse_metrics(&reader));
+      }
+      // reader.EndArray();

Review comment:
       should these commented out lines be removed?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org