sakura.utils.gradient_reverse.ReverseLayerF.backward

static ReverseLayerF.backward(ctx, grad_output)

Backward pass for the custom autograd operation with gradient scaling.

Parameters:
  • ctx (torch.autograd.function.FunctionCtx) – Context object containing saved tensors from forward pass

  • grad_output (torch.Tensor) – Upstream gradient of shape (N, *), matching the forward input dimensions

Returns:

  • grad_input: Gradient of input_ scaled by -alpha_ (shape preserved)

  • None: Placeholder for alpha gradient (not calculated)

Return type:

tuple (torch.Tensor, None)