sakura.utils.gradient_reverse.ReverseLayerF.forward

static ReverseLayerF.forward(ctx, input_, alpha_=1.0)

Forward pass for a custom autograd operation with gradient scaling.

Parameters:
  • ctx (torch.autograd.function.FunctionCtx) – Context object to save tensors for backward computation

  • input_ (torch.Tensor) – Input tensor of shape (N, *) for forward pass, where * means number of dimensions

  • alpha_ (float, optional) – Gradient scaling factor, defaults to 1.0 (no scaling)

Returns:

Output tensor identical to input_ (shape preserved)

Return type:

torch.Tensor