Source code for syntropy.neural.multivariate_mi
import torch
from .shannon import differential_entropy
# %%
[docs]
def total_correlation(
idxs: tuple[int],
data: torch.Tensor,
data_test: None | torch.Tensor = None,
flow_kwargs: dict = None,
train_kwargs: dict = None,
verbose: bool = False,
) -> tuple[torch.Tensor, float]:
"""
Computes the total correlation of the data using normalizing flow estimators.
Parameters
----------
idxs : tuple[int, ...]
The tuple of indices the differential entropy is computed for.
data : torch.Tensor
The training data, in samples x features format.
data_test : None | torch.Tensor
If not None, the testing data in samples x features format.
flow_kwargs : dict
Arguments for the utils.initalize_flow function.
train_kwargs : dict
Arguments for the utils.train_flow function.
verbose : bool
Whether to print the training progress.
Returns
-------
torch.Tensor
float
"""
flow_kwargs = flow_kwargs or {}
train_kwargs = train_kwargs or {}
lookup: dict[tuple[int, ...], torch.Tensor] = {}
h: torch.Tensor
h, _, _ = differential_entropy(
idxs=idxs,
data=data,
data_test=data_test,
verbose=verbose,
train_kwargs=train_kwargs,
flow_kwargs=flow_kwargs,
)
lookup[idxs] = h
h_i: torch.Tensor
for i in idxs:
h_i, _, _ = differential_entropy(
idxs=(i,),
data=data,
data_test=data_test,
verbose=verbose,
flow_kwargs=flow_kwargs,
train_kwargs=train_kwargs,
)
lookup[(i,)] = h_i
tc: torch.Tensor = torch.zeros(data.shape[0])
for key in lookup.keys():
if len(key) == 1:
tc += lookup[key]
else:
tc -= lookup[key]
return tc, tc.mean()
[docs]
def higher_order_information(
idxs: tuple[int],
data: torch.Tensor,
data_test: None | torch.Tensor = None,
flow_kwargs: dict = None,
train_kwargs: dict = None,
verbose: bool = False,
) -> dict[str, float]:
"""
Computes the O-information, S-information, total correlation, and dual total correlation for the data.
Computing them all as a set is more efficient than computing each one independently.
Parameters
----------
idxs : tuple[int, ...]
The tuple of indices the differential entropy is computed for.
data : torch.Tensor
The training data, in samples x features format.
context : None | tuple[int]
If not None, the indices of the conditioning variables.
data_test : None | torch.Tensor
If not None, the testing data in samples x features format.
flow_kwargs : dict
Arguments for the utils.initalize_flow function.
train_kwargs : dict
Arguments for the utils.train_flow function.
verbose : bool
Whether to print the training progress.
Returns
-------
dict[str, dict[str, float | torch.Tensor]]
"""
flow_kwargs = flow_kwargs or {}
train_kwargs = train_kwargs or {}
lookup_marginals: dict[tuple[int, ...], torch.Tensor] = {}
lookup_residuals: dict[tuple[int, ...], torch.Tensor] = {}
N: int = len(idxs)
h, _, _ = differential_entropy(
idxs=idxs,
data=data,
data_test=data_test,
verbose=verbose,
train_kwargs=train_kwargs,
flow_kwargs=flow_kwargs,
)
for i in idxs:
h_i, _, _ = differential_entropy(
idxs=(i,),
data=data,
data_test=data_test,
verbose=verbose,
flow_kwargs=flow_kwargs,
train_kwargs=train_kwargs,
)
lookup_marginals[(i,)] = h_i
idxs_r = tuple(j for j in idxs if j != i)
h_r, _, _ = differential_entropy(
idxs=idxs_r,
data=data,
data_test=data_test,
verbose=verbose,
flow_kwargs=flow_kwargs,
train_kwargs=train_kwargs,
)
lookup_residuals[idxs_r] = h_r
tc: torch.Tensor = -h
for key in lookup_marginals.keys():
tc += lookup_marginals[key]
residual_tcs: torch.Tensor = torch.zeros(data.shape[0])
for key in lookup_residuals.keys():
tc_r = -lookup_residuals[key]
for i in key:
tc_r += lookup_marginals[(i,)]
residual_tcs += tc_r
oinfo: torch.Tensor = ((2 - N) * tc) + residual_tcs
dtc: torch.Tensor = ((N - 1) * tc) - residual_tcs
sinfo: torch.Tensor = (N * tc) - residual_tcs
results: dict[str, dict[str, float | torch.Tensor]] = {
"sinfo": {"ptw": sinfo, "avg": sinfo.mean()},
"dtc": {"ptw": dtc, "avg": dtc.mean()},
"oinfo": {"ptw": oinfo, "avg": oinfo.mean()},
"tc": {"ptw": tc, "avg": tc.mean()},
}
return results