Categories
Misc

how to implement: Model parallelism in TF2

Hey,

I want to split my custom model e.g: class
MyModel(keras.Model):over multiply GPUs. and anything I’ve tried to
do cause OOM exception.

what and how should I do it?

thanks in advanced!

some background:

I’m trying to implement the followed paper:

Neural Audio Synthesis of Musical Notes with WaveNet
Autoencoders –https://arxiv.org/abs/1704.01279

but the large input (64K vector) with the deep WaveNet decoder
cause OOM exception.

I have multiply GPUs and when creating a layer I’m doing it
under a specific GPU but when I’m applying the gradients they all
located on the same GPU (the default one, GPU0) and that causes
OOM.

the train step is a custom one as well as the model.

submitted by /u/ori_yt

[visit reddit]

[comments]

Leave a Reply

Your email address will not be published. Required fields are marked *