sakura.model_controllers.extractor_controller.ExtractorController.loss
- ExtractorController.loss(batch, expr_key='all', forward_pheno=True, selected_pheno=None, forward_signature=True, selected_signature=None, forward_reconstruction=True, forward_main_latent=True, dump_forward_results=False, detach=False, detach_from='', save_raw_loss=False)
Calculate composite loss across all active components.
- Parameters:
batch (dict) –
A dictionary containing a batched data to calculate loss, should include:
expr: Expression matrices
pheno: Phenotype labels
expr_key (str, optional) – The key of expression group to use as input, defaults to ‘all’
forward_pheno (bool, optional) – Whether to calculate phenotype related losses, defaults to True
selected_pheno (dict, optional) –
Phenotype selection for loss calculation, should be None (selecting all phenotypes, and related losses and regularizations), or a dictionary formulated as:
{‘pheno_name’: {‘loss’: [list of loss names] or ‘*’ (selecting all), ‘regularization’: [list of regularization keys, could be Null] or ‘*’ or None (no regularization)}}
forward_signature (bool, optional) – Whether to calculate signature related losses, defaults to True
selected_signature (dict, optional) – Signature selection for loss calculation, should be None (selecting all signatures, and related losses and regularizations), or a dictionary formulated similar to selected_pheno
forward_reconstruction* – Whether to calculate main losses of decoder reconstruction, defaults to True
forward_main_latent (bool, optional) – Whether to calculate main latent regularization losses, defaults to True
dump_forward_results (bool, optional) – Whether to preserve forwarded tensors in the return dict, defaults to False
detach (bool, optional) – Should loss be detached from midway of the network as specified in <detach_from>, defaults to False
detach_from (Literal['pre_encoder', 'encoder'] or str, optional) – Specific component from which the loss should be detached if <detach> is True
save_raw_loss (bool, optional) – Whether to record unweighted, raw losses apart from the weighted losses, defaults to False
Note
<forward_reconstruction>: When turned on, losses of all latents will be calculated by force.
- Returns:
A dictionary containing the computed losses:
’main_latent_loss’: Main latent loss details
’pheno_loss’: Phenotype loss details
’signature_loss’: Signature loss details
’fwd_res’(optional): Forwarded result tensors, if <dump_forward_results> is True.
- Return type: