sakura.sakuraAE.sakuraAE.train_hybrid

sakuraAE.train_hybrid(split_configs: dict, ticks=50000, hybrid_mode='interleave', prog_loss_weight_mode='epoch_end', make_logs=True, log_prefix='', log_loss_groups=['loss', 'regularization'], save_raw_loss=False, perform_test=False, test_segment=2000, tests: dict | None = None, perform_checkpoint=False, checkpoint_segment=2000, checkpoint_prefix='', checkpoint_save_arch=False, loss_prog_on_test: dict | None = None, resume=False, resume_dict=None)

Train the model in hybrid mode, where model module splits are trained with flexibility.

Parameters:
  • split_configs (dict[str, str or int]) – A dictionary containing module split configurations used for training, should contain below keys for each module split: ‘use_split’,’batch_size’,’train_main_latent’,’train_pheno’,’train_signature’

  • ticks (int, optional) – The total number of training iterations, each tick corresponding to the training of one batch of data, defaults to 50000

  • hybrid_mode (Literal['interleave', 'pattern', 'sum'], optional) – hybrid mode defines how the module splits are trained, defaults to ‘interleave’ where each module split is trained in a round-robin fashion.

  • prog_loss_weight_mode (Literal['on_test', 'epoch_end'], optional) – The mode for progressive loss weighting. defaults to ‘epoch_end’ where loss weights progress at the end of each epoch.

  • make_logs (bool, optional) – Should information, including losses be logged, defaults to True

  • log_prefix (str, optional) – Prefix of training log (for losses, this prefix will be added first to the item name in tensorboard and filename of latent embeddings)

  • log_loss_groups (list[str], optional) – Selected loss(es) group to be logged, defaults to [‘loss’, ‘regularization’]

  • save_raw_loss (bool, optional) – Whether to record raw losses, defaults to False

  • perform_test (bool, optional) – Whether to perform testing during training at specified <test_segment> intervals, defaults to False

  • test_segment (int, optional) – Tick interval at which testing is performed, defaults to 2000

  • tests (list[dict[str,Any]], optional) – A list of test configuration dictionaries, where each dictionary should contain keys: ‘on_split’, ‘make_logs’, ‘dump_latent’ and ‘latent_prefix’

  • perform_checkpoint (bool, optional) – Whether to checkpoint the model a certain tick interval, defaults to False

  • checkpoint_segment (int, optional) – Tick interval of model checkpoint, defaults to 2000

  • checkpoint_prefix (str, optional) – Prefix of checkpoint files

  • checkpoint_save_arch (bool, optional) – Should model architecture be checkpointed, defaults to False

  • loss_prog_on_test (dict[str, Any], optional) – A dictionary specifying progressive loss weights to use during testing when prog_loss_weight_mode is ‘on_test’, should contain keys: ‘prog_main’, ‘train_pheno’,’selected_pheno’,’train_signature’ and ‘selected_signature’

  • resume (bool, optional) – Whether to resume from saved training session, defaults to False

  • resume_dict (dict[str, Any], optional) – Session state dictionary used for resuming previous training

Note

When epoch loss progressing is on, the progression will incur only for selected loss when an epoch ends (tick reach end).

Returns:

None