sakura.sakuraAE.sakuraAE.train

sakuraAE.train(split_id, train_main=True, train_pheno=True, selected_pheno=None, train_signature=True, selected_signature=None, epoch=50, batch_size=100, tick_controller_epoch=True, make_logs=True, log_prefix='train', log_loss_groups=['loss', 'regularization'], save_raw_loss=False, test_every_epoch=False, test_on_segment=False, test_segment=2000, tests=None, checkpoint_on_segment=False, checkpoint_segment=2000, checkpoint_prefix='', checkpoint_save_arch=False, resume=False, resume_dict=None, detach=False, detach_from='')

Batch train model for at least one epoch.

Parameters:
  • split_id (str) – Split id to be used in this train

  • train_main (bool, optional) – Whether to forward the main latent space part during training, defaults to True

  • train_pheno (bool, optional) – Whether to forward phenotype side task(s) during training, defaults to True

  • selected_pheno* – Phenotype id(s) used for phenotype side tasks during training, selected phenotype(s)

  • train_signature (bool, optional) – Whether to forward gene signature side tasks during training, defaults to True

  • selected_signature* – Similar to selected_pheno, but for signature side tasks during training

  • epoch (int, optional) – Number of epochs to be trained in this round of training, defaults to 50

  • batch_size (int, optional) – Batch size to be used in this round of training, defaults to 100

  • tick_controller_epoch (bool, optional) – Should controller epoch be ticked, defaults to True

  • make_logs (bool, optional) – Should information, including losses be logged, defaults to True

  • log_prefix (str, optional) – Prefix of training log (for losses, this prefix will be added first to the item name in tensorboard and filename of latent embeddings), defaults to ‘train’

  • log_loss_groups (list[str], optional) – Selected loss(es) group to be logged, defaults to [‘loss’, ‘regularization’]

  • save_raw_loss (bool, optional) – Whether to record raw losses, defaults to False

  • test_every_epoch (bool, optional) – Should test/evaluation be performed after finishing each epoch, defaults to False

  • test_on_segment (bool, optional) – Whether to do segmental testing, defaults to False

  • test_segment (int, optional) – Tick interval of test segment, defaults to 2000

  • tests (list[dict[str,Any]], optional) – A list of test configuration dictionaries, where each dictionary should contain keys: ‘on_split’, ‘make_logs’, ‘dump_latent’ and ‘latent_prefix’

  • checkpoint_on_segment (bool, optional) – Should model be checkpointed after a certain tick interval, defaults to False

  • checkpoint_segment (int, optional) – Tick interval of checkpoint segment, defaults to 2000

  • checkpoint_prefix (str, optional) – Prefix of checkpoint files

  • checkpoint_save_arch (bool) – Should model architecture be checkpointed, defaults to False

  • resume (bool, optional) – Whether to resume from saved training session, defaults to False

  • resume_dict (dict[str, Any], optional) – Session state dictionary used for resuming previous training

  • detach (bool, optional) – Should losses be detached as specified in <detach_from> from the computation graph, defaults to False

  • detach_from (str, optional) – Starting point in the model from which the loss should be detached

Note

The selected_pheno (selected signature) should be configured and stored in self.selected_pheno (self.selected_signature). If it is set to None, self.selected_pheno (self.selected_signature) will act as the default, which means that all selected phenotypes (or signatures) will be trained. This feature is designed for complex training scenarios where the neural network (NN) is partially forwarded.

Returns:

None