Module deepcomp.util.callbacks
Expand source code
from typing import Dict
from ray.rllib import Policy
from ray.rllib.agents.callbacks import DefaultCallbacks
from ray.rllib.env import BaseEnv
from ray.rllib.evaluation import MultiAgentEpisode, RolloutWorker
from ray.rllib.utils.typing import PolicyID
class CustomMetricCallbacks(DefaultCallbacks):
"""
Callbacks for including custom scalar metrics for monitoring with tensorboard
https://docs.ray.io/en/latest/rllib-training.html#callbacks-and-custom-metrics
"""
@staticmethod
def get_info(base_env, episode):
"""Return the info dict for the given base_env and episode"""
# different treatment for MultiAgentEnv where we need to get the info dict from a specific UE
if hasattr(base_env, 'envs'):
# get the info dict for the first UE (it's the same for all)
ue_id = base_env.envs[0].ue_list[0].id
info = episode.last_info_for(ue_id)
else:
info = episode.last_info_for()
return info
def on_episode_step(self, *, worker: "RolloutWorker", base_env: BaseEnv,
episode: MultiAgentEpisode, env_index: int, **kwargs):
info = self.get_info(base_env, episode)
# add all custom scalar metrics in the info dict
if info is not None and 'scalar_metrics' in info:
for metric_name, metric_value in info['scalar_metrics'].items():
episode.custom_metrics[metric_name] = metric_value
# increment (or init) the sum over all time steps inside the episode
eps_metric_name = f'eps_{metric_name}'
if eps_metric_name in episode.user_data:
episode.user_data[eps_metric_name] += metric_value
else:
episode.user_data[eps_metric_name] = metric_value
def on_episode_end(self, *, worker: "RolloutWorker", base_env: BaseEnv,
policies: Dict[PolicyID, Policy],
episode: MultiAgentEpisode, env_index: int, **kwargs):
# log the sum of scalar metrics over an episode as metric
for key, value in episode.user_data.items():
episode.custom_metrics[key] = value
Classes
class CustomMetricCallbacks (legacy_callbacks_dict: Dict[str,
] = None) -
Callbacks for including custom scalar metrics for monitoring with tensorboard https://docs.ray.io/en/latest/rllib-training.html#callbacks-and-custom-metrics
Expand source code
class CustomMetricCallbacks(DefaultCallbacks): """ Callbacks for including custom scalar metrics for monitoring with tensorboard https://docs.ray.io/en/latest/rllib-training.html#callbacks-and-custom-metrics """ @staticmethod def get_info(base_env, episode): """Return the info dict for the given base_env and episode""" # different treatment for MultiAgentEnv where we need to get the info dict from a specific UE if hasattr(base_env, 'envs'): # get the info dict for the first UE (it's the same for all) ue_id = base_env.envs[0].ue_list[0].id info = episode.last_info_for(ue_id) else: info = episode.last_info_for() return info def on_episode_step(self, *, worker: "RolloutWorker", base_env: BaseEnv, episode: MultiAgentEpisode, env_index: int, **kwargs): info = self.get_info(base_env, episode) # add all custom scalar metrics in the info dict if info is not None and 'scalar_metrics' in info: for metric_name, metric_value in info['scalar_metrics'].items(): episode.custom_metrics[metric_name] = metric_value # increment (or init) the sum over all time steps inside the episode eps_metric_name = f'eps_{metric_name}' if eps_metric_name in episode.user_data: episode.user_data[eps_metric_name] += metric_value else: episode.user_data[eps_metric_name] = metric_value def on_episode_end(self, *, worker: "RolloutWorker", base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: MultiAgentEpisode, env_index: int, **kwargs): # log the sum of scalar metrics over an episode as metric for key, value in episode.user_data.items(): episode.custom_metrics[key] = value
Ancestors
- ray.rllib.agents.callbacks.DefaultCallbacks
Static methods
def get_info(base_env, episode)
-
Return the info dict for the given base_env and episode
Expand source code
@staticmethod def get_info(base_env, episode): """Return the info dict for the given base_env and episode""" # different treatment for MultiAgentEnv where we need to get the info dict from a specific UE if hasattr(base_env, 'envs'): # get the info dict for the first UE (it's the same for all) ue_id = base_env.envs[0].ue_list[0].id info = episode.last_info_for(ue_id) else: info = episode.last_info_for() return info
Methods
def on_episode_end(self, *, worker: RolloutWorker, base_env: ray.rllib.env.base_env.BaseEnv, policies: Dict[str, ray.rllib.policy.policy.Policy], episode: ray.rllib.evaluation.episode.MultiAgentEpisode, env_index: int, **kwargs)
-
Runs when an episode is done.
Args
worker
:RolloutWorker
- Reference to the current rollout worker.
base_env
:BaseEnv
- BaseEnv running the episode. The underlying env object can be gotten by calling base_env.get_unwrapped().
policies
:dict
- Mapping of policy id to policy objects. In single agent mode there will only be a single "default" policy.
episode
:MultiAgentEpisode
- Episode object which contains episode
state. You can use the
episode.user_data
dict to store temporary data, andepisode.custom_metrics
to store custom metrics for the episode. env_index
:int
- The index of the (vectorized) env, which the episode belongs to.
kwargs
- Forward compatibility placeholder.
Expand source code
def on_episode_end(self, *, worker: "RolloutWorker", base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: MultiAgentEpisode, env_index: int, **kwargs): # log the sum of scalar metrics over an episode as metric for key, value in episode.user_data.items(): episode.custom_metrics[key] = value
def on_episode_step(self, *, worker: RolloutWorker, base_env: ray.rllib.env.base_env.BaseEnv, episode: ray.rllib.evaluation.episode.MultiAgentEpisode, env_index: int, **kwargs)
-
Runs on each episode step.
Args
worker
:RolloutWorker
- Reference to the current rollout worker.
base_env
:BaseEnv
- BaseEnv running the episode. The underlying env object can be gotten by calling base_env.get_unwrapped().
episode
:MultiAgentEpisode
- Episode object which contains episode
state. You can use the
episode.user_data
dict to store temporary data, andepisode.custom_metrics
to store custom metrics for the episode. env_index
:int
- The index of the (vectorized) env, which the episode belongs to.
kwargs
- Forward compatibility placeholder.
Expand source code
def on_episode_step(self, *, worker: "RolloutWorker", base_env: BaseEnv, episode: MultiAgentEpisode, env_index: int, **kwargs): info = self.get_info(base_env, episode) # add all custom scalar metrics in the info dict if info is not None and 'scalar_metrics' in info: for metric_name, metric_value in info['scalar_metrics'].items(): episode.custom_metrics[metric_name] = metric_value # increment (or init) the sum over all time steps inside the episode eps_metric_name = f'eps_{metric_name}' if eps_metric_name in episode.user_data: episode.user_data[eps_metric_name] += metric_value else: episode.user_data[eps_metric_name] = metric_value