Haiku Merge Params

Merge pre-trained parameters into a new Haiku model.

2020

import haiku as hk

def merge_pretrained_params(new_params: hk.Params, pre_params: hk.Params) -> hk.Params:
    """Merges pre-trained `pre_params` parameters into new_parameters `new_params`.

    The names of the pre_params and new_params are (a) selected intentionally
    or otherwise (b) the reused modules are called before new modules
    s.t. that they end up with the same names.
    """
    # Filter out the parameters from the pre-trained model that aren't used
    # because the optimizer expects the structure of the new_params given:
    # adding new values to the flatmap will cause errors during sgd.
    new_param_keys = set(new_params.keys())
    used_only = lambda module_name, name, value: module_name in new_param_keys
    used_pre_params = hk.data_structures.filter(used_only, pre_params)
    # replaced (untrained) parameters in new params with the pretrained ones.
    return hk.data_structures.merge(new_params, used_pre_params)

View gist

Charles Lovering © 2026