sakura.sakuraAE.sakuraAE

class sakura.sakuraAE.sakuraAE(config_json_path, verbose=False, suppress_train=False, suppress_tensorboardX=False)

Bases: object

A comprehensive class for SAKURA pipeline

This class manages overall workflow of SAKURA includeing model initialization, training, testing, and model inference or external model merging based on the configuration and argument settings.

Parameters:
  • config_json_path (str) – Path to the configuration JSON file, which contains all the necessary settings for the class

  • verbose (bool, optional) – Whether to enable verbose console logging, defaults to False

  • suppress_train (bool, optional) – Whether to suppress model training, only setup dataset and model, defaults to False

  • suppress_tensorboardX (bool, optional) – Whether to suppress Logger to initiate tensorboardX (to prevent flushing logs), defaults to False

Methods

execute_inference

Perform inference on the given tasks represented as a list of stories.

generate_splits

Generate dataset split masks for model training and testing.

insert_external_module

Insert an external module and merge it with SAKURA model.

integrity_check

Perform integrity check on selected phenotypes/signatures against the input dataset.

load_checkpoint

Load a checkpoint file and resume the model's state, including parameters, random states, training progress, etc.

save_checkpoint

Save the current state of the model and training process as a checkpoint.

setup_dataset

Set up dataset for SAKURA model.

test

Test all latents of the model, with options to evaluate specific loss groups and dump selected latents.

train

Batch train model for at least one epoch.

train_hybrid

Train the model in hybrid mode, where model module splits are trained with flexibility.

train_hybrid_fastload

Implement the multithread dataloader version of hybrid mode training, where model module splits are trained with flexibility.

train_story

Train the model on the given sets of tasks represented as a list of storylines.