diff --git a/src/garage/experiment/experiment.py b/src/garage/experiment/experiment.py index f697db6f82..101540a0cb 100644 --- a/src/garage/experiment/experiment.py +++ b/src/garage/experiment/experiment.py @@ -10,10 +10,12 @@ import pathlib import subprocess import warnings +import weakref import dateutil.tz import dowel from dowel import logger +import numpy as np import __main__ as main @@ -456,7 +458,15 @@ def dump_json(filename, data): """ pathlib.Path(os.path.dirname(filename)).mkdir(parents=True, exist_ok=True) with open(filename, 'w') as f: - json.dump(data, f, indent=2, sort_keys=True, cls=LogEncoder) + # We do our own circular reference handling. + # Sometimes sort_keys fails because the keys don't get made into + # strings early enough. + json.dump(data, + f, + indent=2, + sort_keys=False, + cls=LogEncoder, + check_circular=False) def get_metadata(): @@ -550,7 +560,24 @@ def make_launcher_archive(*, git_root_path, log_dir): class LogEncoder(json.JSONEncoder): - """Encoder to be used as cls in json.dump.""" + """Encoder to be used as cls in json.dump. + + Args: + args (object): Passed to super class. + kwargs (dict): Passed to super class. + + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._markers = {} + + # Modules whose contents cannot be meaningfully or safelly jsonified. + BLOCKED_MODULES = { + 'tensorflow', + 'ray', + 'itertools', + } def default(self, o): """Perform JSON encoding. @@ -558,20 +585,111 @@ def default(self, o): Args: o (object): Object to encode. + Raises: + TypeError: If `o` cannot be turned into JSON even using `repr(o)`. + Returns: - str: Object encoded in JSON. + dict or str or float or bool: Object encoded in JSON. """ # Why is this method hidden? What does that mean? # pylint: disable=method-hidden + # pylint: disable=too-many-branches + # pylint: disable=too-many-return-statements + # This circular reference checking code was copied from the standard + # library json implementation, but it outputs a repr'd string instead + # of ValueError on a circular reference. + if isinstance(o, (int, bool, float, str)): + return o + else: + markerid = id(o) + if markerid in self._markers: + return 'circular ' + repr(o) + else: + self._markers[markerid] = o + try: + return self._default_inner(o) + finally: + del self._markers[markerid] + + def _default_inner(self, o): + """Perform JSON encoding. + + Args: + o (object): Object to encode. + + Raises: + TypeError: If `o` cannot be turned into JSON even using `repr(o)`. + ValueError: If raised by calling repr on an object. - if isinstance(o, type): - return {'$class': o.__module__ + '.' + o.__name__} - elif isinstance(o, enum.Enum): - return { - '$enum': - o.__module__ + '.' + o.__class__.__name__ + '.' + o.name - } - elif callable(o): - return {'$function': o.__module__ + '.' + o.__name__} - return json.JSONEncoder.default(self, o) + Returns: + dict or str or float or bool: Object encoded in JSON. + + """ + # Why is this method hidden? What does that mean? + # pylint: disable=method-hidden + # pylint: disable=too-many-branches + # pylint: disable=too-many-return-statements + # This circular reference checking code was copied from the standard + # library json implementation, but it outputs a repr'd string instead + # of ValueError on a circular reference. + try: + return json.JSONEncoder.default(self, o) + except TypeError as err: + if isinstance(o, dict): + data = {} + for (k, v) in o.items(): + if isinstance(k, str): + data[k] = self.default(v) + else: + data[repr(k)] = self.default(v) + return data + elif isinstance(o, weakref.ref): + return repr(o) + elif type(o).__module__.split('.')[0] in self.BLOCKED_MODULES: + return repr(o) + elif isinstance(o, type): + return {'$typename': o.__module__ + '.' + o.__name__} + elif isinstance(o, np.number): + # For some reason these aren't natively considered + # serializable. + # JSON doesn't actually have ints, so always use a float. + return float(o) + elif isinstance(o, np.bool8): + return bool(o) + elif isinstance(o, enum.Enum): + return { + '$enum': + o.__module__ + '.' + o.__class__.__name__ + '.' + o.name + } + elif isinstance(o, np.ndarray): + return repr(o) + elif hasattr(o, '__dict__') or hasattr(o, '__slots__'): + obj_dict = getattr(o, '__dict__', None) + if obj_dict is not None: + data = {k: self.default(v) for (k, v) in obj_dict.items()} + else: + data = { + s: self.default(getattr(o, s)) + for s in o.__slots__ + } + t = type(o) + data['$type'] = t.__module__ + '.' + t.__name__ + return data + elif callable(o) and hasattr(o, '__name__'): + if getattr(o, '__module__', None) is not None: + return {'$function': o.__module__ + '.' + o.__name__} + else: + return repr(o) + else: + try: + # This case handles many built-in datatypes like deques + return [self.default(v) for v in list(o)] + except TypeError: + pass + try: + # This case handles most other weird objects. + return repr(o) + except TypeError: + pass + raise err diff --git a/src/garage/trainer.py b/src/garage/trainer.py index e7b0daf94c..4e4a558ece 100644 --- a/src/garage/trainer.py +++ b/src/garage/trainer.py @@ -9,6 +9,7 @@ # This is avoiding a circular import from garage.experiment.deterministic import get_seed, set_seed +from garage.experiment.experiment import dump_json from garage.experiment.snapshotter import Snapshotter from garage.sampler.default_worker import DefaultWorker from garage.sampler.worker_factory import WorkerFactory @@ -518,6 +519,10 @@ def train(self, self._plot = plot self._start_worker() + log_dir = self._snapshotter.snapshot_dir + summary_file = os.path.join(log_dir, 'experiment.json') + dump_json(summary_file, self) + average_return = self._algo.train(self) self._shutdown_worker()