
Freeze a module during training

  • module: instance of torch.nn.Module
  • return: no return


Un-freeze a module for training

  • module: instance of torch.nn.Module
  • return: no return


Retrieve only trainable parameters, for feeding optimizer

get_trainable_parameters(module, with_name=False)
  • module: instance of torch.nn.Module
  • with_name: if True, output in format of (name, tensor), else only tensor returned
  • return: generator of trainable parameters


Set tensor value with numpy array

set_value(t, v)
  • t: tensor
  • v: numpy array
  • return: no return


Retrieve device from tensor or module

  • x: tensor or instance of nn.Module
  • return: torch.device


Safe run against CUDA OOM, otherwise just raise the captured exception

torch_safe_run(fn, inputs)
  • fn: function to run
  • inputs: dict passed to function fn
  • return: (status, result) in which status = 0 if no exception, = 1 if CUDA OOM exception occurred; result is as returned by calling fn(**inputs)


Pickle with gzip compression enabled.

.dump(data, filename, compresslevel=9, protocol=4)

Dump data and save to file.

  • data: data to be dumped to file
  • filename: file path
  • compresslevel: gzip compression level, default = 9.
  • protocol: protocol version of pickle, defalut = 4.

Load dumped data from file

  • filename: file to be loaded
  • return: data unpickled
.dumps(data, compresslevel=9, protocol=4)

Dump data into bytes

  • return: data pickled & compressed into bytes

Load dumped data from bytes

  • return: data unpickled


print with verbose level filtering

class verbose_print(level=0, prefix=None)
  • level: predefined verbose level. Instance of verbose_print functions the same with python's builtin print() with an additional l arg (default = 0); when l < this predefined verbose level, the print content will be suppressed, thus only content with verbose level >= level can be actually printed on screen.
  • prefix: if given, each print will be preceded by this fixed prefix.


vprint = verbose_print(level=2, prefix='LMExp')
vprint('this line will be actually printed', l=3)
vprint('this line will NOT be printed by verbose level filtering', l=0)