sakura.utils.distributions.swiss_roll

sakura.utils.distributions.swiss_roll(batch_size, n_dim=2, n_labels=10, label_indices=None)

Generates samples from a Swiss roll manifold with optional label conditioning.

The Swiss roll is a 2D manifold embedded in 2D space, shaped like a rolled spiral. Labels determine which segment of the spiral the samples come from.

Parameters:
  • batch_size (int) – Number of samples to generate

  • n_dim (Literal[2], optional) – Dimension of output samples, must be 2

  • n_labels (int, optional) – Number of distinct label segments in the spiral, defaults to 10

  • label_indices (list[int], optional) – List of label indices (0 <= index < n_labels) for each batch sample, randomly assigned if none provided

Returns:

Tensor of shape (batch_size, n_dim) with Swiss roll samples

Return type:

torch.FloatTensor