Categories
Misc

can @tf.function change the learning result?

Hi everybody! It’s my first post here and I’m a beginner with TF too.

I’m trying to implement deep q-learning on the Connect 4 game. Reading some examples on the internet, I’ve understood that using the decorator tf.function can speed up a lot the training, but it has no other effect than performance.

Actually, I have noticed a different behavior in my function:

@tf.function # do I need it? def _train_step(self, boards_batch, scores_batch): with tf.GradientTape() as tape: batch_predictions = self.model(boards_batch, training=True) loss_on_batch = self.loss_object(scores_batch, batch_predictions) gradients = tape.gradient(loss_on_batch, self.model.trainable_variables) self.optimizer.apply_gradients(zip( gradients, self.model.trainable_variables )) self.loss(loss_on_batch) 

In particular, if I train without tf.function the agent is not learning anything and it performs as the Random agent. Instead, if I use tf.function the agent easily beats the random agent after only 1000 episodes.

Do you have any idea why is this happening? Do I have misunderstood something about tf.function?

If I try to remove tf.function from TF example notebooks nothing changes a part from the performance, as I expect from my understanding of tf.function.

submitted by /u/Lake2034
[visit reddit] [comments]

Leave a Reply

Your email address will not be published. Required fields are marked *