custom accuracy metric

Hi, I have been trying to implement a custom masked accuracy metric which does not consider the pad tokens (similar to the masked loss as shown here I have been trying to create a new subclass using the tf.keras. metrics.Metric, but I am very confused. Any leads would be really appreciated.

