sakura.utils.distributions.swiss_roll
- sakura.utils.distributions.swiss_roll(batch_size, n_dim=2, n_labels=10, label_indices=None)
Generates samples from a Swiss roll manifold with optional label conditioning.
The Swiss roll is a 2D manifold embedded in 2D space, shaped like a rolled spiral. Labels determine which segment of the spiral the samples come from.
- Parameters:
batch_size (int) – Number of samples to generate
n_dim (Literal[2], optional) – Dimension of output samples, must be 2
n_labels (int, optional) – Number of distinct label segments in the spiral, defaults to 10
label_indices (list[int], optional) – List of label indices (0 <= index < n_labels) for each batch sample, randomly assigned if none provided
- Returns:
Tensor of shape (batch_size, n_dim) with Swiss roll samples
- Return type:
torch.FloatTensor