sakura.utils.sliced_wasserstein.SlicedWasserstein.sliced_wasserstein_distance

SlicedWasserstein.sliced_wasserstein_distance(encoded_samples, distribution_fn, num_projections=50, p=2, device='cpu')

Compute SWD between encoded samples and distribution function samples.

Parameters:
  • encoded_samples (torch.Tensor) – Samples from encoded distribution

  • distribution_fn (Callable) – Function that generates drawn distribution samples (args: batch_size, n_dim)

  • num_projections (int, optional) – Number of projections to approximate sliced wasserstein distance, defaults to 50

  • p (int, optional) – Exponent for Wasserstein-p distance, defaults to 2

  • device (Literal['cpu', 'cuda'], optional) – torch computation device, defaults to ‘cpu’

Returns:

Sliced Wasserstrain distances of size (num_projections, 1)

Return type:

torch.Tensor