🍪
cookielau
  • Introduction
  • Machine Learning
    • Distributed
      • Bookmarks
    • NLP
      • Transformers
    • MLC
      • Tensor Program Abstraction
      • End-to-End Module Execution
  • Framework
    • PyTorch
      • Bookmarks
      • Model
      • Shared
      • Miscellaneous
    • Tensorflow
      • Bookmarks
      • Model
      • Shared
      • Miscellaneous
    • CUDA
      • Bookmarks
    • DeepSpeed
    • Bagua
      • Model
      • Optimizer
    • Others
      • Bookmarks
  • About Me
    • 2022-04-28
  • Random Thoughts
  • Archives
    • CPP
      • Bookmarks
      • Container
      • Algorithm
      • FILE CONTROL
      • Virtual Table
      • Assembly
      • Key Words
      • Problems
      • Others
    • JAVA
      • String Container
      • Maps
    • PYTHON
      • Bookmarks
      • Python Tools
        • Batch Rename
        • Combine Excel
        • Excel Oprations
        • Read Write Excel
        • Rotate PDF
      • Library
        • Pandas Notes
        • Numpy Notes
        • Json Notes
      • Spider
        • Selenium Install
        • Selenium Locating
        • Selenium Errors
        • Selenium Basics
      • Django
        • Start Up
      • Others
    • LINUX
      • Installation
      • Cli Tools
      • WSL
      • Bugs
    • JUNIOR2
      • Economics
        • Chapter 0x01 经济管理概述
        • Chapter 0x02 微观市场机制分析
        • Chapter 0x03 生产决策与市场结构
        • Chapter 0x04 宏观经济市场分析
        • Chapter 0x05 管理的职能
        • Chapter 0x06 生产系统结构与战略
        • Chapter 0x0b 投资项目经济评价
        • Chapter 0x0f 投资项目经济评价
      • Computer Network
        • 概述
        • 分层模型
        • 物理层
        • 数据链路层
        • 网络层
        • 传输层
        • 应用层
        • HTTP(s)实验
        • [Practice]
      • Software Engineering
        • Introduction
        • Demand Analysis
        • Task Estimation
        • Presentation
      • Network Security
        • Chapter 0x01 概述
        • Chapter 0x02 密码学
        • Chapter 0x03 公钥体制
        • Chapter 0x04 消息认证
        • Chapter 0x05 密钥管理
        • Chapter 0x06 访问控制
        • Assignments
      • x86 Programming
        • Basic Knowledge
        • Program Design
        • System Interruption
        • Frequently used functions
    • MD&LaTex
      • Markdown
      • LaTex
    • NPM
      • NPM LINK
    • MyBlogs
      • 2020BUAA软工——“停下来,回头看”
      • 2020BUAA软工——“初窥构建之法”
      • 2020BUAA软工——“上手软件工程,PSP初体验!”
      • 2020BUAA软工——“深度评测官”
      • 2020BUAA软工——“并肩作战,平面交点Pro”
    • SC
      • PAC 2022
        • Lectures
      • OpenMP & MPI
        • MPI Overview
        • Message Passing Programming
        • OpenMP Overview
        • Work Sharing Directives
        • Annual Challenge
        • Future Topics in OpenMP
        • Tasks
        • OpenMP & MPI
    • Hardware
      • Nvidia GPU
        • Frequent Error
        • Memory Classification
        • CUDA_7_Streams_Simplify_Concurrency
        • Optimize_Data_Transfers_in_CUDA
        • Overlap_Data_Transfers_in_CUDA
        • Write_Flexible_Kernels_with_Grid-Stride_Loops
        • How_to_Access_Global_Memory_Efficiently
        • Using_Shared_Memory
      • Intel CPU
        • Construction
        • Optimization
        • Compilation
        • OpenMP
    • English
      • Vocab
      • Composition
    • Interview
      • Computer Network
Powered by GitBook
On this page

Was this helpful?

  1. Framework
  2. Tensorflow

Model

Keras

get the output of middle layer

Since we know keras model needs to be built and compiled before executing compared to eager execution in TF2 or dynamic graph in PyTorch, therefore, we need to build a new model to get the output of the middle layer. It is easier than PyTorch since we can create a sub graph using the original keras model and will not modify the original model.

> [layer.name for layer in tf_model.layers]
# ['Input-Token',
#  'Input-Segment',
#  'Embedding-Token',
#  'Embedding-Segment',
#  .......
#  'Transformer-11-FeedForward-Dropout',
#  'Transformer-11-FeedForward-Add',
#  'Transformer-11-FeedForward-Norm',
#  **************************************************
#  ** in case we want the output before last layer **
#  **************************************************
#  'dense']

> layer_model = keras.models.Model(inputs=tf_model.input, outputs=tf_model.get_layer('Transformer-11-FeedForward-Norm').output)
> layer_output = layer_model(x_input)

print the attributes of Model

> type(keras_model)
# keras.engine.functional.Functional


> keras_model.tf_model.get_config()
# {'name': 'model_1',
# 'layers': [{'class_name': 'InputLayer',
# 'config': {'batch_input_shape': (None, None),
#     'dtype': 'float32',
#     'sparse': False,
#     'ragged': False,
#     'name': 'Input-Token'},
# 'name': 'Input-Token',
# 'inbound_nodes': []},
# {'class_name': 'InputLayer',
# 'config': {'batch_input_shape': (None, None),
#     'dtype': 'float32',
#     'sparse': False,
#     'ragged': False,
#     'name': 'Input-Segment'},
# 'name': 'Input-Segment',
# 'inbound_nodes': []},
# ...
# {'class_name': 'Dense',
# 'config': {'name': 'dense',
#  'trainable': True,
#  'dtype': 'float32',
#  'units': 22,
#  'activation': 'softmax',
#  'use_bias': True,
#  'kernel_initializer': {'class_name': 'GlorotUniform',
#   'config': {'seed': None}},
#  'bias_initializer': {'class_name': 'Zeros', 'config': {}},
#  'kernel_regularizer': None,
#  'bias_regularizer': None,
#  'activity_regularizer': None,
#  'kernel_constraint': None,
#  'bias_constraint': None},
# 'name': 'dense',
# 'inbound_nodes': [[['Transformer-11-FeedForward-Norm', 0, 0, {}]]]}],
#  'input_layers': [['Input-Token', 0, 0], ['Input-Segment', 0, 0]],
#  'output_layers': [['dense', 0, 0]]}


> tf_model.summary()
# Model: "model_1"
# __________________________________________________________________________________________________
# Layer (type)                   Output Shape         Param #     Connected to                     
# ==================================================================================================
# Input-Token (InputLayer)       [(None, None)]       0           []                               
# Input-Segment (InputLayer)     [(None, None)]       0           []                               
# Embedding-Token (Embedding)    (None, None, 768)    384885504   ['Input-Token[0][0]']            
# Embedding-Segment (Embedding)  (None, None, 768)    1536        ['Input-Segment[0][0]']     
# ................
# dense (Dense)                  (None, None, 22)     16918       ['Transformer-11-FeedForward-Norm[0][0]']      
# ==================================================================================================
# Total params: 470,353,174
# Trainable params: 470,353,174
# Non-trainable params: 0
# __________________________________________________________________________________________________     


> tf_model.layers
# [<keras.engine.input_layer.InputLayer at 0x7ff158495370>,
#  <keras.engine.input_layer.InputLayer at 0x7fef2a285fd0>,
#  <bert4keras.layers.Embedding at 0x7fef2a2b35e0>,
#  <bert4keras.layers.Embedding at 0x7fef2a2b3df0>,
#  <keras.layers.merge.Add at 0x7fef2a2b3820>,
#  <bert4keras.layers.PositionEmbedding at 0x7fef289d2be0>,
#  <bert4keras.layers.LayerNormalization at 0x7fe
#  ................
#  <keras.layers.core.dense.Dense at 0x7ff158495520>]


> [layer.name for layer in tf_model.layers]
# ['Input-Token',
#  'Input-Segment',
#  'Embedding-Token',
#  'Embedding-Segment',
#  'Embedding-Token-Segment',
#  'Embedding-Position',
#  'Embedding-Norm',
#  ................
#  'dense']


> type(tf_model.get_layer('dense'))
# keras.layers.core.dense.Dense


> tf_model.get_layer('dense').weights
# [<tf.Variable 'dense/kernel:0' shape=(768, 22) dtype=float32, numpy=
#  array([[-0.00010732, -0.0852981 , -0.04779567, ..., -0.04220317,
#          ...,
#          [-0.05151889, -0.01136582,  0.03139671, ...,  0.00703724,
#  <tf.Variable 'dense/bias:0' shape=(22,) dtype=float32, numpy=
#  array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
#          0., 0., 0., 0., 0.], dtype=float32)>]


> tf_model.get_layer('dense').variables
# (Return the same as weight in this situation, but they have the different id(method))
PreviousBookmarksNextShared

Last updated 3 years ago

Was this helpful?