Source code for syntropy.discrete.shannon

import numpy as np
import math
from .utils import get_marginal_distribution, flatten_nested_tuple

from typing import Any

DiscreteDist = dict[tuple[Any, ...], float]


[docs] def shannon_entropy(joint_distribution: DiscreteDist) -> tuple[dict, float]: """ Computes the Shannon entropy of the distribution :math:`P(x)`. .. math:: H(X) = -\\sum_{x\\in\\mathcal{X}} P(x) \\log P(x) To compute the entropy of a subset of the variables in the joint distribution, use the :func:`syntropy.discrete.utils.get_marginals` function from the utils library. Parameters ---------- joint_distribution : dict The joint probability distribution. Keys are tuples corresponding to the state of each element. The valules are the probabilities. Returns ------- ptw : dict The pointwise entropy for each state in the joint distribution. avg : float The average entropy """ ptw: dict = { state: -math.log2(joint_distribution[state]) for state in joint_distribution.keys() if joint_distribution[state] > 0.0 } avg: float = sum(joint_distribution[state] * ptw[state] for state in ptw.keys()) return ptw, avg
[docs] def conditional_entropy( idxs_x: tuple[int, ...], idxs_y: tuple[int, ...], joint_distribution: DiscreteDist ) -> tuple[dict, float]: """ Computes the conditional entropy of X given Y. .. math:: H(X|Y) = H(X,Y) - H(Y) Parameters ---------- idxs_x : tuple The indices of the variables to compute the entropy on. idxs_y : tuple The indicies of the variables to contintue on. joint_distribution : dict DESCRIPTION.joint_distribution : dict The joint probability distribution. Keys are tuples corresponding to the state of each element. The valules are the probabilities. Returns ------- ptw : dict The pointwise entropy for each state in the joint distribution. avg : float The average entropy """ Nx: int = len(idxs_x) idxs_xy: tuple[int, ...] = idxs_x + idxs_y marginal_xy: DiscreteDist = get_marginal_distribution(idxs_xy, joint_distribution) marginal_y: DiscreteDist = get_marginal_distribution(idxs_y, joint_distribution) ptw: dict = {} avg: float = 0 for state in marginal_xy.keys(): if marginal_xy[state] > 0.0: p_y: float = marginal_y[state[Nx:]] h: float = -math.log2(marginal_xy[state] / p_y) ptw[((state[:Nx]), (state[Nx:]))] = h avg += marginal_xy[state] * h return ptw, avg
[docs] def mutual_information( idxs_x: tuple[int, ...], idxs_y: tuple[int, ...], joint_distribution: DiscreteDist, ) -> tuple[dict, float]: """ Computes the mutual information between X and Y. .. math:: I(X;Y) &= H(X) + H(Y) - H(X,Y) \\\\ &= H(X) - H(X|Y) \\\\ &= H(Y) - H(Y|X) \\\\ &= H(X,Y) - H(X|Y) - H(Y|X) Parameters ---------- idxs_x : tuple The indices of the X variable(s). idxs_y : tuple The indices of the Y variable(s). joint_distribution : dict The joint probability distribution. Keys are tuples corresponding to the state of each element. The valules are the probabilities. Returns ------- ptw : dict The pointwise mutual information for each state in the joint distribution. avg : float The average mutual information """ Nx: int = len(idxs_x) idxs_xy: tuple[int, ...] = idxs_x + idxs_y marginal_xy: DiscreteDist = get_marginal_distribution(idxs_xy, joint_distribution) marginal_y: DiscreteDist = get_marginal_distribution(idxs_y, joint_distribution) marginal_x: DiscreteDist = get_marginal_distribution(idxs_x, joint_distribution) ptw: dict = {} avg: float = 0.0 for state in marginal_xy.keys(): if marginal_xy[state] > 0.0: p_x: float = marginal_x[state[:Nx]] p_y: float = marginal_y[state[Nx:]] mi = math.log2(marginal_xy[state] / (p_x * p_y)) ptw[((state[:Nx]), (state[Nx:]))] = mi avg += mi * marginal_xy[state] return ptw, avg
[docs] def conditional_mutual_information( idxs_x: tuple, idxs_y: tuple, idxs_z: tuple, joint_distribution: DiscreteDist ) -> tuple[dict, float]: """ Computes the mutual information between X and Y condioned on Z. .. math:: I(X,Y|Z) &= H(X|Z) + H(Y|Z) - H(X,Y|Z) \\\\ &= I(X;Y,Z) - I(X;Z) Parameters ---------- idxs_x : tuple The indices of the X variable(s). idxs_y : tuple The indices of the Y variable(s). idxs_z : tuple The indices of the variables to condition on. joint_distribution : dict The joint probability distribution. Keys are tuples corresponding to the state of each element. The valules are the probabilities. Returns ------- ptw : dict The pointwise mutual information for each state in the joint distribution. avg : float The average mutual information """ Nx: int = len(idxs_x) joint: tuple[int, ...] = idxs_x + idxs_y ptw_xz, avg_xz = conditional_entropy(idxs_x, idxs_z, joint_distribution) ptw_yz, avg_yz = conditional_entropy(idxs_y, idxs_z, joint_distribution) ptw_xyz, avg_xyz = conditional_entropy(joint, idxs_z, joint_distribution) avg: float = avg_xz + avg_yz - avg_xyz ptw: dict = {} for state in ptw_xyz.keys(): sx = state[0][:Nx] sy = state[0][Nx:] sxz = (sx, state[1]) syz = (sy, state[1]) ptw[(sx, sy, state[1])] = ptw_xz[sxz] + ptw_yz[syz] - ptw_xyz[state] return ptw, avg
[docs] def kullback_leibler_divergence( posterior_distribution: DiscreteDist, prior_distribution: DiscreteDist ) -> tuple[dict, float]: """ Computes the Kullback-Leibler divergence from a prior distribution P(X) and and posterior distribution Q(X). .. math:: D_{KL}(P||Q) = \\sum_{x} P(x) \\log \\frac{P(x)}{Q(x)} Parameters ---------- posterior_distribution : dict The joint distribution of the posterior distribution P(X). prior_distribution : dict The joint distribution of the prior distribution Q(X) Returns ------- ptw : dict The pointwise Kullback-Leibler divergence for each state in the joint distribution. avg : float The average Kullback-Leibler divergence. """ assert set(prior_distribution.keys()).issuperset( set(posterior_distribution.keys()) ), "The support set of the prior must be a superset of the posterior" avg: float = 0 ptw: dict = {state: 0 for state in posterior_distribution.keys()} for state in posterior_distribution.keys(): log_ratio: float = math.log2( posterior_distribution[state] / prior_distribution[state] ) avg += posterior_distribution[state] * log_ratio ptw[state] = log_ratio return ptw, avg