将预训练模型中的bert部分取出来加载上去
base_model = BaseModel(config)
base_model_dict = base_model.state_dict()
加载训练好的模型
pre_state_dict = torch.load(Bert_path)
new_state_dict = {k: v for k, v in pre_state_dict.items() if k in base_model_dict}
base_model_dict.update(new_state_dict)
base_model.load_state_dict(base_model_dict)
class BaseModel(nn.Module):
def __init__(self, config):
super(BaseModel, self).__init__()
self.bert_model = BertModel.from_pretrained(config.bert_uncased_path, output_hidden_states=True,
output_attentions=True)
for p in self.bert_model.parameters():
p.requires_grad = False # 预训练模型加载进来后全部设置为不更新参数,然后再后面加层
def forward(self, input_ids, attention_mask):
last_hidden_state = self.bert_model(input_ids=input_ids, attention_mask=attention_mask)[1]
return last_hidden_state
新的方法
- 可以把训练好的模型参数加载给新的模型。新的模型和老模型有部分相同。
model_orig = Bert_Classify(class_num)
model_orig.load_state_dict(torch.load(classifier_name))
model = NewModel(one_t, class_num)
model.load_state_dict(model_orig.state_dict(), strict=False)