Source code for POMDPPlanners.core.belief.gaussian_mixture_belief

"""Gaussian Mixture belief state representation for POMDP environments.

This module provides a Gaussian Mixture Model (GMM) belief state that
represents the posterior as a weighted mixture of multivariate Gaussians.
Updates are delegated to a :class:`GaussianMixtureBeliefUpdater` instance,
following the same dependency injection pattern as
:class:`~POMDPPlanners.core.belief.GaussianBelief`.

Classes:
    GaussianMixtureBeliefUpdater: ABC for GMM belief update strategies.
    GaussianMixtureBelief: GMM belief with pluggable updater.
"""

from abc import ABC, abstractmethod
from typing import Any, List, Optional, Tuple

import numpy as np
from scipy.special import logsumexp

from POMDPPlanners.core.belief.base_belief import Belief
from POMDPPlanners.core.environment import Environment
from POMDPPlanners.utils.config_to_id import config_to_id
from POMDPPlanners.utils.multivariate_normal import (
    CovarianceParameterizedMultivariateNormal,
)


[docs] class GaussianMixtureBeliefUpdater(ABC): """Abstract base class for Gaussian mixture belief updaters. Subclasses implement an update cycle that maps ``(means, covariances, weights, action, observation)`` to an updated ``(new_means, new_covariances, new_weights)`` tuple. Note: This is an abstract base class and cannot be instantiated directly. """
[docs] @abstractmethod def update( self, means: List[np.ndarray], covariances: List[np.ndarray], weights: np.ndarray, action: Any, observation: Any, ) -> Tuple[List[np.ndarray], List[np.ndarray], np.ndarray]: """Perform a belief update for the Gaussian mixture. Args: means: List of k mean vectors, each of shape (d,). covariances: List of k covariance matrices, each of shape (d, d). weights: Mixture weights of shape (k,). action: Action that was executed. observation: Observation that was received. Returns: A tuple ``(new_means, new_covariances, new_weights)``. """
@property @abstractmethod def config_id(self) -> str: """Return a deterministic identifier for this updater configuration."""
[docs] class GaussianMixtureBelief(Belief): """Gaussian Mixture Model belief state representation. Represents the belief as a weighted mixture of multivariate normal distributions: p(x) = sum_k w_k * N(x; mu_k, Sigma_k). The update mechanism delegates to a :class:`GaussianMixtureBeliefUpdater` instance, allowing flexibility in how mixture components are updated, pruned, or merged. This belief type is compatible with PFT_DPW, Sparse-PFT, and SparseSampling planners. It is NOT compatible with POMCP/POMCP_DPW planners because it does not support incremental particle accumulation via ``inplace_update()``. Attributes: means: List of mean vectors, one per component. covariances: List of covariance matrices, one per component. weights: Array of mixture weights summing to 1. updater: GaussianMixtureBeliefUpdater that computes the Bayesian belief update. n_terminal_check_samples: Number of Monte Carlo samples for terminal checks. Example: >>> import numpy as np >>> np.random.seed(42) >>> >>> # Define a simple updater that shrinks covariances >>> from POMDPPlanners.core.belief.gaussian_mixture_belief import ( ... GaussianMixtureBeliefUpdater, ... ) >>> class ShrinkUpdater(GaussianMixtureBeliefUpdater): ... def update(self, means, covs, weights, action, obs): ... return means, [c * 0.9 for c in covs], weights ... @property ... def config_id(self): ... return "shrink" >>> >>> # Create a 2-component GMM belief in 2D >>> means = [np.array([0.0, 0.0]), np.array([3.0, 3.0])] >>> covs = [np.eye(2), np.eye(2)] >>> weights = np.array([0.5, 0.5]) >>> belief = GaussianMixtureBelief( ... means=means, covariances=covs, weights=weights, updater=ShrinkUpdater(), ... ) >>> >>> # Sample a state >>> state = belief.sample() >>> len(state) == 2 True >>> >>> # Update belief >>> new_belief = belief.update( ... action=0, observation=np.array([1.0, 1.0]), pomdp=None ... ) >>> new_belief.n_components == 2 True """ def __init__( self, means: List[np.ndarray], covariances: List[np.ndarray], weights: np.ndarray, updater: GaussianMixtureBeliefUpdater, n_terminal_check_samples: int = 50, ): """Initialize Gaussian Mixture belief. Args: means: List of k mean vectors, each of shape (d,). covariances: List of k positive definite covariance matrices, each of shape (d, d). weights: Mixture weights of shape (k,) that must sum to 1. updater: A :class:`GaussianMixtureBeliefUpdater` instance whose ``update(means, covariances, weights, action, observation)`` method returns ``(new_means, new_covariances, new_weights)``. n_terminal_check_samples: Number of Monte Carlo samples drawn for terminal state checks. Defaults to 50. Raises: ValueError: If inputs are inconsistent (mismatched counts, dimensions, or invalid weights). """ means = [np.asarray(m, dtype=float) for m in means] covariances = [np.asarray(c, dtype=float) for c in covariances] weights = np.asarray(weights, dtype=float) self._validate_inputs(means, covariances, weights) self.means = means self.covariances = covariances self.weights = weights self.updater = updater self.n_terminal_check_samples = n_terminal_check_samples self._mvns = [CovarianceParameterizedMultivariateNormal(c) for c in covariances] @staticmethod def _validate_inputs( means: List[np.ndarray], covariances: List[np.ndarray], weights: np.ndarray, ) -> None: if len(means) == 0: raise ValueError("Must have at least one component") if len(means) != len(covariances): raise ValueError( f"Number of means ({len(means)}) does not match " f"number of covariances ({len(covariances)})" ) if weights.ndim != 1 or len(weights) != len(means): raise ValueError( f"weights must be a 1D array of length {len(means)}, " f"got shape {weights.shape}" ) if not np.isclose(weights.sum(), 1.0): raise ValueError(f"weights must sum to 1, got {weights.sum()}") if np.any(weights < 0): raise ValueError("weights must be non-negative") d = len(means[0]) for i, (m, c) in enumerate(zip(means, covariances)): if m.ndim != 1: raise ValueError(f"means[{i}] must be 1D, got {m.ndim}D") if len(m) != d: raise ValueError(f"means[{i}] has length {len(m)}, expected {d}") if c.ndim != 2 or c.shape != (d, d): raise ValueError(f"covariances[{i}] must have shape ({d}, {d}), " f"got {c.shape}")
[docs] def sample(self) -> np.ndarray: """Sample a state from the Gaussian mixture belief. Selects a component according to the mixture weights, then draws a sample from that component's Gaussian distribution. Returns: A state vector of shape (d,). """ k = np.random.choice(len(self.weights), p=self.weights) return self._mvns[k].sample(self.means[k], n_samples=1)[0]
[docs] def update( self, action: Any, observation: Any, pomdp: Optional[Environment] = None, state: Optional[Any] = None, ) -> "GaussianMixtureBelief": """Update belief using the provided updater. Args: action: Action that was executed. observation: Observation that was received. pomdp: Unused. Kept for interface compatibility with :class:`~POMDPPlanners.core.belief.base_belief.Belief`. state: Ignored for Gaussian mixture beliefs. Returns: New GaussianMixtureBelief with updated components and weights. """ new_means, new_covs, new_weights = self.updater.update( self.means, self.covariances, self.weights, action, observation ) return GaussianMixtureBelief( means=new_means, covariances=new_covs, weights=new_weights, updater=self.updater, n_terminal_check_samples=self.n_terminal_check_samples, )
@property def config_id(self) -> str: """Generate a deterministic identifier based on belief configuration.""" sorted_data = sorted( zip( [m.tolist() for m in self.means], [c.tolist() for c in self.covariances], self.weights.tolist(), ) ) config_dict = { "components": sorted_data, "n_terminal_check_samples": self.n_terminal_check_samples, "updater": self.updater.config_id, } return config_to_id(config_dict) @property def dim(self) -> int: """Return the dimensionality of the belief state.""" return len(self.means[0]) @property def n_components(self) -> int: """Return the number of mixture components.""" return len(self.means)
[docs] def entropy(self, n_samples: int = 1000) -> float: """Estimate the differential entropy via Monte Carlo sampling. There is no closed-form expression for the entropy of a Gaussian mixture, so this method uses the approximation: H ~ -mean(log p(x_i)), x_i ~ p(x) Args: n_samples: Number of Monte Carlo samples. Defaults to 1000. Returns: Estimated differential entropy in nats. """ samples = np.array([self.sample() for _ in range(n_samples)]) log_pdf_values = self._log_pdf(samples) return -float(np.mean(log_pdf_values))
def _log_pdf(self, values: np.ndarray) -> np.ndarray: values = np.atleast_2d(values) n = values.shape[0] k = self.n_components log_components = np.empty((n, k)) for j in range(k): log_components[:, j] = np.log(self.weights[j]) + self._mvns[j].log_pdf( values, self.means[j] ) return np.asarray(logsumexp(log_components, axis=1))