Hey,
I want to split my custom model e.g: class
MyModel(keras.Model):over multiply GPUs. and anything I’ve tried to
do cause OOM exception.
what and how should I do it?
thanks in advanced!
some background:
I’m trying to implement the followed paper:
Neural Audio Synthesis of Musical Notes with WaveNet
Autoencoders –https://arxiv.org/abs/1704.01279
but the large input (64K vector) with the deep WaveNet decoder
cause OOM exception.
I have multiply GPUs and when creating a layer I’m doing it
under a specific GPU but when I’m applying the gradients they all
located on the same GPU (the default one, GPU0) and that causes
OOM.
the train step is a custom one as well as the model.
submitted by /u/ori_yt
[visit reddit]
[comments]