sakura.utils.kl_divergence.KLDivergence.kl_divergence

KLDivergence.kl_divergence(encoded_samples, distribution_fn, target, device='cpu')

Compute KL divergence between encoded samples and distribution function samples

Parameters:
  • encoded_samples (torch.Tensor) – Samples from encoded distribution

  • distribution_fn (Callable) – Function that generates drawn distribution samples (args: batch_size, n_dim)

  • device (Literal['cpu', 'cuda'], optional) – torch computation device, defaults to ‘cpu’

Returns:

KL divergence between the distributions

Return type:

torch.Tensor