POMDPPlanners.planners.mcts_planners.beta_zero package
BetaZero: Neural MCTS for POMDPs.
This package implements the BetaZero algorithm (Moss et al., 2024), which adapts AlphaZero to POMDPs by planning in belief space with learned neural network priors.
- Reference:
Moss, R. J., Corso, A., Caers, J., & Kochenderfer, M. J. (2024). BetaZero: Belief-State Planning for Long-Horizon POMDPs using Learned Approximations. Reinforcement Learning Conference (RLC).
- Classes:
BetaZero: Main planner combining MCTS with neural network value/policy estimates AbstractBetaZeroNetwork: Abstract base class for BetaZero policy and value networks BetaZeroNetwork: Dual-head neural network for policy and value prediction BetaZeroActionSampler: Network-guided action sampling for progressive widening BeliefRepresentation: Abstract belief-to-feature mapping ParticleMeanStdRepresentation: Default belief representation using particle statistics TrainingBuffer: Circular replay buffer for training examples TrainingExample: Single training datum (belief features, policy target, value target)
- class POMDPPlanners.planners.mcts_planners.beta_zero.AbstractBetaZeroNetwork[source]
Bases:
ABCAbstract base class for BetaZero policy and value networks.
Defines the inference and training interface required by the BetaZero planner. Concrete subclasses provide the underlying model architecture.
Note
This is an abstract base class and cannot be instantiated directly.
- abstractmethod fit(buffer, n_epochs, batch_size, learning_rate, weight_decay, track_gradients)[source]
Train on replay buffer and return per-epoch loss metrics.
- class POMDPPlanners.planners.mcts_planners.beta_zero.BeliefRepresentation[source]
Bases:
ABCAbstract base class for mapping beliefs to fixed-size feature vectors.
Subclasses define how a POMDP belief state is compressed into a numerical vector φ(b) ∈ ℝ^d that can be fed to a neural network.
Note
This is an abstract base class and cannot be instantiated directly.
- class POMDPPlanners.planners.mcts_planners.beta_zero.BetaZero(environment, discount_factor, depth, name, action_sampler, k_a=1.0, alpha_a=0.5, k_o=1.0, alpha_o=0.5, exploration_constant=1.0, time_out_in_seconds=None, n_simulations=None, min_visit_count_per_action=1, network=None, belief_representation=None, state_dim=None, z_q=1.0, z_n=1.0, temperature=1.0, n_buffer=1, training_batch_size=256, training_epochs=10, learning_rate=0.001, weight_decay=0.0001, hidden_sizes=(128, 128), track_gradients=False, normalize_inputs=True, normalize_values=True, log_path=None, debug=False, use_queue_logger=False)[source]
Bases:
DoubleProgressiveWideningMCTSPolicy,TrainablePolicyBetaZero: Neural MCTS for POMDPs.
Extends
DoubleProgressiveWideningMCTSPolicywith three key innovations from the BetaZero paper:PUCT selection: Replaces UCB1 using learned policy priors.
Neural value estimation: Replaces random rollouts at leaf nodes.
Policy iteration via ``fit()``: Collects episodes, computes Q-weighted policy targets, and trains the network.
The planner has two modes: - Online planning via
action(belief): builds an MCTS tree withPUCT and network value estimates.
Offline training via
fit(): alternates data collection and network training.
- Parameters:
environment (Environment)
discount_factor (float)
depth (int)
name (str)
action_sampler (ActionSampler)
k_a (float)
alpha_a (float)
k_o (float)
alpha_o (float)
exploration_constant (float)
time_out_in_seconds (int | None)
n_simulations (int | None)
min_visit_count_per_action (int)
network (AbstractBetaZeroNetwork | None)
belief_representation (BeliefRepresentation | None)
state_dim (int | None)
z_q (float)
z_n (float)
temperature (float)
n_buffer (int)
training_batch_size (int)
training_epochs (int)
learning_rate (float)
weight_decay (float)
track_gradients (bool)
normalize_inputs (bool)
normalize_values (bool)
log_path (Path | None)
debug (bool)
use_queue_logger (bool)
- network
Dual-head neural network for policy and value prediction.
- belief_representation
Belief → feature-vector mapping φ(b).
- z_q
Exponent for Q-value term in policy target.
- z_n
Exponent for visit-count term in policy target.
- temperature
Temperature τ for sharpening/smoothing policy target.
Example
>>> import numpy as np >>> np.random.seed(42) >>> from POMDPPlanners.environments.tiger_pomdp import TigerPOMDP >>> from POMDPPlanners.core.belief import get_initial_belief >>> from POMDPPlanners.utils.action_samplers import DiscreteActionSampler >>> from POMDPPlanners.planners.mcts_planners.beta_zero.beta_zero import BetaZero >>> >>> env = TigerPOMDP(discount_factor=0.95) >>> sampler = DiscreteActionSampler(env.get_actions()) >>> planner = BetaZero( ... environment=env, ... discount_factor=0.95, ... depth=3, ... name="BetaZero_Tiger", ... action_sampler=sampler, ... n_simulations=20, ... state_dim=1, ... ) >>> belief = get_initial_belief(env, n_particles=10) >>> actions, run_data = planner.action(belief) >>> actions[0] in env.get_actions() True
- action(belief)[source]
Select an action via MCTS with PUCT and network value estimates.
If data collection is active (during
fit()), also stores a pending training example from the tree root.- Return type:
- Parameters:
belief (Belief)
- collect_episodes_batched(initial_belief_fn, n_episodes, episode_length)[source]
Collect training data using fast batched (network-only) rollouts.
- get_metric_keys()[source]
Return the loss-metric key names produced by
train_step().
- get_network()[source]
Return the underlying trainable network, or
Noneif not applicable.Override in concrete policies to enable weight-histogram logging in
TensorBoardCallback.- Return type:
- classmethod get_space_info()[source]
Get information about action and observation spaces.
Default implementation returns MIXED space types, which is appropriate for most progressive widening MCTS planners that support both discrete and continuous action spaces through the action sampler interface.
Subclasses can override this method to specify different space requirements (e.g., PFT_DPW specifies CONTINUOUS action space).
- Return type:
- Returns:
PolicySpaceInfo with MIXED space types for both actions and observations
- load_normalization_stats(filepath)[source]
Restore normalization statistics from a saved directory.
Should be called after
network.load_weights()when loading a checkpoint that was saved with normalisation enabled.
- class POMDPPlanners.planners.mcts_planners.beta_zero.BetaZeroActionSampler(fallback_sampler=None, actions=None, noise_scale=0.1)[source]
Bases:
ActionSamplerAction sampler that draws from the BetaZero policy network.
For discrete action spaces, samples categorically from the network’s softmax distribution. For continuous action spaces, samples from the predicted Gaussian and adds exploration noise.
When no belief node is available (e.g. during random rollout) or the network has not been set, the sampler delegates to a
fallback_sampler.- Parameters:
fallback_sampler (
Optional[ActionSampler]) – Sampler used when the network is not available. Optional for pickle deserialization compatibility; should always be provided during normal construction.actions (
Optional[List[Any]]) – List of discrete actions (required for discrete spaces).noise_scale (
float) – Standard deviation of exploration noise added to continuous action samples.
- sample(belief_node=None)[source]
Sample a new action for progressive widening.
- Parameters:
belief_node (
Optional[BeliefNode]) – Optional current belief node for informed sampling.- Return type:
- Returns:
A sampled action.
- set_network_and_representation(network, belief_representation, normalize_fn=None)[source]
Attach the network, belief representation, and optional normalizer.
- Parameters:
network (
AbstractBetaZeroNetwork) –AbstractBetaZeroNetworkinstance.belief_representation (
BeliefRepresentation) –BeliefRepresentationinstance.normalize_fn (
Optional[Callable[[ndarray],ndarray]]) – Optional callable that normalizes raw belief features before passing them to the network. WhenNone, features are passed through unchanged. Passplanner._get_normalized_featuresto keep the sampler in sync with BetaZero’s normalization stats.
- Return type:
- class POMDPPlanners.planners.mcts_planners.beta_zero.BetaZeroNetwork(belief_dim, action_space_type, n_actions=None, action_dim=None, hidden_sizes=(128, 128))[source]
Bases:
AbstractBetaZeroNetwork,ModuleDual-head neural network for BetaZero.
- Architecture:
Shared trunk:
Linear(belief_dim, h) → ReLU → Linear(h, h) → ReLUPolicy head (discrete):
Linear(h, h//2) → ReLU → Linear(h//2, n_actions) → LogSoftmaxPolicy head (continuous):
Linear(h, h//2) → ReLU → Linear(h//2, 2*action_dim)(mean, log_std)Value head:
Linear(h, h//2) → ReLU → Linear(h//2, 1)
- Parameters:
belief_dim (
int) – Dimensionality of the belief feature vector φ(b).action_space_type (
str) –"discrete"or"continuous".n_actions (
Optional[int]) – Number of discrete actions (required whenaction_space_type="discrete").action_dim (
Optional[int]) – Dimensionality of continuous actions (required whenaction_space_type="continuous").hidden_sizes (
Sequence[int]) – Tuple of hidden layer widths for the shared trunk.
- Raises:
ValueError – If required parameters for the chosen action space type are missing.
Example
>>> import torch, numpy as np >>> from POMDPPlanners.planners.mcts_planners.beta_zero.beta_zero_network import BetaZeroNetwork >>> net = BetaZeroNetwork(belief_dim=4, action_space_type="discrete", n_actions=3) >>> policy, value = net.predict(np.zeros(4, dtype=np.float32)) >>> policy.shape (3,) >>> isinstance(value, float) True
- fit(buffer, n_epochs, batch_size, learning_rate, weight_decay, track_gradients)[source]
Train the network on a replay buffer.
- Parameters:
buffer (
TrainingBuffer) – Replay buffer with training examples.n_epochs (
int) – Number of full passes over the buffer.batch_size (
int) – Mini-batch size.learning_rate (
float) – Adam learning rate.weight_decay (
float) – L2 regularisation coefficient.track_gradients (
bool) – WhenTrue, gradient and weight norms are included in the returned metrics.
- Return type:
- Returns:
Dictionary with per-epoch loss lists.
- forward(belief_features)[source]
Forward pass returning raw policy and value outputs.
- Parameters:
belief_features (
Tensor) – Tensor of shape(batch, belief_dim)or(belief_dim,).- Return type:
Tuple[Tensor,Tensor]- Returns:
Tuple of (policy_output, value) tensors. - Discrete policy: log-probabilities of shape
(batch, n_actions). - Continuous policy:[mean, log_std]of shape(batch, 2*action_dim). - Value: shape(batch, 1).
- predict(belief_features)[source]
Single-sample inference returning numpy arrays.
Runs in
torch.no_grad()mode. For discrete action spaces the output policy is exponentiated to give probabilities. Supports both CPU and CUDA: input is moved to the model device and outputs are moved to CPU for numpy conversion.
- predict_batch(belief_features_batch)[source]
Batched inference returning numpy arrays.
Runs in
torch.no_grad()mode. Processes multiple belief feature vectors in a single forward pass for efficiency. Supports both CPU and CUDA; outputs are moved to CPU for numpy conversion when on CUDA.- Parameters:
belief_features_batch (
ndarray) – 2-D array of shape(N, belief_dim).- Return type:
- Returns:
Tuple of (policies, values). - Discrete: policies is
(N, n_actions)probability matrix. - Continuous: policies is(N, 2*action_dim)with[mean, log_std]. - values is(N,)array of floats.
- class POMDPPlanners.planners.mcts_planners.beta_zero.ParticleMeanStdRepresentation(state_dim)[source]
Bases:
BeliefRepresentationDefault belief representation: φ(b) = [mean(particles), std(particles)].
For a state space of dimension d, the output is a vector of length 2·d formed by concatenating the (weighted) mean and standard deviation of the belief’s particle set.
Supported belief types: -
WeightedParticleBelief/WeightedParticleBeliefStateUpdate:uses normalised weights for statistics.
GaussianBelief: extracts mean and diagonal of covariance.Any other
Beliefsubclass: falls back to sampling 100 particles.
- Parameters:
state_dim (
int) – Dimensionality of the state space.
Example
>>> import numpy as np >>> np.random.seed(42) >>> from POMDPPlanners.core.belief import WeightedParticleBelief >>> from POMDPPlanners.planners.mcts_planners.beta_zero.belief_representation import ( ... ParticleMeanStdRepresentation, ... ) >>> >>> particles = [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]] >>> log_weights = np.log([0.2, 0.5, 0.3]) >>> belief = WeightedParticleBelief(particles, log_weights) >>> rep = ParticleMeanStdRepresentation(state_dim=2) >>> features = rep(belief) >>> features.shape (4,) >>> rep.feature_dim 4
- class POMDPPlanners.planners.mcts_planners.beta_zero.TrainingBuffer(n_buffer=1)[source]
Bases:
objectIteration-slot replay buffer for BetaZero training.
Examples are grouped into iteration slots. Calling
begin_iteration()commits the current slot to a fixed-length history deque (capacityn_buffer - 1past slots) and opens a fresh slot for the new iteration. Training samples uniformly from all examples across all retained slots plus the current slot.With
n_buffer=1(the default) the history deque has capacity 0, so only the current iteration’s data is ever visible to training — matching the on-policy behaviour of the Julia reference implementation.- Parameters:
n_buffer (
int) – Number of iteration slots to retain (including the current slot). Must be >= 1.
Example
>>> import numpy as np >>> from POMDPPlanners.planners.mcts_planners.beta_zero.training_buffer import ( ... TrainingBuffer, TrainingExample, ... ) >>> buf = TrainingBuffer(n_buffer=1) >>> buf.begin_iteration() >>> buf.add(TrainingExample(np.zeros(4), np.array([0.5, 0.5]), 1.0)) >>> len(buf) 1 >>> batch = buf.sample_batch(1) >>> batch[0].shape (1, 4)
- add(example)[source]
Append an example to the current iteration slot.
- Return type:
- Parameters:
example (TrainingExample)
- begin_iteration()[source]
Commit the current slot and open a fresh one for the new iteration.
If the current slot contains examples it is pushed into the history deque (oldest slot evicted automatically when the deque is full), then the current slot is reset to empty.
- Return type:
- get_all_examples()[source]
Return all examples across all buffer slots (historical + current).
- Return type:
- sample_batch(batch_size)[source]
Sample a random mini-batch uniformly from all retained examples.
- Parameters:
batch_size (
int) – Number of examples to sample (with replacement ifbatch_size > len(buffer)).- Return type:
- Returns:
Tuple of (belief_features, policy_targets, value_targets) arrays. - belief_features: shape
(batch_size, belief_dim)- policy_targets: shape(batch_size, policy_dim)- value_targets: shape(batch_size,)
- class POMDPPlanners.planners.mcts_planners.beta_zero.TrainingExample(belief_features, policy_target, value_target)[source]
Bases:
objectSingle training datum for BetaZero network training.
- belief_features
Belief feature vector φ(b), shape
(belief_dim,).
- policy_target
Q-weighted policy target π_qw, shape
(n_actions,)(discrete).
- value_target
Discounted return g_t.
Submodules
POMDPPlanners.planners.mcts_planners.beta_zero.belief_representation module
Belief-to-feature-vector representations for BetaZero.
This module provides abstractions for converting POMDP belief states into fixed-size feature vectors suitable for neural network input. The default implementation extracts particle statistics (mean and standard deviation).
- Reference:
Moss, R. J., Corso, A., Caers, J., & Kochenderfer, M. J. (2024). BetaZero: Belief-State Planning for Long-Horizon POMDPs using Learned Approximations. Reinforcement Learning Conference (RLC).
- Classes:
BeliefRepresentation: Abstract base for belief feature extraction ParticleMeanStdRepresentation: Default φ(b) = [mean(particles), std(particles)]
- class POMDPPlanners.planners.mcts_planners.beta_zero.belief_representation.BeliefRepresentation[source]
Bases:
ABCAbstract base class for mapping beliefs to fixed-size feature vectors.
Subclasses define how a POMDP belief state is compressed into a numerical vector φ(b) ∈ ℝ^d that can be fed to a neural network.
Note
This is an abstract base class and cannot be instantiated directly.
- class POMDPPlanners.planners.mcts_planners.beta_zero.belief_representation.ParticleMeanStdRepresentation(state_dim)[source]
Bases:
BeliefRepresentationDefault belief representation: φ(b) = [mean(particles), std(particles)].
For a state space of dimension d, the output is a vector of length 2·d formed by concatenating the (weighted) mean and standard deviation of the belief’s particle set.
Supported belief types: -
WeightedParticleBelief/WeightedParticleBeliefStateUpdate:uses normalised weights for statistics.
GaussianBelief: extracts mean and diagonal of covariance.Any other
Beliefsubclass: falls back to sampling 100 particles.
- Parameters:
state_dim (
int) – Dimensionality of the state space.
Example
>>> import numpy as np >>> np.random.seed(42) >>> from POMDPPlanners.core.belief import WeightedParticleBelief >>> from POMDPPlanners.planners.mcts_planners.beta_zero.belief_representation import ( ... ParticleMeanStdRepresentation, ... ) >>> >>> particles = [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]] >>> log_weights = np.log([0.2, 0.5, 0.3]) >>> belief = WeightedParticleBelief(particles, log_weights) >>> rep = ParticleMeanStdRepresentation(state_dim=2) >>> features = rep(belief) >>> features.shape (4,) >>> rep.feature_dim 4
POMDPPlanners.planners.mcts_planners.beta_zero.beta_zero module
BetaZero planner: neural MCTS for POMDPs.
This module implements the BetaZero algorithm, which adapts AlphaZero to POMDPs by
planning in belief space. It combines online MCTS with PUCT and neural network priors
for both action selection and leaf value estimation. Offline policy-iteration training
is orchestrated via PolicyTrainer.
- Reference:
Moss, R. J., Corso, A., Caers, J., & Kochenderfer, M. J. (2024). BetaZero: Belief-State Planning for Long-Horizon POMDPs using Learned Approximations. Reinforcement Learning Conference (RLC).
- Classes:
BetaZero: Main planner extending
DoubleProgressiveWideningMCTSPolicy.
- class POMDPPlanners.planners.mcts_planners.beta_zero.beta_zero.BetaZero(environment, discount_factor, depth, name, action_sampler, k_a=1.0, alpha_a=0.5, k_o=1.0, alpha_o=0.5, exploration_constant=1.0, time_out_in_seconds=None, n_simulations=None, min_visit_count_per_action=1, network=None, belief_representation=None, state_dim=None, z_q=1.0, z_n=1.0, temperature=1.0, n_buffer=1, training_batch_size=256, training_epochs=10, learning_rate=0.001, weight_decay=0.0001, hidden_sizes=(128, 128), track_gradients=False, normalize_inputs=True, normalize_values=True, log_path=None, debug=False, use_queue_logger=False)[source]
Bases:
DoubleProgressiveWideningMCTSPolicy,TrainablePolicyBetaZero: Neural MCTS for POMDPs.
Extends
DoubleProgressiveWideningMCTSPolicywith three key innovations from the BetaZero paper:PUCT selection: Replaces UCB1 using learned policy priors.
Neural value estimation: Replaces random rollouts at leaf nodes.
Policy iteration via ``fit()``: Collects episodes, computes Q-weighted policy targets, and trains the network.
The planner has two modes: - Online planning via
action(belief): builds an MCTS tree withPUCT and network value estimates.
Offline training via
fit(): alternates data collection and network training.
- Parameters:
environment (Environment)
discount_factor (float)
depth (int)
name (str)
action_sampler (ActionSampler)
k_a (float)
alpha_a (float)
k_o (float)
alpha_o (float)
exploration_constant (float)
time_out_in_seconds (int | None)
n_simulations (int | None)
min_visit_count_per_action (int)
network (AbstractBetaZeroNetwork | None)
belief_representation (BeliefRepresentation | None)
state_dim (int | None)
z_q (float)
z_n (float)
temperature (float)
n_buffer (int)
training_batch_size (int)
training_epochs (int)
learning_rate (float)
weight_decay (float)
track_gradients (bool)
normalize_inputs (bool)
normalize_values (bool)
log_path (Path | None)
debug (bool)
use_queue_logger (bool)
- network
Dual-head neural network for policy and value prediction.
- belief_representation
Belief → feature-vector mapping φ(b).
- z_q
Exponent for Q-value term in policy target.
- z_n
Exponent for visit-count term in policy target.
- temperature
Temperature τ for sharpening/smoothing policy target.
Example
>>> import numpy as np >>> np.random.seed(42) >>> from POMDPPlanners.environments.tiger_pomdp import TigerPOMDP >>> from POMDPPlanners.core.belief import get_initial_belief >>> from POMDPPlanners.utils.action_samplers import DiscreteActionSampler >>> from POMDPPlanners.planners.mcts_planners.beta_zero.beta_zero import BetaZero >>> >>> env = TigerPOMDP(discount_factor=0.95) >>> sampler = DiscreteActionSampler(env.get_actions()) >>> planner = BetaZero( ... environment=env, ... discount_factor=0.95, ... depth=3, ... name="BetaZero_Tiger", ... action_sampler=sampler, ... n_simulations=20, ... state_dim=1, ... ) >>> belief = get_initial_belief(env, n_particles=10) >>> actions, run_data = planner.action(belief) >>> actions[0] in env.get_actions() True
- action(belief)[source]
Select an action via MCTS with PUCT and network value estimates.
If data collection is active (during
fit()), also stores a pending training example from the tree root.- Return type:
- Parameters:
belief (Belief)
- collect_episodes_batched(initial_belief_fn, n_episodes, episode_length)[source]
Collect training data using fast batched (network-only) rollouts.
- get_metric_keys()[source]
Return the loss-metric key names produced by
train_step().
- get_network()[source]
Return the underlying trainable network, or
Noneif not applicable.Override in concrete policies to enable weight-histogram logging in
TensorBoardCallback.- Return type:
- classmethod get_space_info()[source]
Get information about action and observation spaces.
Default implementation returns MIXED space types, which is appropriate for most progressive widening MCTS planners that support both discrete and continuous action spaces through the action sampler interface.
Subclasses can override this method to specify different space requirements (e.g., PFT_DPW specifies CONTINUOUS action space).
- Return type:
- Returns:
PolicySpaceInfo with MIXED space types for both actions and observations
- load_normalization_stats(filepath)[source]
Restore normalization statistics from a saved directory.
Should be called after
network.load_weights()when loading a checkpoint that was saved with normalisation enabled.
POMDPPlanners.planners.mcts_planners.beta_zero.beta_zero_action_sampler module
Network-guided action sampler for BetaZero progressive widening.
This module provides an ActionSampler subclass that draws new candidate
actions from the policy network’s output distribution, enabling the progressive
widening mechanism to propose actions guided by learned priors rather than
uniform random sampling.
- Reference:
Moss, R. J., Corso, A., Caers, J., & Kochenderfer, M. J. (2024). BetaZero: Belief-State Planning for Long-Horizon POMDPs using Learned Approximations. Reinforcement Learning Conference (RLC).
- Classes:
BetaZeroActionSampler: Samples actions from the BetaZero policy network.
- class POMDPPlanners.planners.mcts_planners.beta_zero.beta_zero_action_sampler.BetaZeroActionSampler(fallback_sampler=None, actions=None, noise_scale=0.1)[source]
Bases:
ActionSamplerAction sampler that draws from the BetaZero policy network.
For discrete action spaces, samples categorically from the network’s softmax distribution. For continuous action spaces, samples from the predicted Gaussian and adds exploration noise.
When no belief node is available (e.g. during random rollout) or the network has not been set, the sampler delegates to a
fallback_sampler.- Parameters:
fallback_sampler (
Optional[ActionSampler]) – Sampler used when the network is not available. Optional for pickle deserialization compatibility; should always be provided during normal construction.actions (
Optional[List[Any]]) – List of discrete actions (required for discrete spaces).noise_scale (
float) – Standard deviation of exploration noise added to continuous action samples.
- sample(belief_node=None)[source]
Sample a new action for progressive widening.
- Parameters:
belief_node (
Optional[BeliefNode]) – Optional current belief node for informed sampling.- Return type:
- Returns:
A sampled action.
- set_network_and_representation(network, belief_representation, normalize_fn=None)[source]
Attach the network, belief representation, and optional normalizer.
- Parameters:
network (
AbstractBetaZeroNetwork) –AbstractBetaZeroNetworkinstance.belief_representation (
BeliefRepresentation) –BeliefRepresentationinstance.normalize_fn (
Optional[Callable[[ndarray],ndarray]]) – Optional callable that normalizes raw belief features before passing them to the network. WhenNone, features are passed through unchanged. Passplanner._get_normalized_featuresto keep the sampler in sync with BetaZero’s normalization stats.
- Return type:
POMDPPlanners.planners.mcts_planners.beta_zero.beta_zero_network module
Dual-head neural network for BetaZero policy and value prediction.
This module provides the PyTorch neural network used by BetaZero. The network has a shared trunk with two output heads: a policy head that produces action probabilities (discrete) or Gaussian parameters (continuous), and a value head that estimates the state value V(φ(b)).
- Reference:
Moss, R. J., Corso, A., Caers, J., & Kochenderfer, M. J. (2024). BetaZero: Belief-State Planning for Long-Horizon POMDPs using Learned Approximations. Reinforcement Learning Conference (RLC).
- Classes:
AbstractBetaZeroNetwork: Abstract base class for BetaZero policy and value networks. BetaZeroNetwork: Shared-trunk network with policy and value heads.
- class POMDPPlanners.planners.mcts_planners.beta_zero.beta_zero_network.AbstractBetaZeroNetwork[source]
Bases:
ABCAbstract base class for BetaZero policy and value networks.
Defines the inference and training interface required by the BetaZero planner. Concrete subclasses provide the underlying model architecture.
Note
This is an abstract base class and cannot be instantiated directly.
- abstractmethod fit(buffer, n_epochs, batch_size, learning_rate, weight_decay, track_gradients)[source]
Train on replay buffer and return per-epoch loss metrics.
- class POMDPPlanners.planners.mcts_planners.beta_zero.beta_zero_network.BetaZeroNetwork(belief_dim, action_space_type, n_actions=None, action_dim=None, hidden_sizes=(128, 128))[source]
Bases:
AbstractBetaZeroNetwork,ModuleDual-head neural network for BetaZero.
- Architecture:
Shared trunk:
Linear(belief_dim, h) → ReLU → Linear(h, h) → ReLUPolicy head (discrete):
Linear(h, h//2) → ReLU → Linear(h//2, n_actions) → LogSoftmaxPolicy head (continuous):
Linear(h, h//2) → ReLU → Linear(h//2, 2*action_dim)(mean, log_std)Value head:
Linear(h, h//2) → ReLU → Linear(h//2, 1)
- Parameters:
belief_dim (
int) – Dimensionality of the belief feature vector φ(b).action_space_type (
str) –"discrete"or"continuous".n_actions (
Optional[int]) – Number of discrete actions (required whenaction_space_type="discrete").action_dim (
Optional[int]) – Dimensionality of continuous actions (required whenaction_space_type="continuous").hidden_sizes (
Sequence[int]) – Tuple of hidden layer widths for the shared trunk.
- Raises:
ValueError – If required parameters for the chosen action space type are missing.
Example
>>> import torch, numpy as np >>> from POMDPPlanners.planners.mcts_planners.beta_zero.beta_zero_network import BetaZeroNetwork >>> net = BetaZeroNetwork(belief_dim=4, action_space_type="discrete", n_actions=3) >>> policy, value = net.predict(np.zeros(4, dtype=np.float32)) >>> policy.shape (3,) >>> isinstance(value, float) True
- fit(buffer, n_epochs, batch_size, learning_rate, weight_decay, track_gradients)[source]
Train the network on a replay buffer.
- Parameters:
buffer (
TrainingBuffer) – Replay buffer with training examples.n_epochs (
int) – Number of full passes over the buffer.batch_size (
int) – Mini-batch size.learning_rate (
float) – Adam learning rate.weight_decay (
float) – L2 regularisation coefficient.track_gradients (
bool) – WhenTrue, gradient and weight norms are included in the returned metrics.
- Return type:
- Returns:
Dictionary with per-epoch loss lists.
- forward(belief_features)[source]
Forward pass returning raw policy and value outputs.
- Parameters:
belief_features (
Tensor) – Tensor of shape(batch, belief_dim)or(belief_dim,).- Return type:
Tuple[Tensor,Tensor]- Returns:
Tuple of (policy_output, value) tensors. - Discrete policy: log-probabilities of shape
(batch, n_actions). - Continuous policy:[mean, log_std]of shape(batch, 2*action_dim). - Value: shape(batch, 1).
- predict(belief_features)[source]
Single-sample inference returning numpy arrays.
Runs in
torch.no_grad()mode. For discrete action spaces the output policy is exponentiated to give probabilities. Supports both CPU and CUDA: input is moved to the model device and outputs are moved to CPU for numpy conversion.
- predict_batch(belief_features_batch)[source]
Batched inference returning numpy arrays.
Runs in
torch.no_grad()mode. Processes multiple belief feature vectors in a single forward pass for efficiency. Supports both CPU and CUDA; outputs are moved to CPU for numpy conversion when on CUDA.- Parameters:
belief_features_batch (
ndarray) – 2-D array of shape(N, belief_dim).- Return type:
- Returns:
Tuple of (policies, values). - Discrete: policies is
(N, n_actions)probability matrix. - Continuous: policies is(N, 2*action_dim)with[mean, log_std]. - values is(N,)array of floats.
POMDPPlanners.planners.mcts_planners.beta_zero.puct module
PUCT action selection and progressive widening for BetaZero.
Implements the Predictor Upper Confidence Trees (PUCT) selection rule used in BetaZero, replacing the standard UCB1 criterion. PUCT biases exploration towards actions favoured by the policy network.
- Reference:
Moss, R. J., Corso, A., Caers, J., & Kochenderfer, M. J. (2024). BetaZero: Belief-State Planning for Long-Horizon POMDPs using Learned Approximations. Reinforcement Learning Conference (RLC).
- Functions:
puct_selection: Select among existing children using PUCT. puct_action_progressive_widening: Progressive widening with PUCT selection.
- POMDPPlanners.planners.mcts_planners.beta_zero.puct.puct_action_progressive_widening(belief_node, alpha_a, action_sampler, exploration_constant, k_a, action_priors=None, min_visit_count_per_action=1)[source]
Progressive widening with PUCT selection instead of UCB1.
Follows the same widening logic as the standard
action_progressive_wideningbut selects among existing actions usingpuct_selection()with neural network priors.- Parameters:
belief_node (
BeliefNode) – Current belief node.alpha_a (
float) – Progressive widening exponent (0 < α_a ≤ 1).action_sampler (
ActionSampler) – Sampler for generating new candidate actions.exploration_constant (
float) – PUCT exploration constant c.k_a (
float) – Progressive widening coefficient.action_priors (
Optional[ndarray]) – Prior probabilities for existing children. IfNone, uniform priors are used.min_visit_count_per_action (
int) – At the root, ensure every child has been visited at least this many times before selecting via PUCT.
- Return type:
- Returns:
Selected or newly created action node.
- POMDPPlanners.planners.mcts_planners.beta_zero.puct.puct_selection(belief_node, exploration_constant, action_priors=None)[source]
Select an action child using the PUCT criterion.
The selection rule is:
a* = argmax Q̄(b,a) + c · P(a|b) · √N(b) / (1 + N(b,a))
where Q-values are normalised to [0, 1] for problem-independent exploration.
- Parameters:
belief_node (
BeliefNode) – Current belief node with at least one action child.exploration_constant (
float) – Exploration constant c.action_priors (
Optional[ndarray]) – Prior probabilities P(a|b) aligned withbelief_node.children. IfNone, uniform priors are used.
- Return type:
- Returns:
The action node with the highest PUCT score.
POMDPPlanners.planners.mcts_planners.beta_zero.training module
Training utilities for BetaZero network.
This module provides the loss function and training loop used during BetaZero policy iteration to update the dual-head network.
- Reference:
Moss, R. J., Corso, A., Caers, J., & Kochenderfer, M. J. (2024). BetaZero: Belief-State Planning for Long-Horizon POMDPs using Learned Approximations. Reinforcement Learning Conference (RLC).
- Functions:
compute_beta_zero_loss: Combined value + policy + L2 loss (Eq. 7 in paper). train_network: Run multiple epochs of training on a replay buffer.
- POMDPPlanners.planners.mcts_planners.beta_zero.training.compute_beta_zero_loss(network, belief_features, policy_targets, value_targets)[source]
Compute the BetaZero combined loss.
L = MSE(v, g_t) + CrossEntropy(p, π_t) (Eq. 7)
For discrete action spaces the policy loss is cross-entropy between the network’s log-softmax output and the target distribution. For continuous action spaces a Gaussian negative-log-likelihood is used.
- Parameters:
network (
BetaZeroNetwork) – The BetaZero network.belief_features (
Tensor) – Batch of belief feature vectors, shape(B, belief_dim).policy_targets (
Tensor) – Batch of policy targets, shape(B, policy_dim).value_targets (
Tensor) – Batch of scalar value targets, shape(B,).
- Return type:
- Returns:
Tuple of (total_loss, component_dict) where component_dict contains
"value_loss"and"policy_loss"as Python floats.
- POMDPPlanners.planners.mcts_planners.beta_zero.training.train_network(network, buffer, n_epochs=10, batch_size=256, learning_rate=0.001, weight_decay=0.0001, track_gradients=False, input_mean=None, input_std=None, value_mean=None, value_std=None)[source]
Train the network for multiple epochs on buffered data.
- Parameters:
network (
BetaZeroNetwork) – Network to train (modified in-place).buffer (
TrainingBuffer) – Replay buffer with training examples.n_epochs (
int) – Number of full passes over the buffer.batch_size (
int) – Mini-batch size.learning_rate (
float) – Adam learning rate.weight_decay (
float) – L2 regularisation coefficient (λ in Eq. 7).input_mean (
Optional[ndarray]) – Per-feature mean for input normalisation (None= disabled).input_std (
Optional[ndarray]) – Per-feature std for input normalisation (None= disabled).value_mean (
Optional[float]) – Scalar mean for value normalisation (None= disabled).value_std (
Optional[float]) – Scalar std for value normalisation (None= disabled).track_gradients (bool)
- Returns:
"total_loss","value_loss","policy_loss". Whentrack_gradientsisTrue, also includes"grad_norm/global","grad_norm/trunk","grad_norm/policy_head","grad_norm/value_head", and"weight_norm/global".- Return type:
POMDPPlanners.planners.mcts_planners.beta_zero.training_buffer module
Iteration-slot replay buffer for BetaZero training examples.
This module provides a buffer that stores training tuples (φ(b), π_qw, g_t)
collected during BetaZero policy iteration. The buffer is partitioned into
iteration slots: each call to TrainingBuffer.begin_iteration() commits
the current slot to history and opens a fresh one. At most n_buffer
iteration slots are retained; older slots are evicted automatically, which
mirrors the CircularBuffer design used in the reference Julia implementation
(BetaZero.jl, n_buffer parameter).
With the default n_buffer=1 only the current iteration’s data is ever in
the buffer, keeping training fully on-policy. Set n_buffer > 1 to retain
a rolling window of recent iterations.
- Reference:
Moss, R. J., Corso, A., Caers, J., & Kochenderfer, M. J. (2024). BetaZero: Belief-State Planning for Long-Horizon POMDPs using Learned Approximations. Reinforcement Learning Conference (RLC).
- Classes:
TrainingExample: A single training datum. TrainingBuffer: Iteration-slot buffer with uniform batch sampling.
- class POMDPPlanners.planners.mcts_planners.beta_zero.training_buffer.TrainingBuffer(n_buffer=1)[source]
Bases:
objectIteration-slot replay buffer for BetaZero training.
Examples are grouped into iteration slots. Calling
begin_iteration()commits the current slot to a fixed-length history deque (capacityn_buffer - 1past slots) and opens a fresh slot for the new iteration. Training samples uniformly from all examples across all retained slots plus the current slot.With
n_buffer=1(the default) the history deque has capacity 0, so only the current iteration’s data is ever visible to training — matching the on-policy behaviour of the Julia reference implementation.- Parameters:
n_buffer (
int) – Number of iteration slots to retain (including the current slot). Must be >= 1.
Example
>>> import numpy as np >>> from POMDPPlanners.planners.mcts_planners.beta_zero.training_buffer import ( ... TrainingBuffer, TrainingExample, ... ) >>> buf = TrainingBuffer(n_buffer=1) >>> buf.begin_iteration() >>> buf.add(TrainingExample(np.zeros(4), np.array([0.5, 0.5]), 1.0)) >>> len(buf) 1 >>> batch = buf.sample_batch(1) >>> batch[0].shape (1, 4)
- add(example)[source]
Append an example to the current iteration slot.
- Return type:
- Parameters:
example (TrainingExample)
- begin_iteration()[source]
Commit the current slot and open a fresh one for the new iteration.
If the current slot contains examples it is pushed into the history deque (oldest slot evicted automatically when the deque is full), then the current slot is reset to empty.
- Return type:
- get_all_examples()[source]
Return all examples across all buffer slots (historical + current).
- Return type:
- sample_batch(batch_size)[source]
Sample a random mini-batch uniformly from all retained examples.
- Parameters:
batch_size (
int) – Number of examples to sample (with replacement ifbatch_size > len(buffer)).- Return type:
- Returns:
Tuple of (belief_features, policy_targets, value_targets) arrays. - belief_features: shape
(batch_size, belief_dim)- policy_targets: shape(batch_size, policy_dim)- value_targets: shape(batch_size,)
- class POMDPPlanners.planners.mcts_planners.beta_zero.training_buffer.TrainingExample(belief_features, policy_target, value_target)[source]
Bases:
objectSingle training datum for BetaZero network training.
- belief_features
Belief feature vector φ(b), shape
(belief_dim,).
- policy_target
Q-weighted policy target π_qw, shape
(n_actions,)(discrete).
- value_target
Discounted return g_t.