sakura.utils.gradient_reverse.NeutralizeLayerF.backward

static NeutralizeLayerF.backward(ctx, grad_output)

Backward pass that nullifies upstream gradients.

Parameters:
  • ctx (torch.autograd.function.FunctionCtx) – Context object containing saved tensors from forward pass

  • grad_output (torch.Tensor) – Upstream gradient of shape (N, *), matching the forward input dimensions

Returns:

  • grad_input: Zero-valued gradient tensor

  • None: Placeholder for unused gradient

Return type:

tuple (torch.Tensor, None)