i wrote down some of the weight names to help me think. the haiku
weights are in a nested structure and are only named based on their
neural network model type. so matching them will mean reviewing more
than their names, maybe their order of construction and use in
google's source compared to the huggingface source
def haiku2torch(haiku_params):
haiku_params = {**haiku_params}
state_dict = {}
state_dict['perceiver.input_preprocessor.embeddings.weight'] =
haiku_params.pop('embed')
state_dict['perceiver.input_preprocessor.position_embeddings.weight']
= haiku_params.pop('trainable_position_encoding')
haiku_params['perceiver_encoder/~/cross_attention/attention/linear']['w'] ?
state_dict['perceiver.encoder.cross_attention.attention.self.layernorm1.weight']
state_dict['perceiver.encoder.cross_attention.attention.self.layernorm1.bias']
state_dict['perceiver.encoder.cross_attention.attention.self.layernorm2.weight']
state_dict['perceiver.encoder.cross_attention.attention.self.layernorm2.bias']
state_dict['perceiver.encoder.cross_attention.attention.self.query.weight']
state_dict['perceiver.encoder.cross_attention.attention.self.query.bias']
state_dict['perceiver.encoder.cross_attention.attention.self.key.weight']
state_dict['perceiver.encoder.cross_attention.attention.self.key.bias']
state_dict['perceiver.encoder.cross_attention.attention.self.value.weight']
state_dict['perceiver.encoder.cross_attention.attention.self.value.bias']
state_dict['perceiver.encoder.cross_attention.attention.output.dense.weight']
state_dict['perceiver.encoder.cross_attention.attention.output.dense.bias']
state_dict['perceiver.encoder.cross_attention.attention.layernorm.weight']
state_dict['perceiver.encoder.cross_attention.attention.layernorm.bias']
state_dict['perceiver.encoder.cross_attention.attention.mlp.dense1.weight']
state_dict['perceiver.encoder.cross_attention.attention.mlp.dense1.bias']
state_dict['perceiver.encoder.cross_attention.attention.mlp.dense2.weight']
state_dict['perceiver.encoder.cross_attention.attention.mlp.dense2.bias']
state_dict['perceiver.embeddings.latents'] ?