Hi everybody! It’s my first post here and I’m a beginner with TF too.
I’m trying to implement deep q-learning on the Connect 4 game. Reading some examples on the internet, I’ve understood that using the decorator tf.function can speed up a lot the training, but it has no other effect than performance.
Actually, I have noticed a different behavior in my function:
@tf.function # do I need it? def _train_step(self, boards_batch, scores_batch): with tf.GradientTape() as tape: batch_predictions = self.model(boards_batch, training=True) loss_on_batch = self.loss_object(scores_batch, batch_predictions) gradients = tape.gradient(loss_on_batch, self.model.trainable_variables) self.optimizer.apply_gradients(zip( gradients, self.model.trainable_variables )) self.loss(loss_on_batch)
In particular, if I train without tf.function the agent is not learning anything and it performs as the Random agent. Instead, if I use tf.function the agent easily beats the random agent after only 1000 episodes.
Do you have any idea why is this happening? Do I have misunderstood something about tf.function?
If I try to remove tf.function from TF example notebooks nothing changes a part from the performance, as I expect from my understanding of tf.function.