背景

最近在疯狂搭模型,突然有人来问我要我模型的中间层输出的embedding,于是,我就研究了一下怎么获取模型的中间层输出。

代码实现

1
2
3
4
5
6
# 载入模型
full_model = get_model()
full_model.load_weights("xxxx")

# 查看模型各层
full_model.layers
[....]
1
2
3
4
5
6
7
# 抽取中间层输出,组建新模型
model = Model(full_model.input, full_model.layers[3].output)

# 构建数据,获取中间层输出
data, label = next(data_iterator)
bert_output = model.predict([[data[0]], [data[1]]])
bert_output.shape
(1, 768)