sakura.model_controllers.extractor_controller.ExtractorController.regularize

ExtractorController.regularize(tensor, regularization_config: dict, supervision=None)

Handle regularization to the given tensor based on the specified configuration.

Parameters:
  • tensor (torch.Tensor) – Tensor to regularize (usually, of shape (N_batch, …))

  • regularization_config (dict) – A dict containing regularization configuration, including type and parameters specific to the chosen method

  • supervision (list[int], optional) – List of label indices (0 <= index < n_labels) for batch sample supervised regularization

Returns:

Distance between encoded samples and samples randomly drawn from distribution function.