背景
最近在疯狂搭模型,突然有人来问我要我模型的中间层输出的embedding,于是,我就研究了一下怎么获取模型的中间层输出。
代码实现
1
2
3
4
5
6
| # 载入模型
full_model = get_model()
full_model.load_weights("xxxx")
# 查看模型各层
full_model.layers
|
[....]
1
2
3
4
5
6
7
| # 抽取中间层输出,组建新模型
model = Model(full_model.input, full_model.layers[3].output)
# 构建数据,获取中间层输出
data, label = next(data_iterator)
bert_output = model.predict([[data[0]], [data[1]]])
bert_output.shape
|
(1, 768)