Categories
Misc

Weird phenomenon with the Dataset API

I am developing a training pipeline, for which the tf nodes look like as follows:

index = tf.data.Dataset.from_tensor_slices(self.indices) if self.shuffle: index = index.shuffle(buffer_size=len(self.indices) images = index.map(self.make_image) coordinates = index.map(self.get_coordinates) ground_truth = coordinates.map(self.make_ground_truth) images = images.padded_batch(...) ground_truth = ground_truth.batch(...) return tf.data.Dataset.zip((images, ground_truth)) 

If executing the above code with shuffle == False, everything works fine. If shuffle is set to True, it seems the images and ground truths are somehow shuffled differently.

Is this an intended behaviour? How could this be easily solved?

Edit: I am using TensorFlow 2.0, but 2.1 also produces this behaviour

Edit2: Further investigation revealed further weirdness. So it seems, it is not specific to the Dataset.shuffle() method, its root cause is that there is branching in the chain of transformations. Correct me if I’m wrong, but somehow the branching causes the dataset to re-sample the same index as many times as many branches originate from the given node. If there is no shuffling, the re-sampling works as expected, but shuffling causes different indices to be fed to different branches.

I also switched to Dataset.from_generator() and I shuffle the indices in the generator (in NumPy) and it still produces the same bug.

Do you guys think this is a bug in TF? should I file an issue about this?

Or is my approach completely wrong? How could this situation be handled differently?

submitted by /u/CarpenterAcademic
[visit reddit] [comments]

Leave a Reply

Your email address will not be published. Required fields are marked *