sakura.utils.sliced_wasserstein.SlicedWasserstein.rand_projections

SlicedWasserstein.rand_projections(embedding_dim, num_samples=50)

This function generates <num_samples> L2-normalized random samples from unit sphere in latent space.

Parameters:
  • embedding_dim (int) – Dimensionality of the latent space

  • num_samples (int, optional) – Number of random projection vectors to generate, defaults to 50

Returns:

Normalized projection vectors of shape (num_samples, embedding_dim)

Return type:

torch.Tensor