Real-time Image Augmentation
Training an image classifier from scratch requires a truckload of training data, and data preparation can get really painful if you’re doing it statically. Although transfer learning requires much less training data comparatively, you still need a good amount of variation in your training data to keep things fresh.
Fortunately, Tensorflow’s ImageDataGenerator
makes real-time image augmentation really easy, and I particularly enjoy using it because of how much simpler it is to manage my data.
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
train_datagen = ImageDataGenerator(
rescale=1./255,
rotation_range=45,
width_shift_range=0.2,
height_shift_range=0.2,
brightness_range=[0.2, 1.0],
zoom_range=[0.5, 1.0],
horizontal_flip=True,
data_format="channels_last"
)
def make_train_gen():
return train_datagen.flow_from_directory(
train_data_dir,
target_size=(IMG_WIDTH, IMG_HEIGHT),
color_mode="rgb",
class_mode="categorical",
batch_size=BATCH_SIZE
)
And that’s all it takes! The API is pretty well-documented, so go knock yourself out. The next step is just creating a dataset from it, and you can use it in model.fit
straightaway.
Check out my code if you want to train an image classifier.