sakura.models.extractor.Extractor.forward

Extractor.forward(batch, forward_signature=True, selected_signature=None, forward_pheno=True, selected_pheno=None, forward_main_latent=True, forward_reconstruction=True, detach=False, detach_from='')

Forward extractor framework with control over computation branches

Orchestrates data flow through the assembled modular architecture, enabling selective activation of task branches and gradient flow control.

Parameters:
  • batch (torch.Tensor) – Gene expression tensors, shape should be (N,M), where N is number of cell, M is number of gene

  • forward_signature (bool, optional) – Whether to forward signature supervision part, defaults to True

  • selected_signature (list[str], optional) – A list of selected signatures to be forwarded, None to forward all signatures, optional

  • forward_pheno (bool, optional) – Whether to forward phenotype supervision part, defaults to True

  • selected_pheno (list[str], optional) – A list of selected phenotypes to be forwarded, None to forward all phenotypes

  • forward_main_latent (bool, optional) – Whether to forward main latent part, defaults to True

  • forward_reconstruction* – Whether to forward decoder reconstruction part, defaults to True

  • detach (bool, optional) – Should the gradient be blocked from midway of the network as specified in <detach_from>, defaults to False

  • detach_from (Literal['pre_encoder', 'encoder'] or str, optional) – Specific component from which the gradient should be blocked if <detach> is True

Note

<forward_reconstruction>: The decoder reconstruction part could only be forwarded when all latent dimensions are forwarded.

<detach_from> options:
  • ‘pre_encoder’ (lat_pre will be detached, pre_encoder will not be trained);

  • ‘encoder’ (main_lat, pheno_lat, signature_lat will be detached, neither pre-encoder nor encoder will be trained).

Gradient reverse layer and gradient neutralize layer related computations are done in model_controllers.extractor_controller.

Returns:

a dictionary containing hierarchical outputs with keys of model forwarding

Return type:

dict[str, torch.Tensor]