import haiku as hk
def merge_pretrained_params(new_params: hk.Params, pre_params: hk.Params) -> hk.Params:
"""Merges pre-trained `pre_params` parameters into new_parameters `new_params`.
The names of the pre_params and new_params are (a) selected intentionally
or otherwise (b) the reused modules are called before new modules
s.t. that they end up with the same names.
"""
# Filter out the parameters from the pre-trained model that aren't used
# because the optimizer expects the structure of the new_params given:
# adding new values to the flatmap will cause errors during sgd.
new_param_keys = set(new_params.keys())
used_only = lambda module_name, name, value: module_name in new_param_keys
used_pre_params = hk.data_structures.filter(used_only, pre_params)
# replaced (untrained) parameters in new params with the pretrained ones.
return hk.data_structures.merge(new_params, used_pre_params)Haiku Merge Params
Merge pre-trained parameters into a new Haiku model.
2020