Source code for trojai_rl.modelgen.optimizer_interface

import logging
from abc import ABC, abstractmethod
from typing import Any

from .statistics import TestStatistics, TrainingStatistics
from ..datagen.environment_factory import EnvironmentFactory

logger = logging.getLogger(__name__)


[docs]class RLOptimizerInterface(ABC): """Object that performs training and testing of TrojAI RL models."""
[docs] @abstractmethod def train(self, model: Any, env_factory: EnvironmentFactory) -> (Any, TrainingStatistics): """ Train the given model using parameters in self.training_params :param model: (Any) The untrained model :param env_factory: (EnvironmentFactory) :return: (Any, TrainingStatistics) trained model and TrainingStatistics object """ pass
[docs] @abstractmethod def test(self, model: Any, env_factory: EnvironmentFactory) -> TestStatistics: """ Perform whatever tests desired on the model with clean data and triggered data, return a dictionary of results. :param model: (Any) Trained model :param env_factory: (EnvironmentFactory) :return: (Any, TestStatistics) a TestStatistics object """ pass
[docs] @abstractmethod def get_device_type(self) -> str: """ Return a string representation of the type of device used by the optimizer to train the model. """ pass
[docs] @abstractmethod def get_cfg_as_dict(self) -> dict: """ Return a dictionary with key/value pairs that describe the parameters used to train the model. """ pass
@abstractmethod def __deepcopy__(self, memodict={}): """ Required for training on clusters. Return a deep copy of the optimizer. """ pass @abstractmethod def __eq__(self, other): """ Required for training on clusters. Define how to check if two optimizers are equal. """ pass @abstractmethod def __str__(self): pass
[docs] @abstractmethod def save(self, fname: str) -> None: """ Save the optimizer to a file :param fname - the filename to save the optimizer to """ pass
[docs] @staticmethod @abstractmethod def load(fname: str): """ Load an optimizer from disk and return it :param fname: the filename where the optimizer is serialized :return: The loaded optimizer """ pass