keras使用预训练的模型
本文最后更新于 1637 天前,其中的信息可能已经有所发展或是发生改变。
# 多输出
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

Kera 的应用模块 Application 提供了带有预训练权重的 Keras 模型,这些模型可以用来进行预测、特征提取和 finetune。模型存储路径是 ~/.keras/models/
英文官方文档( 建议 ):VGG19模型
中文手册路径:Keras Application应用


1. VGG19

0. 导入库与预定义

将可视化的函数定义在此处

from tensorflow.keras.applications.vgg19 import VGG19, preprocess_input, decode_predictions
from tensorflow.keras.preprocessing import image
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm

%matplotlib inline


# 用于可视化的函数预定义
def visualize_pred(features, top=10):
    ''' 对分类的结果可视化成条形图
    Args:
        features: list,预测的结果,如[('n02123394', 'Persian_cat', 0.24321416),]
        top: int,可视化的数量,默认画 10 组
    '''
    # 从预测结果里提取出自信度以及标签
    pred = decode_predictions(features, top=top)
    y_lis = []  # 存储各类别对应的预测自信度
    label = []  # 存储各类别对应的标签
    for result in pred[0]:
        y_lis.append(result[2])
        label.append(result[1])

    # 生成横纵坐标
    x = np.arange(len(y_lis))
    y = np.array(y_lis)  # list 转 array

    # 生成颜色
    map_vir = cm.get_cmap(name='coolwarm')  # colormap
    norm = plt.Normalize(y.min(), y.max())  # normalize y
    color = map_vir(norm(y))  # trans normalized y to RGB

    # 画条形图
    ax = plt.bar(x, y, color=color)

    # 添加图中标注与修改横坐标标注
    for x0, y0, co in zip(x, y, color):
        plt.text(x0, y0 + 0.001, '%0.3f' %
                 y0, ha='center', va='bottom', color=co)
    plt.xticks(x, label)
    plt.xticks(rotation=75)

1.1 加载模型

如果~/.keras/models/下没有权值文件 h5,那就会自动下载。

model_vgg19 = VGG19(weights='imagenet', include_top=True)

1.2 加载图片数据

img_path = 'E:\ssd_keras-master\examples\cat1.jpg'

img = image.load_img(img_path, target_size=(224, 224))
img
x = image.img_to_array(img)
# image.img_to_array  # 输入必须是 PIL.Image 类型
x.shape

x = np.expand_dims(x, axis=0)  # 增加一列,shape: (n, h, w, c)
x = preprocess_input(x)  # 归一化

png

(224, 224, 3)

1.3 预测

features = model_vgg19.predict(x)

# 从包含 Imagenet 所有标签的 json中解析获得预测的标签
decode_predictions(features)
[[('n02123394', 'Persian_cat', 0.24321416),
  ('n02124075', 'Egyptian_cat', 0.062243998),
  ('n03793489', 'mouse', 0.056839462),
  ('n02112137', 'chow', 0.043545388),
  ('n03642806', 'laptop', 0.042695798)]]

1.4 预测结果可视化

可视化的例子又见plt绘图颜色渐变以及colormap

visualize_pred(features)

png

2. InceptionV3

过程同上。此处将代码整理到一起。

# load img
img_path = 'E:\ssd_keras-master\examples\cat1.jpg'
img = image.load_img(img_path, target_size=(224, 224))
x = image.img_to_array(img)

# preprocess
x = np.expand_dims(x, axis=0)  # 增加一列,shape: (n, h, w, c)
x = preprocess_input(x)  # 归一化

# predict
base_model_icv3 = InceptionV3(weights='imagenet', include_top=True)
features = base_model_icv3.predict(x)

# visualizing
visualize_pred(features)

png

官方文档里有关于准确度的表格。InceptionV3 的准确略高于 VGG19

评论

发送评论 编辑评论


				
|´・ω・)ノ
ヾ(≧∇≦*)ゝ
(☆ω☆)
(╯‵□′)╯︵┴─┴
 ̄﹃ ̄
(/ω\)
∠( ᐛ 」∠)_
(๑•̀ㅁ•́ฅ)
→_→
୧(๑•̀⌄•́๑)૭
٩(ˊᗜˋ*)و
(ノ°ο°)ノ
(´இ皿இ`)
⌇●﹏●⌇
(ฅ´ω`ฅ)
(╯°A°)╯︵○○○
φ( ̄∇ ̄o)
ヾ(´・ ・`。)ノ"
( ง ᵒ̌皿ᵒ̌)ง⁼³₌₃
(ó﹏ò。)
Σ(っ °Д °;)っ
( ,,´・ω・)ノ"(´っω・`。)
╮(╯▽╰)╭
o(*////▽////*)q
>﹏<
( ๑´•ω•) "(ㆆᴗㆆ)
😂
😀
😅
😊
🙂
🙃
😌
😍
😘
😜
😝
😏
😒
🙄
😳
😡
😔
😫
😱
😭
💩
👻
🙌
🖕
👍
👫
👬
👭
🌚
🌝
🙈
💊
😶
🙏
🍦
🍉
😣
Source: github.com/k4yt3x/flowerhd
颜文字
Emoji
小恐龙
花!
上一篇
下一篇