spinup.utils.logx 源代码


Some simple logging functionality, inspired by rllab's logging.

Logs to a tab-separated-values file (path/to/output_directory/progress.txt)

import json
import joblib
import shutil
import numpy as np
import tensorflow as tf
import os.path as osp, time, atexit, os
from spinup.utils.mpi_tools import proc_id, mpi_statistics_scalar
from spinup.utils.serialization_utils import convert_json

color2num = dict(

def colorize(string, color, bold=False, highlight=False):
    Colorize a string.

    This function was originally written by John Schulman.
    attr = []
    num = color2num[color]
    if highlight: num += 10
    if bold: attr.append('1')
    return '\x1b[%sm%s\x1b[0m' % (';'.join(attr), string)

[文档]def restore_tf_graph(sess, fpath): """ Loads graphs saved by Logger. Will output a dictionary whose keys and values are from the 'inputs' and 'outputs' dict you specified with logger.setup_tf_saver(). Args: sess: A Tensorflow session. fpath: Filepath to save directory. Returns: A dictionary mapping from keys to tensors in the computation graph loaded from ``fpath``. """ tf.saved_model.loader.load( sess, [tf.saved_model.tag_constants.SERVING], fpath ) model_info = joblib.load(osp.join(fpath, 'model_info.pkl')) graph = tf.get_default_graph() model = dict() model.update({k: graph.get_tensor_by_name(v) for k,v in model_info['inputs'].items()}) model.update({k: graph.get_tensor_by_name(v) for k,v in model_info['outputs'].items()}) return model
[文档]class Logger: """ A general-purpose logger. Makes it easy to save diagnostics, hyperparameter configurations, the state of a training run, and the trained model. """
[文档] def __init__(self, output_dir=None, output_fname='progress.txt', exp_name=None): """ Initialize a Logger. Args: output_dir (string): A directory for saving results to. If ``None``, defaults to a temp directory of the form ``/tmp/experiments/somerandomnumber``. output_fname (string): Name for the tab-separated-value file containing metrics logged throughout a training run. Defaults to ``progress.txt``. exp_name (string): Experiment name. If you run multiple training runs and give them all the same ``exp_name``, the plotter will know to group them. (Use case: if you run the same hyperparameter configuration with multiple random seeds, you should give them all the same ``exp_name``.) """ if proc_id()==0: self.output_dir = output_dir or "/tmp/experiments/%i"%int(time.time()) if osp.exists(self.output_dir): print("Warning: Log dir %s already exists! Storing info there anyway."%self.output_dir) else: os.makedirs(self.output_dir) self.output_file = open(osp.join(self.output_dir, output_fname), 'w') atexit.register(self.output_file.close) print(colorize("Logging data to %s"%self.output_file.name, 'green', bold=True)) else: self.output_dir = None self.output_file = None self.first_row=True self.log_headers = [] self.log_current_row = {} self.exp_name = exp_name
[文档] def log(self, msg, color='green'): """Print a colorized message to stdout.""" if proc_id()==0: print(colorize(msg, color, bold=True))
[文档] def log_tabular(self, key, val): """ Log a value of some diagnostic. Call this only once for each diagnostic quantity, each iteration. After using ``log_tabular`` to store values for each diagnostic, make sure to call ``dump_tabular`` to write them out to file and stdout (otherwise they will not get saved anywhere). """ if self.first_row: self.log_headers.append(key) else: assert key in self.log_headers, "Trying to introduce a new key %s that you didn't include in the first iteration"%key assert key not in self.log_current_row, "You already set %s this iteration. Maybe you forgot to call dump_tabular()"%key self.log_current_row[key] = val
[文档] def save_config(self, config): """ Log an experiment configuration. Call this once at the top of your experiment, passing in all important config vars as a dict. This will serialize the config to JSON, while handling anything which can't be serialized in a graceful way (writing as informative a string as possible). Example use: .. code-block:: python logger = EpochLogger(**logger_kwargs) logger.save_config(locals()) """ config_json = convert_json(config) if self.exp_name is not None: config_json['exp_name'] = self.exp_name if proc_id()==0: output = json.dumps(config_json, separators=(',',':\t'), indent=4, sort_keys=True) print(colorize('Saving config:\n', color='cyan', bold=True)) print(output) with open(osp.join(self.output_dir, "config.json"), 'w') as out: out.write(output)
[文档] def save_state(self, state_dict, itr=None): """ Saves the state of an experiment. To be clear: this is about saving *state*, not logging diagnostics. All diagnostic logging is separate from this function. This function will save whatever is in ``state_dict``---usually just a copy of the environment---and the most recent parameters for the model you previously set up saving for with ``setup_tf_saver``. Call with any frequency you prefer. If you only want to maintain a single state and overwrite it at each call with the most recent version, leave ``itr=None``. If you want to keep all of the states you save, provide unique (increasing) values for 'itr'. Args: state_dict (dict): Dictionary containing essential elements to describe the current state of training. itr: An int, or None. Current iteration of training. """ if proc_id()==0: fname = 'vars.pkl' if itr is None else 'vars%d.pkl'%itr try: joblib.dump(state_dict, osp.join(self.output_dir, fname)) except: self.log('Warning: could not pickle state_dict.', color='red') if hasattr(self, 'tf_saver_elements'): self._tf_simple_save(itr)
[文档] def setup_tf_saver(self, sess, inputs, outputs): """ Set up easy model saving for tensorflow. Call once, after defining your computation graph but before training. Args: sess: The Tensorflow session in which you train your computation graph. inputs (dict): A dictionary that maps from keys of your choice to the tensorflow placeholders that serve as inputs to the computation graph. Make sure that *all* of the placeholders needed for your outputs are included! outputs (dict): A dictionary that maps from keys of your choice to the outputs from your computation graph. """ self.tf_saver_elements = dict(session=sess, inputs=inputs, outputs=outputs) self.tf_saver_info = {'inputs': {k:v.name for k,v in inputs.items()}, 'outputs': {k:v.name for k,v in outputs.items()}}
def _tf_simple_save(self, itr=None): """ Uses simple_save to save a trained model, plus info to make it easy to associated tensors to variables after restore. """ if proc_id()==0: assert hasattr(self, 'tf_saver_elements'), \ "First have to setup saving with self.setup_tf_saver" fpath = 'simple_save' + ('%d'%itr if itr is not None else '') fpath = osp.join(self.output_dir, fpath) if osp.exists(fpath): # simple_save refuses to be useful if fpath already exists, # so just delete fpath if it's there. shutil.rmtree(fpath) tf.saved_model.simple_save(export_dir=fpath, **self.tf_saver_elements) joblib.dump(self.tf_saver_info, osp.join(fpath, 'model_info.pkl'))
[文档] def dump_tabular(self): """ Write all of the diagnostics from the current iteration. Writes both to stdout, and to the output file. """ if proc_id()==0: vals = [] key_lens = [len(key) for key in self.log_headers] max_key_len = max(15,max(key_lens)) keystr = '%'+'%d'%max_key_len fmt = "| " + keystr + "s | %15s |" n_slashes = 22 + max_key_len print("-"*n_slashes) for key in self.log_headers: val = self.log_current_row.get(key, "") valstr = "%8.3g"%val if hasattr(val, "__float__") else val print(fmt%(key, valstr)) vals.append(val) print("-"*n_slashes) if self.output_file is not None: if self.first_row: self.output_file.write("\t".join(self.log_headers)+"\n") self.output_file.write("\t".join(map(str,vals))+"\n") self.output_file.flush() self.log_current_row.clear() self.first_row=False
[文档]class EpochLogger(Logger): """ A variant of Logger tailored for tracking average values over epochs. Typical use case: there is some quantity which is calculated many times throughout an epoch, and at the end of the epoch, you would like to report the average / std / min / max value of that quantity. With an EpochLogger, each time the quantity is calculated, you would use .. code-block:: python epoch_logger.store(NameOfQuantity=quantity_value) to load it into the EpochLogger's state. Then at the end of the epoch, you would use .. code-block:: python epoch_logger.log_tabular(NameOfQuantity, **options) to record the desired values. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.epoch_dict = dict()
[文档] def store(self, **kwargs): """ Save something into the epoch_logger's current state. Provide an arbitrary number of keyword arguments with numerical values. """ for k,v in kwargs.items(): if not(k in self.epoch_dict.keys()): self.epoch_dict[k] = [] self.epoch_dict[k].append(v)
[文档] def log_tabular(self, key, val=None, with_min_and_max=False, average_only=False): """ Log a value or possibly the mean/std/min/max values of a diagnostic. Args: key (string): The name of the diagnostic. If you are logging a diagnostic whose state has previously been saved with ``store``, the key here has to match the key you used there. val: A value for the diagnostic. If you have previously saved values for this key via ``store``, do *not* provide a ``val`` here. with_min_and_max (bool): If true, log min and max values of the diagnostic over the epoch. average_only (bool): If true, do not log the standard deviation of the diagnostic over the epoch. """ if val is not None: super().log_tabular(key,val) else: v = self.epoch_dict[key] vals = np.concatenate(v) if isinstance(v[0], np.ndarray) and len(v[0].shape)>0 else v stats = mpi_statistics_scalar(vals, with_min_and_max=with_min_and_max) super().log_tabular(key if average_only else 'Average' + key, stats[0]) if not(average_only): super().log_tabular('Std'+key, stats[1]) if with_min_and_max: super().log_tabular('Max'+key, stats[3]) super().log_tabular('Min'+key, stats[2]) self.epoch_dict[key] = []
[文档] def get_stats(self, key): """ Lets an algorithm ask the logger for mean/std/min/max of a diagnostic. """ v = self.epoch_dict[key] vals = np.concatenate(v) if isinstance(v[0], np.ndarray) and len(v[0].shape)>0 else v return mpi_statistics_scalar(vals)