Source code for POMDPPlanners.core.policy

"""Module for POMDP policy abstractions and execution tracking.

This module provides the foundational interface for POMDP policies, including
abstract base classes for policy implementations and data structures for
tracking policy execution and performance metrics.

Classes:
    Policy: Abstract base class for all POMDP policies
    PolicySpaceInfo: Space type information for policy compatibility
    PolicyInfoVariable: Named tuple for policy execution metrics
    PolicyRunData: Container for policy execution information
"""

import importlib
import inspect
import json
import logging
import warnings
from abc import ABC, abstractmethod
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Union

import numpy as np
import pkg_resources

from POMDPPlanners.utils.config_to_id import config_to_id, NumpyEncoder
from POMDPPlanners.utils.logger import get_logger

from POMDPPlanners.core.environment import Environment, SpaceType

if TYPE_CHECKING:
    from POMDPPlanners.core.belief import Belief

    # from POMDPPlanners.core.environment import Environment, SpaceType


[docs] @dataclass class PolicySpaceInfo: """Data class containing space type requirements for policy compatibility. This class specifies the action and observation space types that a policy is designed to work with, enabling compatibility checking with environments. Attributes: action_space: Required action space type (discrete, continuous, or mixed) observation_space: Required observation space type (discrete, continuous, or mixed) """ action_space: "SpaceType" observation_space: "SpaceType"
[docs] class PolicyInfoVariable(NamedTuple): """Named tuple for storing policy execution metrics. This structure stores key-value pairs of policy performance metrics that are collected during policy execution. Attributes: name: Descriptive name of the metric (e.g., "nodes_expanded", "planning_time") value: Numeric value of the metric """ name: str value: Union[float, int]
[docs] class PolicyRunData(NamedTuple): """Container for policy execution information and metrics. This class aggregates all the information collected during a policy's action selection process, including performance metrics and execution details. Attributes: info_variables: List of policy-specific metrics and performance data """ info_variables: List[PolicyInfoVariable]
# Module-level helper functions for Policy save/load def _serialize_value(value: Any) -> Any: # pylint: disable=too-many-return-statements """Serialize value for JSON compatibility. Args: value: Value to serialize Returns: JSON-serializable representation of value """ if value is None: return None if isinstance(value, Path): return str(value) if isinstance(value, np.ndarray): return value.tolist() if isinstance(value, (np.integer, np.floating)): return value.item() if isinstance(value, Enum): return value.value if isinstance(value, (str, int, float, bool)): return value if isinstance(value, (list, tuple)): return [_serialize_value(v) for v in value] if isinstance(value, dict): return {str(k): _serialize_value(v) for k, v in value.items()} if isinstance(value, logging.Logger): return None # Skip loggers # For unknown types, try to convert to string return str(value) def _deserialize_value(value: Any, target_type: type) -> Any: """Deserialize value to target type. Args: value: Serialized value target_type: Target type for deserialization Returns: Value converted to target type """ if value is None: return None # Handle Path objects if target_type == Path: return Path(value) if value is not None else None # Handle Optional types if hasattr(target_type, "__origin__") and target_type.__origin__ is Union: # Get the non-None type from Optional[T] args = [ arg for arg in target_type.__args__ if arg is not type(None) ] # pylint: disable=unidiomatic-typecheck if args: return _deserialize_value(value, args[0]) return value def _extract_constructor_params(policy: "Policy") -> Dict[str, Any]: """Extract constructor parameters from policy instance. Uses inspect.signature() to discover constructor parameters and walks through the class hierarchy to capture all parameters. Args: policy: Policy instance to extract parameters from Returns: Dictionary of parameter names to values """ params = {} # Walk through class hierarchy (Policy → PathSimulationPolicy → Concrete) for cls in inspect.getmro(policy.__class__): if cls == object or not issubclass(cls, Policy): break sig = inspect.signature(cls.__init__) for param_name, _ in sig.parameters.items(): if param_name == "self": continue if param_name == "environment": # Skip environment - handle separately continue # Get current value from instance if hasattr(policy, param_name): value = getattr(policy, param_name) # Skip action_sampler - handle separately if param_name == "action_sampler": continue params[param_name] = _serialize_value(value) return params def _serialize_action_sampler(action_sampler: Any) -> Dict[str, Any]: """Serialize ActionSampler object. Uses ActionSampler's __getstate__ method for serialization. Args: action_sampler: ActionSampler instance to serialize Returns: Dictionary with class info and state """ return { "class": f"{action_sampler.__class__.__module__}.{action_sampler.__class__.__name__}", "module": action_sampler.__class__.__module__, "state": action_sampler.__getstate__(), } def _deserialize_action_sampler(sampler_data: Dict[str, Any]) -> Any: """Reconstruct ActionSampler from serialized data. Args: sampler_data: Dictionary containing ActionSampler class and state Returns: Reconstructed ActionSampler instance """ # Import ActionSampler class module_name = sampler_data["module"] class_name = sampler_data["class"].split(".")[-1] module = importlib.import_module(module_name) sampler_class = getattr(module, class_name) # Use __reduce__ pattern: create instance, then restore state sampler = sampler_class.__new__(sampler_class) sampler.__setstate__(sampler_data["state"]) return sampler def _get_default_filepath(policy: "Policy", base_dir: Path = Path("saved_policies")) -> Path: """Generate default filepath for saving policy. Args: policy: Policy instance base_dir: Base directory for saved policies Returns: Path with structure: {base_dir}/{env_name}/{policy_class}/{policy_name}_{timestamp}.json """ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") env_name = policy.environment.name policy_class_name = policy.__class__.__name__ policy_name = policy.name filepath = base_dir / env_name / policy_class_name / f"{policy_name}_{timestamp}.json" return filepath def _get_package_version() -> str: """Get POMDPPlanners package version. Returns: Package version string or "unknown" """ try: return pkg_resources.get_distribution("POMDPPlanners").version except Exception: # pylint: disable=broad-exception-caught # Catch all exceptions to ensure function always returns a version string return "unknown"
[docs] class Policy(ABC): """Abstract base class for POMDP policies. This class defines the interface for POMDP policies that select actions based on belief states. All concrete policy implementations must inherit from this class and implement the action selection and space information methods. Note: This is an abstract base class and cannot be instantiated directly. Subclasses must implement the action() and get_space_info() methods. Attributes: environment: The POMDP environment this policy operates in discount_factor: Discount factor for future rewards name: Unique identifier for the policy log_path: Optional directory for logging output debug: Flag to enable debug logging """ def __init__( self, environment: "Environment", discount_factor: float, name: str, log_path: Optional[Path] = None, debug: bool = False, use_queue_logger: bool = False, ): """Initialize the POMDP policy. Args: environment: Environment that this policy will operate in discount_factor: Discount factor for future rewards (0 < discount_factor <= 1) name: Unique identifier for this policy instance log_path: Optional directory for logging output. Defaults to None. debug: Enable debug logging. Defaults to False. """ self.environment = environment self.discount_factor = discount_factor self.name = name self.log_path = log_path self.debug = debug self.use_queue_logger = use_queue_logger self._verify_environment_compatibility() # Initialize logger with the policy's name and user-specified settings self.logger.info("Initialized policy: %s (debug=%s)", self.name, self.debug) def _verify_environment_compatibility(self) -> None: """Verify that the policy is compatible with the environment.""" policy_space_info = self.get_space_info() environment_space_info = self.environment.space_info if ( policy_space_info.action_space == SpaceType.DISCRETE and environment_space_info.action_space in [SpaceType.CONTINUOUS, SpaceType.MIXED] ): raise ValueError( f"Policy {self.name} is not compatible with the environment {self.environment.name} because the policy assumes discrete action space and the environment assumes continuous action space" ) if ( policy_space_info.observation_space == SpaceType.DISCRETE and environment_space_info.observation_space in [SpaceType.CONTINUOUS, SpaceType.MIXED] ): raise ValueError( f"Policy {self.name} is not compatible with the environment {self.environment.name} because the policy assumes discrete observation space and the environment assumes continuous observation space" ) @property def logger(self) -> logging.Logger: """Get logger instance for this policy. The logger is implemented as a property to maintain pickle compatibility, as logger objects cannot be pickled directly. Returns: Configured logger instance with hierarchical naming """ return get_logger( name=f"policy.{self.name}", level=logging.INFO, output_dir=self.log_path, debug=self.debug, use_queue=self.use_queue_logger, ) @property def config_id(self) -> str: """Generate a deterministic identifier based on policy configuration.""" def serialize_value(value): # pylint: disable=too-many-return-statements """Helper function to serialize values in a deterministic way.""" if isinstance(value, np.ndarray): return value.tolist() if isinstance(value, (str, int, float, bool)): return value if isinstance(value, (list, tuple)): return [serialize_value(v) for v in value] if isinstance(value, dict): return {str(k): serialize_value(v) for k, v in sorted(value.items())} if isinstance(value, logging.Logger): # Exclude logger from config serialization return None if hasattr(value, "__dict__"): # Skip logger objects in __dict__ to avoid recursion if isinstance(value, logging.Logger): return serialize_value(value.name) return serialize_value(value.__dict__) return str(value) config_dict = {} for key, value in self.__dict__.items(): if key.startswith("_") or callable(value): continue config_dict[key] = serialize_value(value) config_dict = dict(sorted(config_dict.items())) config_dict["environment"] = self.environment.config_id return config_to_id(config_dict) def __hash__(self) -> int: return hash(self.config_id)
[docs] @abstractmethod def action(self, belief: "Belief") -> Tuple[List[Any], PolicyRunData]: """Select action(s) based on the current belief state. This is the core method that implements the policy's decision-making logic. It takes a belief state and returns the selected action(s) along with execution information and performance metrics. Args: belief: Current belief state representing uncertainty over states Returns: Tuple containing: - List of selected actions (typically single action, but supports multiple) - PolicyRunData with execution metrics and performance information Note: Subclasses must implement this method with their specific planning or decision-making algorithm. """
[docs] @classmethod @abstractmethod def get_space_info(cls) -> PolicySpaceInfo: """Get space type requirements for this policy class. This class method specifies what types of action and observation spaces this policy implementation can handle, enabling compatibility checking with environments. Returns: PolicySpaceInfo specifying required action and observation space types Note: Subclasses must implement this method to declare their space compatibility. This is used for validation when pairing policies with environments. """
[docs] @classmethod @abstractmethod def get_info_variable_names(cls) -> List[str]: """Get names of policy info variables that this policy produces. This class method returns the names of metrics and performance data that the policy tracks during execution via PolicyInfoVariable objects. It enables users to discover what metrics are available for hyperparameter optimization before running simulations. Returns: List of info variable names that this policy produces during action selection Note: Subclasses must implement this method to declare what metrics they track. Use an Enum to ensure consistency between the names returned here and the names used when creating PolicyInfoVariable objects in the action() method. """
[docs] def save(self, filepath: Optional[Union[str, Path]] = None) -> Path: """Save policy configuration to JSON file. Saves only constructor parameters needed to reconstruct the policy, not the full internal state. This enables human-readable policy configurations that can be versioned, inspected, and modified. Args: filepath: Path where to save the policy configuration. If None, uses default location: saved_policies/{env_name}/{policy_class}/{policy_name}_{timestamp}.json Returns: Path where policy was saved Raises: ValueError: If policy parameters cannot be serialized IOError: If file cannot be written Example: >>> from POMDPPlanners.environments import TigerPOMDP >>> from POMDPPlanners.planners import POMCP >>> env = TigerPOMDP(discount_factor=0.95) >>> planner = POMCP(environment=env, discount_factor=0.95, ... depth=10, exploration_constant=1.0, ... name="test", n_simulations=100) >>> # Save with default path >>> filepath = planner.save() >>> # Or save to custom path >>> filepath = planner.save("my_policy.json") """ if filepath is None: filepath = _get_default_filepath(self) filepath = Path(filepath) try: # Extract policy parameters params = _extract_constructor_params(self) # Serialize environment env_data = self.environment.to_dict() # Build save dictionary save_data = { "metadata": { "saved_at": datetime.now().isoformat(), "pomdpplanners_version": _get_package_version(), "policy_class": f"{self.__class__.__module__}.{self.__class__.__name__}", "policy_config_id": self.config_id, "format_version": "1.0", }, "environment": env_data, "policy": {"params": params}, } # Handle action_sampler if present action_sampler = getattr(self, "action_sampler", None) if action_sampler is not None: save_data["action_sampler"] = _serialize_action_sampler(action_sampler) # Write to file with NumpyEncoder filepath.parent.mkdir(parents=True, exist_ok=True) with open(filepath, "w", encoding="utf-8") as f: json.dump(save_data, f, indent=2, cls=NumpyEncoder) return filepath except Exception as e: raise ValueError(f"Failed to save policy: {str(e)}") from e
[docs] @classmethod def load(cls, filepath: Union[str, Path]) -> "Policy": """Load policy configuration from JSON file. Reconstructs policy instance from saved constructor parameters. Creates both the environment and policy from the saved configuration. Args: filepath: Path to the saved policy configuration file Returns: Reconstructed policy instance Raises: FileNotFoundError: If filepath does not exist ValueError: If JSON format is invalid or unsupported ImportError: If policy/environment classes cannot be imported Example: >>> import tempfile >>> from pathlib import Path >>> from POMDPPlanners.planners import POMCP >>> from POMDPPlanners.environments.tiger_pomdp import TigerPOMDP >>> # Create and save a policy >>> env = TigerPOMDP(discount_factor=0.95) >>> planner = POMCP(environment=env, discount_factor=0.95, depth=10, exploration_constant=1.0, name="test", n_simulations=100) >>> with tempfile.TemporaryDirectory() as tmpdir: ... filepath = Path(tmpdir) / "test_policy.json" ... _ = planner.save(filepath) ... # Load the policy back ... loaded_planner = POMCP.load(filepath) ... print(loaded_planner.depth) 10 """ filepath = Path(filepath) if not filepath.exists(): raise FileNotFoundError(f"Policy file not found: {filepath}") try: with open(filepath, "r", encoding="utf-8") as f: data = json.load(f) # Validate format version format_version = data.get("metadata", {}).get("format_version") if format_version != "1.0": raise ValueError(f"Unsupported format version: {format_version}") # Get policy class policy_class_path = data["metadata"]["policy_class"] module_name = ".".join(policy_class_path.split(".")[:-1]) class_name = policy_class_path.split(".")[-1] module = importlib.import_module(module_name) policy_class = getattr(module, class_name) # Reconstruct environment environment = Environment.from_dict(data["environment"]) # Get policy parameters policy_params = data["policy"]["params"].copy() policy_params["environment"] = environment # Handle action_sampler if present if "action_sampler" in data: sampler = _deserialize_action_sampler(data["action_sampler"]) policy_params["action_sampler"] = sampler # Deserialize parameter types and filter to only valid constructor params sig = inspect.signature(policy_class.__init__) filtered_params = {} for param_name, param in sig.parameters.items(): if param_name == "self": continue if param_name in policy_params: value = policy_params[param_name] # Deserialize if type annotation available if param.annotation != inspect.Parameter.empty: value = _deserialize_value(value, param.annotation) filtered_params[param_name] = value # Construct policy with only valid parameters policy = policy_class(**filtered_params) # Optionally warn about config_id mismatch loaded_config_id = policy.config_id saved_config_id = data["metadata"].get("policy_config_id") if loaded_config_id != saved_config_id: warnings.warn( f"Loaded policy config_id ({loaded_config_id}) differs from saved " f"config_id ({saved_config_id}). This may indicate parameter mismatch." ) return policy except Exception as e: raise ValueError(f"Failed to load policy from {filepath}: {str(e)}") from e
[docs] class TrainablePolicy(Policy): """Abstract base class for policies that support offline training. Extends :class:`Policy` with hooks that separate the **model** (what to compute) from the **trainer** (how to run the training loop), following the PyTorch Lightning pattern. Concrete subclasses implement these hooks, and :class:`~POMDPPlanners.training.PolicyTrainer` orchestrates the loop. """
[docs] @abstractmethod def begin_collecting(self) -> None: """Signal the start of a data-collection phase."""
[docs] @abstractmethod def end_collecting(self) -> None: """Signal the end of a data-collection phase."""
[docs] @abstractmethod def prepare_episode(self) -> None: """Reset per-episode scratch state before an episode begins."""
[docs] @abstractmethod def finalize_episode(self, history: Any) -> None: """Process a completed episode into the replay buffer. Args: history: The :class:`~POMDPPlanners.core.simulation.History` returned by the episode runner. """
[docs] @abstractmethod def train_step(self) -> Dict[str, List[float]]: """Train the network on the current replay buffer. Returns: Per-key lists of loss values produced during training. """
[docs] @abstractmethod def buffer_size(self) -> int: """Return the number of examples currently in the replay buffer."""
[docs] @abstractmethod def collect_episodes_batched( self, initial_belief_fn: Callable[[], Any], n_episodes: int, episode_length: int, ) -> None: """Collect training data using fast batched (network-only) rollouts. Args: initial_belief_fn: Callable returning a fresh initial belief. n_episodes: Number of episodes to collect. episode_length: Maximum steps per episode. """
[docs] @abstractmethod def get_metric_keys(self) -> List[str]: """Return the loss-metric key names produced by :meth:`train_step`."""
[docs] def get_network(self) -> Optional[Any]: """Return the underlying trainable network, or ``None`` if not applicable. Override in concrete policies to enable weight-histogram logging in :class:`~POMDPPlanners.training.callbacks.TensorBoardCallback`. """ return None