sakura.utils.size_estimator.SizeEstimator

class sakura.utils.size_estimator.SizeEstimator(model, input_size=(1, 1, 32, 32), bits=32)

Bases: object

Estimates memory consumption of PyTorch models

Calculates:
  • Parameter storage requirements

  • Activation memory for forward pass

  • Gradient memory for backward pass

  • Input tensor memory

Parameters:
  • model (torch.nn.Module) – Model to analyze

  • input_size (tuple, optional) – Input dimensions (batch, channels, height, width), defaults to (1, 1, 32, 32)

  • bits (int, optional) – Bit precision for memory calculations, defaults to 32

Methods

calc_forward_backward_bits

Calculate bits needed for activation storage during forward/backward passes.

calc_input_bits

Calculate bits required for input tensor storage.

calc_param_bits

Calculate total bits required for parameter storage.

estimate_size

Calculate total memory requirements.

get_output_sizes

Determine output dimensions for each layer by running a sample input through the model.

get_parameter_sizes

Collect dimensions of all parameters in the model.