Source code for POMDPPlanners.core.environment

"""Module for POMDP environment abstractions.

This module provides the foundational classes and interfaces for defining
POMDP environments, including abstract base classes for state transitions,
observation models, and reward functions.

Classes:
    Environment: Abstract base class for POMDP environments
    DiscreteActionsEnvironment: Specialized for discrete action spaces
    ObservationModel: Abstract observation model interface
    StateTransitionModel: Abstract state transition interface
    EnvironmentGenerator: Factory pattern for environment creation
    SpaceType: Enumeration for action/observation space types
    SpaceInfo: Data class containing space type information
"""

import importlib
import inspect
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, Union

import numpy as np

from POMDPPlanners.core.distributions import Distribution
from POMDPPlanners.core.serialization import (
    deserialize_value as deserialize_value_base,
    register_deserializer,
    register_serializer,
    serialize_value as serialize_value_base,
)
from POMDPPlanners.utils.config_to_id import config_to_id
from POMDPPlanners.utils.logger import get_logger

if TYPE_CHECKING:
    from POMDPPlanners.core.simulation import History, MetricValue, StepData


def _serialize_space_info(space_info: Any) -> dict:
    """Serialize SpaceInfo to plain dict without type markers.

    Maintains backward compatibility with existing saved environments.
    Format: {"action_space": "discrete", "observation_space": "continuous"}

    Args:
        space_info: SpaceInfo instance to serialize

    Returns:
        Plain dict with action_space and observation_space string values
    """
    return {
        "action_space": space_info.action_space.value,
        "observation_space": space_info.observation_space.value,
    }


def _deserialize_space_info(data: dict) -> Any:
    """Deserialize SpaceInfo from plain dict format.

    Handles dicts with action_space and observation_space keys without
    requiring __type__ markers for backward compatibility.

    Args:
        data: Dict with action_space and observation_space keys

    Returns:
        SpaceInfo instance

    Raises:
        ValueError: If data cannot be deserialized to SpaceInfo
    """
    if isinstance(data, dict) and "action_space" in data and "observation_space" in data:
        # Import SpaceType here to avoid circular dependency
        return SpaceInfo(
            action_space=SpaceType(data["action_space"]),
            observation_space=SpaceType(data["observation_space"]),
        )
    raise ValueError(f"Cannot deserialize SpaceInfo from {data}")


[docs] class SpaceType(Enum): """Enumeration for categorizing action and observation spaces. This enum is used to classify the mathematical structure of action and observation spaces in POMDP environments. Attributes: DISCRETE: Finite, countable spaces (e.g., {0, 1, 2, ...}) CONTINUOUS: Real-valued continuous spaces (e.g., R^n) MIXED: Combination of discrete and continuous elements """ DISCRETE = "discrete" CONTINUOUS = "continuous" MIXED = "mixed"
[docs] @dataclass class SpaceInfo: """Data class containing space type information for an environment. This class encapsulates the space type classifications for both actions and observations in a POMDP environment. Attributes: action_space: The type of action space (discrete, continuous, or mixed) observation_space: The type of observation space (discrete, continuous, or mixed) Example: Creating space info for different environment types: >>> # Discrete actions, continuous observations >>> space_info = SpaceInfo( ... action_space=SpaceType.DISCRETE, ... observation_space=SpaceType.CONTINUOUS ... ) """ action_space: SpaceType observation_space: SpaceType
# Register SpaceInfo serialization handlers at module load time # This enables centralized serialization system to handle SpaceInfo automatically register_serializer(SpaceInfo, _serialize_space_info) register_deserializer(SpaceInfo, _deserialize_space_info)
[docs] class ObservationModel(Distribution, ABC): """Abstract base class for POMDP observation models. This class defines the interface for observation models that generate observations given a next state and action. Inherits from Distribution to provide sampling and probability calculation capabilities. Note: This is an abstract base class and cannot be instantiated directly. Subclasses must implement the sample() method. Attributes: next_state: The state after taking an action action: The action that was taken """ def __init__(self, next_state: Any, action: Any): """Initialize the observation model. Args: next_state: The resulting state after taking an action action: The action that was executed """ self.next_state = next_state self.action = action
[docs] @abstractmethod def sample(self, n_samples: int = 1) -> List[Any]: """Sample observations from the observation model. Args: n_samples: Number of observation samples to generate. Defaults to 1. Returns: List of sampled observations of length n_samples. Note: Subclasses must implement this method according to their specific observation generation logic. """
[docs] def probability(self, values: List[Any]) -> np.ndarray: """Calculate observation probabilities for given values. Args: values: List of observation values to calculate probabilities for Returns: Array of probabilities corresponding to the input values Raises: NotImplementedError: This method is not implemented by default. Subclasses should override if probability calculation is needed. """ raise NotImplementedError("The method is not implemented for this observation model.")
[docs] class StateTransitionModel(Distribution, ABC): """Abstract base class for POMDP state transition models. This class defines the interface for state transition models that generate next states given a current state and action. Inherits from Distribution to provide sampling and probability calculation capabilities. Note: This is an abstract base class and cannot be instantiated directly. Subclasses must implement the sample() method. Attributes: state: The current state action: The action to be taken """ def __init__(self, state: Any, action: Any): """Initialize the state transition model. Args: state: The current state action: The action to be executed from the current state """ self.state = state self.action = action
[docs] @abstractmethod def sample(self, n_samples: int = 1) -> List[Any]: """Sample next states from the transition model. Args: n_samples: Number of next state samples to generate. Defaults to 1. Returns: List of sampled next states of length n_samples. Note: Subclasses must implement this method according to their specific state transition dynamics. """
[docs] def probability(self, values: List[Any]) -> np.ndarray: """Calculate transition probabilities for given next states. Args: values: List of next state values to calculate probabilities for Returns: Array of transition probabilities corresponding to the input values Raises: NotImplementedError: This method is not implemented by default. Subclasses should override if probability calculation is needed. """ raise NotImplementedError("The method is not implemented for this state transition model.")
[docs] class Environment(ABC): """Abstract base class for POMDP environments. This is the core abstract class that all POMDP environments must inherit from. It defines the essential interface for POMDP environments including state transitions, observations, rewards, and terminal conditions. Note: This is an abstract base class and cannot be instantiated directly. Subclasses must implement all abstract methods. Attributes: discount_factor: Discount factor for future rewards name: Environment identifier string space_info: Information about action and observation space types reward_range: Optional tuple containing (min_reward, max_reward) output_dir: Optional directory for logging output debug: Flag to enable debug logging """ def __init__( self, discount_factor: float, name: str, space_info: SpaceInfo, reward_range: Optional[Tuple[float, float]] = None, output_dir: Optional[Path] = None, debug: bool = False, use_queue_logger: bool = False, ): """Initialize the POMDP environment. Args: discount_factor: Discount factor for future rewards (0 < discount_factor <= 1) name: Unique identifier for the environment space_info: Information about action and observation space types reward_range: Optional tuple containing (min_reward, max_reward) for the environment. Defaults to None. If provided, will be validated. output_dir: Optional directory for logging output. Defaults to None. debug: Enable debug logging. Defaults to False. use_queue_logger: Whether to use queue-based logging. Defaults to True. """ self.discount_factor = discount_factor self.name = name self.space_info = space_info self.reward_range = self._validate_reward_range(reward_range) self.output_dir = output_dir self.debug = debug self.use_queue_logger = use_queue_logger self.logger.info( "Initializing %s environment with discount factor %s", self.name, self.discount_factor ) self.logger.debug( "Space info: action_space=%s, observation_space=%s", self.space_info.action_space, self.space_info.observation_space, ) if self.reward_range is not None: self.logger.debug("Reward range: %s", self.reward_range) def _validate_reward_range( self, reward_range: Optional[Tuple[float, float]] ) -> Optional[Tuple[float, float]]: """Validate reward_range if provided. Args: reward_range: Optional tuple containing (min_reward, max_reward) Returns: Validated reward_range tuple or None if input was None Raises: ValueError: If reward_range structure or values are invalid TypeError: If reward_range values are not numeric """ if reward_range is None: return None # Validate structure if not isinstance(reward_range, tuple) or len(reward_range) != 2: raise ValueError("reward_range must be a tuple of exactly two float values") min_reward, max_reward = reward_range # Check that both values are numeric (float or int) if not isinstance(min_reward, (int, float)) or not isinstance(max_reward, (int, float)): raise TypeError("reward_range values must be numeric (int or float)") # Convert to float to ensure consistency min_reward, max_reward = float(min_reward), float(max_reward) # Check for NaN values if np.isnan(min_reward) or np.isnan(max_reward): raise ValueError("reward_range values cannot be NaN") # Check that min_reward <= max_reward (allowing inf values) if min_reward > max_reward: raise ValueError( f"reward_range minimum ({min_reward}) must be less than or equal to maximum ({max_reward})" ) return (min_reward, max_reward) @property def logger(self) -> logging.Logger: """Get logger instance for this environment. The logger is implemented as a property to maintain pickle compatibility, as logger objects cannot be pickled directly. Returns: Configured logger instance with hierarchical naming """ return get_logger( name=f"environment.{self.name}", output_dir=self.output_dir, debug=self.debug, use_queue=self.use_queue_logger, ) def __eq__(self, other): if not isinstance(other, Environment): return False if self.__class__ != other.__class__: return False def _compare_values(v1, v2): # pylint: disable=too-many-return-statements """Helper function to compare values, handling numpy arrays specially.""" if isinstance(v1, np.ndarray) or isinstance(v2, np.ndarray): if not (isinstance(v1, np.ndarray) and isinstance(v2, np.ndarray)): return False return np.array_equal(v1, v2) if isinstance(v1, (list, tuple)) and isinstance(v2, (list, tuple)): if len(v1) != len(v2): return False return all(_compare_values(x1, x2) for x1, x2 in zip(v1, v2)) if isinstance(v1, dict) and isinstance(v2, dict): if v1.keys() != v2.keys(): return False return all(_compare_values(v1[k], v2[k]) for k in v1) return v1 == v2 # Compare all public attributes (excluding callables and private) for key, value in self.__dict__.items(): if key.startswith("_") or callable(value): continue if not hasattr(other, key): return False other_value = getattr(other, key) if not _compare_values(value, other_value): return False # Check for any attributes in other that aren't in self for key in other.__dict__: if key.startswith("_") or callable(getattr(other, key)): continue if not hasattr(self, key): return False return True @property def config_id(self) -> str: """Generate a deterministic identifier based on environment configuration. Note: Uses custom serialization logic (not centralized serialize_value) to ensure: - Deterministic dict key ordering for consistent hashing - Compact format without __type__ markers - Recursive handling of nested objects Changing this serialization format would invalidate all cached results. """ def serialize_value(value): # pylint: disable=too-many-return-statements if isinstance(value, np.ndarray): return value.tolist() if isinstance(value, (str, int, float, bool)): return value if isinstance(value, (list, tuple)): return [serialize_value(v) for v in value] if isinstance(value, dict): return {str(k): serialize_value(v) for k, v in sorted(value.items())} if isinstance(value, SpaceInfo): return { "action_space": serialize_value(value.action_space), "observation_space": serialize_value(value.observation_space), } if isinstance(value, Enum): return value.value if hasattr(value, "__dict__"): # Skip logger objects if isinstance(value, logging.Logger): return None return serialize_value(value.__dict__) return str(value) config_dict = {} for key, value in self.__dict__.items(): # Skip logger and private attributes if key.startswith("_") or callable(value) or isinstance(value, logging.Logger): continue serialized_value = serialize_value(value) if serialized_value is not None: # Skip None values (like logger) config_dict[key] = serialized_value config_dict = dict(sorted(config_dict.items())) return config_to_id(config_dict) def __hash__(self) -> int: return hash(self.config_id)
[docs] @abstractmethod def state_transition_model(self, state: Any, action: Any) -> StateTransitionModel: """Get the state transition model for a given state-action pair. Args: state: Current state action: Action to be executed Returns: State transition model that can sample next states Note: Subclasses must implement this method to define state dynamics. """
[docs] @abstractmethod def observation_model(self, next_state: Any, action: Any) -> ObservationModel: """Get the observation model for a given next state and action. Args: next_state: The resulting state after taking an action action: The action that was executed Returns: Observation model that can sample observations Note: Subclasses must implement this method to define observation generation. """
[docs] @abstractmethod def reward(self, state: Any, action: Any) -> float: """Calculate the immediate reward for a state-action pair. Args: state: Current state action: Action executed from the state Returns: Immediate reward value Note: Subclasses must implement this method to define reward structure. """
[docs] def reward_batch(self, states: Union[np.ndarray, Sequence[Any]], action: Any) -> np.ndarray: """Calculate rewards for a batch of states given a single action. Provides a loop-based default that subclasses can override with vectorized numpy implementations for better performance. Args: states: Sequence of states of length ``N``. action: Action executed from each state. Returns: 1-D array of reward values with shape ``(N,)``. """ return np.array([self.reward(states[i], action) for i in range(len(states))])
[docs] @abstractmethod def is_terminal(self, state: Any) -> bool: """Check if a state is terminal. Args: state: State to check for terminal condition Returns: True if the state is terminal, False otherwise Note: Subclasses must implement this method to define terminal conditions. """
[docs] @abstractmethod def initial_state_dist(self) -> Distribution: """Get the initial state distribution. Returns: Distribution over initial states Note: Subclasses must implement this method to define the starting distribution. """
[docs] @abstractmethod def initial_observation_dist(self) -> Distribution: """Get the initial observation distribution. Returns: Distribution over initial observations Note: Subclasses must implement this method to define initial observations. """
[docs] @abstractmethod def is_equal_observation(self, observation1: Any, observation2: Any) -> bool: """Check if two observations are equal. Args: observation1: First observation to compare observation2: Second observation to compare Returns: True if observations are considered equal, False otherwise Note: Subclasses must implement this method to define observation equality. This is particularly important for discrete observation spaces. """
[docs] def sample_next_step(self, state: Any, action: Any) -> Tuple[Any, Any, float]: """Sample a complete state transition step. This convenience method combines state transition, observation generation, and reward calculation in a single operation. Args: state: Current state action: Action to execute Returns: Tuple containing: - next_state: Sampled next state - next_observation: Sampled observation - reward: Immediate reward """ next_state = self.state_transition_model(state=state, action=action).sample()[0] next_observation = self.observation_model(next_state=next_state, action=action).sample()[0] reward = self.reward(state=state, action=action) return next_state, next_observation, reward
[docs] def cache_visualization(self, history: "List[StepData]", cache_path: Path) -> None: """Cache visualization data for an episode history. This method can be overridden by subclasses to provide environment-specific visualization caching capabilities. Args: history: List of step data from an episode cache_path: Path where visualization data should be cached """
[docs] def get_metric_names(self) -> List[str]: """Get names of environment-specific metrics. This method returns the names of custom metrics that this environment computes in the compute_metrics() method. It enables users to discover what metrics are available for hyperparameter optimization. Returns: List of metric names that this environment produces. Default implementation returns empty list for environments without custom metrics. Note: Subclasses that override compute_metrics() should also override this method to return the names of metrics they produce. Use an Enum to ensure consistency between the names returned here and the names used in compute_metrics(). """ return []
[docs] def compute_metrics( self, histories: "List[History]" ) -> "List[MetricValue]": # pylint: disable=unused-argument """Compute environment-specific metrics from episode histories. This method can be overridden by subclasses to provide custom metric calculations beyond standard return and episode length. Args: histories: List of episode histories to analyze Returns: List of computed metrics with confidence intervals """ return []
[docs] def to_dict(self) -> Dict[str, Any]: """Serialize environment to dictionary format. Extracts environment class information and constructor parameters to enable JSON serialization and reconstruction. Returns: Dictionary with structure: - class: Full class path (module.ClassName) - module: Module name - params: Constructor parameters - config_id: Deterministic configuration identifier Example: >>> from POMDPPlanners.environments.tiger_pomdp import TigerPOMDP >>> env = TigerPOMDP(discount_factor=0.95) >>> env_dict = env.to_dict() >>> 'class' in env_dict and 'params' in env_dict True Note: Uses centralized serialization system with registered SpaceInfo handler. """ # Get environment class information env_class = self.__class__ env_module = env_class.__module__ env_class_name = env_class.__name__ # Extract constructor parameters sig = inspect.signature(env_class.__init__) params = {} for param_name, _ in sig.parameters.items(): if param_name == "self": continue if hasattr(self, param_name): value = getattr(self, param_name) # Use centralized serialization (SpaceInfo handled by registered handler) serialized_value = serialize_value_base(value) if serialized_value is not None: # Skip None values (like logger) params[param_name] = serialized_value return { "class": f"{env_module}.{env_class_name}", "module": env_module, "params": params, "config_id": self.config_id, }
[docs] @classmethod def from_dict(cls, data: Dict[str, Any]) -> "Environment": """Reconstruct environment from dictionary. Dynamically imports the environment class and instantiates it with the saved parameters. Args: data: Dictionary containing environment serialization data with keys: class, module, params, config_id Returns: Reconstructed environment instance Raises: ImportError: If environment class cannot be imported ValueError: If required data fields are missing TypeError: If parameters are invalid for environment constructor Example: >>> from POMDPPlanners.environments.tiger_pomdp import TigerPOMDP >>> env = TigerPOMDP(discount_factor=0.95) >>> env_dict = env.to_dict() >>> reconstructed_env = Environment.from_dict(env_dict) >>> reconstructed_env.discount_factor 0.95 """ def deserialize_value( value, target_type, param_name="" ): # pylint: disable=too-many-branches """Deserialize value with environment-specific handling. Handles environment-specific patterns before delegating to centralized system: - List[Tuple[...]] / Set[Tuple[...]] for obstacles, rock positions - Matrix parameters (covariance matrices) with parameter name detection Note: SpaceInfo is handled automatically by registered handler in centralized system. """ # Unwrap Optional[T] types first unwrapped_type = target_type if hasattr(target_type, "__origin__") and target_type.__origin__ is Union: # Get non-None type from Optional # pylint: disable=unidiomatic-typecheck args = [arg for arg in target_type.__args__ if arg is not type(None)] if args: unwrapped_type = args[0] # Environment-specific pattern: List[Tuple[...]] and Set[Tuple[...]] # Used by PushPOMDP (obstacles) and RockSamplePOMDP (rock_positions) # Handles multiple serialized formats for compatibility if hasattr(unwrapped_type, "__origin__"): if unwrapped_type.__origin__ in (list, set): # Check if the element type is a tuple args = getattr(unwrapped_type, "__args__", ()) if args and hasattr(args[0], "__origin__") and args[0].__origin__ is tuple: # This is List[Tuple[...]] or Set[Tuple[...]] if isinstance(value, list) and value: # Format 1: Tuple markers like {'__type__': 'tuple', 'values': [x, y]} if isinstance(value[0], dict) and value[0].get("__type__") == "tuple": return [deserialize_value_base(elem, None) for elem in value] # First deserialize the value (might be ndarray marker or plain list) deserialized = deserialize_value_base(value, None) # Format 2: NumPy array shape (2, N) → [(x1,y1), (x2,y2), ...] if isinstance(deserialized, np.ndarray): if deserialized.ndim == 2 and deserialized.shape[0] == 2: return list(zip(deserialized[0], deserialized[1])) # Format 3: 2D list [[x1,x2,...], [y1,y2,...]] → [(x1,y1), ...] elif isinstance(deserialized, list) and deserialized: if len(deserialized) == 2 and isinstance(deserialized[0], list): return list(zip(deserialized[0], deserialized[1])) # Environment-specific pattern: Matrix parameter name detection # Ensures covariance matrices are always numpy arrays matrix_param_names = [ "noise_cov", "_cov", "cov_matrix", "state_transition_cov_matrix", "observation_cov_matrix", ] if any(name in param_name.lower() for name in matrix_param_names): result = deserialize_value_base(value, target_type) if not isinstance(result, np.ndarray): result = np.array(result) return result # Handle numpy array type annotations if target_type == np.ndarray or ( hasattr(target_type, "__name__") and "ndarray" in target_type.__name__ ): result = deserialize_value_base(value, target_type) if not isinstance(result, np.ndarray): result = np.array(result) return result # Delegate to centralized deserialization for all other types return deserialize_value_base(value, target_type) # Validate required fields if "class" not in data or "module" not in data or "params" not in data: raise ValueError("Environment data missing required fields: class, module, or params") # Import environment class dynamically module_name = data["module"] class_name = data["class"].split(".")[-1] try: module = importlib.import_module(module_name) env_class = getattr(module, class_name) except (ImportError, AttributeError) as e: raise ImportError( f"Failed to import environment class {data['class']}: {str(e)}" ) from e # Deserialize parameters with type hints sig = inspect.signature(env_class.__init__) params = {} for param_name, param in sig.parameters.items(): if param_name == "self": continue if param_name in data["params"]: value = data["params"][param_name] # Try to deserialize with type annotation if available if param.annotation != inspect.Parameter.empty: value = deserialize_value(value, param.annotation, param_name) else: value = deserialize_value(value, type(value), param_name) params[param_name] = value # Reconstruct environment try: return env_class(**params) except TypeError as e: raise TypeError( f"Failed to construct {class_name} with params {params}: {str(e)}" ) from e
[docs] class DiscreteActionsEnvironment(Environment): """Abstract base class for POMDP environments with discrete action spaces. This class extends the base Environment class with additional functionality specific to environments that have finite, enumerable action sets. Note: This is an abstract base class and cannot be instantiated directly. Subclasses must implement all abstract methods from Environment plus the get_actions() method. """ def __init__( self, discount_factor: float, name: str, space_info: SpaceInfo, reward_range: Optional[Tuple[float, float]] = None, output_dir: Optional[Path] = None, debug: bool = False, use_queue_logger: bool = False, ): """Initialize the discrete actions environment. Args: discount_factor: Discount factor for future rewards (0 < discount_factor <= 1) name: Unique identifier for the environment space_info: Information about action and observation space types reward_range: Optional tuple containing (min_reward, max_reward) for the environment. Defaults to None. If provided, will be validated. output_dir: Optional directory for logging output. Defaults to None. debug: Enable debug logging. Defaults to False. """ super().__init__( discount_factor=discount_factor, name=name, space_info=space_info, reward_range=reward_range, output_dir=output_dir, debug=debug, use_queue_logger=use_queue_logger, ) self.logger.debug("Initialized DiscreteActionsEnvironment")
[docs] @abstractmethod def state_transition_model(self, state: Any, action: Any) -> StateTransitionModel: pass
[docs] @abstractmethod def observation_model(self, next_state: Any, action: Any) -> ObservationModel: pass
[docs] @abstractmethod def reward(self, state: Any, action: Any) -> float: pass
[docs] @abstractmethod def is_terminal(self, state: Any) -> bool: pass
[docs] @abstractmethod def initial_state_dist(self) -> Distribution: pass
[docs] @abstractmethod def initial_observation_dist(self) -> Distribution: pass
[docs] @abstractmethod def get_actions(self) -> List[Any]: """Get all possible actions in the discrete action space. Returns: List containing all valid actions that can be executed Note: Subclasses must implement this method to enumerate all possible actions. This is used by planning algorithms that need to iterate over actions. """
[docs] @abstractmethod def is_equal_observation(self, observation1: Any, observation2: Any) -> bool: pass
[docs] class EnvironmentGenerator(ABC): """Abstract base class for environment generators. This class implements the factory pattern for creating environment instances. It's useful for generating environments with randomized parameters or for creating multiple environment variants. Note: This is an abstract base class and cannot be instantiated directly. Subclasses must implement the generate_environment() method. Attributes: name: Identifier for the generator """ def __init__(self, name: str): """Initialize the environment generator. Args: name: Unique identifier for this generator """ self.name = name
[docs] @abstractmethod def generate_environment(self) -> Environment: """Generate a new environment instance. Returns: Newly created environment instance Note: Subclasses must implement this method to define environment creation logic. This may involve randomization, parameter sampling, or deterministic generation. """