sakura.utils.kl_divergence.KLDivergence.kl_divergence
- KLDivergence.kl_divergence(encoded_samples, distribution_fn, target, device='cpu')
Compute KL divergence between encoded samples and distribution function samples
- Parameters:
encoded_samples (torch.Tensor) – Samples from encoded distribution
distribution_fn (Callable) – Function that generates drawn distribution samples (args: batch_size, n_dim)
device (Literal['cpu', 'cuda'], optional) – torch computation device, defaults to ‘cpu’
- Returns:
KL divergence between the distributions
- Return type:
torch.Tensor