sakura.utils.distributions.randn

sakura.utils.distributions.randn(dim_size)

Creates a function that generates standard normal random samples.

Parameters:

batch_size (int) – Number of batch samples

Returns:

Tensor of shape (batch_size, dim_size) with uniform samples

Return type:

torch.FloatTensor