POMDPPlanners.training package

Submodules

POMDPPlanners.training.callbacks module

Trainer callbacks for policy training loops.

This module provides a callback interface and concrete implementations for monitoring and controlling the PolicyTrainer training loop.

Classes:

TrainerCallback: Abstract base with default no-op hooks. EarlyStopping: Stops training after patience iterations without improvement. ModelCheckpoint: Saves the policy on metric improvement (or every iteration). OptunaPruning: Reports metrics to an Optuna trial and prunes when appropriate. TensorBoardCallback: Logs training metrics and weight histograms to TensorBoard.

class POMDPPlanners.training.callbacks.EarlyStopping(monitor, patience=5, mode='min')[source]

Bases: TrainerCallback

Stop training when a monitored metric stops improving.

Parameters:
monitor

Metric key to watch (e.g. "total_loss").

patience

Number of iterations with no improvement before stopping.

mode

"min" to stop when metric stops decreasing, "max" for increasing.

on_iteration_end(trainer, iteration, metrics)[source]

Called at the end of every iteration.

Parameters:
  • trainer (PolicyTrainer) – The running trainer instance.

  • iteration (int) – Zero-based iteration index.

  • metrics (Dict[str, List[float]]) – Loss metrics returned by the current train_step.

Return type:

Optional[bool]

Returns:

True to request early stopping; None or False to continue.

class POMDPPlanners.training.callbacks.ModelCheckpoint(filepath, monitor='total_loss', mode='min', save_every=False)[source]

Bases: TrainerCallback

Save the policy whenever a monitored metric improves.

Parameters:
filepath

Directory where checkpoints are saved.

monitor

Metric key to watch.

mode

"min" or "max".

save_every

When True, save after every iteration regardless of improvement.

on_iteration_end(trainer, iteration, metrics)[source]

Called at the end of every iteration.

Parameters:
  • trainer (PolicyTrainer) – The running trainer instance.

  • iteration (int) – Zero-based iteration index.

  • metrics (Dict[str, List[float]]) – Loss metrics returned by the current train_step.

Return type:

Optional[bool]

Returns:

True to request early stopping; None or False to continue.

class POMDPPlanners.training.callbacks.OptunaPruning(trial, monitor='total_loss')[source]

Bases: TrainerCallback

Report metrics to an Optuna trial and prune when appropriate.

The optuna package is imported lazily so that it is not a hard dependency of the training module.

Parameters:
trial

An active Optuna Trial object.

monitor

Metric key to report.

on_iteration_end(trainer, iteration, metrics)[source]

Called at the end of every iteration.

Parameters:
  • trainer (PolicyTrainer) – The running trainer instance.

  • iteration (int) – Zero-based iteration index.

  • metrics (Dict[str, List[float]]) – Loss metrics returned by the current train_step.

Return type:

Optional[bool]

Returns:

True to request early stopping; None or False to continue.

class POMDPPlanners.training.callbacks.TensorBoardCallback(log_dir=None, comment='', flush_secs=120, log_histograms=False)[source]

Bases: TrainerCallback

Log training metrics to TensorBoard.

The torch.utils.tensorboard package is imported lazily so it is not a hard startup dependency of the training module.

Parameters:
  • log_dir (str | None)

  • comment (str)

  • flush_secs (int)

  • log_histograms (bool)

log_dir

Directory for TensorBoard event files.

comment

Suffix appended to the auto-generated run directory name.

flush_secs

How often the writer flushes to disk (seconds).

log_histograms

When True and the policy exposes get_network(), logs per-parameter weight histograms each iteration.

Example

>>> from unittest.mock import MagicMock, patch
>>> with patch("torch.utils.tensorboard.SummaryWriter"):
...     cb = TensorBoardCallback(log_dir="/tmp/tb_test")
...     trainer = MagicMock()
...     cb.on_train_begin(trainer)
...     cb.on_train_end(trainer, {})
on_collection_end(trainer, iteration)[source]

Called after episode collection finishes.

Return type:

None

Parameters:
  • trainer (PolicyTrainer)

  • iteration (int)

on_iteration_end(trainer, iteration, metrics)[source]

Called at the end of every iteration.

Parameters:
  • trainer (PolicyTrainer) – The running trainer instance.

  • iteration (int) – Zero-based iteration index.

  • metrics (Dict[str, List[float]]) – Loss metrics returned by the current train_step.

Return type:

Optional[bool]

Returns:

True to request early stopping; None or False to continue.

on_train_begin(trainer)[source]

Called once before the first training iteration.

Return type:

None

Parameters:

trainer (PolicyTrainer)

on_train_end(trainer, all_metrics)[source]

Called once after the last training iteration.

Return type:

None

Parameters:
class POMDPPlanners.training.callbacks.TrainerCallback[source]

Bases: ABC

Abstract base class for training-loop callbacks.

All hooks have default no-op implementations so that subclasses only need to override the methods they care about.

Note

This is an abstract base class. Instantiate one of the concrete subclasses (EarlyStopping, ModelCheckpoint, or OptunaPruning) or write your own.

on_collection_begin(trainer, iteration)[source]

Called before episode collection starts.

Return type:

None

Parameters:
  • trainer (PolicyTrainer)

  • iteration (int)

on_collection_end(trainer, iteration)[source]

Called after episode collection finishes.

Return type:

None

Parameters:
  • trainer (PolicyTrainer)

  • iteration (int)

on_iteration_begin(trainer, iteration)[source]

Called at the start of every iteration.

Return type:

None

Parameters:
  • trainer (PolicyTrainer)

  • iteration (int)

on_iteration_end(trainer, iteration, metrics)[source]

Called at the end of every iteration.

Parameters:
  • trainer (PolicyTrainer) – The running trainer instance.

  • iteration (int) – Zero-based iteration index.

  • metrics (Dict[str, List[float]]) – Loss metrics returned by the current train_step.

Return type:

Optional[bool]

Returns:

True to request early stopping; None or False to continue.

on_train_begin(trainer)[source]

Called once before the first training iteration.

Return type:

None

Parameters:

trainer (PolicyTrainer)

on_train_end(trainer, all_metrics)[source]

Called once after the last training iteration.

Return type:

None

Parameters:

POMDPPlanners.training.policy_trainer module