JAX Transformer implementation from scratch with notes

I have implemented a simple helper module to code layers easier. It has embedding layers, layer normalization, multi-head attention and an Adam optimizer implemented from ground up. I may have made mistakes and not followed JAX best practices since I’m new to JAX. Let me know if you see any opportunities for improvement.

Hope this is helpful and welcome any feedback.

submitted by /u/mlvpj
[visit reddit] [comments]

Leave a Reply

Your email address will not be published. Required fields are marked *