You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2019/11/22 10:37:25 UTC

[GitHub] [incubator-tvm] mbarrett97 commented on a change in pull request #4151: [AutoTVM] Added an autotuning Config Library to store autotune results

mbarrett97 commented on a change in pull request #4151: [AutoTVM] Added an autotuning Config Library to store autotune results
URL: https://github.com/apache/incubator-tvm/pull/4151#discussion_r349532311
 
 

 ##########
 File path: python/tvm/autotvm/config_library.py
 ##########
 @@ -0,0 +1,279 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.import datetime
+#pylint: disable=arguments-differ, method-hidden, inconsistent-return-statements
+"""Config Library to store tuning configs"""
+import json
+import os
+from _ctypes import PyObj_FromPtr
+from shutil import copyfile
+from pathlib import Path
+
+import numpy as np
+
+from . import record
+from .task import ApplyHistoryBest
+
+
+class ConfigLibraryException(Exception):
+    pass
+
+
+class ConfigLibrary:
+    """A library to store auto-tuning results for any number of targets/workloads.
+
+    Parameters
+    ----------
+    library_dir: str
+        Path to the config library directory. If the library does not already
+        exist, a new library will be initialised in the directory. This will
+        create an 'index.json' file in the directory which contains the location
+        of all other files used to store configs in the library.
+
+    """
+
+    LIBRARY_INDEX_FILE_NAME = "index.json"
+    JOBS_INDEX_FILE_NAME = "jobs.json"
+    JOBS_DIRECTORY_NAME = "jobs"
+
+    def __init__(self, library_dir):
+        # Handle if the directory doesn't exist
+        if not os.path.isdir(library_dir):
+            os.mkdir(library_dir)
+
+        index_file = os.path.join(library_dir, self.LIBRARY_INDEX_FILE_NAME)
+        if not os.path.isfile(index_file):
+            with open(index_file, "w") as f:
+                full_index_path = os.path.abspath(f.name)
+                index = {
+                    "root": os.path.dirname(full_index_path),
+                    "targets": {},
+                    "jobs_index": self.JOBS_INDEX_FILE_NAME,
+                    "jobs_dir": self.JOBS_DIRECTORY_NAME,
+                }
+                json.dump(index, f, indent=4)
+
+            with open(os.path.join(library_dir, self.JOBS_INDEX_FILE_NAME), "w") as f:
+                json.dump({}, f)
+
+        self.library_dir = library_dir
+        self.library_index = index_file
+        self.jobs_dir = os.path.join(self.library_dir, self.JOBS_DIRECTORY_NAME)
+        self.jobs_index = os.path.join(self.library_dir, self.JOBS_INDEX_FILE_NAME)
+        if not os.path.isdir(self.jobs_dir):
+            os.makedirs(self.jobs_dir)
+
+    def load(self, target):
+        """Load the configs for a given TVM target string.
+
+        Returns a DispatchContext with the appropriate configs loaded."""
+        target_configs = self._load_target_configs(target)
+        return ApplyHistoryBest(target_configs)
+
+    def _load_target_configs(self, target):
+        """Yield the configs in the library for a given target."""
+        target_file = self.get_config_file(target, create=False)
+        if target_file:
+            with open(target_file) as f:
+                configs = json.load(f)
+                for config in configs.values():
+                    row = json.dumps(config)
+                    yield record.decode(row)
+
+        else:
+            yield from []
+
+    def save_job(self, job, save_history=True):
+        """Save the results of an auto-tuning job to the library.
+
+        Parameters
+        ----------
+        job: TuningJob
+            The auto-tuning job to save.
+        save_history: bool
+            Whether to save the history log of the job.
+
+        """
+        with open(self.jobs_index, 'r+') as f:
+            job_index = json.load(f)
+            highest_job_id = 0
+            for job_id in job_index:
+                highest_job_id = max(highest_job_id, int(job_index[job_id]["id"]))
+
+            job_id = str(highest_job_id + 1)
+            job_log = None
+            if save_history:
+                job_log = self._create_job_log(job_id, job.target)
+                copyfile(job.log, job_log)
+
+            job_entry = {
+                "id": job_id,
+                "log": job_log,
+                "target": job.target,
+                "platform": job.platform,
+                "content": job.content,
+                "start_time": job.start_time,
+                "end_time": job.end_time,
+                "tasks": [],
+            }
+
+            for workload in job.results_by_workload:
+                inp, best_result, tuner_name, trials = job.results_by_workload[workload]
+                config_entry_str = record.encode(inp, best_result, 'json')
+                config_entry = json.loads(config_entry_str)
+                config_entry["t"] = [job_id, tuner_name, trials]
+                task_entry = json.dumps(config_entry)
+                self.save_config(config_entry)
+                job_entry["tasks"].append(task_entry)
+
+            job_index[job_id] = job_entry
+            f.seek(0)
+            json.dump(job_index, f, indent=4)
+
+
+    def _create_job_log(self, job_id, target):
+        """Returns a path to a job log file.
+
+        This will delete any log that exists with the same name."""
+        log_name = job_id + "_" + (
+            target.replace(" ", "").replace("-device=", "_").replace("-model=", "_")
+        )
+        job_log = os.path.join(self.jobs_dir, log_name + ".log")
+        if os.path.isfile(job_log):
+            os.remove(job_log)
+
+        Path(job_log).touch()
+        return job_log
+
+    def save_config(self, new_config):
+        """Save a config to the library if it's better than existing entries."""
+        target = new_config["i"][0]
+        workload = str(new_config["i"][4])
+        new_config_key = workload
+        config_file = self.get_config_file(target)
+        with open(config_file, 'r+') as f:
+            existing_configs = json.load(f)
+            if new_config_key in existing_configs:
+                existing_config = existing_configs[new_config_key]
+                if np.mean(new_config["r"][0]) < np.mean(existing_config["r"][0]):
+                    existing_configs[new_config_key] = new_config
+            else:
+                existing_configs[new_config_key] = new_config
+
+            for config_key in existing_configs:
+                existing_configs[config_key] = NoIndent(existing_configs[config_key])
+
+            configs_str = json.dumps(existing_configs, indent=4, cls=NoIndentEncoder)
+            # Delete the current file, then write the new configs
+            # TODO @mbarrett97 We should make a copy of the existing file here in
+            # case something bad happens before the write and the data is lost
+            f.truncate(0)
+            f.seek(0)
+            f.write(configs_str)
+
+    def get_config(self, target, workload):
+        """Get a config for a given target/workload from the library.
+
+        Parameters
+        ----------
+        target: str
+            The target string of the config.
+        workload: list
+            The workload of the config.
+
+        Returns
+        -------
+        config: Union[dict, None]
+            The config for the specified task. Returns None if no config was
+            found.
+
+        """
+        config_file = self.get_config_file(target, create=False)
+        if config_file:
+            with open(config_file) as f:
+                configs = json.load(f)
+                workload_key = str(workload)
+                if workload_key in configs:
+                    return configs[workload_key]
+
+        return None
+
+    def get_config_file(self, target, create=True):
+        """Return the config file path associated with a given target"""
+        with open(self.library_index, "r+") as f:
+            config_index = json.load(f)
+
+        target_file_name = self._get_target_file_name(target)
+        root = config_index["root"]
+        config_files = config_index["targets"]
+        if target_file_name in config_files:
+            return config_files[target_file_name]
+        elif create:  # Create the file if it's not already in the index
+            with open(self.library_index, "w") as f:
+                config_file_name = target_file_name + ".configs"
+                config_file = os.path.join(root, config_file_name)
+                with open(config_file, "w") as g:
+                    json.dump({}, g)
+
+                config_index["targets"][target_file_name] = config_file
+                json.dump(config_index, f, indent=4)
+                return config_file
+
+        return None
+
+    @staticmethod
+    def _get_target_file_name(target):
+        """Create a file name from a TVM target string."""
+        options = target.split(" ")
+        sorted_options = [options[0]] + sorted(options[1:])
+        return "-".join(sorted_options)
+
+
+class NoIndent(object):
+    def __init__(self, value):
+        self.value = value
+
+
+# TODO @mbarrett97 Find a more efficient way to pretty print the JSON
+class NoIndentEncoder(json.JSONEncoder):
+    """JSON pretty printing class to print configs on one line."""
+    marker = "~@{}@~"
+
+    def default(self, obj):
+        if isinstance(obj, NoIndent):
+            return self.marker.format(id(obj))
+        else:
+            super(NoIndentEncoder, self).default(obj)
+
+    def encode(self, obj):
+        json_obj = super(NoIndentEncoder, self).encode(obj)
+        offset = 0
+        while offset != -1:
+            no_indent_index_start = json_obj.find("~@", offset)
+            no_indent_index_stop = json_obj.find("@~", offset)
+            if no_indent_index_start == -1:
+                break
+
+            no_indent_id = int(json_obj[no_indent_index_start+2:no_indent_index_stop])
+            no_indent_obj = PyObj_FromPtr(no_indent_id)
+            no_indent_json = json.dumps(no_indent_obj.value)
+            json_obj = json_obj.replace(
+                '"{}"'.format(self.marker.format(no_indent_id)),
+                no_indent_json,
+            )
+            offset = no_indent_index_stop + 2
+
+        return json_obj
 
 Review comment:
   I store the configs currently as 
   { 
     [workload_str]: config,
     [workload_str]: config,
   }
   to make it simpler to query by workload. The issue I was finding was when using the JSON indent option, the configs got massively expanded as they are themselves heavily nested objects. This encoder makes it so the configs ignore the indent value and are printed on one line. I'm very open to an alternative method here as it seems quite messy!

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services