ImageAI:自定义预测模型训练


ImageAI 提供4种不同的算法及模型来执行自定义预测模型训练,通过以下简单几个步骤即可实现自定义预测模型训练。提供用于自定义预测模型训练的4种算法包括 SqueezeNetResNetInceptionV3DenseNet。您可以将其中一种算法加载到imageai.Prediction.Custom.CustomImagePrediction类中,这允许您在任何对象/人的图像集上训练您自己的模型。训练过程生成一个 JSON 文件,用于映射图像数据集和许多模型中的对象类型。然后,您就可以使用生成的 JSON 文进行高精度自定义图像预测。

由于视频模型训练是非常消耗硬件资源的任务,所以我们建议您使用安装了 NVIDIA GPU 和 GPU 版 Tensorflow 的计算机来完成此实验。因为在 CPU 上执行模型培训将需要数小时或数天,但使用安装了 NVIDIA GPU 的计算机可能只需几个小时。您也可以使用 Google Colab 进行此实验,因为它具有可用的 NVIDIA K80 GPU。

要进行自定义预测模型训​​练,您需要准备要用于训练的图像。您需要按如下方式提供图像:

  1. 创建一个数据集文件夹并命名(如 pets)
  2. 在数据集文件中创建一个名称为 train 的子文件夹
  3. 在数据集文件中创建一个名称为 test 的子文件夹
  4. train 文件夹中,为每个你要训练的对象创建文件夹并命名(如 dog,cat,squirrel,snake)
  5. test 文件夹中,为每个你要训练的对象创建文件夹并命名(如 dog,cat,squirrel,snake)
  6. 把每个对象的图像放在 train 文件夹下对应名称的子文件夹,这些图像是用于训练模型的图像,为了训练出精准度较高的模型,我建议每个对象收集大约500张以上图像。
  7. 把每个对象用于测试的图像放在 test 文件夹下对应名称的子文件夹,为了训练出精准度较高的模型,我建议每个对象用于测试的图像在100~200张。用于训练模型时在这些图像中识别出要训练的对象。
  8. 按照上述步骤操作完成后,图像数据集文件夹的结构应如下所示:
pets >
train >> dog >>> dog_train_images
      >> cat >>> cat_train_images
      >> squirrel >>> squirrel_train_images
      >> snake >> snake_train_images
test  >> dog >>> dog_test_images
      >> cat >>> cat_test_images
      >> squirrel >>> squirrel_test_images
      >> snake >>> snake_test_images

然后您的训练代码如下:

from imageai.Prediction.Custom import ModelTraining
model_trainer = ModelTraining()
model_trainer.setModelTypeAsResNet()
model_trainer.setDataDirectory("pets")
model_trainer.trainModel(num_objects=4, num_experiments=100, enhance_data=True, batch_size=32, show_network_summary=True)

没错! 只需 5 行代码,就可以在您的数据集上使用所支持的4种深度学习算法来训练自定义模型。现在让我们来看看上面的代码是如何工作的:

from imageai.Prediction.Custom import ModelTraining

model_trainer = ModelTraining()
model_trainer.setModelTypeAsResNet()
model_trainer.setDataDirectory("pets")

在上面的代码中,第一行导入 ImageAIModelTraining类,第二行创建了ModelTraining类的新实例,第三行将模型类型设置为ResNet,第四行设置我们想要训练的数据集的路径。

model_trainer.trainModel(num_objects=4, num_experiments=100, enhance_data=True, batch_size=32, show_network_summary=True)

在上面的代码中,我们开始了模型训练,参数如下:

  • num_objects:该参数用于指定图像数据集中对象的数量
  • num_experiments:该参数用于指定将对图像训练的次数,也称为epochs
  • enhance_data(可选):该参数用于指定是否生成训练图像的副本以获得更好的性能。
  • batch_size:该参数用于指定批次数量。由于内存限制,需要分批训练,直到所有批次训练集都完成为止。
  • show_network_summary:该参数用于指定是否在控制台中显示训练的过程。

当您开始训练时,您应该在控制台中看到类似的内容:

____________________________________________________________________________________________________
Layer (type)                     Output Shape          Param #     Connected to
====================================================================================================
input_2 (InputLayer)             (None, 224, 224, 3)   0
____________________________________________________________________________________________________
conv2d_1 (Conv2D)                (None, 112, 112, 64)  9472        input_2[0][0]
____________________________________________________________________________________________________
batch_normalization_1 (BatchNorm (None, 112, 112, 64)  256         conv2d_1[0][0]
____________________________________________________________________________________________________
activation_1 (Activation)        (None, 112, 112, 64)  0           batch_normalization_1[0][0]
____________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D)   (None, 55, 55, 64)    0           activation_1[0][0]
____________________________________________________________________________________________________
conv2d_3 (Conv2D)                (None, 55, 55, 64)    4160        max_pooling2d_1[0][0]
____________________________________________________________________________________________________
batch_normalization_3 (BatchNorm (None, 55, 55, 64)    256         conv2d_3[0][0]
____________________________________________________________________________________________________
activation_2 (Activation)        (None, 55, 55, 64)    0           batch_normalization_3[0][0]
____________________________________________________________________________________________________
conv2d_4 (Conv2D)                (None, 55, 55, 64)    36928       activation_2[0][0]
____________________________________________________________________________________________________
batch_normalization_4 (BatchNorm (None, 55, 55, 64)    256         conv2d_4[0][0]
____________________________________________________________________________________________________
activation_3 (Activation)        (None, 55, 55, 64)    0           batch_normalization_4[0][0]
____________________________________________________________________________________________________
conv2d_5 (Conv2D)                (None, 55, 55, 256)   16640       activation_3[0][0]
____________________________________________________________________________________________________
conv2d_2 (Conv2D)                (None, 55, 55, 256)   16640       max_pooling2d_1[0][0]
____________________________________________________________________________________________________
batch_normalization_5 (BatchNorm (None, 55, 55, 256)   1024        conv2d_5[0][0]
____________________________________________________________________________________________________
batch_normalization_2 (BatchNorm (None, 55, 55, 256)   1024        conv2d_2[0][0]
____________________________________________________________________________________________________
add_1 (Add)                      (None, 55, 55, 256)   0           batch_normalization_5[0][0]
                                                                   batch_normalization_2[0][0]
____________________________________________________________________________________________________
activation_4 (Activation)        (None, 55, 55, 256)   0           add_1[0][0]
____________________________________________________________________________________________________
conv2d_6 (Conv2D)                (None, 55, 55, 64)    16448       activation_4[0][0]
____________________________________________________________________________________________________
batch_normalization_6 (BatchNorm (None, 55, 55, 64)    256         conv2d_6[0][0]
____________________________________________________________________________________________________
activation_5 (Activation)        (None, 55, 55, 64)    0           batch_normalization_6[0][0]
____________________________________________________________________________________________________
conv2d_7 (Conv2D)                (None, 55, 55, 64)    36928       activation_5[0][0]
____________________________________________________________________________________________________
batch_normalization_7 (BatchNorm (None, 55, 55, 64)    256         conv2d_7[0][0]
____________________________________________________________________________________________________
activation_6 (Activation)        (None, 55, 55, 64)    0           batch_normalization_7[0][0]
____________________________________________________________________________________________________
conv2d_8 (Conv2D)                (None, 55, 55, 256)   16640       activation_6[0][0]
____________________________________________________________________________________________________
batch_normalization_8 (BatchNorm (None, 55, 55, 256)   1024        conv2d_8[0][0]
____________________________________________________________________________________________________
add_2 (Add)                      (None, 55, 55, 256)   0           batch_normalization_8[0][0]
                                                                   activation_4[0][0]
____________________________________________________________________________________________________
activation_7 (Activation)        (None, 55, 55, 256)   0           add_2[0][0]
____________________________________________________________________________________________________
conv2d_9 (Conv2D)                (None, 55, 55, 64)    16448       activation_7[0][0]
____________________________________________________________________________________________________
batch_normalization_9 (BatchNorm (None, 55, 55, 64)    256         conv2d_9[0][0]
____________________________________________________________________________________________________
activation_8 (Activation)        (None, 55, 55, 64)    0           batch_normalization_9[0][0]
____________________________________________________________________________________________________
conv2d_10 (Conv2D)               (None, 55, 55, 64)    36928       activation_8[0][0]
____________________________________________________________________________________________________
batch_normalization_10 (BatchNor (None, 55, 55, 64)    256         conv2d_10[0][0]
____________________________________________________________________________________________________
activation_9 (Activation)        (None, 55, 55, 64)    0           batch_normalization_10[0][0]
____________________________________________________________________________________________________
conv2d_11 (Conv2D)               (None, 55, 55, 256)   16640       activation_9[0][0]
____________________________________________________________________________________________________
batch_normalization_11 (BatchNor (None, 55, 55, 256)   1024        conv2d_11[0][0]
____________________________________________________________________________________________________
add_3 (Add)                      (None, 55, 55, 256)   0           batch_normalization_11[0][0]
                                                                   activation_7[0][0]
____________________________________________________________________________________________________
activation_10 (Activation)       (None, 55, 55, 256)   0           add_3[0][0]
____________________________________________________________________________________________________
conv2d_13 (Conv2D)               (None, 28, 28, 128)   32896       activation_10[0][0]
____________________________________________________________________________________________________
batch_normalization_13 (BatchNor (None, 28, 28, 128)   512         conv2d_13[0][0]
____________________________________________________________________________________________________
activation_11 (Activation)       (None, 28, 28, 128)   0           batch_normalization_13[0][0]
____________________________________________________________________________________________________
conv2d_14 (Conv2D)               (None, 28, 28, 128)   147584      activation_11[0][0]
____________________________________________________________________________________________________
batch_normalization_14 (BatchNor (None, 28, 28, 128)   512         conv2d_14[0][0]
____________________________________________________________________________________________________
activation_12 (Activation)       (None, 28, 28, 128)   0           batch_normalization_14[0][0]
____________________________________________________________________________________________________
conv2d_15 (Conv2D)               (None, 28, 28, 512)   66048       activation_12[0][0]
____________________________________________________________________________________________________
conv2d_12 (Conv2D)               (None, 28, 28, 512)   131584      activation_10[0][0]
____________________________________________________________________________________________________
batch_normalization_15 (BatchNor (None, 28, 28, 512)   2048        conv2d_15[0][0]
____________________________________________________________________________________________________
batch_normalization_12 (BatchNor (None, 28, 28, 512)   2048        conv2d_12[0][0]
____________________________________________________________________________________________________
add_4 (Add)                      (None, 28, 28, 512)   0           batch_normalization_15[0][0]
                                                                   batch_normalization_12[0][0]
____________________________________________________________________________________________________
activation_13 (Activation)       (None, 28, 28, 512)   0           add_4[0][0]
____________________________________________________________________________________________________
conv2d_16 (Conv2D)               (None, 28, 28, 128)   65664       activation_13[0][0]
____________________________________________________________________________________________________
batch_normalization_16 (BatchNor (None, 28, 28, 128)   512         conv2d_16[0][0]
____________________________________________________________________________________________________
activation_14 (Activation)       (None, 28, 28, 128)   0           batch_normalization_16[0][0]
____________________________________________________________________________________________________
conv2d_17 (Conv2D)               (None, 28, 28, 128)   147584      activation_14[0][0]
____________________________________________________________________________________________________
batch_normalization_17 (BatchNor (None, 28, 28, 128)   512         conv2d_17[0][0]
____________________________________________________________________________________________________
activation_15 (Activation)       (None, 28, 28, 128)   0           batch_normalization_17[0][0]
____________________________________________________________________________________________________
conv2d_18 (Conv2D)               (None, 28, 28, 512)   66048       activation_15[0][0]
____________________________________________________________________________________________________
batch_normalization_18 (BatchNor (None, 28, 28, 512)   2048        conv2d_18[0][0]
____________________________________________________________________________________________________
add_5 (Add)                      (None, 28, 28, 512)   0           batch_normalization_18[0][0]
                                                                   activation_13[0][0]
____________________________________________________________________________________________________
activation_16 (Activation)       (None, 28, 28, 512)   0           add_5[0][0]
____________________________________________________________________________________________________
conv2d_19 (Conv2D)               (None, 28, 28, 128)   65664       activation_16[0][0]
____________________________________________________________________________________________________
batch_normalization_19 (BatchNor (None, 28, 28, 128)   512         conv2d_19[0][0]
____________________________________________________________________________________________________
activation_17 (Activation)       (None, 28, 28, 128)   0           batch_normalization_19[0][0]
____________________________________________________________________________________________________
conv2d_20 (Conv2D)               (None, 28, 28, 128)   147584      activation_17[0][0]
____________________________________________________________________________________________________
batch_normalization_20 (BatchNor (None, 28, 28, 128)   512         conv2d_20[0][0]
____________________________________________________________________________________________________
activation_18 (Activation)       (None, 28, 28, 128)   0           batch_normalization_20[0][0]
____________________________________________________________________________________________________
conv2d_21 (Conv2D)               (None, 28, 28, 512)   66048       activation_18[0][0]
____________________________________________________________________________________________________
batch_normalization_21 (BatchNor (None, 28, 28, 512)   2048        conv2d_21[0][0]
____________________________________________________________________________________________________
add_6 (Add)                      (None, 28, 28, 512)   0           batch_normalization_21[0][0]
                                                                   activation_16[0][0]
____________________________________________________________________________________________________
activation_19 (Activation)       (None, 28, 28, 512)   0           add_6[0][0]
____________________________________________________________________________________________________
conv2d_22 (Conv2D)               (None, 28, 28, 128)   65664       activation_19[0][0]
____________________________________________________________________________________________________
batch_normalization_22 (BatchNor (None, 28, 28, 128)   512         conv2d_22[0][0]
____________________________________________________________________________________________________
activation_20 (Activation)       (None, 28, 28, 128)   0           batch_normalization_22[0][0]
____________________________________________________________________________________________________
conv2d_23 (Conv2D)               (None, 28, 28, 128)   147584      activation_20[0][0]
____________________________________________________________________________________________________
batch_normalization_23 (BatchNor (None, 28, 28, 128)   512         conv2d_23[0][0]
____________________________________________________________________________________________________
activation_21 (Activation)       (None, 28, 28, 128)   0           batch_normalization_23[0][0]
____________________________________________________________________________________________________
conv2d_24 (Conv2D)               (None, 28, 28, 512)   66048       activation_21[0][0]
____________________________________________________________________________________________________
batch_normalization_24 (BatchNor (None, 28, 28, 512)   2048        conv2d_24[0][0]
____________________________________________________________________________________________________
add_7 (Add)                      (None, 28, 28, 512)   0           batch_normalization_24[0][0]
                                                                   activation_19[0][0]
____________________________________________________________________________________________________
activation_22 (Activation)       (None, 28, 28, 512)   0           add_7[0][0]
____________________________________________________________________________________________________
conv2d_26 (Conv2D)               (None, 14, 14, 256)   131328      activation_22[0][0]
____________________________________________________________________________________________________
batch_normalization_26 (BatchNor (None, 14, 14, 256)   1024        conv2d_26[0][0]
____________________________________________________________________________________________________
activation_23 (Activation)       (None, 14, 14, 256)   0           batch_normalization_26[0][0]
____________________________________________________________________________________________________
conv2d_27 (Conv2D)               (None, 14, 14, 256)   590080      activation_23[0][0]
____________________________________________________________________________________________________
batch_normalization_27 (BatchNor (None, 14, 14, 256)   1024        conv2d_27[0][0]
____________________________________________________________________________________________________
activation_24 (Activation)       (None, 14, 14, 256)   0           batch_normalization_27[0][0]
____________________________________________________________________________________________________
conv2d_28 (Conv2D)               (None, 14, 14, 1024)  263168      activation_24[0][0]
____________________________________________________________________________________________________
conv2d_25 (Conv2D)               (None, 14, 14, 1024)  525312      activation_22[0][0]
____________________________________________________________________________________________________
batch_normalization_28 (BatchNor (None, 14, 14, 1024)  4096        conv2d_28[0][0]
____________________________________________________________________________________________________
batch_normalization_25 (BatchNor (None, 14, 14, 1024)  4096        conv2d_25[0][0]
____________________________________________________________________________________________________
add_8 (Add)                      (None, 14, 14, 1024)  0           batch_normalization_28[0][0]
                                                                   batch_normalization_25[0][0]
____________________________________________________________________________________________________
activation_25 (Activation)       (None, 14, 14, 1024)  0           add_8[0][0]
____________________________________________________________________________________________________
conv2d_29 (Conv2D)               (None, 14, 14, 256)   262400      activation_25[0][0]
____________________________________________________________________________________________________
batch_normalization_29 (BatchNor (None, 14, 14, 256)   1024        conv2d_29[0][0]
____________________________________________________________________________________________________
activation_26 (Activation)       (None, 14, 14, 256)   0           batch_normalization_29[0][0]
____________________________________________________________________________________________________
conv2d_30 (Conv2D)               (None, 14, 14, 256)   590080      activation_26[0][0]
____________________________________________________________________________________________________
batch_normalization_30 (BatchNor (None, 14, 14, 256)   1024        conv2d_30[0][0]
____________________________________________________________________________________________________
activation_27 (Activation)       (None, 14, 14, 256)   0           batch_normalization_30[0][0]
____________________________________________________________________________________________________
conv2d_31 (Conv2D)               (None, 14, 14, 1024)  263168      activation_27[0][0]
____________________________________________________________________________________________________
batch_normalization_31 (BatchNor (None, 14, 14, 1024)  4096        conv2d_31[0][0]
____________________________________________________________________________________________________
add_9 (Add)                      (None, 14, 14, 1024)  0           batch_normalization_31[0][0]
                                                                   activation_25[0][0]
____________________________________________________________________________________________________
activation_28 (Activation)       (None, 14, 14, 1024)  0           add_9[0][0]
____________________________________________________________________________________________________
conv2d_32 (Conv2D)               (None, 14, 14, 256)   262400      activation_28[0][0]
____________________________________________________________________________________________________
batch_normalization_32 (BatchNor (None, 14, 14, 256)   1024        conv2d_32[0][0]
____________________________________________________________________________________________________
activation_29 (Activation)       (None, 14, 14, 256)   0           batch_normalization_32[0][0]
____________________________________________________________________________________________________
conv2d_33 (Conv2D)               (None, 14, 14, 256)   590080      activation_29[0][0]
____________________________________________________________________________________________________
batch_normalization_33 (BatchNor (None, 14, 14, 256)   1024        conv2d_33[0][0]
____________________________________________________________________________________________________
activation_30 (Activation)       (None, 14, 14, 256)   0           batch_normalization_33[0][0]
____________________________________________________________________________________________________
conv2d_34 (Conv2D)               (None, 14, 14, 1024)  263168      activation_30[0][0]
____________________________________________________________________________________________________
batch_normalization_34 (BatchNor (None, 14, 14, 1024)  4096        conv2d_34[0][0]
____________________________________________________________________________________________________
add_10 (Add)                     (None, 14, 14, 1024)  0           batch_normalization_34[0][0]
                                                                   activation_28[0][0]
____________________________________________________________________________________________________
activation_31 (Activation)       (None, 14, 14, 1024)  0           add_10[0][0]
____________________________________________________________________________________________________
conv2d_35 (Conv2D)               (None, 14, 14, 256)   262400      activation_31[0][0]
____________________________________________________________________________________________________
batch_normalization_35 (BatchNor (None, 14, 14, 256)   1024        conv2d_35[0][0]
____________________________________________________________________________________________________
activation_32 (Activation)       (None, 14, 14, 256)   0           batch_normalization_35[0][0]
____________________________________________________________________________________________________
conv2d_36 (Conv2D)               (None, 14, 14, 256)   590080      activation_32[0][0]
____________________________________________________________________________________________________
batch_normalization_36 (BatchNor (None, 14, 14, 256)   1024        conv2d_36[0][0]
____________________________________________________________________________________________________
activation_33 (Activation)       (None, 14, 14, 256)   0           batch_normalization_36[0][0]
____________________________________________________________________________________________________
conv2d_37 (Conv2D)               (None, 14, 14, 1024)  263168      activation_33[0][0]
____________________________________________________________________________________________________
batch_normalization_37 (BatchNor (None, 14, 14, 1024)  4096        conv2d_37[0][0]
____________________________________________________________________________________________________
add_11 (Add)                     (None, 14, 14, 1024)  0           batch_normalization_37[0][0]
                                                                   activation_31[0][0]
____________________________________________________________________________________________________
activation_34 (Activation)       (None, 14, 14, 1024)  0           add_11[0][0]
____________________________________________________________________________________________________
conv2d_38 (Conv2D)               (None, 14, 14, 256)   262400      activation_34[0][0]
____________________________________________________________________________________________________
batch_normalization_38 (BatchNor (None, 14, 14, 256)   1024        conv2d_38[0][0]
____________________________________________________________________________________________________
activation_35 (Activation)       (None, 14, 14, 256)   0           batch_normalization_38[0][0]
____________________________________________________________________________________________________
conv2d_39 (Conv2D)               (None, 14, 14, 256)   590080      activation_35[0][0]
____________________________________________________________________________________________________
batch_normalization_39 (BatchNor (None, 14, 14, 256)   1024        conv2d_39[0][0]
____________________________________________________________________________________________________
activation_36 (Activation)       (None, 14, 14, 256)   0           batch_normalization_39[0][0]
____________________________________________________________________________________________________
conv2d_40 (Conv2D)               (None, 14, 14, 1024)  263168      activation_36[0][0]
____________________________________________________________________________________________________
batch_normalization_40 (BatchNor (None, 14, 14, 1024)  4096        conv2d_40[0][0]
____________________________________________________________________________________________________
add_12 (Add)                     (None, 14, 14, 1024)  0           batch_normalization_40[0][0]
                                                                   activation_34[0][0]
____________________________________________________________________________________________________
activation_37 (Activation)       (None, 14, 14, 1024)  0           add_12[0][0]
____________________________________________________________________________________________________
conv2d_41 (Conv2D)               (None, 14, 14, 256)   262400      activation_37[0][0]
____________________________________________________________________________________________________
batch_normalization_41 (BatchNor (None, 14, 14, 256)   1024        conv2d_41[0][0]
____________________________________________________________________________________________________
activation_38 (Activation)       (None, 14, 14, 256)   0           batch_normalization_41[0][0]
____________________________________________________________________________________________________
conv2d_42 (Conv2D)               (None, 14, 14, 256)   590080      activation_38[0][0]
____________________________________________________________________________________________________
batch_normalization_42 (BatchNor (None, 14, 14, 256)   1024        conv2d_42[0][0]
____________________________________________________________________________________________________
activation_39 (Activation)       (None, 14, 14, 256)   0           batch_normalization_42[0][0]
____________________________________________________________________________________________________
conv2d_43 (Conv2D)               (None, 14, 14, 1024)  263168      activation_39[0][0]
____________________________________________________________________________________________________
batch_normalization_43 (BatchNor (None, 14, 14, 1024)  4096        conv2d_43[0][0]
____________________________________________________________________________________________________
add_13 (Add)                     (None, 14, 14, 1024)  0           batch_normalization_43[0][0]
                                                                   activation_37[0][0]
____________________________________________________________________________________________________
activation_40 (Activation)       (None, 14, 14, 1024)  0           add_13[0][0]
____________________________________________________________________________________________________
conv2d_45 (Conv2D)               (None, 7, 7, 512)     524800      activation_40[0][0]
____________________________________________________________________________________________________
batch_normalization_45 (BatchNor (None, 7, 7, 512)     2048        conv2d_45[0][0]
____________________________________________________________________________________________________
activation_41 (Activation)       (None, 7, 7, 512)     0           batch_normalization_45[0][0]
____________________________________________________________________________________________________
conv2d_46 (Conv2D)               (None, 7, 7, 512)     2359808     activation_41[0][0]
____________________________________________________________________________________________________
batch_normalization_46 (BatchNor (None, 7, 7, 512)     2048        conv2d_46[0][0]
____________________________________________________________________________________________________
activation_42 (Activation)       (None, 7, 7, 512)     0           batch_normalization_46[0][0]
____________________________________________________________________________________________________
conv2d_47 (Conv2D)               (None, 7, 7, 2048)    1050624     activation_42[0][0]
____________________________________________________________________________________________________
conv2d_44 (Conv2D)               (None, 7, 7, 2048)    2099200     activation_40[0][0]
____________________________________________________________________________________________________
batch_normalization_47 (BatchNor (None, 7, 7, 2048)    8192        conv2d_47[0][0]
____________________________________________________________________________________________________
batch_normalization_44 (BatchNor (None, 7, 7, 2048)    8192        conv2d_44[0][0]
____________________________________________________________________________________________________
add_14 (Add)                     (None, 7, 7, 2048)    0           batch_normalization_47[0][0]
                                                                   batch_normalization_44[0][0]
____________________________________________________________________________________________________
activation_43 (Activation)       (None, 7, 7, 2048)    0           add_14[0][0]
____________________________________________________________________________________________________
conv2d_48 (Conv2D)               (None, 7, 7, 512)     1049088     activation_43[0][0]
____________________________________________________________________________________________________
batch_normalization_48 (BatchNor (None, 7, 7, 512)     2048        conv2d_48[0][0]
____________________________________________________________________________________________________
activation_44 (Activation)       (None, 7, 7, 512)     0           batch_normalization_48[0][0]
____________________________________________________________________________________________________
conv2d_49 (Conv2D)               (None, 7, 7, 512)     2359808     activation_44[0][0]
____________________________________________________________________________________________________
batch_normalization_49 (BatchNor (None, 7, 7, 512)     2048        conv2d_49[0][0]
____________________________________________________________________________________________________
activation_45 (Activation)       (None, 7, 7, 512)     0           batch_normalization_49[0][0]
____________________________________________________________________________________________________
conv2d_50 (Conv2D)               (None, 7, 7, 2048)    1050624     activation_45[0][0]
____________________________________________________________________________________________________
batch_normalization_50 (BatchNor (None, 7, 7, 2048)    8192        conv2d_50[0][0]
____________________________________________________________________________________________________
add_15 (Add)                     (None, 7, 7, 2048)    0           batch_normalization_50[0][0]
                                                                   activation_43[0][0]
____________________________________________________________________________________________________
activation_46 (Activation)       (None, 7, 7, 2048)    0           add_15[0][0]
____________________________________________________________________________________________________
conv2d_51 (Conv2D)               (None, 7, 7, 512)     1049088     activation_46[0][0]
____________________________________________________________________________________________________
batch_normalization_51 (BatchNor (None, 7, 7, 512)     2048        conv2d_51[0][0]
____________________________________________________________________________________________________
activation_47 (Activation)       (None, 7, 7, 512)     0           batch_normalization_51[0][0]
____________________________________________________________________________________________________
conv2d_52 (Conv2D)               (None, 7, 7, 512)     2359808     activation_47[0][0]
____________________________________________________________________________________________________
batch_normalization_52 (BatchNor (None, 7, 7, 512)     2048        conv2d_52[0][0]
____________________________________________________________________________________________________
activation_48 (Activation)       (None, 7, 7, 512)     0           batch_normalization_52[0][0]
____________________________________________________________________________________________________
conv2d_53 (Conv2D)               (None, 7, 7, 2048)    1050624     activation_48[0][0]
____________________________________________________________________________________________________
batch_normalization_53 (BatchNor (None, 7, 7, 2048)    8192        conv2d_53[0][0]
____________________________________________________________________________________________________
add_16 (Add)                     (None, 7, 7, 2048)    0           batch_normalization_53[0][0]
                                                                   activation_46[0][0]
____________________________________________________________________________________________________
activation_49 (Activation)       (None, 7, 7, 2048)    0           add_16[0][0]
____________________________________________________________________________________________________
global_avg_pooling (GlobalAverag (None, 2048)          0           activation_49[0][0]
____________________________________________________________________________________________________
dense_1 (Dense)                  (None, 10)            20490       global_avg_pooling[0][0]
____________________________________________________________________________________________________
activation_50 (Activation)       (None, 10)            0           dense_1[0][0]
====================================================================================================
Total params: 23,608,202
Trainable params: 23,555,082
Non-trainable params: 53,120
____________________________________________________________________________________________________
Using Enhanced Data Generation
Found 4000 images belonging to 4 classes.
Found 800 images belonging to 4 classes.
JSON Mapping for the model classes saved to  C:\Users\User\PycharmProjects\ImageAITest\pets\json\model_class.json
Number of experiments (Epochs) :  100

训练过程开始后,您将在控制台中看到如下结果:

Epoch 1/100
 1/25 [>.............................] - ETA: 52s - loss: 2.3026 - acc: 0.2500
 2/25 [=>............................] - ETA: 41s - loss: 2.3027 - acc: 0.1250
 3/25 [==>...........................] - ETA: 37s - loss: 2.2961 - acc: 0.1667
 4/25 [===>..........................] - ETA: 36s - loss: 2.2980 - acc: 0.1250
 5/25 [=====>........................] - ETA: 33s - loss: 2.3178 - acc: 0.1000
 6/25 [======>.......................] - ETA: 31s - loss: 2.3214 - acc: 0.0833
 7/25 [=======>......................] - ETA: 30s - loss: 2.3202 - acc: 0.0714
 8/25 [========>.....................] - ETA: 29s - loss: 2.3207 - acc: 0.0625
 9/25 [=========>....................] - ETA: 27s - loss: 2.3191 - acc: 0.0556
10/25 [===========>..................] - ETA: 25s - loss: 2.3167 - acc: 0.0750
11/25 [============>.................] - ETA: 23s - loss: 2.3162 - acc: 0.0682
12/25 [=============>................] - ETA: 21s - loss: 2.3143 - acc: 0.0833
13/25 [==============>...............] - ETA: 20s - loss: 2.3135 - acc: 0.0769
14/25 [===============>..............] - ETA: 18s - loss: 2.3132 - acc: 0.0714
15/25 [=================>............] - ETA: 16s - loss: 2.3128 - acc: 0.0667
16/25 [==================>...........] - ETA: 15s - loss: 2.3121 - acc: 0.0781
17/25 [===================>..........] - ETA: 13s - loss: 2.3116 - acc: 0.0735
18/25 [====================>.........] - ETA: 12s - loss: 2.3114 - acc: 0.0694
19/25 [=====================>........] - ETA: 10s - loss: 2.3112 - acc: 0.0658
20/25 [=======================>......] - ETA: 8s - loss: 2.3109 - acc: 0.0625
21/25 [========================>.....] - ETA: 7s - loss: 2.3107 - acc: 0.0595
22/25 [=========================>....] - ETA: 5s - loss: 2.3104 - acc: 0.0568
23/25 [==========================>...] - ETA: 3s - loss: 2.3101 - acc: 0.0543
24/25 [===========================>..] - ETA: 1s - loss: 2.3097 - acc: 0.0625Epoch 00000: saving model to C:\Users\Moses\Documents\Moses\W7\AI\Custom Datasets\IDENPROF\idenprof-small-test\idenprof\models\model_ex-000_acc-0.100000.h5

25/25 [==============================] - 51s - loss: 2.3095 - acc: 0.0600 - val_loss: 2.3026 - val_acc: 0.1000

让我们解释一下上面显示的细节:

  1. Epoch 1/100 这行表示正在进行第100个目标的第1次训练
  2. 1/25 [>………………………..] - ETA: 52s - loss: 2.3026 - acc: 0.2500 表示本实验中正在训练的批次数
  3. Epoch 00000: saving model to C:\Users\Moses\Documents\Moses\W7\AI\Custom Datasets\IDENPROF\idenprof-small-test\idenprof\models\model_ex-000_acc-0.100000.h5 是指本实验后保存的模型文件。该 ex_000 表示实验的阶段,而 acc_0.100000val_acc:0.1000 表示本实验完成后测试图像上模型的精准度(最大精准度为1.0)。此结果有助于了解可用于自定义图像预测的最佳模型。

完成自定义模型的训练后,可以使用CustomImagePrediction类对自定义模型执行图像预测。只需点击以下链接即可查看完整示例。

https://github.com/OlafenwaMoses/ImageAI/blob/master/imageai/Prediction/CUSTOMPREDICTION.md

在 IdenProf 数据集上训练

来自 IdenProf 数据集的样本用于训练模型以预测专业人员。

https://github.com/kangvcar/ImageAI/raw/master/images/idenprof.jpg

下面我们提供的示例代码在IdenProf数据集中进行包含10名穿制服专业人员图像的训练:

from io import open
import requests
import shutil
from zipfile import ZipFile
import os
from imageai.Prediction.Custom import ModelTraining

execution_path = os.getcwd()


TRAIN_ZIP_ONE = os.path.join(execution_path, "idenprof-train1.zip")
TRAIN_ZIP_TWO = os.path.join(execution_path, "idenprof-train2.zip")
TEST_ZIP = os.path.join(execution_path, "idenprof-test.zip")


DATASET_DIR = os.path.join(execution_path, "idenprof")
DATASET_TRAIN_DIR = os.path.join(DATASET_DIR, "train")
DATASET_TEST_DIR = os.path.join(DATASET_DIR, "test")


if(os.path.exists(DATASET_DIR) == False):
    os.mkdir(DATASET_DIR)
if(os.path.exists(DATASET_TRAIN_DIR) == False):
    os.mkdir(DATASET_TRAIN_DIR)
if(os.path.exists(DATASET_TEST_DIR) == False):
    os.mkdir(DATASET_TEST_DIR)


if(len(os.listdir(DATASET_TRAIN_DIR)) < 10):
    if(os.path.exists(TRAIN_ZIP_ONE) == False):
        print("Downloading idenprof-train1.zip")
        data = requests.get("https://github.com/OlafenwaMoses/IdenProf/releases/download/v1.0/idenprof-train1.zip", stream = True)
        with open(TRAIN_ZIP_ONE, "wb") as file:
            shutil.copyfileobj(data.raw, file)
            del data
    if (os.path.exists(TRAIN_ZIP_TWO) == False):
        print("Downloading idenprof-train2.zip")
        data = requests.get("https://github.com/OlafenwaMoses/IdenProf/releases/download/v1.0/idenprof-train2.zip", stream=True)
        with open(TRAIN_ZIP_TWO, "wb") as file:
            shutil.copyfileobj(data.raw, file)
            del data
    print("Extracting idenprof-train1.zip")
    extract1 = ZipFile(TRAIN_ZIP_ONE)
    extract1.extractall(DATASET_TRAIN_DIR)
    extract1.close()
    print("Extracting idenprof-train2.zip")
    extract2 = ZipFile(TRAIN_ZIP_TWO)
    extract2.extractall(DATASET_TRAIN_DIR)
    extract2.close()


if(len(os.listdir(DATASET_TEST_DIR)) < 10):
    if (os.path.exists(TEST_ZIP) == False):
        print("Downloading idenprof-test.zip")
        data = requests.get("https://github.com/OlafenwaMoses/IdenProf/releases/download/v1.0/idenprof-test.zip", stream=True)
        with open(TEST_ZIP, "wb") as file:
            shutil.copyfileobj(data.raw, file)
            del data
    print("Extracting idenprof-test.zip")
    extract = ZipFile(TEST_ZIP)
    extract.extractall(DATASET_TEST_DIR)
    extract.close()


model_trainer = ModelTraining()
model_trainer.setModelTypeAsResNet()
model_trainer.setDataDirectory(DATASET_DIR)
model_trainer.trainModel(num_objects=10, num_experiments=100, enhance_data=True, batch_size=32, show_network_summary=True)

文档

imageai.Prediction.Custom.ModelTraining class


在任何的Python程序中通过实例化ModelTraining类并调用下面的函数即可定制训练模型:

  • setModelTypeAsSqueezeNet 如果您选择使用 SqueezeNet 模型文件来预测图像,你只需调用一次该函数。
  • setModelTypeAsResNet 如果您选择使用 ResNet 模型文件来预测图像,你只需调用一次该函数。
  • setModelTypeAsInceptionV3 如果您选择使用 InceptionV3Net 模型文件来预测图像,你只需调用一次该函数。
  • setModelTypeAsDenseNet 如果您选择使用 DenseNet 模型文件来预测图像,你只需调用一次该函数。
  • setDataDirectory 该函数设置用于训练的数据/数据集的路径。
  • trainModel 该函数用于启动模型训练。它接受以下参数:
    • num_objects,该参数用于指定图像数据集中对象的数量
    • num_experiments 该参数用于指定将对图像训练的次数,也称为epochs
    • enhance_data(可选),该参数用于指定是否生成训练图像的副本以获得更好的性能
    • batch_size(可选,默认为32),该参数用于指定批次数量。由于内存限制,需要分批训练,直到所有批次训练集都完成为止。您可以根据具体训练的数量进行调整该值,batch_size通常设置为16,32,64,128。
    • initial_learning_rate(可选),此值用于调整训练模型的权重。如果您对此概念没有深刻理解,建议您保持默认值。
    • show_network_summary(可选,默认为 False),该参数用于指定是否在控制台中显示训练的过程。

:param num_objects:
:param num_experiments:
:param enhance_data:
:param batch_size:
:param initial_learning_rate:
:param show_network_summary:
:return:

提交自定义模型

我们欢迎所有使用此库的人提交您的训练模型及其JSON文件,并将其加入此repository中。通过以下联系方式提交您的训练模型及其JSON文件。

联系开发人员

Moses Olafenwa

Email: guymodscientist@gmail.com
Website: https://moses.specpal.science
Twitter:@OlafenwaMoses
Medium: @guymodscientist
Facebook: moses.olafenwa