sakura.model_controllers.extractor_controller.ExtractorController.train

ExtractorController.train(batch, backward_reconstruction_loss=True, backward_main_latent_regularization=True, backward_pheno_loss=True, selected_pheno: dict | None = None, backward_signature_loss=True, selected_signature: dict | None = None, suppress_backward=False, detach=False, detach_from='', save_raw_loss=False)

Train the model using the spectrain(ified batch of data.

This function performs a training step by computing the losses associated with reconstruction, main latent regularization, phenotype, and signature components. It then performs backpropagation to update the model weights based on the computed total loss.

Parameters:
  • batch (dict) – A dictionary containing a batched data for training, typically obtained from rna_count dataset

  • backward_reconstruction_loss (bool, optional) – Whether to optimize and backward reconstruction loss, defaults to True

  • backward_main_latent_regularization (bool, optional) – Whether to optimize and backward regularization of main latent space, defaults to True

  • backward_pheno_loss (bool, optional) – Whether to optimize and backward phenotype-related loss, defaults to True

  • selected_pheno (dict, optional) – Phenotype selection for backpropagation optimization, should be None (selecting all phenotypes, and related losses and regularizations), or a dictionary formulated as: {‘pheno_name’: {‘loss’: [list of loss names] or ‘*’ (selecting all), ‘regularization’: [list of regularization keys, could be Null] or ‘*’ or None (no regularization)}}

  • backward_signature_loss (bool, optional) – Whether to optimize and backward signature-related loss, defaults to True

  • selected_signature (dict, optional) – Signature selection for backpropagation optimization, should be None (selecting all signatures, and related losses and regularizations), or a dictionary formulated similar to selected_pheno

  • suppress_backward (bool, optional) – Whether to suppress backward of sum of losses (useful when external training agent override the control)

  • detach (bool, optional) – Should losses be detached as specified in <detach_from> from the computation graph, defaults to False

  • detach_from (str, optional) – Starting point in the model from which the loss should be detached

  • save_raw_loss (bool, optional) – Whether to record unweighted, raw losses apart from the weighted losses, defaults to False

Returns:

A dictionary containing the losses computed during this training tick, including total loss and individual component losses

Return type:

dict