Written by Ryusei

【VGG16 転移学習】で呪術廻戦のキャラクターの画像分類をやってみた

programming

今回は、今流行りの呪術廻戦で機械学習やってみました。
画像検索で拾った画像を約90%の精度で見分けることができました。
画像分類の例として、アウトプットさせていただきます。

VGG16を転移学習させたモデルで予測した画像の一覧

ちなみにキャラクターは、

Kerasで実装(データセット準備)

今回は、Google Colaboratoryを使って誰でも簡単に実装できるようにしています。

スクリプトを実装する前にデータセットを用意します。
GitHubからデータセットはダウンロードできます。こちら→GitHub

今回は、jyujyutu_VGGというフォルダの中にdisplay , test , train , validationを作っています。


画像は train = 250枚、validation = 50枚、test = 50枚の内訳で入っています。

displayは冒頭のような画像を表示させるときに使う用で、中身はtestと同じです。
displayは用意しなくても大丈夫です。

このデータセットは、Googleの画像検索で自作しました。

ライブラリのインポートとモデル構築

jyujyutu_VGG.zipをダウンロードしたら、Google Driveにアップロードしておきましょう。

# Googleドライブをマウント
from google.colab import drive
drive.mount('/content/drive')

.zipのままGoogle Driveにアップロードしてください。

# Google ColaboratoryでZipファイルを解凍
from zipfile import ZipFile
file_name = '/content/drive/My Drive/jyujyutu_VGG.zip'

with ZipFile(file_name, 'r') as zip:
zip.extractall()

zipファイルのままアップロードして、ここで解凍した方が時間短縮になります。

from keras.models import Model
from keras.layers import Dense, GlobalAveragePooling2D,Input
from keras.applications.vgg16 import VGG16
from keras.preprocessing.image import ImageDataGenerator
from keras.optimizers import SGD
from keras.callbacks import CSVLogger

必要なライブラリをimportします。

n_categories=5
batch_size=32
train_dir='/content/jyujyutu_VGG/train'
validation_dir='/content/jyujyutu_VGG/validation'
file_name='vgg16_jyujyutu_file'

base_model=VGG16(weights='imagenet',include_top=False,
                 input_tensor=Input(shape=(224,224,3)))

今回は、5人のキャラクターを分類するので、n_categories = 5とします。

x=base_model.output
x=GlobalAveragePooling2D()(x)
x=Dense(1024,activation='relu')(x)
prediction=Dense(n_categories,activation='softmax')(x)
model=Model(inputs=base_model.input,outputs=prediction)
#fix weights before VGG16 14layers
for layer in base_model.layers[:15]:
    layer.trainable=False

model.compile(optimizer=SGD(lr=0.0001,momentum=0.9),
              loss='categorical_crossentropy',
              metrics=['accuracy'])

model.summary()

今回の実装では、15層以降のみを学習させています。
最終的な出力は1次元(COVID-19かどうか確率)なので、新たな全結合層を追加しています。

転移学習、VGG16について詳しくはVGG16を転移学習させて「まどか☆マギカ」のキャラを見分けるを見てください。

#save model
json_string=model.to_json()
open(file_name+'.json','w').write(json_string)

画像の前処理をして学習

train_datagen=ImageDataGenerator(
    rescale=1.0/255,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True)

validation_datagen=ImageDataGenerator(rescale=1.0/255)

train_generator=train_datagen.flow_from_directory(
    train_dir,
    target_size=(224,224),
    batch_size=batch_size,
    class_mode='categorical',
    shuffle=True
)

validation_generator=validation_datagen.flow_from_directory(
    validation_dir,
    target_size=(224,224),
    batch_size=batch_size,
    class_mode='categorical',
    shuffle=True
)

hist=model.fit_generator(train_generator,
                         epochs=100,
                         verbose=1,
                         validation_data=validation_generator,
                         callbacks=[CSVLogger(file_name+'.csv')])

#save weights
model.save(file_name+'.h5')

ImageDataGeneratorは画像を整形したり、水増しするのに便利です。読み込むフォルダを与えれば、自動的にそのフォルダの名前をラベルにしてくれます。

ImageDataGeneratorは、デフォルトの機能だけでなく、関数として定義してあげれば他のData Augmentationも出来るのでめっちゃ便利です!

結果

テスト画像でモデルの評価を行います。

from keras.models import model_from_json
import matplotlib.pyplot as plt
import numpy as np
import os,random
from keras.preprocessing.image import img_to_array, load_img
from keras.preprocessing.image import ImageDataGenerator
from keras.optimizers import SGD

batch_size=32
file_name='vgg16_jyujyutu_file'
test_dir='jyujyutu_VGG/test'
display_dir='jyujyutu_VGG/display'
label=['gojo','hushiguro','itadori','nobara','todo']

#load model and weights
json_string=open(file_name+'.json').read()
model=model_from_json(json_string)
model.load_weights(file_name+'.h5')

model.compile(optimizer=SGD(lr=0.0001,momentum=0.9),
              loss='categorical_crossentropy',
              metrics=['accuracy'])

#data generate
test_datagen=ImageDataGenerator(rescale=1.0/255)

test_generator=test_datagen.flow_from_directory(
    test_dir,
    target_size=(224,224),
    batch_size=batch_size,
    class_mode='categorical',
    shuffle=True
)

#evaluate model
score=model.evaluate_generator(test_generator)
print('\n test loss:',score[0])
print('\n test_acc:',score[1])

validation_accuracyは最終的に約90%になりました。
test accuracyが約96%になりました。

訓練画像250枚(各キャラ50枚ずつ)で、Google ColaboratoryのGPUを使えばかなり短時間で学習できました。

今後の検討

今回は、VGG16を転移学習させて「まどか☆マギカ」のキャラを見分けるの記事を参考に、簡単に好きなデータセットで実装できることを試し、さらに今の流行りにも乗ってみました。

今後、Mixup や Random Erasing , Cutmixなども試して、精度が向上するかどうかやってみたいですね!

今回は、以上になります。こんな感じでアウトプットもしていきます!