def deleteEncodingLayers(model, num_layers_to_keep): # must pass in the full bert model oldModuleList = model.bert.encoder.layer newModuleList = nn.ModuleList() # Now iterate over all layers, only keepign only the relevant layers. for i in range(0, len(num_layers_to_keep)): newModuleList.append(oldModuleList[i]) # create a copy of the model, modify it with the new list, and return copyOfModel = copy.deepcopy(model) copyOfModel.bert.encoder.layer = newModuleList return copyOfModel