sakura.utils.gradient_reverse.ReverseLayerF.backward
- static ReverseLayerF.backward(ctx, grad_output)
Backward pass for the custom autograd operation with gradient scaling.
- Parameters:
ctx (torch.autograd.function.FunctionCtx) – Context object containing saved tensors from forward pass
grad_output (torch.Tensor) – Upstream gradient of shape (N, *), matching the forward input dimensions
- Returns:
grad_input: Gradient of input_ scaled by -alpha_ (shape preserved)
None: Placeholder for alpha gradient (not calculated)
- Return type:
tuple (torch.Tensor, None)