Source code for trojai_rl.modelgen.config

import logging
import os
import re
import uuid
from typing import Any

import torch.nn as nn

from .optimizer_interface import RLOptimizerInterface
from ..datagen.environment_factory import EnvironmentFactory
from .utils import is_jsonable

logger = logging.getLogger(__name__)
SUPPORTED_TRAINING_ALGOS = ['ppo']


[docs]class RLOptimizerConfig: """ Defines configuration parameters for RL training """ def __init__(self, algorithm: str = 'ppo', num_frames: int = int(8e6), max_num_frames_rollout: int = 128, num_epochs: int = 1000, device: str = 'cuda', num_frames_per_test: int = int(5e5), learning_rate: float = 1e-3): self.algorithm = algorithm self.num_frames = num_frames self.max_num_frames_rollout = max_num_frames_rollout self.num_epochs = num_epochs self.device = device self.num_frames_per_test = num_frames_per_test self.learning_rate = learning_rate self.validate()
[docs] def validate(self): if not isinstance(self.algorithm, str) or self.algorithm not in SUPPORTED_TRAINING_ALGOS: msg = "algorithm input must be a string, and one of:" + str(SUPPORTED_TRAINING_ALGOS) logger.error(msg) raise ValueError(msg) if not isinstance(self.num_frames, int) or self.num_frames < 1: msg = "num_frames must be at least 1!" logger.error(msg) raise ValueError(msg) if not isinstance(self.max_num_frames_rollout, int) or self.max_num_frames_rollout < 1: msg = "max_num_frames_rollout must be an integer > 0" logger.error(msg) raise ValueError(msg) if not isinstance(self.num_epochs, int) or self.num_epochs < 1: msg = "num_epochs must be an integer > 0" logger.error(msg) raise ValueError(msg) if not isinstance(self.device, str) or ( self.device != 'cpu' and self.device != 'cuda' or not re.match(r'cuda:\d', self.device)): msg = "device specification must be a string: either cpu, cuda, cuda:#, where # is an integer >= 0" logger.error(msg) raise ValueError(msg) if self.num_frames_per_test is not None and \ (not isinstance(self.num_frames_per_test, int) or self.num_frames_per_test < 1): msg = "num_frames_per_test must be an integer > 1" logger.error(msg) raise ValueError(msg) if not isinstance(self.learning_rate, float) or self.learning_rate <= 0: msg = "learning_rate must be a float > 0" logger.error(msg) raise ValueError(msg)
[docs]class RunnerConfig: """ Defines a runner configuration object, required to configure a Runner to train RL models """ def __init__(self, train_env_factory: EnvironmentFactory, test_env_factory: EnvironmentFactory, trainable_model: nn.Module, optimizer: RLOptimizerInterface, parallel: bool = False, model_save_dir: str = "/tmp/models", stats_save_dir: str = "/tmp/model_stats", run_id: Any = None, filename: str = None, save_with_hash: bool = False, save_info: dict = None): """ Initializes the RunnerConfig object :param train_env_factory: (EnvironmentFactory) environment factory for producing training environments :param test_env_factory: (EnvironmentFactory) similar to train_env_factory, but for the test environments :param trainable_model: (nn.Module) model to be trained :param optimizer: (RLOptimizerInterface) RLOptimizerInterface object that will be used to train and test the model :param parallel: (bool) Whether to run training in parallel. Note: while currently unused by current optimizers, we expect this to provide additional instruction about parallelization to the optimizer if implemented :param model_save_dir: (str) path to folder where models should be saved :param stats_save_dir: (str) path to folder where train/test stats should be saved :param run_id: (int) optional id to use to identify this run :param filename: (str) filename under which to save the model and stats :param save_with_hash: (bool) save the model and stats under a hash :param save_info: (dict) optional dictionary of json serializable information to save with train and test stats """ self.train_env_factory = train_env_factory self.test_env_factory = test_env_factory self.trainable_model = trainable_model self.optimizer = optimizer self.parallel = parallel self.model_save_dir = model_save_dir self.stats_save_dir = stats_save_dir self.run_id = run_id self.filename = filename self.save_with_hash = save_with_hash self.save_info = save_info self.validate() # quick, hack-y way to do this, may need to be updated if doesn't work later; should maybe go in runner? if self.save_with_hash: self.filename += '.' + str(uuid.uuid1().hex)
[docs] def validate(self): if not type(self.model_save_dir) == str: msg = "Expected type 'string' for argument 'model_save_dir, instead got type: " \ "{}".format(type(self.model_save_dir)) logger.error(msg) raise TypeError(msg) if not os.path.isdir(self.model_save_dir): try: os.makedirs(self.model_save_dir) except OSError as e: # not sure this error is possible as written msg = "'model_save_dir' was not found and could not be created" \ "...\n{}".format(e.__traceback__) logger.error(msg) raise OSError(msg) if not os.path.isdir(self.stats_save_dir): try: os.makedirs(self.stats_save_dir) except OSError as e: # not sure this error is possible as written msg = "'stats_save_dir' was not found and could not be created" \ "...\n{}".format(e.__traceback__) logger.error(msg) raise OSError(msg) if not type(self.filename) == str: msg = "Expected a string for argument 'filename', instead got " \ "type {}".format(type(self.filename)) logger.error(msg) raise TypeError(msg) if not isinstance(self.save_with_hash, bool): msg = "Expected boolean for argument save_with_hash" logger.error(msg) raise TypeError(msg) if not isinstance(self.save_info, dict): msg = "Expected type 'dict' for argument 'save_info', instead got type {}".format(type(self.save_info)) logger.error(msg) raise TypeError(msg) if not is_jsonable(self.save_info): msg = "Argument 'save_info', must be json serializable." logger.error(msg) raise TypeError(msg) if not isinstance(self.parallel, bool): msg = "Expected boolean for argument 'parallel', instead got type {}".format(type(self.parallel)) logger.error(msg) raise TypeError(msg)
[docs]class TestConfig: def __init__(self, environment_cfg: Any, count: int = 100, test_description: dict = None, agent_argmax_action: bool = False): """ Test configuration specification for a single run of an agent through an environment. :param environment_cfg: (Any) This is whatever should be passed to the environment factory to instantiate an environment. :param count: (int) Number of times to run the agent through the environment. :param test_description: (dict) A dictionary of key, value pairs providing information about the test; currently the only required key is 'poison', whose value should be a string describing the poison strategy or 'clean' for no poison. This is also used to save data, and should be mutable to include any information desired to be used or saved with the test results. :param agent_argmax_action: (bool) Have the agent choose the argmax of it policy distribution. torch_ac has the model return a distribution from which it samples an action, set this to True to instead choose the agent's highest confidence action. """ self.env_cfg = environment_cfg self.count = count self.desc = test_description self.argmax_action = agent_argmax_action
[docs] def get_environment_cfg(self): return self.env_cfg
[docs] def get_count(self): return self.count
[docs] def get_description(self): return self.desc
[docs] def get_argmax_action(self): return self.argmax_action
[docs] def validate(self): if type(self.count) != int or self.count < 1: msg = "count must be an integer greater than 0, got {}".format(self.count) logger.error(msg) raise ValueError(msg) if type(self.desc) != dict or 'poison' not in self.desc.keys(): msg = "test_description must be a dictionary at least containing the key 'poison'" logger.error(msg) raise ValueError(msg) if not isinstance(self.argmax_action, bool): msg = "argmax_action must be a bool!" logger.error(msg) raise ValueError(msg)