ImageAI:自定义预测模型训练¶
ImageAI 提供4种不同的算法及模型来执行自定义预测模型训练,通过以下简单几个步骤即可实现自定义预测模型训练。提供用于自定义预测模型训练的4种算法包括 SqueezeNet,ResNet,InceptionV3 和 DenseNet。您可以将其中一种算法加载到imageai.Prediction.Custom.CustomImagePrediction
类中,这允许您在任何对象/人的图像集上训练您自己的模型。训练过程生成一个 JSON 文件,用于映射图像数据集和许多模型中的对象类型。然后,您就可以使用生成的 JSON 文进行高精度自定义图像预测。
由于视频模型训练是非常消耗硬件资源的任务,所以我们建议您使用安装了 NVIDIA GPU 和 GPU 版 Tensorflow 的计算机来完成此实验。因为在 CPU 上执行模型培训将需要数小时或数天,但使用安装了 NVIDIA GPU 的计算机可能只需几个小时。您也可以使用 Google Colab 进行此实验,因为它具有可用的 NVIDIA K80 GPU。
要进行自定义预测模型训练,您需要准备要用于训练的图像。您需要按如下方式提供图像:
- 创建一个数据集文件夹并命名(如 pets)
- 在数据集文件中创建一个名称为 train 的子文件夹
- 在数据集文件中创建一个名称为 test 的子文件夹
- 在 train 文件夹中,为每个你要训练的对象创建文件夹并命名(如 dog,cat,squirrel,snake)
- 在 test 文件夹中,为每个你要训练的对象创建文件夹并命名(如 dog,cat,squirrel,snake)
- 把每个对象的图像放在 train 文件夹下对应名称的子文件夹,这些图像是用于训练模型的图像,为了训练出精准度较高的模型,我建议每个对象收集大约500张以上图像。
- 把每个对象用于测试的图像放在 test 文件夹下对应名称的子文件夹,为了训练出精准度较高的模型,我建议每个对象用于测试的图像在100~200张。用于训练模型时在这些图像中识别出要训练的对象。
- 按照上述步骤操作完成后,图像数据集文件夹的结构应如下所示:
pets >
train >> dog >>> dog_train_images
>> cat >>> cat_train_images
>> squirrel >>> squirrel_train_images
>> snake >> snake_train_images
test >> dog >>> dog_test_images
>> cat >>> cat_test_images
>> squirrel >>> squirrel_test_images
>> snake >>> snake_test_images
然后您的训练代码如下:
from imageai.Prediction.Custom import ModelTraining
model_trainer = ModelTraining()
model_trainer.setModelTypeAsResNet()
model_trainer.setDataDirectory("pets")
model_trainer.trainModel(num_objects=4, num_experiments=100, enhance_data=True, batch_size=32, show_network_summary=True)
没错! 只需 5 行代码,就可以在您的数据集上使用所支持的4种深度学习算法来训练自定义模型。现在让我们来看看上面的代码是如何工作的:
from imageai.Prediction.Custom import ModelTraining
model_trainer = ModelTraining()
model_trainer.setModelTypeAsResNet()
model_trainer.setDataDirectory("pets")
在上面的代码中,第一行导入 ImageAI 的ModelTraining
类,第二行创建了ModelTraining
类的新实例,第三行将模型类型设置为ResNet,第四行设置我们想要训练的数据集的路径。
model_trainer.trainModel(num_objects=4, num_experiments=100, enhance_data=True, batch_size=32, show_network_summary=True)
在上面的代码中,我们开始了模型训练,参数如下:
num_objects
:该参数用于指定图像数据集中对象的数量num_experiments
:该参数用于指定将对图像训练的次数,也称为epochsenhance_data
(可选):该参数用于指定是否生成训练图像的副本以获得更好的性能。batch_size
:该参数用于指定批次数量。由于内存限制,需要分批训练,直到所有批次训练集都完成为止。show_network_summary
:该参数用于指定是否在控制台中显示训练的过程。
当您开始训练时,您应该在控制台中看到类似的内容:
____________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
====================================================================================================
input_2 (InputLayer) (None, 224, 224, 3) 0
____________________________________________________________________________________________________
conv2d_1 (Conv2D) (None, 112, 112, 64) 9472 input_2[0][0]
____________________________________________________________________________________________________
batch_normalization_1 (BatchNorm (None, 112, 112, 64) 256 conv2d_1[0][0]
____________________________________________________________________________________________________
activation_1 (Activation) (None, 112, 112, 64) 0 batch_normalization_1[0][0]
____________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D) (None, 55, 55, 64) 0 activation_1[0][0]
____________________________________________________________________________________________________
conv2d_3 (Conv2D) (None, 55, 55, 64) 4160 max_pooling2d_1[0][0]
____________________________________________________________________________________________________
batch_normalization_3 (BatchNorm (None, 55, 55, 64) 256 conv2d_3[0][0]
____________________________________________________________________________________________________
activation_2 (Activation) (None, 55, 55, 64) 0 batch_normalization_3[0][0]
____________________________________________________________________________________________________
conv2d_4 (Conv2D) (None, 55, 55, 64) 36928 activation_2[0][0]
____________________________________________________________________________________________________
batch_normalization_4 (BatchNorm (None, 55, 55, 64) 256 conv2d_4[0][0]
____________________________________________________________________________________________________
activation_3 (Activation) (None, 55, 55, 64) 0 batch_normalization_4[0][0]
____________________________________________________________________________________________________
conv2d_5 (Conv2D) (None, 55, 55, 256) 16640 activation_3[0][0]
____________________________________________________________________________________________________
conv2d_2 (Conv2D) (None, 55, 55, 256) 16640 max_pooling2d_1[0][0]
____________________________________________________________________________________________________
batch_normalization_5 (BatchNorm (None, 55, 55, 256) 1024 conv2d_5[0][0]
____________________________________________________________________________________________________
batch_normalization_2 (BatchNorm (None, 55, 55, 256) 1024 conv2d_2[0][0]
____________________________________________________________________________________________________
add_1 (Add) (None, 55, 55, 256) 0 batch_normalization_5[0][0]
batch_normalization_2[0][0]
____________________________________________________________________________________________________
activation_4 (Activation) (None, 55, 55, 256) 0 add_1[0][0]
____________________________________________________________________________________________________
conv2d_6 (Conv2D) (None, 55, 55, 64) 16448 activation_4[0][0]
____________________________________________________________________________________________________
batch_normalization_6 (BatchNorm (None, 55, 55, 64) 256 conv2d_6[0][0]
____________________________________________________________________________________________________
activation_5 (Activation) (None, 55, 55, 64) 0 batch_normalization_6[0][0]
____________________________________________________________________________________________________
conv2d_7 (Conv2D) (None, 55, 55, 64) 36928 activation_5[0][0]
____________________________________________________________________________________________________
batch_normalization_7 (BatchNorm (None, 55, 55, 64) 256 conv2d_7[0][0]
____________________________________________________________________________________________________
activation_6 (Activation) (None, 55, 55, 64) 0 batch_normalization_7[0][0]
____________________________________________________________________________________________________
conv2d_8 (Conv2D) (None, 55, 55, 256) 16640 activation_6[0][0]
____________________________________________________________________________________________________
batch_normalization_8 (BatchNorm (None, 55, 55, 256) 1024 conv2d_8[0][0]
____________________________________________________________________________________________________
add_2 (Add) (None, 55, 55, 256) 0 batch_normalization_8[0][0]
activation_4[0][0]
____________________________________________________________________________________________________
activation_7 (Activation) (None, 55, 55, 256) 0 add_2[0][0]
____________________________________________________________________________________________________
conv2d_9 (Conv2D) (None, 55, 55, 64) 16448 activation_7[0][0]
____________________________________________________________________________________________________
batch_normalization_9 (BatchNorm (None, 55, 55, 64) 256 conv2d_9[0][0]
____________________________________________________________________________________________________
activation_8 (Activation) (None, 55, 55, 64) 0 batch_normalization_9[0][0]
____________________________________________________________________________________________________
conv2d_10 (Conv2D) (None, 55, 55, 64) 36928 activation_8[0][0]
____________________________________________________________________________________________________
batch_normalization_10 (BatchNor (None, 55, 55, 64) 256 conv2d_10[0][0]
____________________________________________________________________________________________________
activation_9 (Activation) (None, 55, 55, 64) 0 batch_normalization_10[0][0]
____________________________________________________________________________________________________
conv2d_11 (Conv2D) (None, 55, 55, 256) 16640 activation_9[0][0]
____________________________________________________________________________________________________
batch_normalization_11 (BatchNor (None, 55, 55, 256) 1024 conv2d_11[0][0]
____________________________________________________________________________________________________
add_3 (Add) (None, 55, 55, 256) 0 batch_normalization_11[0][0]
activation_7[0][0]
____________________________________________________________________________________________________
activation_10 (Activation) (None, 55, 55, 256) 0 add_3[0][0]
____________________________________________________________________________________________________
conv2d_13 (Conv2D) (None, 28, 28, 128) 32896 activation_10[0][0]
____________________________________________________________________________________________________
batch_normalization_13 (BatchNor (None, 28, 28, 128) 512 conv2d_13[0][0]
____________________________________________________________________________________________________
activation_11 (Activation) (None, 28, 28, 128) 0 batch_normalization_13[0][0]
____________________________________________________________________________________________________
conv2d_14 (Conv2D) (None, 28, 28, 128) 147584 activation_11[0][0]
____________________________________________________________________________________________________
batch_normalization_14 (BatchNor (None, 28, 28, 128) 512 conv2d_14[0][0]
____________________________________________________________________________________________________
activation_12 (Activation) (None, 28, 28, 128) 0 batch_normalization_14[0][0]
____________________________________________________________________________________________________
conv2d_15 (Conv2D) (None, 28, 28, 512) 66048 activation_12[0][0]
____________________________________________________________________________________________________
conv2d_12 (Conv2D) (None, 28, 28, 512) 131584 activation_10[0][0]
____________________________________________________________________________________________________
batch_normalization_15 (BatchNor (None, 28, 28, 512) 2048 conv2d_15[0][0]
____________________________________________________________________________________________________
batch_normalization_12 (BatchNor (None, 28, 28, 512) 2048 conv2d_12[0][0]
____________________________________________________________________________________________________
add_4 (Add) (None, 28, 28, 512) 0 batch_normalization_15[0][0]
batch_normalization_12[0][0]
____________________________________________________________________________________________________
activation_13 (Activation) (None, 28, 28, 512) 0 add_4[0][0]
____________________________________________________________________________________________________
conv2d_16 (Conv2D) (None, 28, 28, 128) 65664 activation_13[0][0]
____________________________________________________________________________________________________
batch_normalization_16 (BatchNor (None, 28, 28, 128) 512 conv2d_16[0][0]
____________________________________________________________________________________________________
activation_14 (Activation) (None, 28, 28, 128) 0 batch_normalization_16[0][0]
____________________________________________________________________________________________________
conv2d_17 (Conv2D) (None, 28, 28, 128) 147584 activation_14[0][0]
____________________________________________________________________________________________________
batch_normalization_17 (BatchNor (None, 28, 28, 128) 512 conv2d_17[0][0]
____________________________________________________________________________________________________
activation_15 (Activation) (None, 28, 28, 128) 0 batch_normalization_17[0][0]
____________________________________________________________________________________________________
conv2d_18 (Conv2D) (None, 28, 28, 512) 66048 activation_15[0][0]
____________________________________________________________________________________________________
batch_normalization_18 (BatchNor (None, 28, 28, 512) 2048 conv2d_18[0][0]
____________________________________________________________________________________________________
add_5 (Add) (None, 28, 28, 512) 0 batch_normalization_18[0][0]
activation_13[0][0]
____________________________________________________________________________________________________
activation_16 (Activation) (None, 28, 28, 512) 0 add_5[0][0]
____________________________________________________________________________________________________
conv2d_19 (Conv2D) (None, 28, 28, 128) 65664 activation_16[0][0]
____________________________________________________________________________________________________
batch_normalization_19 (BatchNor (None, 28, 28, 128) 512 conv2d_19[0][0]
____________________________________________________________________________________________________
activation_17 (Activation) (None, 28, 28, 128) 0 batch_normalization_19[0][0]
____________________________________________________________________________________________________
conv2d_20 (Conv2D) (None, 28, 28, 128) 147584 activation_17[0][0]
____________________________________________________________________________________________________
batch_normalization_20 (BatchNor (None, 28, 28, 128) 512 conv2d_20[0][0]
____________________________________________________________________________________________________
activation_18 (Activation) (None, 28, 28, 128) 0 batch_normalization_20[0][0]
____________________________________________________________________________________________________
conv2d_21 (Conv2D) (None, 28, 28, 512) 66048 activation_18[0][0]
____________________________________________________________________________________________________
batch_normalization_21 (BatchNor (None, 28, 28, 512) 2048 conv2d_21[0][0]
____________________________________________________________________________________________________
add_6 (Add) (None, 28, 28, 512) 0 batch_normalization_21[0][0]
activation_16[0][0]
____________________________________________________________________________________________________
activation_19 (Activation) (None, 28, 28, 512) 0 add_6[0][0]
____________________________________________________________________________________________________
conv2d_22 (Conv2D) (None, 28, 28, 128) 65664 activation_19[0][0]
____________________________________________________________________________________________________
batch_normalization_22 (BatchNor (None, 28, 28, 128) 512 conv2d_22[0][0]
____________________________________________________________________________________________________
activation_20 (Activation) (None, 28, 28, 128) 0 batch_normalization_22[0][0]
____________________________________________________________________________________________________
conv2d_23 (Conv2D) (None, 28, 28, 128) 147584 activation_20[0][0]
____________________________________________________________________________________________________
batch_normalization_23 (BatchNor (None, 28, 28, 128) 512 conv2d_23[0][0]
____________________________________________________________________________________________________
activation_21 (Activation) (None, 28, 28, 128) 0 batch_normalization_23[0][0]
____________________________________________________________________________________________________
conv2d_24 (Conv2D) (None, 28, 28, 512) 66048 activation_21[0][0]
____________________________________________________________________________________________________
batch_normalization_24 (BatchNor (None, 28, 28, 512) 2048 conv2d_24[0][0]
____________________________________________________________________________________________________
add_7 (Add) (None, 28, 28, 512) 0 batch_normalization_24[0][0]
activation_19[0][0]
____________________________________________________________________________________________________
activation_22 (Activation) (None, 28, 28, 512) 0 add_7[0][0]
____________________________________________________________________________________________________
conv2d_26 (Conv2D) (None, 14, 14, 256) 131328 activation_22[0][0]
____________________________________________________________________________________________________
batch_normalization_26 (BatchNor (None, 14, 14, 256) 1024 conv2d_26[0][0]
____________________________________________________________________________________________________
activation_23 (Activation) (None, 14, 14, 256) 0 batch_normalization_26[0][0]
____________________________________________________________________________________________________
conv2d_27 (Conv2D) (None, 14, 14, 256) 590080 activation_23[0][0]
____________________________________________________________________________________________________
batch_normalization_27 (BatchNor (None, 14, 14, 256) 1024 conv2d_27[0][0]
____________________________________________________________________________________________________
activation_24 (Activation) (None, 14, 14, 256) 0 batch_normalization_27[0][0]
____________________________________________________________________________________________________
conv2d_28 (Conv2D) (None, 14, 14, 1024) 263168 activation_24[0][0]
____________________________________________________________________________________________________
conv2d_25 (Conv2D) (None, 14, 14, 1024) 525312 activation_22[0][0]
____________________________________________________________________________________________________
batch_normalization_28 (BatchNor (None, 14, 14, 1024) 4096 conv2d_28[0][0]
____________________________________________________________________________________________________
batch_normalization_25 (BatchNor (None, 14, 14, 1024) 4096 conv2d_25[0][0]
____________________________________________________________________________________________________
add_8 (Add) (None, 14, 14, 1024) 0 batch_normalization_28[0][0]
batch_normalization_25[0][0]
____________________________________________________________________________________________________
activation_25 (Activation) (None, 14, 14, 1024) 0 add_8[0][0]
____________________________________________________________________________________________________
conv2d_29 (Conv2D) (None, 14, 14, 256) 262400 activation_25[0][0]
____________________________________________________________________________________________________
batch_normalization_29 (BatchNor (None, 14, 14, 256) 1024 conv2d_29[0][0]
____________________________________________________________________________________________________
activation_26 (Activation) (None, 14, 14, 256) 0 batch_normalization_29[0][0]
____________________________________________________________________________________________________
conv2d_30 (Conv2D) (None, 14, 14, 256) 590080 activation_26[0][0]
____________________________________________________________________________________________________
batch_normalization_30 (BatchNor (None, 14, 14, 256) 1024 conv2d_30[0][0]
____________________________________________________________________________________________________
activation_27 (Activation) (None, 14, 14, 256) 0 batch_normalization_30[0][0]
____________________________________________________________________________________________________
conv2d_31 (Conv2D) (None, 14, 14, 1024) 263168 activation_27[0][0]
____________________________________________________________________________________________________
batch_normalization_31 (BatchNor (None, 14, 14, 1024) 4096 conv2d_31[0][0]
____________________________________________________________________________________________________
add_9 (Add) (None, 14, 14, 1024) 0 batch_normalization_31[0][0]
activation_25[0][0]
____________________________________________________________________________________________________
activation_28 (Activation) (None, 14, 14, 1024) 0 add_9[0][0]
____________________________________________________________________________________________________
conv2d_32 (Conv2D) (None, 14, 14, 256) 262400 activation_28[0][0]
____________________________________________________________________________________________________
batch_normalization_32 (BatchNor (None, 14, 14, 256) 1024 conv2d_32[0][0]
____________________________________________________________________________________________________
activation_29 (Activation) (None, 14, 14, 256) 0 batch_normalization_32[0][0]
____________________________________________________________________________________________________
conv2d_33 (Conv2D) (None, 14, 14, 256) 590080 activation_29[0][0]
____________________________________________________________________________________________________
batch_normalization_33 (BatchNor (None, 14, 14, 256) 1024 conv2d_33[0][0]
____________________________________________________________________________________________________
activation_30 (Activation) (None, 14, 14, 256) 0 batch_normalization_33[0][0]
____________________________________________________________________________________________________
conv2d_34 (Conv2D) (None, 14, 14, 1024) 263168 activation_30[0][0]
____________________________________________________________________________________________________
batch_normalization_34 (BatchNor (None, 14, 14, 1024) 4096 conv2d_34[0][0]
____________________________________________________________________________________________________
add_10 (Add) (None, 14, 14, 1024) 0 batch_normalization_34[0][0]
activation_28[0][0]
____________________________________________________________________________________________________
activation_31 (Activation) (None, 14, 14, 1024) 0 add_10[0][0]
____________________________________________________________________________________________________
conv2d_35 (Conv2D) (None, 14, 14, 256) 262400 activation_31[0][0]
____________________________________________________________________________________________________
batch_normalization_35 (BatchNor (None, 14, 14, 256) 1024 conv2d_35[0][0]
____________________________________________________________________________________________________
activation_32 (Activation) (None, 14, 14, 256) 0 batch_normalization_35[0][0]
____________________________________________________________________________________________________
conv2d_36 (Conv2D) (None, 14, 14, 256) 590080 activation_32[0][0]
____________________________________________________________________________________________________
batch_normalization_36 (BatchNor (None, 14, 14, 256) 1024 conv2d_36[0][0]
____________________________________________________________________________________________________
activation_33 (Activation) (None, 14, 14, 256) 0 batch_normalization_36[0][0]
____________________________________________________________________________________________________
conv2d_37 (Conv2D) (None, 14, 14, 1024) 263168 activation_33[0][0]
____________________________________________________________________________________________________
batch_normalization_37 (BatchNor (None, 14, 14, 1024) 4096 conv2d_37[0][0]
____________________________________________________________________________________________________
add_11 (Add) (None, 14, 14, 1024) 0 batch_normalization_37[0][0]
activation_31[0][0]
____________________________________________________________________________________________________
activation_34 (Activation) (None, 14, 14, 1024) 0 add_11[0][0]
____________________________________________________________________________________________________
conv2d_38 (Conv2D) (None, 14, 14, 256) 262400 activation_34[0][0]
____________________________________________________________________________________________________
batch_normalization_38 (BatchNor (None, 14, 14, 256) 1024 conv2d_38[0][0]
____________________________________________________________________________________________________
activation_35 (Activation) (None, 14, 14, 256) 0 batch_normalization_38[0][0]
____________________________________________________________________________________________________
conv2d_39 (Conv2D) (None, 14, 14, 256) 590080 activation_35[0][0]
____________________________________________________________________________________________________
batch_normalization_39 (BatchNor (None, 14, 14, 256) 1024 conv2d_39[0][0]
____________________________________________________________________________________________________
activation_36 (Activation) (None, 14, 14, 256) 0 batch_normalization_39[0][0]
____________________________________________________________________________________________________
conv2d_40 (Conv2D) (None, 14, 14, 1024) 263168 activation_36[0][0]
____________________________________________________________________________________________________
batch_normalization_40 (BatchNor (None, 14, 14, 1024) 4096 conv2d_40[0][0]
____________________________________________________________________________________________________
add_12 (Add) (None, 14, 14, 1024) 0 batch_normalization_40[0][0]
activation_34[0][0]
____________________________________________________________________________________________________
activation_37 (Activation) (None, 14, 14, 1024) 0 add_12[0][0]
____________________________________________________________________________________________________
conv2d_41 (Conv2D) (None, 14, 14, 256) 262400 activation_37[0][0]
____________________________________________________________________________________________________
batch_normalization_41 (BatchNor (None, 14, 14, 256) 1024 conv2d_41[0][0]
____________________________________________________________________________________________________
activation_38 (Activation) (None, 14, 14, 256) 0 batch_normalization_41[0][0]
____________________________________________________________________________________________________
conv2d_42 (Conv2D) (None, 14, 14, 256) 590080 activation_38[0][0]
____________________________________________________________________________________________________
batch_normalization_42 (BatchNor (None, 14, 14, 256) 1024 conv2d_42[0][0]
____________________________________________________________________________________________________
activation_39 (Activation) (None, 14, 14, 256) 0 batch_normalization_42[0][0]
____________________________________________________________________________________________________
conv2d_43 (Conv2D) (None, 14, 14, 1024) 263168 activation_39[0][0]
____________________________________________________________________________________________________
batch_normalization_43 (BatchNor (None, 14, 14, 1024) 4096 conv2d_43[0][0]
____________________________________________________________________________________________________
add_13 (Add) (None, 14, 14, 1024) 0 batch_normalization_43[0][0]
activation_37[0][0]
____________________________________________________________________________________________________
activation_40 (Activation) (None, 14, 14, 1024) 0 add_13[0][0]
____________________________________________________________________________________________________
conv2d_45 (Conv2D) (None, 7, 7, 512) 524800 activation_40[0][0]
____________________________________________________________________________________________________
batch_normalization_45 (BatchNor (None, 7, 7, 512) 2048 conv2d_45[0][0]
____________________________________________________________________________________________________
activation_41 (Activation) (None, 7, 7, 512) 0 batch_normalization_45[0][0]
____________________________________________________________________________________________________
conv2d_46 (Conv2D) (None, 7, 7, 512) 2359808 activation_41[0][0]
____________________________________________________________________________________________________
batch_normalization_46 (BatchNor (None, 7, 7, 512) 2048 conv2d_46[0][0]
____________________________________________________________________________________________________
activation_42 (Activation) (None, 7, 7, 512) 0 batch_normalization_46[0][0]
____________________________________________________________________________________________________
conv2d_47 (Conv2D) (None, 7, 7, 2048) 1050624 activation_42[0][0]
____________________________________________________________________________________________________
conv2d_44 (Conv2D) (None, 7, 7, 2048) 2099200 activation_40[0][0]
____________________________________________________________________________________________________
batch_normalization_47 (BatchNor (None, 7, 7, 2048) 8192 conv2d_47[0][0]
____________________________________________________________________________________________________
batch_normalization_44 (BatchNor (None, 7, 7, 2048) 8192 conv2d_44[0][0]
____________________________________________________________________________________________________
add_14 (Add) (None, 7, 7, 2048) 0 batch_normalization_47[0][0]
batch_normalization_44[0][0]
____________________________________________________________________________________________________
activation_43 (Activation) (None, 7, 7, 2048) 0 add_14[0][0]
____________________________________________________________________________________________________
conv2d_48 (Conv2D) (None, 7, 7, 512) 1049088 activation_43[0][0]
____________________________________________________________________________________________________
batch_normalization_48 (BatchNor (None, 7, 7, 512) 2048 conv2d_48[0][0]
____________________________________________________________________________________________________
activation_44 (Activation) (None, 7, 7, 512) 0 batch_normalization_48[0][0]
____________________________________________________________________________________________________
conv2d_49 (Conv2D) (None, 7, 7, 512) 2359808 activation_44[0][0]
____________________________________________________________________________________________________
batch_normalization_49 (BatchNor (None, 7, 7, 512) 2048 conv2d_49[0][0]
____________________________________________________________________________________________________
activation_45 (Activation) (None, 7, 7, 512) 0 batch_normalization_49[0][0]
____________________________________________________________________________________________________
conv2d_50 (Conv2D) (None, 7, 7, 2048) 1050624 activation_45[0][0]
____________________________________________________________________________________________________
batch_normalization_50 (BatchNor (None, 7, 7, 2048) 8192 conv2d_50[0][0]
____________________________________________________________________________________________________
add_15 (Add) (None, 7, 7, 2048) 0 batch_normalization_50[0][0]
activation_43[0][0]
____________________________________________________________________________________________________
activation_46 (Activation) (None, 7, 7, 2048) 0 add_15[0][0]
____________________________________________________________________________________________________
conv2d_51 (Conv2D) (None, 7, 7, 512) 1049088 activation_46[0][0]
____________________________________________________________________________________________________
batch_normalization_51 (BatchNor (None, 7, 7, 512) 2048 conv2d_51[0][0]
____________________________________________________________________________________________________
activation_47 (Activation) (None, 7, 7, 512) 0 batch_normalization_51[0][0]
____________________________________________________________________________________________________
conv2d_52 (Conv2D) (None, 7, 7, 512) 2359808 activation_47[0][0]
____________________________________________________________________________________________________
batch_normalization_52 (BatchNor (None, 7, 7, 512) 2048 conv2d_52[0][0]
____________________________________________________________________________________________________
activation_48 (Activation) (None, 7, 7, 512) 0 batch_normalization_52[0][0]
____________________________________________________________________________________________________
conv2d_53 (Conv2D) (None, 7, 7, 2048) 1050624 activation_48[0][0]
____________________________________________________________________________________________________
batch_normalization_53 (BatchNor (None, 7, 7, 2048) 8192 conv2d_53[0][0]
____________________________________________________________________________________________________
add_16 (Add) (None, 7, 7, 2048) 0 batch_normalization_53[0][0]
activation_46[0][0]
____________________________________________________________________________________________________
activation_49 (Activation) (None, 7, 7, 2048) 0 add_16[0][0]
____________________________________________________________________________________________________
global_avg_pooling (GlobalAverag (None, 2048) 0 activation_49[0][0]
____________________________________________________________________________________________________
dense_1 (Dense) (None, 10) 20490 global_avg_pooling[0][0]
____________________________________________________________________________________________________
activation_50 (Activation) (None, 10) 0 dense_1[0][0]
====================================================================================================
Total params: 23,608,202
Trainable params: 23,555,082
Non-trainable params: 53,120
____________________________________________________________________________________________________
Using Enhanced Data Generation
Found 4000 images belonging to 4 classes.
Found 800 images belonging to 4 classes.
JSON Mapping for the model classes saved to C:\Users\User\PycharmProjects\ImageAITest\pets\json\model_class.json
Number of experiments (Epochs) : 100
训练过程开始后,您将在控制台中看到如下结果:
Epoch 1/100
1/25 [>.............................] - ETA: 52s - loss: 2.3026 - acc: 0.2500
2/25 [=>............................] - ETA: 41s - loss: 2.3027 - acc: 0.1250
3/25 [==>...........................] - ETA: 37s - loss: 2.2961 - acc: 0.1667
4/25 [===>..........................] - ETA: 36s - loss: 2.2980 - acc: 0.1250
5/25 [=====>........................] - ETA: 33s - loss: 2.3178 - acc: 0.1000
6/25 [======>.......................] - ETA: 31s - loss: 2.3214 - acc: 0.0833
7/25 [=======>......................] - ETA: 30s - loss: 2.3202 - acc: 0.0714
8/25 [========>.....................] - ETA: 29s - loss: 2.3207 - acc: 0.0625
9/25 [=========>....................] - ETA: 27s - loss: 2.3191 - acc: 0.0556
10/25 [===========>..................] - ETA: 25s - loss: 2.3167 - acc: 0.0750
11/25 [============>.................] - ETA: 23s - loss: 2.3162 - acc: 0.0682
12/25 [=============>................] - ETA: 21s - loss: 2.3143 - acc: 0.0833
13/25 [==============>...............] - ETA: 20s - loss: 2.3135 - acc: 0.0769
14/25 [===============>..............] - ETA: 18s - loss: 2.3132 - acc: 0.0714
15/25 [=================>............] - ETA: 16s - loss: 2.3128 - acc: 0.0667
16/25 [==================>...........] - ETA: 15s - loss: 2.3121 - acc: 0.0781
17/25 [===================>..........] - ETA: 13s - loss: 2.3116 - acc: 0.0735
18/25 [====================>.........] - ETA: 12s - loss: 2.3114 - acc: 0.0694
19/25 [=====================>........] - ETA: 10s - loss: 2.3112 - acc: 0.0658
20/25 [=======================>......] - ETA: 8s - loss: 2.3109 - acc: 0.0625
21/25 [========================>.....] - ETA: 7s - loss: 2.3107 - acc: 0.0595
22/25 [=========================>....] - ETA: 5s - loss: 2.3104 - acc: 0.0568
23/25 [==========================>...] - ETA: 3s - loss: 2.3101 - acc: 0.0543
24/25 [===========================>..] - ETA: 1s - loss: 2.3097 - acc: 0.0625Epoch 00000: saving model to C:\Users\Moses\Documents\Moses\W7\AI\Custom Datasets\IDENPROF\idenprof-small-test\idenprof\models\model_ex-000_acc-0.100000.h5
25/25 [==============================] - 51s - loss: 2.3095 - acc: 0.0600 - val_loss: 2.3026 - val_acc: 0.1000
让我们解释一下上面显示的细节:
- Epoch 1/100 这行表示正在进行第100个目标的第1次训练
- 1/25 [>………………………..] - ETA: 52s - loss: 2.3026 - acc: 0.2500 表示本实验中正在训练的批次数
- Epoch 00000: saving model to C:\Users\Moses\Documents\Moses\W7\AI\Custom Datasets\IDENPROF\idenprof-small-test\idenprof\models\model_ex-000_acc-0.100000.h5 是指本实验后保存的模型文件。该 ex_000 表示实验的阶段,而 acc_0.100000 和 val_acc:0.1000 表示本实验完成后测试图像上模型的精准度(最大精准度为1.0)。此结果有助于了解可用于自定义图像预测的最佳模型。
完成自定义模型的训练后,可以使用CustomImagePrediction
类对自定义模型执行图像预测。只需点击以下链接即可查看完整示例。
https://github.com/OlafenwaMoses/ImageAI/blob/master/imageai/Prediction/CUSTOMPREDICTION.md
在 IdenProf 数据集上训练¶
来自 IdenProf 数据集的样本用于训练模型以预测专业人员。
下面我们提供的示例代码在IdenProf数据集中进行包含10名穿制服专业人员图像的训练:
from io import open
import requests
import shutil
from zipfile import ZipFile
import os
from imageai.Prediction.Custom import ModelTraining
execution_path = os.getcwd()
TRAIN_ZIP_ONE = os.path.join(execution_path, "idenprof-train1.zip")
TRAIN_ZIP_TWO = os.path.join(execution_path, "idenprof-train2.zip")
TEST_ZIP = os.path.join(execution_path, "idenprof-test.zip")
DATASET_DIR = os.path.join(execution_path, "idenprof")
DATASET_TRAIN_DIR = os.path.join(DATASET_DIR, "train")
DATASET_TEST_DIR = os.path.join(DATASET_DIR, "test")
if(os.path.exists(DATASET_DIR) == False):
os.mkdir(DATASET_DIR)
if(os.path.exists(DATASET_TRAIN_DIR) == False):
os.mkdir(DATASET_TRAIN_DIR)
if(os.path.exists(DATASET_TEST_DIR) == False):
os.mkdir(DATASET_TEST_DIR)
if(len(os.listdir(DATASET_TRAIN_DIR)) < 10):
if(os.path.exists(TRAIN_ZIP_ONE) == False):
print("Downloading idenprof-train1.zip")
data = requests.get("https://github.com/OlafenwaMoses/IdenProf/releases/download/v1.0/idenprof-train1.zip", stream = True)
with open(TRAIN_ZIP_ONE, "wb") as file:
shutil.copyfileobj(data.raw, file)
del data
if (os.path.exists(TRAIN_ZIP_TWO) == False):
print("Downloading idenprof-train2.zip")
data = requests.get("https://github.com/OlafenwaMoses/IdenProf/releases/download/v1.0/idenprof-train2.zip", stream=True)
with open(TRAIN_ZIP_TWO, "wb") as file:
shutil.copyfileobj(data.raw, file)
del data
print("Extracting idenprof-train1.zip")
extract1 = ZipFile(TRAIN_ZIP_ONE)
extract1.extractall(DATASET_TRAIN_DIR)
extract1.close()
print("Extracting idenprof-train2.zip")
extract2 = ZipFile(TRAIN_ZIP_TWO)
extract2.extractall(DATASET_TRAIN_DIR)
extract2.close()
if(len(os.listdir(DATASET_TEST_DIR)) < 10):
if (os.path.exists(TEST_ZIP) == False):
print("Downloading idenprof-test.zip")
data = requests.get("https://github.com/OlafenwaMoses/IdenProf/releases/download/v1.0/idenprof-test.zip", stream=True)
with open(TEST_ZIP, "wb") as file:
shutil.copyfileobj(data.raw, file)
del data
print("Extracting idenprof-test.zip")
extract = ZipFile(TEST_ZIP)
extract.extractall(DATASET_TEST_DIR)
extract.close()
model_trainer = ModelTraining()
model_trainer.setModelTypeAsResNet()
model_trainer.setDataDirectory(DATASET_DIR)
model_trainer.trainModel(num_objects=10, num_experiments=100, enhance_data=True, batch_size=32, show_network_summary=True)
文档¶
imageai.Prediction.Custom.ModelTraining
class
在任何的Python程序中通过实例化ModelTraining
类并调用下面的函数即可定制训练模型:
setModelTypeAsSqueezeNet
如果您选择使用 SqueezeNet 模型文件来预测图像,你只需调用一次该函数。setModelTypeAsResNet
如果您选择使用 ResNet 模型文件来预测图像,你只需调用一次该函数。setModelTypeAsInceptionV3
如果您选择使用 InceptionV3Net 模型文件来预测图像,你只需调用一次该函数。setModelTypeAsDenseNet
如果您选择使用 DenseNet 模型文件来预测图像,你只需调用一次该函数。setDataDirectory
该函数设置用于训练的数据/数据集的路径。trainModel
该函数用于启动模型训练。它接受以下参数:num_objects
,该参数用于指定图像数据集中对象的数量num_experiments
该参数用于指定将对图像训练的次数,也称为epochsenhance_data
(可选),该参数用于指定是否生成训练图像的副本以获得更好的性能batch_size
(可选,默认为32),该参数用于指定批次数量。由于内存限制,需要分批训练,直到所有批次训练集都完成为止。您可以根据具体训练的数量进行调整该值,batch_size
通常设置为16,32,64,128。initial_learning_rate
(可选),此值用于调整训练模型的权重。如果您对此概念没有深刻理解,建议您保持默认值。show_network_summary
(可选,默认为False
),该参数用于指定是否在控制台中显示训练的过程。
:param num_objects:
:param num_experiments:
:param enhance_data:
:param batch_size:
:param initial_learning_rate:
:param show_network_summary:
:return:
提交自定义模型¶
我们欢迎所有使用此库的人提交您的训练模型及其JSON文件,并将其加入此repository中。通过以下联系方式提交您的训练模型及其JSON文件。
联系开发人员¶
Moses Olafenwa
Email: guymodscientist@gmail.com
Website: https://moses.specpal.science
Twitter:@OlafenwaMoses
Medium: @guymodscientist
Facebook: moses.olafenwa