Real-time Image Augmentation

// development

Training an image classifier from scratch requires a truckload of training data, and data preparation can get really painful if you’re doing it statically. Although transfer learning requires much less training data comparatively, you still need a good amount of variation in your training data to keep things fresh.

Fortunately, Tensorflow’s ImageDataGenerator makes real-time image augmentation really easy, and I particularly enjoy using it because of how much simpler it is to manage my data.

import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator

train_datagen = ImageDataGenerator(
  rescale=1./255,
  rotation_range=45,
  width_shift_range=0.2,
  height_shift_range=0.2,
  brightness_range=[0.2, 1.0],
  zoom_range=[0.5, 1.0],
  horizontal_flip=True,
  data_format="channels_last"
)

def make_train_gen():
  return train_datagen.flow_from_directory(
  train_data_dir,
  target_size=(IMG_WIDTH, IMG_HEIGHT),
  color_mode="rgb",
  class_mode="categorical",
  batch_size=BATCH_SIZE
)

And that’s all it takes! The API is pretty well-documented, so go knock yourself out. The next step is just creating a dataset from it, and you can use it in model.fit straightaway.

Check out my code if you want to train an image classifier.