POMDPPlanners.tests.test_planners.test_mcts_planners.test_constrained_zero package

Submodules

POMDPPlanners.tests.test_planners.test_mcts_planners.test_constrained_zero.test_constrained_puct module

Tests for the SPUCT action selection module.

This module tests the safety-constrained PUCT selection and progressive widening functions used by ConstrainedZero.

class POMDPPlanners.tests.test_planners.test_mcts_planners.test_constrained_zero.test_constrained_puct.SimpleActionSampler[source]

Bases: ActionSampler

sample(belief_node=None)[source]

Sample a new action for progressive widening.

Parameters:

belief_node – Optional belief node context for informed sampling

Returns:

A sampled action compatible with the environment’s action space

class POMDPPlanners.tests.test_planners.test_mcts_planners.test_constrained_zero.test_constrained_puct.TestComputeSafetyMask[source]

Bases: object

Tests for the _compute_safety_mask helper.

test_all_safe()[source]

Test mask when all actions are safe.

Purpose: Validates safety mask for all-safe case.

Given: Failure probabilities [0.05, 0.08] and delta_prime=0.1. When: _compute_safety_mask is called. Then: Mask is [1, 1].

Test type: unit

test_all_unsafe_fallback()[source]

Test mask when all actions are unsafe.

Purpose: Validates fallback to unconstrained when no safe action exists.

Given: Failure probabilities [0.5, 0.8] and delta_prime=0.1. When: _compute_safety_mask is called. Then: Mask is [1, 1] (fallback to unconstrained).

Test type: unit

test_exact_threshold()[source]

Test action at exactly the threshold is considered safe.

Purpose: Validates boundary condition (f <= delta).

Given: Failure probability [0.1] and delta_prime=0.1. When: _compute_safety_mask is called. Then: Mask is [1] (safe).

Test type: unit

test_some_unsafe()[source]

Test mask when some actions are unsafe.

Purpose: Validates safety mask masks only unsafe actions.

Given: Failure probabilities [0.05, 0.2] and delta_prime=0.1. When: _compute_safety_mask is called. Then: Mask is [1, 0].

Test type: unit

class POMDPPlanners.tests.test_planners.test_mcts_planners.test_constrained_zero.test_constrained_puct.TestSpuctActionProgressiveWidening[source]

Bases: object

Tests for the spuct_action_progressive_widening function.

test_creates_new_action_on_empty_node()[source]

Test new action is created for a leaf belief node.

Purpose: Validates widening creates new action on first visit.

Given: A leaf belief node with no children. When: spuct_action_progressive_widening is called. Then: A new action node is created and returned.

Test type: unit

test_min_visit_count_at_root()[source]

Test min_visit_count_per_action enforced at root.

Purpose: Validates that at depth 0, unvisited children are returned first.

Given: A root node with two children, one unvisited. When: spuct_action_progressive_widening is called with min_visit_count=1. Then: The unvisited child is returned.

Test type: unit

test_selects_via_spuct_when_not_widening()[source]

Test SPUCT selection when widening threshold not met.

Purpose: Validates SPUCT is used instead of creating new actions.

Given: A node with many children relative to k_a * N^alpha_a. When: spuct_action_progressive_widening is called. Then: An existing child is selected via SPUCT.

Test type: unit

class POMDPPlanners.tests.test_planners.test_mcts_planners.test_constrained_zero.test_constrained_puct.TestSpuctSelection[source]

Bases: object

Tests for the spuct_selection function.

test_all_unsafe_selects_best_q()[source]

Test all-unsafe fallback selects highest PUCT score.

Purpose: Validates unconstrained fallback when all actions unsafe.

Given: Two unsafe actions with Q=[1.0, 3.0] and many visits. When: spuct_selection is called. Then: The higher-Q action is selected (fallback to unconstrained).

Test type: unit

test_matches_puct_when_all_safe()[source]

Test SPUCT matches PUCT when all actions are safe.

Purpose: Validates SPUCT reduces to standard PUCT when safety mask is all ones.

Given: Two actions with Q=[1.0, 3.0], both safe. When: spuct_selection is called. Then: Selects the same action as standard PUCT (higher Q with many visits).

Test type: unit

test_safe_action_preferred_over_unsafe()[source]

Test SPUCT prefers safe actions over unsafe ones.

Purpose: Validates that unsafe actions are masked out.

Given: Two actions with equal Q/visits but action 0 is safe, action 1 is unsafe. When: spuct_selection is called. Then: The safe action (a_0) is selected.

Test type: unit

test_uniform_priors_when_none()[source]

Test SPUCT uses uniform priors when action_priors is None.

Purpose: Validates default prior behavior.

Given: Three actions with different Q-values, no explicit priors. When: spuct_selection is called with action_priors=None. Then: Returns a valid action (no error).

Test type: unit

POMDPPlanners.tests.test_planners.test_mcts_planners.test_constrained_zero.test_constrained_training module

Tests for the ConstrainedZero training utilities.

This module tests the loss function and training loop for the 3-head ConstrainedZero network, including the additional failure BCE loss.

class POMDPPlanners.tests.test_planners.test_mcts_planners.test_constrained_zero.test_constrained_training.TestComputeConstrainedZeroLoss[source]

Bases: object

Tests for the compute_constrained_zero_loss function.

test_bce_correct_for_known_inputs()[source]

Test BCE loss matches manual computation for extreme inputs.

Purpose: Validates the failure loss is proper BCE.

Given: A network with targets of all zeros. When: Loss is computed. Then: failure_loss equals BCE(logit, 0) which is log(1 + exp(logit)).

Test type: unit

test_continuous_loss_has_failure_component()[source]

Test failure loss works with continuous action space.

Purpose: Validates loss function for continuous networks.

Given: A continuous ConstrainedZeroNetwork. When: Loss is computed with failure targets. Then: All three loss components are present and positive.

Test type: unit

test_loss_has_failure_component()[source]

Test that failure loss contributes to total.

Purpose: Validates BCE failure loss is non-zero and part of total.

Given: A network and batch with known failure targets. When: Loss is computed. Then: failure_loss > 0 and total >= value + policy + failure.

Test type: unit

test_returns_total_loss_and_components()[source]

Test loss function returns total and component dict.

Purpose: Validates output structure of the loss function.

Given: A discrete ConstrainedZeroNetwork and random batch data. When: compute_constrained_zero_loss is called. Then: Returns (tensor, dict) with value_loss, policy_loss, failure_loss keys.

Test type: unit

class POMDPPlanners.tests.test_planners.test_mcts_planners.test_constrained_zero.test_constrained_training.TestTrainConstrainedNetwork[source]

Bases: object

Tests for the train_constrained_network function.

test_all_metrics_present()[source]

Test training returns all four metric keys.

Purpose: Validates all expected metrics are returned.

Given: A network and buffer. When: Training completes. Then: Metrics dict has total_loss, value_loss, policy_loss, failure_loss.

Test type: unit

test_returns_failure_loss_metric()[source]

Test that training metrics include failure_loss.

Purpose: Validates failure_loss is tracked during training.

Given: A network and buffer with examples. When: Training is run. Then: metrics dict contains “failure_loss” key with per-epoch values.

Test type: unit

test_track_gradients_adds_norm_keys()[source]

Test track_gradients=True adds gradient and weight norm keys.

Purpose: Validates that enabling track_gradients returns all five gradient norm metrics plus the global weight norm.

Given: A discrete network and buffer with 20 examples. When: train_constrained_network is called with track_gradients=True. Then: Metrics dict contains all _CONSTRAINED_GRAD_NORM_KEYS and

‘weight_norm/global’, each with one value per epoch.

Test type: unit

test_track_gradients_false_excludes_norm_keys()[source]

Test track_gradients=False returns only the four base loss keys.

Purpose: Validates that the default (track_gradients=False) behaviour is unchanged — no extra keys are added to the metrics dict.

Given: A discrete network and buffer. When: train_constrained_network is called with track_gradients=False. Then: Metrics dict contains exactly the four base loss keys.

Test type: unit

test_track_gradients_includes_failure_head()[source]

Test that grad_norm/failure_head is present when track_gradients=True.

Purpose: Validates the failure_head gradient norm is tracked separately, which is unique to ConstrainedZero vs BetaZero.

Given: A discrete network and buffer. When: train_constrained_network is called with track_gradients=True. Then: ‘grad_norm/failure_head’ is in the metrics and is non-negative.

Test type: unit

test_training_reduces_loss()[source]

Test that training reduces total loss over epochs.

Purpose: Validates the training loop is functional.

Given: A discrete network and a buffer with 50 examples. When: Training is run for 10 epochs. Then: The final total_loss is less than the initial total_loss.

Test type: integration

POMDPPlanners.tests.test_planners.test_mcts_planners.test_constrained_zero.test_constrained_training_buffer module

Tests for the ConstrainedTrainingBuffer module.

This module tests the ConstrainedTrainingBuffer and ConstrainedTrainingExample, which extend the BetaZero training buffer with failure targets.

class POMDPPlanners.tests.test_planners.test_mcts_planners.test_constrained_zero.test_constrained_training_buffer.TestConstrainedTrainingBuffer[source]

Bases: object

Tests for the ConstrainedTrainingBuffer class.

test_add_and_length()[source]

Test adding examples and checking buffer length.

Purpose: Validates basic add/length functionality.

Given: An empty ConstrainedTrainingBuffer after begin_iteration(). When: Adding 3 examples. Then: Length is 3.

Test type: unit

test_begin_iteration_discards_old_data_with_n_buffer_1()[source]

Test begin_iteration evicts previous iteration’s data when n_buffer=1.

Purpose: Validates on-policy behaviour: with n_buffer=1 only the current iteration’s data is retained after begin_iteration() is called.

Given: A ConstrainedTrainingBuffer(n_buffer=1) with 5 examples from iter 0. When: begin_iteration() is called and 2 new examples are added. Then: Buffer length equals 2 (only the new iteration’s examples).

Test type: unit

test_clear()[source]

Test clear empties the buffer.

Purpose: Validates clear() works for the constrained buffer.

Given: A buffer with 5 examples. When: clear() is called. Then: Buffer length is 0.

Test type: unit

test_failure_targets_propagate()[source]

Test failure targets are correctly stored and retrieved.

Purpose: Validates failure targets survive the buffer roundtrip.

Given: A buffer with examples having known failure targets (all 1.0). When: sample_batch is called. Then: All failure targets in the batch are 1.0.

Test type: unit

test_inherits_from_training_buffer()[source]

Test ConstrainedTrainingBuffer inherits from TrainingBuffer.

Purpose: Validates the subclass relationship.

Given: A ConstrainedTrainingBuffer instance. When: Checking isinstance. Then: It is an instance of TrainingBuffer.

Test type: unit

test_mixed_failure_targets()[source]

Test buffer handles mixed failure targets correctly.

Purpose: Validates buffer stores both 0.0 and 1.0 failure targets.

Given: A buffer with alternating failure targets. When: All examples are sampled. Then: The failure targets contain both 0.0 and 1.0.

Test type: unit

test_sample_batch_returns_four_arrays()[source]

Test sample_batch returns 4-tuple with failure targets.

Purpose: Validates the 4-tuple output from sample_batch.

Given: A buffer with 10 examples (belief_dim=4, policy_dim=3). When: sample_batch(5) is called. Then: Returns (beliefs, policies, values, failures) with correct shapes.

Test type: unit

class POMDPPlanners.tests.test_planners.test_mcts_planners.test_constrained_zero.test_constrained_training_buffer.TestConstrainedTrainingExample[source]

Bases: object

Tests for the ConstrainedTrainingExample dataclass.

test_has_failure_target_field()[source]

Test ConstrainedTrainingExample has failure_target.

Purpose: Validates the dataclass includes the failure field.

Given: A ConstrainedTrainingExample with failure_target=0.5. When: Accessing the failure_target attribute. Then: The value is 0.5.

Test type: unit

test_inherits_standard_fields()[source]

Test ConstrainedTrainingExample has standard BetaZero fields.

Purpose: Validates belief_features, policy_target, and value_target exist.

Given: A ConstrainedTrainingExample. When: Accessing standard fields. Then: All fields are accessible with correct values.

Test type: unit

POMDPPlanners.tests.test_planners.test_mcts_planners.test_constrained_zero.test_constrained_zero module

POMDPPlanners.tests.test_planners.test_mcts_planners.test_constrained_zero.test_constrained_zero_network module