sakura.model_controllers.extractor_controller.ExtractorController.eval
- ExtractorController.eval(batch, forward_pheno=False, selected_pheno=None, forward_signature=False, selected_signature=None, forward_reconstruction=False, forward_main_latent=False, dump_latent=False, save_raw_loss=False)
Evaluate the model using the specified batch of data.
This function performs an evaluation step by computing the losses associated with reconstruction, main latent regularization, phenotype, and signature components.
- Parameters:
batch (dict) – A dictionary containing a batched data for evaluation
forward_pheno (bool, optional) – Whether to calculate phenotype related losses, defaults to True
selected_pheno (dict, optional) – Phenotype selection for loss calculation, should be None (selecting all phenotypes, and related losses and regularizations), or a dictionary formulated as: {‘pheno_name’: {‘loss’: [list of loss names] or ‘*’ (selecting all), ‘regularization’: [list of regularization keys, could be Null] or ‘*’ or None (no regularization)}}
forward_signature (bool, optional) – Whether to calculate signature related losses, defaults to True
selected_signature (dict, optional) – Signature selection for loss calculation, should be None (selecting all signatures, and related losses and regularizations), or a dictionary formulated similar to selected_pheno
forward_reconstruction* – Whether to calculate main losses of decoder reconstruction, defaults to True
forward_main_latent (bool, optional) – Whether to calculate main latent regularization losses, defaults to True
dump_latent (bool, optional) – Whether to preserve forwarded latents in the return dict, defaults to False
save_raw_loss (bool, optional) – Whether to record unweighted, raw losses apart from the weighted losses, defaults to False
- Returns:
A dictionary containing the losses computed during this evaluation, including total loss and individual component losses
- Return type: