import numpy as np
import networkx as nx
from typing import Callable
from .utils import make_powerset, reduce_state
from ..lattices import LATTICE_2, LATTICE_3, LATTICE_4
BOTTOM_2: tuple = ((0,), (1,))
BOTTOM_3: tuple = ((0,), (1,), (2,))
BOTTOM_4: tuple = ((0,), (1,), (2,), (3,))
PATHS_2: dict = nx.shortest_paths.shortest_path_length(
LATTICE_2, source=None, target=BOTTOM_2
)
LAYERS_2: dict = {
val: {key for key in PATHS_2.keys() if PATHS_2[key] == val}
for val in set(PATHS_2.values())
}
PATHS_3: dict = nx.shortest_paths.shortest_path_length(
LATTICE_3, source=None, target=BOTTOM_3
)
LAYERS_3: dict = {
val: {key for key in PATHS_3.keys() if PATHS_3[key] == val}
for val in set(PATHS_3.values())
}
PATHS_4: dict = nx.shortest_paths.shortest_path_length(
LATTICE_4, source=None, target=BOTTOM_4
)
LAYERS_4: dict = {
val: {key for key in PATHS_4.keys() if PATHS_4[key] == val}
for val in set(PATHS_4.values())
}
def local_precompute_sources(joint_distribution: dict[tuple, float]) -> dict:
"""
A utility function that computes the local entropy of each subset of
elements. This speeds up the computation using the hmin function
considerably,
Parameters
----------
joint_distribution: dict[tuple, float]
The joint probability distribution.
Keys are tuples corresponding to the state of each element.
The valules are the probabilities.
Returns
-------
dict
For each state, the joint entropy of every subset.
"""
N: int = len(list(joint_distribution.keys())[0])
sources: list = list(make_powerset(range(N)))
sources.remove(())
local_entropies: dict = {
state: {source: {} for source in sources} for state in joint_distribution.keys()
}
state: tuple
for state in local_entropies.keys():
source: tuple
for source in local_entropies[state]:
# For a given state x, and a given source s, computes the total probability mass of the joint
# distribution consistent with the state of the source.
# E.g. if state = (0,0,0) and source = (0,1), probability_mass = P(X0=0 AND X1=0)
probability_mass = sum(
[
joint_distribution[key]
for key in joint_distribution.keys()
if reduce_state(key, source) == reduce_state(state, source)
]
)
if probability_mass > 0:
local_entropies[state][source] = -np.log2(probability_mass)
else:
local_entropies[state][
source
] = 0 # Set log2(0) to 0 since impossible events contain no information.
return local_entropies
def hmin_discrete_redundancy(
atom: tuple, state: tuple, sources: dict, joint_distribution: dict[tuple, float]
) -> float:
"""
For a collection of sources :math:`\\alpha=\\{a_1, a_2, \\ldots, a_k\\}`, computes
the redundnat entropy shared by all sources as:
:math:`h_{\\cap}^{min}(\\alpha) = \\min_{i}h(a_i)`
See:
Finn, C., & Lizier, J. T. (2020).
Generalised Measures of Multivariate Information Content.
Entropy, 22(2), Article 2.
https://doi.org/10.3390/e22020216
Parameters
----------
atom : tuple
The partial entropy atom.
state : tuple
The state of the system.
sources : dict
The pre-computed local entropies constructed by the
precompute_local_entropies() function.
joint_distribution: dict[tuple, float]
The joint probability distribution.
Keys are tuples corresponding to the state of each element.
The valules are the probabilities.
Returns
-------
float
The redundant entropy shared by every source in the atom.
"""
if joint_distribution[state] == 0:
return 0
else:
return min(sources[state][source] for source in atom)
def imin_discrete_redundancy(
atom: tuple,
state: tuple,
inputs: tuple,
target: tuple,
sources: dict,
joint_distribution: dict[tuple, float],
) -> float:
"""
For a collection of sources :math:`\\alpha = \\{a_{1}, a_{2}, \\ldots, a_{k}\\}` and a target :math:`t` the redundancy is defined as:
:math:`i_{min}(\\alpha;t) = \\min_{i}h(a_i) - \\min_{i}h(a_i|t)`
See:
Finn, C., & Lizier, J. T. (2020).
Generalised Measures of Multivariate Information Content.
Entropy, 22(2), Article 2.
https://doi.org/10.3390/e22020216
Parameters
----------
atom : tuple
The partial information atom.
state : tuple
The joint-state of the systems.
inputs : tuple
The indices of the input variables.
target : tuple
The indices of the target variable. May be multivariate.
sources : dict
The pre-computed local entropies
for the joint state of the source and the target.
joint_distribution: dict[tuple, float]
The joint probability distribution.
Keys are tuples corresponding to the state of each element.
The valules are the probabilities.
Returns
-------
float
The local redundant mutual information.
"""
i_plus: float = np.inf
i_minus: float = np.inf
h_target: float = sources[state][target]
for source in atom:
input_source: tuple = tuple(inputs[x] for x in source)
h_input_source: float = sources[state][input_source]
if i_plus > h_input_source:
i_plus = h_input_source
joint_source: tuple = tuple(sorted(input_source + target))
h_conditional: float = sources[state][joint_source] - h_target
if i_minus > h_conditional:
i_minus = h_conditional
return i_plus - i_minus
def hsx_discrete_redundancy(
atom: tuple, state: tuple, joint_distribution: dict[tuple, float]
) -> float:
"""
Computes the redundant entropy shared by a set of sources using the :math:`h_{sx}` function.
For a collection of sources :math:`\\alpha = \\{a_{1}, a_{2}, \\ldots, a_{k}\\}`,
the redundancy is defined as
:math:`h^{sx}_{\cap}(\\alpha) = -\\log_{2} P(a_{1} \\cup a_{2} \\cup \\ldots \\cup a_{k})`.
See:
Varley, T. F., Pope, M., Maria Grazia, P., Joshua, F., & Sporns, O. (2023).
Partial entropy decomposition reveals higher-order information structures in human brain activity.
Proceedings of the National Academy of Sciences, 120(30), e2300888120.
https://doi.org/10.1073/pnas.2300888120
Parameters
----------
atom : tuple
The partial entropy atom.
state : tuple
The state of the system.
joint_distribution: dict[tuple, float]
The joint probability distribution.
Keys are tuples corresponding to the state of each element.
The valules are the probabilities.
Returns
-------
float
The redundant entropy shared by every source in the atom..
"""
if joint_distribution[state] == 0:
return 0.0
else:
state_set: set = set()
source: tuple
for source in atom:
state_source: tuple = reduce_state(state, source)
state_set.update(
{
key
for key in joint_distribution.keys()
if reduce_state(key, source) == state_source
}
)
redundant_entropy: float = -np.log2(
sum(joint_distribution[s] for s in state_set)
)
return redundant_entropy
def isx_discrete_redundancy(
atom: tuple,
state: tuple,
inputs: tuple,
target: tuple,
joint_distribution: dict[tuple, float],
) -> float:
"""
For a collection of sources :math:`\\alpha = \\{a_{1}, a_{2}, \\ldots, a_{k}\\}` and a target :math:`t` the redundancy is defined as:
:math:`i_{sx}(\\alpha;t) = \\log_{2}\\frac{P(t) - P(t \\cap (\\bar{a}_1 \\cap \\ldots \\cap \\bar{a}_k)}{1 - P(\\bar{a}_1 \\cap \\ldots \\cap \\bar{a}_{k})} - \\log_2 P(t)`
Parameters
----------
atom : tuple
The partial information atom.
state : tuple
The joint-state of the systems.
inputs : tuple
The indices of the input variables.
target : tuple
The indices of the target variable. May be multivariate.
sources : dict
The pre-computed local entropies
for the joint state of the source and the target.
joint_distribution: dict[tuple, float]
The joint probability distribution.
Keys are tuples corresponding to the state of each element.
The valules are the probabilities.
Returns
-------
float
The redundant mutual informations.
"""
# i_sx^+
# -log2(1 / a1 U a2 U ... U ak)
state_set: set = set()
source: tuple
for source in atom:
input_source = tuple(inputs[x] for x in source)
state_source: tuple = reduce_state(state, input_source)
state_set.update(
{
key
for key in joint_distribution.keys()
if reduce_state(key, input_source) == state_source
}
)
i_plus = np.log2(1 / sum(joint_distribution[x] for x in state_set))
# i_sx^-
# -log2( P(t) / t ^ (a1 U a2 U ... U ak))
target_state: tuple = reduce_state(state, target)
p_target: float = sum(
[
joint_distribution[key]
for key in joint_distribution.keys()
if reduce_state(key, target) == target_state
]
)
state_set: set = set()
for source in atom:
input_source = tuple(inputs[x] for x in source)
state_source: tuple = reduce_state(state, input_source)
state_set.update(
{
key
for key in joint_distribution.keys()
if reduce_state(key, input_source) == state_source
}
)
state_set = {key for key in state_set if reduce_state(key, target) == target_state}
i_minus = np.log2(p_target / sum([joint_distribution[x] for x in state_set]))
return i_plus - i_minus
def mobius_inversion(
decomposition: str,
joint_distribution: dict[tuple, float],
redundancy: str,
inputs: tuple = (None,),
target: tuple = (None,),
) -> tuple[dict, dict]:
"""
Computes the Mobius inversion on the antichain lattice.
Parameters
----------
decomposition : str
Which decomposition to use:
For partial information decomposition use "pid".
For partial entropy decomposition use "ped".
joint_distribution: dict[tuple, float]
The joint probability distribution.
Keys are tuples corresponding to the state of each element.
The valules are the probabilities.
redundancy : str
The redundancy function.
inputs : tuple, optional
The indicies of the inputs. The default is (None,)
target : tuple, optional
The (potentially multivariate) indices of the target.
The default is (None,).
Returns
-------
(dict, dict)
The pointwise and average partial information atom dictionaries.
"""
decomposition_lower = decomposition.lower()
assert decomposition_lower in {
"pid",
"ped",
}, "You must specify a decomposition: PID or PED."
if decomposition_lower == "pid":
assert redundancy in {
"isx",
"imin",
}, "The implemented redundancy functions are 'imin' and 'hsx'."
assert target != (None,), "You must specify a target."
if redundancy == "isx":
redundancy_function = isx_discrete_redundancy
elif redundancy == "imin":
redundancy_function = imin_discrete_redundancy
total_str: str = "total_information"
partial_str: str = "pi"
elif decomposition_lower == "ped":
assert redundancy in {
"hsx",
"hmin",
}, "The implemented redundancy functions are 'hxs' and 'hmin'."
if redundancy == "hsx":
redundancy_function = hsx_discrete_redundancy
elif redundancy == "hmin":
redundancy_function = hmin_discrete_redundancy
total_str: str = "total_entropy"
partial_str: str = "pe"
if inputs == (None,): # If no inputs are specified, all elements are inputs
N = {len(x) for x in joint_distribution.keys()}.pop()
else:
N = len(inputs)
if N == 2:
lattice, bottom, layers = LATTICE_2, BOTTOM_2, LAYERS_2
elif N == 3: # |
lattice, bottom, layers = LATTICE_3, BOTTOM_3, LAYERS_3
else: # |
lattice, bottom, layers = LATTICE_4, BOTTOM_4, LAYERS_4
# Pre-computing all the local h_{\partial}(source) values
# this saves a lot of time.
if "min" in redundancy:
sources: dict = local_precompute_sources(joint_distribution)
ptw: dict = {state: dict() for state in joint_distribution.keys()}
for state in joint_distribution.keys():
for layer in layers:
for atom in layers[layer]:
if decomposition_lower == "pid":
if redundancy == "imin":
args = {
"atom": atom,
"state": state,
"inputs": inputs,
"target": target,
"sources": sources,
"joint_distribution": joint_distribution,
}
elif redundancy == "isx":
args = {
"atom": atom,
"state": state,
"inputs": inputs,
"target": target,
"joint_distribution": joint_distribution,
}
lattice.nodes[atom][total_str] = redundancy_function(**args)
elif decomposition_lower == "ped":
if redundancy == "hmin":
args = {
"atom": atom,
"state": state,
"sources": sources,
"joint_distribution": joint_distribution,
}
elif redundancy == "hsx":
args = {
"atom": atom,
"state": state,
"joint_distribution": joint_distribution,
}
lattice.nodes[atom][total_str] = redundancy_function(**args)
if atom == bottom:
lattice.nodes[atom][partial_str] = lattice.nodes[atom][total_str]
else:
lattice.nodes[atom][partial_str] = lattice.nodes[atom][
total_str
] - sum(
[
lattice.nodes[d][partial_str]
for d in lattice.nodes[atom]["descendants"]
]
)
local_ptw: dict = {
node: lattice.nodes[node][partial_str] for node in lattice.nodes
}
ptw[state] = local_ptw
avg = {}
for node in lattice.nodes:
avg[node] = sum(
[joint_distribution[state] * ptw[state][node] for state in ptw.keys()]
)
return ptw, avg
[docs]
def partial_entropy_decomposition(
redundancy: str, joint_distribution: dict[tuple[int, ...], float]
) -> (dict, dict):
"""
Computes the partial entropy decomposition of a joint distribution
with up to four elements.
The available redundancy functions are h_min from Finn and Lizier
and h_sx from Varley et al.,
.. math::
h_{\\min}(\\alpha) &= \\min(\\alpha_i) \\\\
h_{sx}(\\alpha) &= \\log\\frac{1}{P(\\alpha_1\\cup ... \\cup\\alpha_N)}
Parameters
----------
redundancy : str
The redundnacy function to use.
"hmin" for the measure from Finn and Lizier,
"hsx" for the measure from Varley et al.,
joint_distribution : dict[tuple,float]
The joint distribution object.
Returns
-------
ptw : dict[tuple, dict]
A dictionary of dictionaries,
The outer dictionary has one key for each joint state.
Each inner dictionary is the lookup of partial entropy atoms.
avg : dict[tuple, float]
The expected value for each partial entropy atom.
References
----------
Finn, C., & Lizier, J. T. (2020).
Generalised Measures of Multivariate Information Content.
Entropy, 22(2), Article 2.
https://doi.org/10.3390/e22020216
Varley, T. F., Pope, M., Maria Grazia, P., Joshua, F., & Sporns, O. (2023).
Partial entropy decomposition reveals higher-order
information structures in human brain activity.
Proceedings of the National Academy of Sciences,
120(30), e2300888120.
https://doi.org/10.1073/pnas.2300888120
"""
ptw, avg = mobius_inversion(
decomposition="ped",
joint_distribution=joint_distribution,
redundancy=redundancy,
)
return ptw, avg
[docs]
def representational_complexity(avg: dict, comparator: Callable = min) -> float:
"""
Computes the representational complexity of a given partial information or entropy lattice.
The representational complexity is a measure of how
much partial information atoms of a given degree of synergy
contribute to the overall mutual information or entropy.
Parameters
----------
avg : dict[tuple, float]
The dictionary of partial information/entropy atoms.
Returned from any of the above functions.
comparator : function, optional
Whether to consider the minimum complexity of an atom.
or the maximum complexity of an atom.
Options are: min, max, np.min, np.max.
The default is min, following the original work
by Ehrlich et al.,.
Returns
-------
float
The representational complexity.
References
----------
Ehrlich, D. A., Schneider, A. C., Priesemann, V., Wibral, M., & Makkeh, A. (2023).
A Measure of the Complexity of Neural Representations
based on Partial Information Decomposition.
Transactions on Machine Learning Research.
https://openreview.net/forum?id=R8TU3pfzFr
"""
assert comparator in (min, max, np.min, np.max), "The comparator must be min or max"
rc: float = 0.0
for atom in avg.keys():
rc += avg[atom] * comparator(len(source) for source in atom)
return rc / sum(avg.values())