sakura.utils.gradient_reverse.NeutralizeLayerF.backward
- static NeutralizeLayerF.backward(ctx, grad_output)
Backward pass that nullifies upstream gradients.
- Parameters:
ctx (torch.autograd.function.FunctionCtx) – Context object containing saved tensors from forward pass
grad_output (torch.Tensor) – Upstream gradient of shape (N, *), matching the forward input dimensions
- Returns:
grad_input: Zero-valued gradient tensor
None: Placeholder for unused gradient
- Return type:
tuple (torch.Tensor, None)