My dataset uses tf.train.SequenceExample, which contains a sequence of N elements, and this N by definition can vary from one sequence to another. I want to select M, which is fixed for all sequences, elements uniformly from the N elements. For example, if the sequence has N=10 elements, then for M = 2 I want to select index=0, index=5 elements. M will always be smaller than any N in the dataset.
Now the issue is, when dataset iterator calls parser function through the ‘map’ method it is executed in the ‘graph’ mode and axis dimension corresponding to ‘N’ is ‘None’. So, I can’t iterate on that axis to find the value of N.
I resolved this issue by using tf.py_function, but it is 10X slower. I tried using tf.data.AUTOTUNE in num_parallel_calls and also in prefetch, and also set deterministic=False, But performance is still 10X slower.
What is the other possible solution for this?