sakura.model_controllers.extractor_controller.ExtractorController.eval

ExtractorController.eval(batch, forward_pheno=False, selected_pheno=None, forward_signature=False, selected_signature=None, forward_reconstruction=False, forward_main_latent=False, dump_latent=False, save_raw_loss=False)

Evaluate the model using the specified batch of data.

This function performs an evaluation step by computing the losses associated with reconstruction, main latent regularization, phenotype, and signature components.

Parameters:
  • batch (dict) – A dictionary containing a batched data for evaluation

  • forward_pheno (bool, optional) – Whether to calculate phenotype related losses, defaults to True

  • selected_pheno (dict, optional) – Phenotype selection for loss calculation, should be None (selecting all phenotypes, and related losses and regularizations), or a dictionary formulated as: {‘pheno_name’: {‘loss’: [list of loss names] or ‘*’ (selecting all), ‘regularization’: [list of regularization keys, could be Null] or ‘*’ or None (no regularization)}}

  • forward_signature (bool, optional) – Whether to calculate signature related losses, defaults to True

  • selected_signature (dict, optional) – Signature selection for loss calculation, should be None (selecting all signatures, and related losses and regularizations), or a dictionary formulated similar to selected_pheno

  • forward_reconstruction* – Whether to calculate main losses of decoder reconstruction, defaults to True

  • forward_main_latent (bool, optional) – Whether to calculate main latent regularization losses, defaults to True

  • dump_latent (bool, optional) – Whether to preserve forwarded latents in the return dict, defaults to False

  • save_raw_loss (bool, optional) – Whether to record unweighted, raw losses apart from the weighted losses, defaults to False

Returns:

A dictionary containing the losses computed during this evaluation, including total loss and individual component losses

Return type:

dict