POMDPPlanners.training package
Submodules
POMDPPlanners.training.callbacks module
Trainer callbacks for policy training loops.
This module provides a callback interface and concrete implementations for
monitoring and controlling the PolicyTrainer
training loop.
- Classes:
TrainerCallback: Abstract base with default no-op hooks. EarlyStopping: Stops training after
patienceiterations without improvement. ModelCheckpoint: Saves the policy on metric improvement (or every iteration). OptunaPruning: Reports metrics to an Optuna trial and prunes when appropriate. TensorBoardCallback: Logs training metrics and weight histograms to TensorBoard.
- class POMDPPlanners.training.callbacks.EarlyStopping(monitor, patience=5, mode='min')[source]
Bases:
TrainerCallbackStop training when a monitored metric stops improving.
- monitor
Metric key to watch (e.g.
"total_loss").
- patience
Number of iterations with no improvement before stopping.
- mode
"min"to stop when metric stops decreasing,"max"for increasing.
- class POMDPPlanners.training.callbacks.ModelCheckpoint(filepath, monitor='total_loss', mode='min', save_every=False)[source]
Bases:
TrainerCallbackSave the policy whenever a monitored metric improves.
- filepath
Directory where checkpoints are saved.
- monitor
Metric key to watch.
- mode
"min"or"max".
- save_every
When
True, save after every iteration regardless of improvement.
- class POMDPPlanners.training.callbacks.OptunaPruning(trial, monitor='total_loss')[source]
Bases:
TrainerCallbackReport metrics to an Optuna trial and prune when appropriate.
The
optunapackage is imported lazily so that it is not a hard dependency of the training module.- trial
An active Optuna
Trialobject.
- monitor
Metric key to report.
- class POMDPPlanners.training.callbacks.TensorBoardCallback(log_dir=None, comment='', flush_secs=120, log_histograms=False)[source]
Bases:
TrainerCallbackLog training metrics to TensorBoard.
The
torch.utils.tensorboardpackage is imported lazily so it is not a hard startup dependency of the training module.- log_dir
Directory for TensorBoard event files.
- comment
Suffix appended to the auto-generated run directory name.
- flush_secs
How often the writer flushes to disk (seconds).
- log_histograms
When
Trueand the policy exposesget_network(), logs per-parameter weight histograms each iteration.
Example
>>> from unittest.mock import MagicMock, patch >>> with patch("torch.utils.tensorboard.SummaryWriter"): ... cb = TensorBoardCallback(log_dir="/tmp/tb_test") ... trainer = MagicMock() ... cb.on_train_begin(trainer) ... cb.on_train_end(trainer, {})
- class POMDPPlanners.training.callbacks.TrainerCallback[source]
Bases:
ABCAbstract base class for training-loop callbacks.
All hooks have default no-op implementations so that subclasses only need to override the methods they care about.
Note
This is an abstract base class. Instantiate one of the concrete subclasses (
EarlyStopping,ModelCheckpoint, orOptunaPruning) or write your own.