sakura.model_controllers.extractor_controller.ExtractorController.regularize
- ExtractorController.regularize(tensor, regularization_config: dict, supervision=None)
Handle regularization to the given tensor based on the specified configuration.
- Parameters:
tensor (torch.Tensor) – Tensor to regularize (usually, of shape (N_batch, …))
regularization_config (dict) – A dict containing regularization configuration, including type and parameters specific to the chosen method
supervision (list[int], optional) – List of label indices (0 <= index < n_labels) for batch sample supervised regularization
- Returns:
Distance between encoded samples and samples randomly drawn from distribution function.