POMDPPlanners.planners.mcts_planners.constrained_zero package

ConstrainedZero: Neural MCTS for Chance-Constrained POMDPs.

This package implements the ConstrainedZero algorithm (Moss et al., IJCAI 2024), which extends BetaZero to solve CC-POMDPs by adding a failure probability head, safety-constrained PUCT, and adaptive failure threshold calibration.

Reference:

Moss, R. J., Jamgochian, A., Fischer, J., Corso, A., & Kochenderfer, M. J. (2024). ConstrainedZero: Chance-Constrained POMDP Planning Using Learned Probabilistic Failure Surrogates and Adaptive Safety Constraints. Proceedings of the Thirty-Third International Joint Conference on Artificial Intelligence (IJCAI), 6752-6760. https://www.ijcai.org/proceedings/2024/746

Classes:

ConstrainedZero: Main planner extending BetaZero for CC-POMDPs ConstrainedZeroNetwork: Three-head network with policy, value, and failure heads ConstrainedTrainingBuffer: Replay buffer with failure targets ConstrainedTrainingExample: Training datum with failure target

class POMDPPlanners.planners.mcts_planners.constrained_zero.ConstrainedTrainingBuffer(n_buffer=1)[source]

Bases: TrainingBuffer

Iteration-slot replay buffer for ConstrainedZero training.

Extends TrainingBuffer to store and sample ConstrainedTrainingExample instances, returning a 4-tuple from sample_batch() that includes failure targets.

Parameters:

n_buffer (int) – Number of policy-iteration slots to retain. With the default n_buffer=1 only the current iteration’s data is used for training (on-policy).

Example

>>> import numpy as np
>>> from POMDPPlanners.planners.mcts_planners.constrained_zero.constrained_training_buffer import (
...     ConstrainedTrainingBuffer, ConstrainedTrainingExample,
... )
>>> buf = ConstrainedTrainingBuffer(n_buffer=1)
>>> buf.begin_iteration()
>>> buf.add(ConstrainedTrainingExample(np.zeros(4), np.array([0.5, 0.5]), 1.0, 0.0))
>>> len(buf)
1
>>> batch = buf.sample_batch(1)
>>> len(batch)
4
>>> batch[0].shape
(1, 4)
add(example)[source]

Append a constrained example to the current iteration slot.

Return type:

None

Parameters:

example (ConstrainedTrainingExample)

sample_batch(batch_size)[source]

Sample a random mini-batch including failure targets.

Parameters:

batch_size (int) – Number of examples to sample (with replacement if batch_size > len(buffer)).

Return type:

Tuple[ndarray, ndarray, ndarray, ndarray]

Returns:

Tuple of (belief_features, policy_targets, value_targets, failure_targets). - belief_features: shape (batch_size, belief_dim) - policy_targets: shape (batch_size, policy_dim) - value_targets: shape (batch_size,) - failure_targets: shape (batch_size,)

class POMDPPlanners.planners.mcts_planners.constrained_zero.ConstrainedTrainingExample(belief_features, policy_target, value_target, failure_target)[source]

Bases: object

Single training datum for ConstrainedZero network training.

Parameters:
belief_features

Belief feature vector phi(b), shape (belief_dim,).

policy_target

Q-weighted policy target pi_qw, shape (n_actions,) (discrete).

value_target

Discounted return g_t.

failure_target

Binary episode-level failure indicator (1.0 if failure occurred).

belief_features: ndarray
failure_target: float
policy_target: ndarray
value_target: float
class POMDPPlanners.planners.mcts_planners.constrained_zero.ConstrainedZero(environment, discount_factor, depth, name, action_sampler, failure_fn, delta_0=0.01, eta=1e-05, delta_compounding=1.0, k_a=1.0, alpha_a=0.5, k_o=1.0, alpha_o=0.5, exploration_constant=1.0, time_out_in_seconds=None, n_simulations=None, min_visit_count_per_action=1, network=None, belief_representation=None, state_dim=None, z_q=1.0, z_n=1.0, temperature=1.0, n_buffer=1, training_batch_size=256, training_epochs=10, learning_rate=0.001, weight_decay=0.0001, hidden_sizes=(128, 128), use_dropout=True, p_dropout=0.2, track_gradients=False, normalize_inputs=True, normalize_values=True, log_path=None, debug=False, use_queue_logger=False)[source]

Bases: BetaZero

ConstrainedZero: Neural MCTS for Chance-Constrained POMDPs.

Extends BetaZero with:

  1. 3-head network: Adds a failure probability head alongside policy and value.

  2. SPUCT selection: Safety-constrained PUCT that masks unsafe actions.

  3. Adaptive Delta (conformal inference): Calibrates the failure threshold during tree search using online conformal inference.

  4. Failure propagation: Tracks failure probability per action node using p = p_immediate + delta_compounding * (1 - p_immediate) * p_next.

  5. Constrained policy targets: Applies safety mask during target computation.

Parameters:
failure_fn

User-provided function state -> bool defining failure.

delta_0

Nominal failure probability threshold.

eta

Learning rate for adaptive Delta calibration.

delta_compounding

Discount factor for failure propagation.

Example

>>> import numpy as np
>>> np.random.seed(42)
>>> from POMDPPlanners.environments.tiger_pomdp import TigerPOMDP
>>> from POMDPPlanners.core.belief import get_initial_belief
>>> from POMDPPlanners.utils.action_samplers import DiscreteActionSampler
>>> from POMDPPlanners.planners.mcts_planners.constrained_zero.constrained_zero import ConstrainedZero
>>>
>>> env = TigerPOMDP(discount_factor=0.95)
>>> sampler = DiscreteActionSampler(env.get_actions())
>>> planner = ConstrainedZero(
...     environment=env,
...     discount_factor=0.95,
...     depth=3,
...     name="CZ_Tiger",
...     action_sampler=sampler,
...     n_simulations=20,
...     state_dim=1,
...     failure_fn=lambda s: False,
... )
>>> belief = get_initial_belief(env, n_particles=10)
>>> actions, run_data = planner.action(belief)
>>> actions[0] in env.get_actions()
True
get_metric_keys()[source]

Return the loss-metric key names produced by train_step().

Return type:

List[str]

network: ConstrainedZeroNetwork
class POMDPPlanners.planners.mcts_planners.constrained_zero.ConstrainedZeroNetwork(belief_dim, action_space_type, n_actions=None, action_dim=None, hidden_sizes=(128, 128), use_dropout=True, p_dropout=0.2)[source]

Bases: BetaZeroNetwork

Three-head neural network for ConstrainedZero.

Architecture:
  • Shared trunk: Linear(belief_dim, h) ReLU [Dropout] ... Linear(h, h) ReLU [Dropout]

  • Policy head: inherited from BetaZeroNetwork

  • Value head: inherited from BetaZeroNetwork

  • Failure head: Linear(h, h//2) -> ReLU -> Linear(h//2, 1)

The failure head outputs a raw logit. During predict(), sigmoid is applied to produce a failure probability in [0, 1].

Parameters:
  • belief_dim (int) – Dimensionality of the belief feature vector phi(b).

  • action_space_type (str) – "discrete" or "continuous".

  • n_actions (Optional[int]) – Number of discrete actions (required when action_space_type="discrete").

  • action_dim (Optional[int]) – Dimensionality of continuous actions (required when action_space_type="continuous").

  • hidden_sizes (Sequence[int]) – Tuple of hidden layer widths for the shared trunk.

  • use_dropout (bool) – If True, apply dropout after each ReLU in the shared trunk (default True).

  • p_dropout (float) – Dropout probability for trunk layers (default 0.2).

Example

>>> import numpy as np
>>> from POMDPPlanners.planners.mcts_planners.constrained_zero.constrained_zero_network import ConstrainedZeroNetwork
>>> net = ConstrainedZeroNetwork(belief_dim=4, action_space_type="discrete", n_actions=3, use_dropout=False)
>>> policy, value, failure_prob = net.predict(np.zeros(4, dtype=np.float32))
>>> policy.shape
(3,)
>>> isinstance(value, float)
True
>>> 0.0 <= failure_prob <= 1.0
True
forward(belief_features)[source]

Forward pass returning raw policy, value, and failure logit.

Parameters:

belief_features (Tensor) – Tensor of shape (batch, belief_dim) or (belief_dim,).

Return type:

Tuple[Tensor, Tensor, Tensor]

Returns:

Tuple of (policy_output, value, failure_logit) tensors.

predict(belief_features)[source]

Single-sample inference returning numpy policy, value, and failure probability.

Switches to eval mode before inference to disable dropout, then restores the original training mode.

Parameters:

belief_features (ndarray) – 1-D array of shape (belief_dim,).

Return type:

Tuple[ndarray, float, float]

Returns:

Tuple of (policy, value, failure_prob). - Discrete: policy is a probability vector summing to 1. - Continuous: policy is [mean, log_std]. - value is a Python float. - failure_prob is a Python float in [0, 1].

predict_batch(belief_features_batch)[source]

Batched inference returning numpy policy, value, and failure probability arrays.

Switches to eval mode before inference to disable dropout, then restores the original training mode.

Parameters:

belief_features_batch (ndarray) – 2-D array of shape (N, belief_dim).

Return type:

Tuple[ndarray, ndarray, ndarray]

Returns:

Tuple of (policies, values, failure_probs). - Discrete: policies is (N, n_actions) probability matrix. - Continuous: policies is (N, 2*action_dim) with [mean, log_std]. - values is (N,) array of floats. - failure_probs is (N,) array of floats in [0, 1].

Submodules

POMDPPlanners.planners.mcts_planners.constrained_zero.constrained_puct module

Safety-constrained PUCT (SPUCT) action selection for ConstrainedZero.

Implements the SPUCT selection rule that masks unsafe actions based on their estimated failure probability relative to an adaptive threshold Delta’.

Reference:

Moss, R. J., Jamgochian, A., Fischer, J., Corso, A., & Kochenderfer, M. J. (2024). ConstrainedZero: Chance-Constrained POMDP Planning Using Learned Probabilistic Failure Surrogates and Adaptive Safety Constraints. Proceedings of the Thirty-Third International Joint Conference on Artificial Intelligence (IJCAI), 6752-6760. https://www.ijcai.org/proceedings/2024/746

Functions:

spuct_selection: Select among existing children using safety-masked PUCT. spuct_action_progressive_widening: Progressive widening with SPUCT selection.

POMDPPlanners.planners.mcts_planners.constrained_zero.constrained_puct.spuct_action_progressive_widening(belief_node, alpha_a, action_sampler, exploration_constant, k_a, failure_dict, delta_prime, action_priors=None, min_visit_count_per_action=1)[source]

Progressive widening with SPUCT selection instead of PUCT.

Parameters:
  • belief_node (BeliefNode) – Current belief node.

  • alpha_a (float) – Progressive widening exponent (0 < alpha_a <= 1).

  • action_sampler (ActionSampler) – Sampler for generating new candidate actions.

  • exploration_constant (float) – PUCT exploration constant c.

  • k_a (float) – Progressive widening coefficient.

  • failure_dict (Dict[int, float]) – Maps id(action_node) to estimated failure probability.

  • delta_prime (float) – Adaptive failure threshold.

  • action_priors (Optional[ndarray]) – Prior probabilities for existing children.

  • min_visit_count_per_action (int) – At the root, ensure every child has been visited at least this many times before selecting via SPUCT.

Return type:

ActionNode

Returns:

Selected or newly created action node.

POMDPPlanners.planners.mcts_planners.constrained_zero.constrained_puct.spuct_selection(belief_node, exploration_constant, failure_dict, delta_prime, action_priors=None)[source]

Select an action child using the safety-constrained PUCT criterion.

The selection rule is:

a* = argmax subject_to(a) * [Q_norm(b,a) + c * P(a|b) * sqrt(N(b)) / (1 + N(b,a))]

where subject_to(a) = I(f(a) <= Delta') masks unsafe actions. If ALL actions are unsafe, falls back to unconstrained selection.

Parameters:
  • belief_node (BeliefNode) – Current belief node with at least one action child.

  • exploration_constant (float) – Exploration constant c.

  • failure_dict (Dict[int, float]) – Maps id(action_node) to estimated failure probability.

  • delta_prime (float) – Adaptive failure threshold.

  • action_priors (Optional[ndarray]) – Prior probabilities P(a|b) aligned with belief_node.children. If None, uniform priors are used.

Return type:

ActionNode

Returns:

The action node with the highest safety-masked PUCT score.

POMDPPlanners.planners.mcts_planners.constrained_zero.constrained_training module

Training utilities for ConstrainedZero network.

This module provides the loss function and training loop for the 3-head ConstrainedZero network. It extends the BetaZero training with an additional binary cross-entropy loss for the failure head.

Reference:

Moss, R. J., Jamgochian, A., Fischer, J., Corso, A., & Kochenderfer, M. J. (2024). ConstrainedZero: Chance-Constrained POMDP Planning Using Learned Probabilistic Failure Surrogates and Adaptive Safety Constraints. Proceedings of the Thirty-Third International Joint Conference on Artificial Intelligence (IJCAI), 6752-6760. https://www.ijcai.org/proceedings/2024/746

Functions:

compute_constrained_zero_loss: Combined value + policy + failure loss. train_constrained_network: Multi-epoch training on a replay buffer.

POMDPPlanners.planners.mcts_planners.constrained_zero.constrained_training.compute_constrained_zero_loss(network, belief_features, policy_targets, value_targets, failure_targets)[source]

Compute the ConstrainedZero combined loss.

L = MSE(v, g_t) + CrossEntropy(p, pi_t) + BCE(failure_logit, f_t)

Parameters:
  • network (ConstrainedZeroNetwork) – The ConstrainedZero 3-head network.

  • belief_features (Tensor) – Batch of belief feature vectors, shape (B, belief_dim).

  • policy_targets (Tensor) – Batch of policy targets, shape (B, policy_dim).

  • value_targets (Tensor) – Batch of scalar value targets, shape (B,).

  • failure_targets (Tensor) – Batch of binary failure targets, shape (B,).

Return type:

Tuple[Tensor, Dict[str, float]]

Returns:

Tuple of (total_loss, component_dict) where component_dict contains "value_loss", "policy_loss", and "failure_loss" as Python floats.

POMDPPlanners.planners.mcts_planners.constrained_zero.constrained_training.train_constrained_network(network, buffer, n_epochs=10, batch_size=256, learning_rate=0.001, weight_decay=0.0001, track_gradients=False, input_mean=None, input_std=None, value_mean=None, value_std=None)[source]

Train the 3-head network for multiple epochs on buffered data.

Parameters:
  • network (ConstrainedZeroNetwork) – Network to train (modified in-place).

  • buffer (ConstrainedTrainingBuffer) – Replay buffer with constrained training examples.

  • n_epochs (int) – Number of full passes over the buffer.

  • batch_size (int) – Mini-batch size.

  • learning_rate (float) – Adam learning rate.

  • weight_decay (float) – L2 regularisation coefficient.

  • track_gradients (bool) – When True, gradient and weight norms are computed per-batch/epoch and included in the returned metrics.

  • input_mean (Optional[ndarray]) – Per-feature mean for input normalisation (None = disabled).

  • input_std (Optional[ndarray]) – Per-feature std for input normalisation (None = disabled).

  • value_mean (Optional[float]) – Scalar mean for value normalisation (None = disabled).

  • value_std (Optional[float]) – Scalar std for value normalisation (None = disabled).

Returns:

"total_loss", "value_loss", "policy_loss", "failure_loss". When track_gradients is True, also includes "grad_norm/global", "grad_norm/trunk", "grad_norm/policy_head", "grad_norm/value_head", "grad_norm/failure_head", and "weight_norm/global".

Return type:

Dict[str, List[float]]

POMDPPlanners.planners.mcts_planners.constrained_zero.constrained_training_buffer module

Iteration-slot replay buffer for ConstrainedZero training examples.

This module extends the BetaZero training buffer with an additional failure target, used for training the 3-head ConstrainedZero network.

Reference:

Moss, R. J., Jamgochian, A., Fischer, J., Corso, A., & Kochenderfer, M. J. (2024). ConstrainedZero: Chance-Constrained POMDP Planning Using Learned Probabilistic Failure Surrogates and Adaptive Safety Constraints. Proceedings of the Thirty-Third International Joint Conference on Artificial Intelligence (IJCAI), 6752-6760. https://www.ijcai.org/proceedings/2024/746

Classes:

ConstrainedTrainingExample: Training datum with failure target. ConstrainedTrainingBuffer: Buffer returning 4-tuple batches.

class POMDPPlanners.planners.mcts_planners.constrained_zero.constrained_training_buffer.ConstrainedTrainingBuffer(n_buffer=1)[source]

Bases: TrainingBuffer

Iteration-slot replay buffer for ConstrainedZero training.

Extends TrainingBuffer to store and sample ConstrainedTrainingExample instances, returning a 4-tuple from sample_batch() that includes failure targets.

Parameters:

n_buffer (int) – Number of policy-iteration slots to retain. With the default n_buffer=1 only the current iteration’s data is used for training (on-policy).

Example

>>> import numpy as np
>>> from POMDPPlanners.planners.mcts_planners.constrained_zero.constrained_training_buffer import (
...     ConstrainedTrainingBuffer, ConstrainedTrainingExample,
... )
>>> buf = ConstrainedTrainingBuffer(n_buffer=1)
>>> buf.begin_iteration()
>>> buf.add(ConstrainedTrainingExample(np.zeros(4), np.array([0.5, 0.5]), 1.0, 0.0))
>>> len(buf)
1
>>> batch = buf.sample_batch(1)
>>> len(batch)
4
>>> batch[0].shape
(1, 4)
add(example)[source]

Append a constrained example to the current iteration slot.

Return type:

None

Parameters:

example (ConstrainedTrainingExample)

sample_batch(batch_size)[source]

Sample a random mini-batch including failure targets.

Parameters:

batch_size (int) – Number of examples to sample (with replacement if batch_size > len(buffer)).

Return type:

Tuple[ndarray, ndarray, ndarray, ndarray]

Returns:

Tuple of (belief_features, policy_targets, value_targets, failure_targets). - belief_features: shape (batch_size, belief_dim) - policy_targets: shape (batch_size, policy_dim) - value_targets: shape (batch_size,) - failure_targets: shape (batch_size,)

class POMDPPlanners.planners.mcts_planners.constrained_zero.constrained_training_buffer.ConstrainedTrainingExample(belief_features, policy_target, value_target, failure_target)[source]

Bases: object

Single training datum for ConstrainedZero network training.

Parameters:
belief_features

Belief feature vector phi(b), shape (belief_dim,).

policy_target

Q-weighted policy target pi_qw, shape (n_actions,) (discrete).

value_target

Discounted return g_t.

failure_target

Binary episode-level failure indicator (1.0 if failure occurred).

belief_features: ndarray
failure_target: float
policy_target: ndarray
value_target: float

POMDPPlanners.planners.mcts_planners.constrained_zero.constrained_zero module

ConstrainedZero planner: neural MCTS for Chance-Constrained POMDPs.

This module implements the ConstrainedZero algorithm, which extends BetaZero to solve CC-POMDPs. It adds a 3-head network with a failure probability head, safety-constrained PUCT (SPUCT), adaptive failure threshold calibration via conformal inference, and constrained policy targets for training.

Reference:

Moss, R. J., Jamgochian, A., Fischer, J., Corso, A., & Kochenderfer, M. J. (2024). ConstrainedZero: Chance-Constrained POMDP Planning Using Learned Probabilistic Failure Surrogates and Adaptive Safety Constraints. Proceedings of the Thirty-Third International Joint Conference on Artificial Intelligence (IJCAI), 6752-6760. https://www.ijcai.org/proceedings/2024/746

Classes:

ConstrainedZero: Main planner extending BetaZero.

class POMDPPlanners.planners.mcts_planners.constrained_zero.constrained_zero.ConstrainedZero(environment, discount_factor, depth, name, action_sampler, failure_fn, delta_0=0.01, eta=1e-05, delta_compounding=1.0, k_a=1.0, alpha_a=0.5, k_o=1.0, alpha_o=0.5, exploration_constant=1.0, time_out_in_seconds=None, n_simulations=None, min_visit_count_per_action=1, network=None, belief_representation=None, state_dim=None, z_q=1.0, z_n=1.0, temperature=1.0, n_buffer=1, training_batch_size=256, training_epochs=10, learning_rate=0.001, weight_decay=0.0001, hidden_sizes=(128, 128), use_dropout=True, p_dropout=0.2, track_gradients=False, normalize_inputs=True, normalize_values=True, log_path=None, debug=False, use_queue_logger=False)[source]

Bases: BetaZero

ConstrainedZero: Neural MCTS for Chance-Constrained POMDPs.

Extends BetaZero with:

  1. 3-head network: Adds a failure probability head alongside policy and value.

  2. SPUCT selection: Safety-constrained PUCT that masks unsafe actions.

  3. Adaptive Delta (conformal inference): Calibrates the failure threshold during tree search using online conformal inference.

  4. Failure propagation: Tracks failure probability per action node using p = p_immediate + delta_compounding * (1 - p_immediate) * p_next.

  5. Constrained policy targets: Applies safety mask during target computation.

Parameters:
failure_fn

User-provided function state -> bool defining failure.

delta_0

Nominal failure probability threshold.

eta

Learning rate for adaptive Delta calibration.

delta_compounding

Discount factor for failure propagation.

Example

>>> import numpy as np
>>> np.random.seed(42)
>>> from POMDPPlanners.environments.tiger_pomdp import TigerPOMDP
>>> from POMDPPlanners.core.belief import get_initial_belief
>>> from POMDPPlanners.utils.action_samplers import DiscreteActionSampler
>>> from POMDPPlanners.planners.mcts_planners.constrained_zero.constrained_zero import ConstrainedZero
>>>
>>> env = TigerPOMDP(discount_factor=0.95)
>>> sampler = DiscreteActionSampler(env.get_actions())
>>> planner = ConstrainedZero(
...     environment=env,
...     discount_factor=0.95,
...     depth=3,
...     name="CZ_Tiger",
...     action_sampler=sampler,
...     n_simulations=20,
...     state_dim=1,
...     failure_fn=lambda s: False,
... )
>>> belief = get_initial_belief(env, n_particles=10)
>>> actions, run_data = planner.action(belief)
>>> actions[0] in env.get_actions()
True
get_metric_keys()[source]

Return the loss-metric key names produced by train_step().

Return type:

List[str]

network: ConstrainedZeroNetwork

POMDPPlanners.planners.mcts_planners.constrained_zero.constrained_zero_network module

Three-head neural network for ConstrainedZero.

This module extends the BetaZero dual-head network with an additional failure probability head. The failure head outputs a raw logit; sigmoid is applied during prediction to produce a probability in [0, 1].

Reference:

Moss, R. J., Jamgochian, A., Fischer, J., Corso, A., & Kochenderfer, M. J. (2024). ConstrainedZero: Chance-Constrained POMDP Planning Using Learned Probabilistic Failure Surrogates and Adaptive Safety Constraints. Proceedings of the Thirty-Third International Joint Conference on Artificial Intelligence (IJCAI), 6752-6760. https://www.ijcai.org/proceedings/2024/746

Classes:

ConstrainedZeroNetwork: Shared-trunk network with policy, value, and failure heads.

class POMDPPlanners.planners.mcts_planners.constrained_zero.constrained_zero_network.ConstrainedZeroNetwork(belief_dim, action_space_type, n_actions=None, action_dim=None, hidden_sizes=(128, 128), use_dropout=True, p_dropout=0.2)[source]

Bases: BetaZeroNetwork

Three-head neural network for ConstrainedZero.

Architecture:
  • Shared trunk: Linear(belief_dim, h) ReLU [Dropout] ... Linear(h, h) ReLU [Dropout]

  • Policy head: inherited from BetaZeroNetwork

  • Value head: inherited from BetaZeroNetwork

  • Failure head: Linear(h, h//2) -> ReLU -> Linear(h//2, 1)

The failure head outputs a raw logit. During predict(), sigmoid is applied to produce a failure probability in [0, 1].

Parameters:
  • belief_dim (int) – Dimensionality of the belief feature vector phi(b).

  • action_space_type (str) – "discrete" or "continuous".

  • n_actions (Optional[int]) – Number of discrete actions (required when action_space_type="discrete").

  • action_dim (Optional[int]) – Dimensionality of continuous actions (required when action_space_type="continuous").

  • hidden_sizes (Sequence[int]) – Tuple of hidden layer widths for the shared trunk.

  • use_dropout (bool) – If True, apply dropout after each ReLU in the shared trunk (default True).

  • p_dropout (float) – Dropout probability for trunk layers (default 0.2).

Example

>>> import numpy as np
>>> from POMDPPlanners.planners.mcts_planners.constrained_zero.constrained_zero_network import ConstrainedZeroNetwork
>>> net = ConstrainedZeroNetwork(belief_dim=4, action_space_type="discrete", n_actions=3, use_dropout=False)
>>> policy, value, failure_prob = net.predict(np.zeros(4, dtype=np.float32))
>>> policy.shape
(3,)
>>> isinstance(value, float)
True
>>> 0.0 <= failure_prob <= 1.0
True
forward(belief_features)[source]

Forward pass returning raw policy, value, and failure logit.

Parameters:

belief_features (Tensor) – Tensor of shape (batch, belief_dim) or (belief_dim,).

Return type:

Tuple[Tensor, Tensor, Tensor]

Returns:

Tuple of (policy_output, value, failure_logit) tensors.

predict(belief_features)[source]

Single-sample inference returning numpy policy, value, and failure probability.

Switches to eval mode before inference to disable dropout, then restores the original training mode.

Parameters:

belief_features (ndarray) – 1-D array of shape (belief_dim,).

Return type:

Tuple[ndarray, float, float]

Returns:

Tuple of (policy, value, failure_prob). - Discrete: policy is a probability vector summing to 1. - Continuous: policy is [mean, log_std]. - value is a Python float. - failure_prob is a Python float in [0, 1].

predict_batch(belief_features_batch)[source]

Batched inference returning numpy policy, value, and failure probability arrays.

Switches to eval mode before inference to disable dropout, then restores the original training mode.

Parameters:

belief_features_batch (ndarray) – 2-D array of shape (N, belief_dim).

Return type:

Tuple[ndarray, ndarray, ndarray]

Returns:

Tuple of (policies, values, failure_probs). - Discrete: policies is (N, n_actions) probability matrix. - Continuous: policies is (N, 2*action_dim) with [mean, log_std]. - values is (N,) array of floats. - failure_probs is (N,) array of floats in [0, 1].