POMDPPlanners.tests.test_planners.test_mcts_planners.test_constrained_zero package
Submodules
POMDPPlanners.tests.test_planners.test_mcts_planners.test_constrained_zero.test_constrained_puct module
Tests for the SPUCT action selection module.
This module tests the safety-constrained PUCT selection and progressive widening functions used by ConstrainedZero.
- class POMDPPlanners.tests.test_planners.test_mcts_planners.test_constrained_zero.test_constrained_puct.SimpleActionSampler[source]
Bases:
ActionSampler
- class POMDPPlanners.tests.test_planners.test_mcts_planners.test_constrained_zero.test_constrained_puct.TestComputeSafetyMask[source]
Bases:
objectTests for the _compute_safety_mask helper.
- test_all_safe()[source]
Test mask when all actions are safe.
Purpose: Validates safety mask for all-safe case.
Given: Failure probabilities [0.05, 0.08] and delta_prime=0.1. When: _compute_safety_mask is called. Then: Mask is [1, 1].
Test type: unit
- test_all_unsafe_fallback()[source]
Test mask when all actions are unsafe.
Purpose: Validates fallback to unconstrained when no safe action exists.
Given: Failure probabilities [0.5, 0.8] and delta_prime=0.1. When: _compute_safety_mask is called. Then: Mask is [1, 1] (fallback to unconstrained).
Test type: unit
- class POMDPPlanners.tests.test_planners.test_mcts_planners.test_constrained_zero.test_constrained_puct.TestSpuctActionProgressiveWidening[source]
Bases:
objectTests for the spuct_action_progressive_widening function.
- test_creates_new_action_on_empty_node()[source]
Test new action is created for a leaf belief node.
Purpose: Validates widening creates new action on first visit.
Given: A leaf belief node with no children. When: spuct_action_progressive_widening is called. Then: A new action node is created and returned.
Test type: unit
- test_min_visit_count_at_root()[source]
Test min_visit_count_per_action enforced at root.
Purpose: Validates that at depth 0, unvisited children are returned first.
Given: A root node with two children, one unvisited. When: spuct_action_progressive_widening is called with min_visit_count=1. Then: The unvisited child is returned.
Test type: unit
- test_selects_via_spuct_when_not_widening()[source]
Test SPUCT selection when widening threshold not met.
Purpose: Validates SPUCT is used instead of creating new actions.
Given: A node with many children relative to k_a * N^alpha_a. When: spuct_action_progressive_widening is called. Then: An existing child is selected via SPUCT.
Test type: unit
- class POMDPPlanners.tests.test_planners.test_mcts_planners.test_constrained_zero.test_constrained_puct.TestSpuctSelection[source]
Bases:
objectTests for the spuct_selection function.
- test_all_unsafe_selects_best_q()[source]
Test all-unsafe fallback selects highest PUCT score.
Purpose: Validates unconstrained fallback when all actions unsafe.
Given: Two unsafe actions with Q=[1.0, 3.0] and many visits. When: spuct_selection is called. Then: The higher-Q action is selected (fallback to unconstrained).
Test type: unit
- test_matches_puct_when_all_safe()[source]
Test SPUCT matches PUCT when all actions are safe.
Purpose: Validates SPUCT reduces to standard PUCT when safety mask is all ones.
Given: Two actions with Q=[1.0, 3.0], both safe. When: spuct_selection is called. Then: Selects the same action as standard PUCT (higher Q with many visits).
Test type: unit
- test_safe_action_preferred_over_unsafe()[source]
Test SPUCT prefers safe actions over unsafe ones.
Purpose: Validates that unsafe actions are masked out.
Given: Two actions with equal Q/visits but action 0 is safe, action 1 is unsafe. When: spuct_selection is called. Then: The safe action (a_0) is selected.
Test type: unit
- test_uniform_priors_when_none()[source]
Test SPUCT uses uniform priors when action_priors is None.
Purpose: Validates default prior behavior.
Given: Three actions with different Q-values, no explicit priors. When: spuct_selection is called with action_priors=None. Then: Returns a valid action (no error).
Test type: unit
POMDPPlanners.tests.test_planners.test_mcts_planners.test_constrained_zero.test_constrained_training module
Tests for the ConstrainedZero training utilities.
This module tests the loss function and training loop for the 3-head ConstrainedZero network, including the additional failure BCE loss.
- class POMDPPlanners.tests.test_planners.test_mcts_planners.test_constrained_zero.test_constrained_training.TestComputeConstrainedZeroLoss[source]
Bases:
objectTests for the compute_constrained_zero_loss function.
- test_bce_correct_for_known_inputs()[source]
Test BCE loss matches manual computation for extreme inputs.
Purpose: Validates the failure loss is proper BCE.
Given: A network with targets of all zeros. When: Loss is computed. Then: failure_loss equals BCE(logit, 0) which is log(1 + exp(logit)).
Test type: unit
- test_continuous_loss_has_failure_component()[source]
Test failure loss works with continuous action space.
Purpose: Validates loss function for continuous networks.
Given: A continuous ConstrainedZeroNetwork. When: Loss is computed with failure targets. Then: All three loss components are present and positive.
Test type: unit
- test_loss_has_failure_component()[source]
Test that failure loss contributes to total.
Purpose: Validates BCE failure loss is non-zero and part of total.
Given: A network and batch with known failure targets. When: Loss is computed. Then: failure_loss > 0 and total >= value + policy + failure.
Test type: unit
- test_returns_total_loss_and_components()[source]
Test loss function returns total and component dict.
Purpose: Validates output structure of the loss function.
Given: A discrete ConstrainedZeroNetwork and random batch data. When: compute_constrained_zero_loss is called. Then: Returns (tensor, dict) with value_loss, policy_loss, failure_loss keys.
Test type: unit
- class POMDPPlanners.tests.test_planners.test_mcts_planners.test_constrained_zero.test_constrained_training.TestTrainConstrainedNetwork[source]
Bases:
objectTests for the train_constrained_network function.
- test_all_metrics_present()[source]
Test training returns all four metric keys.
Purpose: Validates all expected metrics are returned.
Given: A network and buffer. When: Training completes. Then: Metrics dict has total_loss, value_loss, policy_loss, failure_loss.
Test type: unit
- test_returns_failure_loss_metric()[source]
Test that training metrics include failure_loss.
Purpose: Validates failure_loss is tracked during training.
Given: A network and buffer with examples. When: Training is run. Then: metrics dict contains “failure_loss” key with per-epoch values.
Test type: unit
- test_track_gradients_adds_norm_keys()[source]
Test track_gradients=True adds gradient and weight norm keys.
Purpose: Validates that enabling track_gradients returns all five gradient norm metrics plus the global weight norm.
Given: A discrete network and buffer with 20 examples. When: train_constrained_network is called with track_gradients=True. Then: Metrics dict contains all _CONSTRAINED_GRAD_NORM_KEYS and
‘weight_norm/global’, each with one value per epoch.
Test type: unit
- test_track_gradients_false_excludes_norm_keys()[source]
Test track_gradients=False returns only the four base loss keys.
Purpose: Validates that the default (track_gradients=False) behaviour is unchanged — no extra keys are added to the metrics dict.
Given: A discrete network and buffer. When: train_constrained_network is called with track_gradients=False. Then: Metrics dict contains exactly the four base loss keys.
Test type: unit
- test_track_gradients_includes_failure_head()[source]
Test that grad_norm/failure_head is present when track_gradients=True.
Purpose: Validates the failure_head gradient norm is tracked separately, which is unique to ConstrainedZero vs BetaZero.
Given: A discrete network and buffer. When: train_constrained_network is called with track_gradients=True. Then: ‘grad_norm/failure_head’ is in the metrics and is non-negative.
Test type: unit
- test_training_reduces_loss()[source]
Test that training reduces total loss over epochs.
Purpose: Validates the training loop is functional.
Given: A discrete network and a buffer with 50 examples. When: Training is run for 10 epochs. Then: The final total_loss is less than the initial total_loss.
Test type: integration
POMDPPlanners.tests.test_planners.test_mcts_planners.test_constrained_zero.test_constrained_training_buffer module
Tests for the ConstrainedTrainingBuffer module.
This module tests the ConstrainedTrainingBuffer and ConstrainedTrainingExample, which extend the BetaZero training buffer with failure targets.
- class POMDPPlanners.tests.test_planners.test_mcts_planners.test_constrained_zero.test_constrained_training_buffer.TestConstrainedTrainingBuffer[source]
Bases:
objectTests for the ConstrainedTrainingBuffer class.
- test_add_and_length()[source]
Test adding examples and checking buffer length.
Purpose: Validates basic add/length functionality.
Given: An empty ConstrainedTrainingBuffer after begin_iteration(). When: Adding 3 examples. Then: Length is 3.
Test type: unit
- test_begin_iteration_discards_old_data_with_n_buffer_1()[source]
Test begin_iteration evicts previous iteration’s data when n_buffer=1.
Purpose: Validates on-policy behaviour: with n_buffer=1 only the current iteration’s data is retained after begin_iteration() is called.
Given: A ConstrainedTrainingBuffer(n_buffer=1) with 5 examples from iter 0. When: begin_iteration() is called and 2 new examples are added. Then: Buffer length equals 2 (only the new iteration’s examples).
Test type: unit
- test_clear()[source]
Test clear empties the buffer.
Purpose: Validates clear() works for the constrained buffer.
Given: A buffer with 5 examples. When: clear() is called. Then: Buffer length is 0.
Test type: unit
- test_failure_targets_propagate()[source]
Test failure targets are correctly stored and retrieved.
Purpose: Validates failure targets survive the buffer roundtrip.
Given: A buffer with examples having known failure targets (all 1.0). When: sample_batch is called. Then: All failure targets in the batch are 1.0.
Test type: unit
- test_inherits_from_training_buffer()[source]
Test ConstrainedTrainingBuffer inherits from TrainingBuffer.
Purpose: Validates the subclass relationship.
Given: A ConstrainedTrainingBuffer instance. When: Checking isinstance. Then: It is an instance of TrainingBuffer.
Test type: unit
- test_mixed_failure_targets()[source]
Test buffer handles mixed failure targets correctly.
Purpose: Validates buffer stores both 0.0 and 1.0 failure targets.
Given: A buffer with alternating failure targets. When: All examples are sampled. Then: The failure targets contain both 0.0 and 1.0.
Test type: unit
- test_sample_batch_returns_four_arrays()[source]
Test sample_batch returns 4-tuple with failure targets.
Purpose: Validates the 4-tuple output from sample_batch.
Given: A buffer with 10 examples (belief_dim=4, policy_dim=3). When: sample_batch(5) is called. Then: Returns (beliefs, policies, values, failures) with correct shapes.
Test type: unit
- class POMDPPlanners.tests.test_planners.test_mcts_planners.test_constrained_zero.test_constrained_training_buffer.TestConstrainedTrainingExample[source]
Bases:
objectTests for the ConstrainedTrainingExample dataclass.
- test_has_failure_target_field()[source]
Test ConstrainedTrainingExample has failure_target.
Purpose: Validates the dataclass includes the failure field.
Given: A ConstrainedTrainingExample with failure_target=0.5. When: Accessing the failure_target attribute. Then: The value is 0.5.
Test type: unit
- test_inherits_standard_fields()[source]
Test ConstrainedTrainingExample has standard BetaZero fields.
Purpose: Validates belief_features, policy_target, and value_target exist.
Given: A ConstrainedTrainingExample. When: Accessing standard fields. Then: All fields are accessible with correct values.
Test type: unit