sakura.utils.sliced_wasserstein.SlicedWasserstein.sliced_wasserstein_distance
- SlicedWasserstein.sliced_wasserstein_distance(encoded_samples, distribution_fn, num_projections=50, p=2, device='cpu')
Compute SWD between encoded samples and distribution function samples.
- Parameters:
encoded_samples (torch.Tensor) – Samples from encoded distribution
distribution_fn (Callable) – Function that generates drawn distribution samples (args: batch_size, n_dim)
num_projections (int, optional) – Number of projections to approximate sliced wasserstein distance, defaults to 50
p (int, optional) – Exponent for Wasserstein-p distance, defaults to 2
device (Literal['cpu', 'cuda'], optional) – torch computation device, defaults to ‘cpu’
- Returns:
Sliced Wasserstrain distances of size (num_projections, 1)
- Return type:
torch.Tensor