Source code for syntropy.neural.multivariate_mi
import scipy.stats as stats
import torch
from .shannon import differential_entropy
[docs]
def total_correlation(
idxs: tuple[int],
data: torch.Tensor,
context: None | tuple[int] = None,
data_test: None | torch.Tensor = None,
flow_kwargs: dict = None,
train_kwargs: dict = None,
verbose: bool = False,
) -> tuple[float, float]:
"""
Computes the total correlation of the data using normalizing flow estimators.
Parameters
----------
idxs : tuple[int, ...]
The tuple of indices the differential entropy is computed for.
data : torch.Tensor
The training data, in samples x features format.
context : None | tuple[int]
If not None, the indices of the conditioning variables.
data_test : None | torch.Tensor
If not None, the testing data in samples x features format.
flow_kwargs : dict
Arguments for the utils.initalize_flow function.
train_kwargs : dict
Arguments for the utils.train_flow function.
verbose : bool
Whether to print the training progress.
Returns
-------
float
"""
flow_kwargs = flow_kwargs or {}
train_kwargs = train_kwargs or {}
if context is None:
context_arg = None
else:
context_arg = context
lookup: dict[tuple[int,...], float] = {}
h_idxs, _ = differential_entropy(
idxs=idxs,
data=data,
context=context_arg,
data_test=data_test,
verbose=verbose,
train_kwargs=train_kwargs,
flow_kwargs=flow_kwargs,
)
lookup[idxs] = h_idxs
for i in idxs:
h_i, _ = differential_entropy(
idxs=(i,),
data=data,
context=context_arg,
data_test=data_test,
verbose=verbose,
flow_kwargs=flow_kwargs,
train_kwargs=train_kwargs,
)
lookup[(i,)] = h_i
tc: float = 0.0
for key in lookup.keys():
if len(key) == 1:
tc += lookup[key][0]
else:
tc -= lookup[key][0]
return tc
[docs]
def higher_order_information(
idxs: tuple[int],
data: torch.Tensor,
context: None | tuple[int] = None,
data_test: None | torch.Tensor = None,
flow_kwargs: dict = None,
train_kwargs: dict = None,
verbose: bool = False,
) -> dict[str, float]:
"""
Computes the O-information, S-information, total correlation, and dual total correlation for the data.
Computing them all as a set is more efficient than computing each one independently.
Parameters
----------
idxs : tuple[int, ...]
The tuple of indices the differential entropy is computed for.
data : torch.Tensor
The training data, in samples x features format.
context : None | tuple[int]
If not None, the indices of the conditioning variables.
data_test : None | torch.Tensor
If not None, the testing data in samples x features format.
flow_kwargs : dict
Arguments for the utils.initalize_flow function.
train_kwargs : dict
Arguments for the utils.train_flow function.
verbose : bool
Whether to print the training progress.
Returns
-------
dict[str, float]
"""
flow_kwargs = flow_kwargs or {}
train_kwargs = train_kwargs or {}
if context is None:
context_arg = None
else:
context_arg = context
lookup_marginals: dict[tuple[int, ...], float] = {(i,): 0.0 for i in idxs}
lookup_residuals: dict[tuple[int, ...], float] = {tuple(j for j in idxs if j != i): 0.0 for i in idxs}
N: int = len(idxs)
h_idxs, _ = differential_entropy(
idxs=idxs,
data=data,
context=context_arg,
data_test=data_test,
verbose=verbose,
train_kwargs=train_kwargs,
flow_kwargs=flow_kwargs,
)
for i in idxs:
h_i, _ = differential_entropy(
idxs=(i,),
data=data,
context=context_arg,
data_test=data_test,
verbose=verbose,
flow_kwargs=flow_kwargs,
train_kwargs=train_kwargs,
)
lookup_marginals[(i,)] = h_i
idxs_minus_i = tuple(j for j in idxs if j != i)
h_minus_i, _ = differential_entropy(
idxs = idxs_minus_i,
data = data,
context = context_arg,
data_test = data_test,
verbose = verbose,
flow_kwargs = flow_kwargs,
train_kwargs = train_kwargs
)
lookup_residuals[idxs_minus_i] = h_minus_i
tc_idxs: float = sum(lookup_marginals[key] for key in lookup_marginals.keys()) - h_idxs
residual_tcs: dict = {}
for key in lookup_residuals:
residual_tc = -lookup_residuals[key]
for i in key:
residual_tc += lookup_marginals[(i,)]
residual_tcs[key] = residual_tc
results = {
"o_information": -(((N-2)*tc_idxs) - sum(residual_tcs[key] for key in residual_tcs.keys())),
"s_information": (N*tc_idxs) - sum(residual_tcs[key] for key in residual_tcs.keys()),
"dual_total_correlation": ((N-1)*tc_idxs) - sum(residual_tcs[key] for key in residual_tcs.keys()),
"total_correlation": tc_idxs
}
return results