Categories
Misc

Plotting example images from TensorFlow dataset not working

I’m trying to plot a few example images using Matplotlib from a cats and dogs Kaggle dataset that I loaded into the script using TensorFlow’s import_image_dataset_from_directory, but the images aren’t displaying correctly. The Matplotlib plots are either empty or contain some speckled blue or yellow dots… Does anyone know how to fix this? (code below)

def normalize(x,y):
x = tf.cast(x,tf.float32) / 255.0
return x, y
def convert_to_categorical(input):
if input == 1:
return “Dog”
else:
return “Cat”
def to_list(ds):
ds_list = []
for sample in ds:
image, label = sample
ds_list.append((image, label))
return ds_list

# load dataset
directory = ‘train’
ds_train = tf.keras.utils.image_dataset_from_directory(
directory,
labels=’inferred’,
label_mode=’binary’,
batch_size=1,
shuffle=False,
validation_split=0.3,
subset=’training’,
image_size=(300,300)
)
ds_test = tf.keras.utils.image_dataset_from_directory(
directory,
labels=’inferred’,
label_mode=’binary’,
batch_size=1,
shuffle=False,
validation_split=0.3,
subset=’validation’,
image_size=(300,300)
)
# normalize data
ds_train.map(normalize)
ds_test.map(normalize)
# plot 10 random images from training set
num = len(ds_train)
ds_train_list = to_list(ds_train)
for i in range(1,11):
random_index = np.random.randint(num)
img, label = ds_train_list[random_index]
label = convert_to_categorical(np.array(label))
img = np.reshape(img,(300,300,3))
plt.subplot(2,5,i)
plt.imshow(img)
plt.title(label)
plt.savefig(‘figures/example_images.png’)

submitted by /u/berimbolo21
[visit reddit] [comments]

Leave a Reply

Your email address will not be published. Required fields are marked *