POMDPPlanners.planners.mcts_planners.constrained_zero package
ConstrainedZero: Neural MCTS for Chance-Constrained POMDPs.
This package implements the ConstrainedZero algorithm (Moss et al., IJCAI 2024), which extends BetaZero to solve CC-POMDPs by adding a failure probability head, safety-constrained PUCT, and adaptive failure threshold calibration.
- Reference:
Moss, R. J., Jamgochian, A., Fischer, J., Corso, A., & Kochenderfer, M. J. (2024). ConstrainedZero: Chance-Constrained POMDP Planning Using Learned Probabilistic Failure Surrogates and Adaptive Safety Constraints. Proceedings of the Thirty-Third International Joint Conference on Artificial Intelligence (IJCAI), 6752-6760. https://www.ijcai.org/proceedings/2024/746
- Classes:
ConstrainedZero: Main planner extending BetaZero for CC-POMDPs ConstrainedZeroNetwork: Three-head network with policy, value, and failure heads ConstrainedTrainingBuffer: Replay buffer with failure targets ConstrainedTrainingExample: Training datum with failure target
- class POMDPPlanners.planners.mcts_planners.constrained_zero.ConstrainedTrainingBuffer(n_buffer=1)[source]
Bases:
TrainingBufferIteration-slot replay buffer for ConstrainedZero training.
Extends
TrainingBufferto store and sampleConstrainedTrainingExampleinstances, returning a 4-tuple fromsample_batch()that includes failure targets.- Parameters:
n_buffer (
int) – Number of policy-iteration slots to retain. With the defaultn_buffer=1only the current iteration’s data is used for training (on-policy).
Example
>>> import numpy as np >>> from POMDPPlanners.planners.mcts_planners.constrained_zero.constrained_training_buffer import ( ... ConstrainedTrainingBuffer, ConstrainedTrainingExample, ... ) >>> buf = ConstrainedTrainingBuffer(n_buffer=1) >>> buf.begin_iteration() >>> buf.add(ConstrainedTrainingExample(np.zeros(4), np.array([0.5, 0.5]), 1.0, 0.0)) >>> len(buf) 1 >>> batch = buf.sample_batch(1) >>> len(batch) 4 >>> batch[0].shape (1, 4)
- add(example)[source]
Append a constrained example to the current iteration slot.
- Return type:
- Parameters:
example (ConstrainedTrainingExample)
- sample_batch(batch_size)[source]
Sample a random mini-batch including failure targets.
- Parameters:
batch_size (
int) – Number of examples to sample (with replacement ifbatch_size > len(buffer)).- Return type:
- Returns:
Tuple of (belief_features, policy_targets, value_targets, failure_targets). - belief_features: shape
(batch_size, belief_dim)- policy_targets: shape(batch_size, policy_dim)- value_targets: shape(batch_size,)- failure_targets: shape(batch_size,)
- class POMDPPlanners.planners.mcts_planners.constrained_zero.ConstrainedTrainingExample(belief_features, policy_target, value_target, failure_target)[source]
Bases:
objectSingle training datum for ConstrainedZero network training.
- Parameters:
- belief_features
Belief feature vector phi(b), shape
(belief_dim,).
- policy_target
Q-weighted policy target pi_qw, shape
(n_actions,)(discrete).
- value_target
Discounted return g_t.
- failure_target
Binary episode-level failure indicator (1.0 if failure occurred).
- class POMDPPlanners.planners.mcts_planners.constrained_zero.ConstrainedZero(environment, discount_factor, depth, name, action_sampler, failure_fn, delta_0=0.01, eta=1e-05, delta_compounding=1.0, k_a=1.0, alpha_a=0.5, k_o=1.0, alpha_o=0.5, exploration_constant=1.0, time_out_in_seconds=None, n_simulations=None, min_visit_count_per_action=1, network=None, belief_representation=None, state_dim=None, z_q=1.0, z_n=1.0, temperature=1.0, n_buffer=1, training_batch_size=256, training_epochs=10, learning_rate=0.001, weight_decay=0.0001, hidden_sizes=(128, 128), use_dropout=True, p_dropout=0.2, track_gradients=False, normalize_inputs=True, normalize_values=True, log_path=None, debug=False, use_queue_logger=False)[source]
Bases:
BetaZeroConstrainedZero: Neural MCTS for Chance-Constrained POMDPs.
Extends
BetaZerowith:3-head network: Adds a failure probability head alongside policy and value.
SPUCT selection: Safety-constrained PUCT that masks unsafe actions.
Adaptive Delta (conformal inference): Calibrates the failure threshold during tree search using online conformal inference.
Failure propagation: Tracks failure probability per action node using
p = p_immediate + delta_compounding * (1 - p_immediate) * p_next.Constrained policy targets: Applies safety mask during target computation.
- Parameters:
environment (Environment)
discount_factor (float)
depth (int)
name (str)
action_sampler (ActionSampler)
delta_0 (float)
eta (float)
delta_compounding (float)
k_a (float)
alpha_a (float)
k_o (float)
alpha_o (float)
exploration_constant (float)
time_out_in_seconds (int | None)
n_simulations (int | None)
min_visit_count_per_action (int)
network (ConstrainedZeroNetwork)
belief_representation (BeliefRepresentation | None)
state_dim (int | None)
z_q (float)
z_n (float)
temperature (float)
n_buffer (int)
training_batch_size (int)
training_epochs (int)
learning_rate (float)
weight_decay (float)
use_dropout (bool)
p_dropout (float)
track_gradients (bool)
normalize_inputs (bool)
normalize_values (bool)
log_path (Path | None)
debug (bool)
use_queue_logger (bool)
- failure_fn
User-provided function
state -> booldefining failure.
- delta_0
Nominal failure probability threshold.
- eta
Learning rate for adaptive Delta calibration.
- delta_compounding
Discount factor for failure propagation.
Example
>>> import numpy as np >>> np.random.seed(42) >>> from POMDPPlanners.environments.tiger_pomdp import TigerPOMDP >>> from POMDPPlanners.core.belief import get_initial_belief >>> from POMDPPlanners.utils.action_samplers import DiscreteActionSampler >>> from POMDPPlanners.planners.mcts_planners.constrained_zero.constrained_zero import ConstrainedZero >>> >>> env = TigerPOMDP(discount_factor=0.95) >>> sampler = DiscreteActionSampler(env.get_actions()) >>> planner = ConstrainedZero( ... environment=env, ... discount_factor=0.95, ... depth=3, ... name="CZ_Tiger", ... action_sampler=sampler, ... n_simulations=20, ... state_dim=1, ... failure_fn=lambda s: False, ... ) >>> belief = get_initial_belief(env, n_particles=10) >>> actions, run_data = planner.action(belief) >>> actions[0] in env.get_actions() True
- network: ConstrainedZeroNetwork
- class POMDPPlanners.planners.mcts_planners.constrained_zero.ConstrainedZeroNetwork(belief_dim, action_space_type, n_actions=None, action_dim=None, hidden_sizes=(128, 128), use_dropout=True, p_dropout=0.2)[source]
Bases:
BetaZeroNetworkThree-head neural network for ConstrainedZero.
- Architecture:
Shared trunk:
Linear(belief_dim, h) → ReLU → [Dropout] → ... → Linear(h, h) → ReLU → [Dropout]Policy head: inherited from
BetaZeroNetworkValue head: inherited from
BetaZeroNetworkFailure head:
Linear(h, h//2) -> ReLU -> Linear(h//2, 1)
The failure head outputs a raw logit. During
predict(), sigmoid is applied to produce a failure probability in [0, 1].- Parameters:
belief_dim (
int) – Dimensionality of the belief feature vector phi(b).action_space_type (
str) –"discrete"or"continuous".n_actions (
Optional[int]) – Number of discrete actions (required whenaction_space_type="discrete").action_dim (
Optional[int]) – Dimensionality of continuous actions (required whenaction_space_type="continuous").hidden_sizes (
Sequence[int]) – Tuple of hidden layer widths for the shared trunk.use_dropout (
bool) – If True, apply dropout after each ReLU in the shared trunk (default True).p_dropout (
float) – Dropout probability for trunk layers (default 0.2).
Example
>>> import numpy as np >>> from POMDPPlanners.planners.mcts_planners.constrained_zero.constrained_zero_network import ConstrainedZeroNetwork >>> net = ConstrainedZeroNetwork(belief_dim=4, action_space_type="discrete", n_actions=3, use_dropout=False) >>> policy, value, failure_prob = net.predict(np.zeros(4, dtype=np.float32)) >>> policy.shape (3,) >>> isinstance(value, float) True >>> 0.0 <= failure_prob <= 1.0 True
- forward(belief_features)[source]
Forward pass returning raw policy, value, and failure logit.
- Parameters:
belief_features (
Tensor) – Tensor of shape(batch, belief_dim)or(belief_dim,).- Return type:
Tuple[Tensor,Tensor,Tensor]- Returns:
Tuple of (policy_output, value, failure_logit) tensors.
- predict(belief_features)[source]
Single-sample inference returning numpy policy, value, and failure probability.
Switches to eval mode before inference to disable dropout, then restores the original training mode.
- Parameters:
belief_features (
ndarray) – 1-D array of shape(belief_dim,).- Return type:
- Returns:
Tuple of (policy, value, failure_prob). - Discrete: policy is a probability vector summing to 1. - Continuous: policy is
[mean, log_std]. - value is a Python float. - failure_prob is a Python float in [0, 1].
- predict_batch(belief_features_batch)[source]
Batched inference returning numpy policy, value, and failure probability arrays.
Switches to eval mode before inference to disable dropout, then restores the original training mode.
- Parameters:
belief_features_batch (
ndarray) – 2-D array of shape(N, belief_dim).- Return type:
- Returns:
Tuple of (policies, values, failure_probs). - Discrete: policies is
(N, n_actions)probability matrix. - Continuous: policies is(N, 2*action_dim)with[mean, log_std]. - values is(N,)array of floats. - failure_probs is(N,)array of floats in [0, 1].
Submodules
POMDPPlanners.planners.mcts_planners.constrained_zero.constrained_puct module
Safety-constrained PUCT (SPUCT) action selection for ConstrainedZero.
Implements the SPUCT selection rule that masks unsafe actions based on their estimated failure probability relative to an adaptive threshold Delta’.
- Reference:
Moss, R. J., Jamgochian, A., Fischer, J., Corso, A., & Kochenderfer, M. J. (2024). ConstrainedZero: Chance-Constrained POMDP Planning Using Learned Probabilistic Failure Surrogates and Adaptive Safety Constraints. Proceedings of the Thirty-Third International Joint Conference on Artificial Intelligence (IJCAI), 6752-6760. https://www.ijcai.org/proceedings/2024/746
- Functions:
spuct_selection: Select among existing children using safety-masked PUCT. spuct_action_progressive_widening: Progressive widening with SPUCT selection.
- POMDPPlanners.planners.mcts_planners.constrained_zero.constrained_puct.spuct_action_progressive_widening(belief_node, alpha_a, action_sampler, exploration_constant, k_a, failure_dict, delta_prime, action_priors=None, min_visit_count_per_action=1)[source]
Progressive widening with SPUCT selection instead of PUCT.
- Parameters:
belief_node (
BeliefNode) – Current belief node.alpha_a (
float) – Progressive widening exponent (0 < alpha_a <= 1).action_sampler (
ActionSampler) – Sampler for generating new candidate actions.exploration_constant (
float) – PUCT exploration constant c.k_a (
float) – Progressive widening coefficient.failure_dict (
Dict[int,float]) – Mapsid(action_node)to estimated failure probability.delta_prime (
float) – Adaptive failure threshold.action_priors (
Optional[ndarray]) – Prior probabilities for existing children.min_visit_count_per_action (
int) – At the root, ensure every child has been visited at least this many times before selecting via SPUCT.
- Return type:
- Returns:
Selected or newly created action node.
- POMDPPlanners.planners.mcts_planners.constrained_zero.constrained_puct.spuct_selection(belief_node, exploration_constant, failure_dict, delta_prime, action_priors=None)[source]
Select an action child using the safety-constrained PUCT criterion.
The selection rule is:
a* = argmax subject_to(a) * [Q_norm(b,a) + c * P(a|b) * sqrt(N(b)) / (1 + N(b,a))]
where
subject_to(a) = I(f(a) <= Delta')masks unsafe actions. If ALL actions are unsafe, falls back to unconstrained selection.- Parameters:
belief_node (
BeliefNode) – Current belief node with at least one action child.exploration_constant (
float) – Exploration constant c.failure_dict (
Dict[int,float]) – Mapsid(action_node)to estimated failure probability.delta_prime (
float) – Adaptive failure threshold.action_priors (
Optional[ndarray]) – Prior probabilities P(a|b) aligned withbelief_node.children. IfNone, uniform priors are used.
- Return type:
- Returns:
The action node with the highest safety-masked PUCT score.
POMDPPlanners.planners.mcts_planners.constrained_zero.constrained_training module
Training utilities for ConstrainedZero network.
This module provides the loss function and training loop for the 3-head ConstrainedZero network. It extends the BetaZero training with an additional binary cross-entropy loss for the failure head.
- Reference:
Moss, R. J., Jamgochian, A., Fischer, J., Corso, A., & Kochenderfer, M. J. (2024). ConstrainedZero: Chance-Constrained POMDP Planning Using Learned Probabilistic Failure Surrogates and Adaptive Safety Constraints. Proceedings of the Thirty-Third International Joint Conference on Artificial Intelligence (IJCAI), 6752-6760. https://www.ijcai.org/proceedings/2024/746
- Functions:
compute_constrained_zero_loss: Combined value + policy + failure loss. train_constrained_network: Multi-epoch training on a replay buffer.
- POMDPPlanners.planners.mcts_planners.constrained_zero.constrained_training.compute_constrained_zero_loss(network, belief_features, policy_targets, value_targets, failure_targets)[source]
Compute the ConstrainedZero combined loss.
L = MSE(v, g_t) + CrossEntropy(p, pi_t) + BCE(failure_logit, f_t)
- Parameters:
network (
ConstrainedZeroNetwork) – The ConstrainedZero 3-head network.belief_features (
Tensor) – Batch of belief feature vectors, shape(B, belief_dim).policy_targets (
Tensor) – Batch of policy targets, shape(B, policy_dim).value_targets (
Tensor) – Batch of scalar value targets, shape(B,).failure_targets (
Tensor) – Batch of binary failure targets, shape(B,).
- Return type:
- Returns:
Tuple of (total_loss, component_dict) where component_dict contains
"value_loss","policy_loss", and"failure_loss"as Python floats.
- POMDPPlanners.planners.mcts_planners.constrained_zero.constrained_training.train_constrained_network(network, buffer, n_epochs=10, batch_size=256, learning_rate=0.001, weight_decay=0.0001, track_gradients=False, input_mean=None, input_std=None, value_mean=None, value_std=None)[source]
Train the 3-head network for multiple epochs on buffered data.
- Parameters:
network (
ConstrainedZeroNetwork) – Network to train (modified in-place).buffer (
ConstrainedTrainingBuffer) – Replay buffer with constrained training examples.n_epochs (
int) – Number of full passes over the buffer.batch_size (
int) – Mini-batch size.learning_rate (
float) – Adam learning rate.weight_decay (
float) – L2 regularisation coefficient.track_gradients (
bool) – WhenTrue, gradient and weight norms are computed per-batch/epoch and included in the returned metrics.input_mean (
Optional[ndarray]) – Per-feature mean for input normalisation (None= disabled).input_std (
Optional[ndarray]) – Per-feature std for input normalisation (None= disabled).value_mean (
Optional[float]) – Scalar mean for value normalisation (None= disabled).value_std (
Optional[float]) – Scalar std for value normalisation (None= disabled).
- Returns:
"total_loss","value_loss","policy_loss","failure_loss". Whentrack_gradientsisTrue, also includes"grad_norm/global","grad_norm/trunk","grad_norm/policy_head","grad_norm/value_head","grad_norm/failure_head", and"weight_norm/global".- Return type:
POMDPPlanners.planners.mcts_planners.constrained_zero.constrained_training_buffer module
Iteration-slot replay buffer for ConstrainedZero training examples.
This module extends the BetaZero training buffer with an additional failure target, used for training the 3-head ConstrainedZero network.
- Reference:
Moss, R. J., Jamgochian, A., Fischer, J., Corso, A., & Kochenderfer, M. J. (2024). ConstrainedZero: Chance-Constrained POMDP Planning Using Learned Probabilistic Failure Surrogates and Adaptive Safety Constraints. Proceedings of the Thirty-Third International Joint Conference on Artificial Intelligence (IJCAI), 6752-6760. https://www.ijcai.org/proceedings/2024/746
- Classes:
ConstrainedTrainingExample: Training datum with failure target. ConstrainedTrainingBuffer: Buffer returning 4-tuple batches.
- class POMDPPlanners.planners.mcts_planners.constrained_zero.constrained_training_buffer.ConstrainedTrainingBuffer(n_buffer=1)[source]
Bases:
TrainingBufferIteration-slot replay buffer for ConstrainedZero training.
Extends
TrainingBufferto store and sampleConstrainedTrainingExampleinstances, returning a 4-tuple fromsample_batch()that includes failure targets.- Parameters:
n_buffer (
int) – Number of policy-iteration slots to retain. With the defaultn_buffer=1only the current iteration’s data is used for training (on-policy).
Example
>>> import numpy as np >>> from POMDPPlanners.planners.mcts_planners.constrained_zero.constrained_training_buffer import ( ... ConstrainedTrainingBuffer, ConstrainedTrainingExample, ... ) >>> buf = ConstrainedTrainingBuffer(n_buffer=1) >>> buf.begin_iteration() >>> buf.add(ConstrainedTrainingExample(np.zeros(4), np.array([0.5, 0.5]), 1.0, 0.0)) >>> len(buf) 1 >>> batch = buf.sample_batch(1) >>> len(batch) 4 >>> batch[0].shape (1, 4)
- add(example)[source]
Append a constrained example to the current iteration slot.
- Return type:
- Parameters:
example (ConstrainedTrainingExample)
- sample_batch(batch_size)[source]
Sample a random mini-batch including failure targets.
- Parameters:
batch_size (
int) – Number of examples to sample (with replacement ifbatch_size > len(buffer)).- Return type:
- Returns:
Tuple of (belief_features, policy_targets, value_targets, failure_targets). - belief_features: shape
(batch_size, belief_dim)- policy_targets: shape(batch_size, policy_dim)- value_targets: shape(batch_size,)- failure_targets: shape(batch_size,)
- class POMDPPlanners.planners.mcts_planners.constrained_zero.constrained_training_buffer.ConstrainedTrainingExample(belief_features, policy_target, value_target, failure_target)[source]
Bases:
objectSingle training datum for ConstrainedZero network training.
- Parameters:
- belief_features
Belief feature vector phi(b), shape
(belief_dim,).
- policy_target
Q-weighted policy target pi_qw, shape
(n_actions,)(discrete).
- value_target
Discounted return g_t.
- failure_target
Binary episode-level failure indicator (1.0 if failure occurred).
POMDPPlanners.planners.mcts_planners.constrained_zero.constrained_zero module
ConstrainedZero planner: neural MCTS for Chance-Constrained POMDPs.
This module implements the ConstrainedZero algorithm, which extends BetaZero to solve CC-POMDPs. It adds a 3-head network with a failure probability head, safety-constrained PUCT (SPUCT), adaptive failure threshold calibration via conformal inference, and constrained policy targets for training.
- Reference:
Moss, R. J., Jamgochian, A., Fischer, J., Corso, A., & Kochenderfer, M. J. (2024). ConstrainedZero: Chance-Constrained POMDP Planning Using Learned Probabilistic Failure Surrogates and Adaptive Safety Constraints. Proceedings of the Thirty-Third International Joint Conference on Artificial Intelligence (IJCAI), 6752-6760. https://www.ijcai.org/proceedings/2024/746
- Classes:
ConstrainedZero: Main planner extending
BetaZero.
- class POMDPPlanners.planners.mcts_planners.constrained_zero.constrained_zero.ConstrainedZero(environment, discount_factor, depth, name, action_sampler, failure_fn, delta_0=0.01, eta=1e-05, delta_compounding=1.0, k_a=1.0, alpha_a=0.5, k_o=1.0, alpha_o=0.5, exploration_constant=1.0, time_out_in_seconds=None, n_simulations=None, min_visit_count_per_action=1, network=None, belief_representation=None, state_dim=None, z_q=1.0, z_n=1.0, temperature=1.0, n_buffer=1, training_batch_size=256, training_epochs=10, learning_rate=0.001, weight_decay=0.0001, hidden_sizes=(128, 128), use_dropout=True, p_dropout=0.2, track_gradients=False, normalize_inputs=True, normalize_values=True, log_path=None, debug=False, use_queue_logger=False)[source]
Bases:
BetaZeroConstrainedZero: Neural MCTS for Chance-Constrained POMDPs.
Extends
BetaZerowith:3-head network: Adds a failure probability head alongside policy and value.
SPUCT selection: Safety-constrained PUCT that masks unsafe actions.
Adaptive Delta (conformal inference): Calibrates the failure threshold during tree search using online conformal inference.
Failure propagation: Tracks failure probability per action node using
p = p_immediate + delta_compounding * (1 - p_immediate) * p_next.Constrained policy targets: Applies safety mask during target computation.
- Parameters:
environment (Environment)
discount_factor (float)
depth (int)
name (str)
action_sampler (ActionSampler)
delta_0 (float)
eta (float)
delta_compounding (float)
k_a (float)
alpha_a (float)
k_o (float)
alpha_o (float)
exploration_constant (float)
time_out_in_seconds (int | None)
n_simulations (int | None)
min_visit_count_per_action (int)
network (ConstrainedZeroNetwork)
belief_representation (BeliefRepresentation | None)
state_dim (int | None)
z_q (float)
z_n (float)
temperature (float)
n_buffer (int)
training_batch_size (int)
training_epochs (int)
learning_rate (float)
weight_decay (float)
use_dropout (bool)
p_dropout (float)
track_gradients (bool)
normalize_inputs (bool)
normalize_values (bool)
log_path (Path | None)
debug (bool)
use_queue_logger (bool)
- failure_fn
User-provided function
state -> booldefining failure.
- delta_0
Nominal failure probability threshold.
- eta
Learning rate for adaptive Delta calibration.
- delta_compounding
Discount factor for failure propagation.
Example
>>> import numpy as np >>> np.random.seed(42) >>> from POMDPPlanners.environments.tiger_pomdp import TigerPOMDP >>> from POMDPPlanners.core.belief import get_initial_belief >>> from POMDPPlanners.utils.action_samplers import DiscreteActionSampler >>> from POMDPPlanners.planners.mcts_planners.constrained_zero.constrained_zero import ConstrainedZero >>> >>> env = TigerPOMDP(discount_factor=0.95) >>> sampler = DiscreteActionSampler(env.get_actions()) >>> planner = ConstrainedZero( ... environment=env, ... discount_factor=0.95, ... depth=3, ... name="CZ_Tiger", ... action_sampler=sampler, ... n_simulations=20, ... state_dim=1, ... failure_fn=lambda s: False, ... ) >>> belief = get_initial_belief(env, n_particles=10) >>> actions, run_data = planner.action(belief) >>> actions[0] in env.get_actions() True
- network: ConstrainedZeroNetwork
POMDPPlanners.planners.mcts_planners.constrained_zero.constrained_zero_network module
Three-head neural network for ConstrainedZero.
This module extends the BetaZero dual-head network with an additional failure probability head. The failure head outputs a raw logit; sigmoid is applied during prediction to produce a probability in [0, 1].
- Reference:
Moss, R. J., Jamgochian, A., Fischer, J., Corso, A., & Kochenderfer, M. J. (2024). ConstrainedZero: Chance-Constrained POMDP Planning Using Learned Probabilistic Failure Surrogates and Adaptive Safety Constraints. Proceedings of the Thirty-Third International Joint Conference on Artificial Intelligence (IJCAI), 6752-6760. https://www.ijcai.org/proceedings/2024/746
- Classes:
ConstrainedZeroNetwork: Shared-trunk network with policy, value, and failure heads.
- class POMDPPlanners.planners.mcts_planners.constrained_zero.constrained_zero_network.ConstrainedZeroNetwork(belief_dim, action_space_type, n_actions=None, action_dim=None, hidden_sizes=(128, 128), use_dropout=True, p_dropout=0.2)[source]
Bases:
BetaZeroNetworkThree-head neural network for ConstrainedZero.
- Architecture:
Shared trunk:
Linear(belief_dim, h) → ReLU → [Dropout] → ... → Linear(h, h) → ReLU → [Dropout]Policy head: inherited from
BetaZeroNetworkValue head: inherited from
BetaZeroNetworkFailure head:
Linear(h, h//2) -> ReLU -> Linear(h//2, 1)
The failure head outputs a raw logit. During
predict(), sigmoid is applied to produce a failure probability in [0, 1].- Parameters:
belief_dim (
int) – Dimensionality of the belief feature vector phi(b).action_space_type (
str) –"discrete"or"continuous".n_actions (
Optional[int]) – Number of discrete actions (required whenaction_space_type="discrete").action_dim (
Optional[int]) – Dimensionality of continuous actions (required whenaction_space_type="continuous").hidden_sizes (
Sequence[int]) – Tuple of hidden layer widths for the shared trunk.use_dropout (
bool) – If True, apply dropout after each ReLU in the shared trunk (default True).p_dropout (
float) – Dropout probability for trunk layers (default 0.2).
Example
>>> import numpy as np >>> from POMDPPlanners.planners.mcts_planners.constrained_zero.constrained_zero_network import ConstrainedZeroNetwork >>> net = ConstrainedZeroNetwork(belief_dim=4, action_space_type="discrete", n_actions=3, use_dropout=False) >>> policy, value, failure_prob = net.predict(np.zeros(4, dtype=np.float32)) >>> policy.shape (3,) >>> isinstance(value, float) True >>> 0.0 <= failure_prob <= 1.0 True
- forward(belief_features)[source]
Forward pass returning raw policy, value, and failure logit.
- Parameters:
belief_features (
Tensor) – Tensor of shape(batch, belief_dim)or(belief_dim,).- Return type:
Tuple[Tensor,Tensor,Tensor]- Returns:
Tuple of (policy_output, value, failure_logit) tensors.
- predict(belief_features)[source]
Single-sample inference returning numpy policy, value, and failure probability.
Switches to eval mode before inference to disable dropout, then restores the original training mode.
- Parameters:
belief_features (
ndarray) – 1-D array of shape(belief_dim,).- Return type:
- Returns:
Tuple of (policy, value, failure_prob). - Discrete: policy is a probability vector summing to 1. - Continuous: policy is
[mean, log_std]. - value is a Python float. - failure_prob is a Python float in [0, 1].
- predict_batch(belief_features_batch)[source]
Batched inference returning numpy policy, value, and failure probability arrays.
Switches to eval mode before inference to disable dropout, then restores the original training mode.
- Parameters:
belief_features_batch (
ndarray) – 2-D array of shape(N, belief_dim).- Return type:
- Returns:
Tuple of (policies, values, failure_probs). - Discrete: policies is
(N, n_actions)probability matrix. - Continuous: policies is(N, 2*action_dim)with[mean, log_std]. - values is(N,)array of floats. - failure_probs is(N,)array of floats in [0, 1].