Module deepcomp.util.callbacks

Expand source code
from typing import Dict

from ray.rllib import Policy
from ray.rllib.agents.callbacks import DefaultCallbacks
from ray.rllib.env import BaseEnv
from ray.rllib.evaluation import MultiAgentEpisode, RolloutWorker
from ray.rllib.utils.typing import PolicyID


class CustomMetricCallbacks(DefaultCallbacks):
    """
    Callbacks for including custom scalar metrics for monitoring with tensorboard
    https://docs.ray.io/en/latest/rllib-training.html#callbacks-and-custom-metrics
    """
    @staticmethod
    def get_info(base_env, episode):
        """Return the info dict for the given base_env and episode"""
        # different treatment for MultiAgentEnv where we need to get the info dict from a specific UE
        if hasattr(base_env, 'envs'):
            # get the info dict for the first UE (it's the same for all)
            ue_id = base_env.envs[0].ue_list[0].id
            info = episode.last_info_for(ue_id)
        else:
            info = episode.last_info_for()
        return info

    def on_episode_step(self, *, worker: "RolloutWorker", base_env: BaseEnv,
                        episode: MultiAgentEpisode, env_index: int, **kwargs):
        info = self.get_info(base_env, episode)
        # add all custom scalar metrics in the info dict
        if info is not None and 'scalar_metrics' in info:
            for metric_name, metric_value in info['scalar_metrics'].items():
                episode.custom_metrics[metric_name] = metric_value

                # increment (or init) the sum over all time steps inside the episode
                eps_metric_name = f'eps_{metric_name}'
                if eps_metric_name in episode.user_data:
                    episode.user_data[eps_metric_name] += metric_value
                else:
                    episode.user_data[eps_metric_name] = metric_value

    def on_episode_end(self, *, worker: "RolloutWorker", base_env: BaseEnv,
                       policies: Dict[PolicyID, Policy],
                       episode: MultiAgentEpisode, env_index: int, **kwargs):
        # log the sum of scalar metrics over an episode as metric
        for key, value in episode.user_data.items():
            episode.custom_metrics[key] = value

Classes

class CustomMetricCallbacks (legacy_callbacks_dict: Dict[str, ] = None)

Callbacks for including custom scalar metrics for monitoring with tensorboard https://docs.ray.io/en/latest/rllib-training.html#callbacks-and-custom-metrics

Expand source code
class CustomMetricCallbacks(DefaultCallbacks):
    """
    Callbacks for including custom scalar metrics for monitoring with tensorboard
    https://docs.ray.io/en/latest/rllib-training.html#callbacks-and-custom-metrics
    """
    @staticmethod
    def get_info(base_env, episode):
        """Return the info dict for the given base_env and episode"""
        # different treatment for MultiAgentEnv where we need to get the info dict from a specific UE
        if hasattr(base_env, 'envs'):
            # get the info dict for the first UE (it's the same for all)
            ue_id = base_env.envs[0].ue_list[0].id
            info = episode.last_info_for(ue_id)
        else:
            info = episode.last_info_for()
        return info

    def on_episode_step(self, *, worker: "RolloutWorker", base_env: BaseEnv,
                        episode: MultiAgentEpisode, env_index: int, **kwargs):
        info = self.get_info(base_env, episode)
        # add all custom scalar metrics in the info dict
        if info is not None and 'scalar_metrics' in info:
            for metric_name, metric_value in info['scalar_metrics'].items():
                episode.custom_metrics[metric_name] = metric_value

                # increment (or init) the sum over all time steps inside the episode
                eps_metric_name = f'eps_{metric_name}'
                if eps_metric_name in episode.user_data:
                    episode.user_data[eps_metric_name] += metric_value
                else:
                    episode.user_data[eps_metric_name] = metric_value

    def on_episode_end(self, *, worker: "RolloutWorker", base_env: BaseEnv,
                       policies: Dict[PolicyID, Policy],
                       episode: MultiAgentEpisode, env_index: int, **kwargs):
        # log the sum of scalar metrics over an episode as metric
        for key, value in episode.user_data.items():
            episode.custom_metrics[key] = value

Ancestors

  • ray.rllib.agents.callbacks.DefaultCallbacks

Static methods

def get_info(base_env, episode)

Return the info dict for the given base_env and episode

Expand source code
@staticmethod
def get_info(base_env, episode):
    """Return the info dict for the given base_env and episode"""
    # different treatment for MultiAgentEnv where we need to get the info dict from a specific UE
    if hasattr(base_env, 'envs'):
        # get the info dict for the first UE (it's the same for all)
        ue_id = base_env.envs[0].ue_list[0].id
        info = episode.last_info_for(ue_id)
    else:
        info = episode.last_info_for()
    return info

Methods

def on_episode_end(self, *, worker: RolloutWorker, base_env: ray.rllib.env.base_env.BaseEnv, policies: Dict[str, ray.rllib.policy.policy.Policy], episode: ray.rllib.evaluation.episode.MultiAgentEpisode, env_index: int, **kwargs)

Runs when an episode is done.

Args

worker : RolloutWorker
Reference to the current rollout worker.
base_env : BaseEnv
BaseEnv running the episode. The underlying env object can be gotten by calling base_env.get_unwrapped().
policies : dict
Mapping of policy id to policy objects. In single agent mode there will only be a single "default" policy.
episode : MultiAgentEpisode
Episode object which contains episode state. You can use the episode.user_data dict to store temporary data, and episode.custom_metrics to store custom metrics for the episode.
env_index : int
The index of the (vectorized) env, which the episode belongs to.
kwargs
Forward compatibility placeholder.
Expand source code
def on_episode_end(self, *, worker: "RolloutWorker", base_env: BaseEnv,
                   policies: Dict[PolicyID, Policy],
                   episode: MultiAgentEpisode, env_index: int, **kwargs):
    # log the sum of scalar metrics over an episode as metric
    for key, value in episode.user_data.items():
        episode.custom_metrics[key] = value
def on_episode_step(self, *, worker: RolloutWorker, base_env: ray.rllib.env.base_env.BaseEnv, episode: ray.rllib.evaluation.episode.MultiAgentEpisode, env_index: int, **kwargs)

Runs on each episode step.

Args

worker : RolloutWorker
Reference to the current rollout worker.
base_env : BaseEnv
BaseEnv running the episode. The underlying env object can be gotten by calling base_env.get_unwrapped().
episode : MultiAgentEpisode
Episode object which contains episode state. You can use the episode.user_data dict to store temporary data, and episode.custom_metrics to store custom metrics for the episode.
env_index : int
The index of the (vectorized) env, which the episode belongs to.
kwargs
Forward compatibility placeholder.
Expand source code
def on_episode_step(self, *, worker: "RolloutWorker", base_env: BaseEnv,
                    episode: MultiAgentEpisode, env_index: int, **kwargs):
    info = self.get_info(base_env, episode)
    # add all custom scalar metrics in the info dict
    if info is not None and 'scalar_metrics' in info:
        for metric_name, metric_value in info['scalar_metrics'].items():
            episode.custom_metrics[metric_name] = metric_value

            # increment (or init) the sum over all time steps inside the episode
            eps_metric_name = f'eps_{metric_name}'
            if eps_metric_name in episode.user_data:
                episode.user_data[eps_metric_name] += metric_value
            else:
                episode.user_data[eps_metric_name] = metric_value