Source code for syntropy.discrete.utils

# import pickle
import numpy as np
import itertools
from typing import Any


[docs] def make_powerset(iterable): """ A utility function for quickly making powersets, powerset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3) """ xs: list = list(iterable) # note we return an iterator rather than a list return itertools.chain.from_iterable(itertools.combinations(xs, n) for n in range(len(xs) + 1))
[docs] def flatten_nested_tuple(x: tuple[tuple[Any, ...], ...]) -> tuple[Any, ...]: return tuple(itertools.chain(*x))
[docs] def clean_distribution(joint_distribution: dict[tuple, float]) -> dict: """ A utility function to remove states with 0 probability Parameters ---------- joint_distribution: dict[tuple, float] The joint probability distribution. Keys are tuples corresponding to the state of each element. The valules are the probabilities. Returns ------- dict The joint probability distribution with zero-probability elements removed. """ return { key: joint_distribution[key] for key in joint_distribution.keys() if joint_distribution[key] > 0.0 }
[docs] def reduce_state(state: tuple, source: tuple) -> tuple: """ A utility function for reducing tuples to just the elements in the source. Parameters ---------- state : tuple The particular state of each variable. source : tuple The indices of the variable to remove. Returns ------- tuple The reduced state consisting only of those elements indexed in the source variable. """ return tuple(state[i] for i in source)
[docs] def construct_joint_distribution(data: np.ndarray) -> dict[tuple, float]: """ Given a channels x time, discrete Numpy array, computes the probability distribution that describes the data. Parameters ---------- data : np.ndarray The data: assumed to be in elements x time format. Returns ------- dict The joint probability distribution. Keys are tuples corresponding to the state of each element. The valules are the probabilities. """ assert data.dtype != "float", "The array must be discrete-valued variables." N0: int N1: int N0, N1 = data.shape unq: np.ndarray counts: np.ndarray unq, counts = np.unique(data, return_counts=True, axis=-1) return {tuple(unq[:, i]): counts[i] / counts.sum() for i in range(unq.shape[1])}
[docs] def get_marginal_distribution( idxs: tuple, joint_distribution: dict[tuple, float] ) -> dict: """ Returns the marginal distribution of the variables indexed by the idxs tuple. The opposite of the marginalize_out() function. Parameters ---------- idxs : tuple The indices of the variable to retain. joint_distribution: dict[tuple, float] The joint probability distribution. Keys are tuples corresponding to the state of each element. The valules are the probabilities. Returns ------- dict The marginal joint probability distribution object. """ reduced_distribution = {} for state, prob in joint_distribution.items(): r_state = reduce_state(state, idxs) reduced_distribution[r_state] = reduced_distribution.get(r_state, 0.0) + prob return reduced_distribution
[docs] def marginalize_out(idxs: tuple, joint_distribution: dict[tuple, float]) -> dict: """ Returns a distribution with the variables indexed by idxs marginalized out. Parameters ---------- idxs : tuple The indices of the variables to be marginalized out. joint_distribution: dict[tuple, float] The joint probability distribution. Keys are tuples corresponding to the state of each element. The valules are the probabilities. Returns ------- dict A joint probability distribution dictionary. """ N: int = len(list(joint_distribution.keys())[0]) residuals: tuple = tuple(i for i in range(N) if i not in idxs) return get_marginal_distribution(residuals, joint_distribution)
[docs] def get_all_marginal_distributions( joint_distribution: dict[tuple, float], ) -> dict[tuple, dict]: """ Computes the set of all marginal probability distributions. If the original distribution has variables: :math:`P(X_1, X_2, X_3)` Returns a dictionary of dictionaries for each: :math:`P(X_1,), P(X_2,), P(X_3,), P(X_1, X_2), P(X_1, X_3), P(X_2, X_3), P(X_1,X_2,X_3)` Parameters ---------- joint_distribution: dict[tuple, float][tuple, float] The joint probability distribution. Keys are tuples corresponding to the state of each element. The valules are the probabilities. Returns ------- dict[tuple, dict] A dictionary of dictionaries: each key is a set of marginals, each value is the associated marginal distribution . """ N: int = len(list(joint_distribution.keys())[0]) sources: list = list(make_powerset(range(N))) sources.remove(()) marginal_dict: dict = { source: get_marginal_distribution(source, joint_distribution) for source in sources } return marginal_dict
[docs] def product_distribution( A: dict[tuple[int, ...], float], B: dict[tuple[int, ...], float] ) -> dict[tuple[int, ...], float]: """ Compute the product of two independent distributions. Parameters ---------- A : dict[tuple[int, ...], float] The first distribution B : dict[tuple[int, ...], float] The second distribution Returns ------- dict[tuple[int, ...], float] Joint distribution over concatenated states """ result = {} for state_a, prob_a in A.items(): for state_b, prob_b in B.items(): joint_state = state_a + state_b result[joint_state] = prob_a * prob_b return result
[docs] def generate_closed_distribution(N: int, seed: int = None) -> dict[tuple[int, ...], float]: """ Generate a random closed discrete probability distribution on N binary elements. A distribution is closed iff H(X_i | X^{-i}) = 0 for all i, meaning every variable is fully determined by the others. This requires the support to have minimum Hamming distance >= 2. Parameters ---------- N : int Number of binary elements. seed : int, optional Random seed for reproducibility. Returns ------- dict[tuple[int, ...], float] Probability distribution as {state: probability} mapping. """ rng = np.random.default_rng(seed) # Generate all 2^N possible states all_states = list(itertools.product([0, 1], repeat=N)) # Build valid support: greedily add states with Hamming distance >= 2 from all others support = [] candidates = list(all_states) rng.shuffle(candidates) for state in candidates: # Check if this state has Hamming distance >= 2 from all states in support valid = True for s in support: hamming_dist = sum(a != b for a, b in zip(state, s)) if hamming_dist < 2: valid = False break if valid: support.append(state) # Assign random probabilities to support states weights = rng.random(len(support)) weights /= weights.sum() return {state: float(p) for state, p in zip(support, weights)}