sakura.model_controllers.extractor_controller.ExtractorController.loss

ExtractorController.loss(batch, expr_key='all', forward_pheno=True, selected_pheno=None, forward_signature=True, selected_signature=None, forward_reconstruction=True, forward_main_latent=True, dump_forward_results=False, detach=False, detach_from='', save_raw_loss=False)

Calculate composite loss across all active components.

Parameters:
  • batch (dict) –

    A dictionary containing a batched data to calculate loss, should include:

    • expr: Expression matrices

    • pheno: Phenotype labels

  • expr_key (str, optional) – The key of expression group to use as input, defaults to ‘all’

  • forward_pheno (bool, optional) – Whether to calculate phenotype related losses, defaults to True

  • selected_pheno (dict, optional) –

    Phenotype selection for loss calculation, should be None (selecting all phenotypes, and related losses and regularizations), or a dictionary formulated as:

    {‘pheno_name’: {‘loss’: [list of loss names] or ‘*’ (selecting all), ‘regularization’: [list of regularization keys, could be Null] or ‘*’ or None (no regularization)}}

  • forward_signature (bool, optional) – Whether to calculate signature related losses, defaults to True

  • selected_signature (dict, optional) – Signature selection for loss calculation, should be None (selecting all signatures, and related losses and regularizations), or a dictionary formulated similar to selected_pheno

  • forward_reconstruction* – Whether to calculate main losses of decoder reconstruction, defaults to True

  • forward_main_latent (bool, optional) – Whether to calculate main latent regularization losses, defaults to True

  • dump_forward_results (bool, optional) – Whether to preserve forwarded tensors in the return dict, defaults to False

  • detach (bool, optional) – Should loss be detached from midway of the network as specified in <detach_from>, defaults to False

  • detach_from (Literal['pre_encoder', 'encoder'] or str, optional) – Specific component from which the loss should be detached if <detach> is True

  • save_raw_loss (bool, optional) – Whether to record unweighted, raw losses apart from the weighted losses, defaults to False

Note

<forward_reconstruction>: When turned on, losses of all latents will be calculated by force.

Returns:

A dictionary containing the computed losses:

  • ’main_latent_loss’: Main latent loss details

  • ’pheno_loss’: Phenotype loss details

  • ’signature_loss’: Signature loss details

  • ’fwd_res’(optional): Forwarded result tensors, if <dump_forward_results> is True.

Return type:

dict