I want to split my custom model e.g: class
MyModel(keras.Model):over multiply GPUs. and anything I’ve tried to
do cause OOM exception.
what and how should I do it?
thanks in advanced!
I’m trying to implement the followed paper:
Neural Audio Synthesis of Musical Notes with WaveNet
but the large input (64K vector) with the deep WaveNet decoder
cause OOM exception.
I have multiply GPUs and when creating a layer I’m doing it
under a specific GPU but when I’m applying the gradients they all
located on the same GPU (the default one, GPU0) and that causes
the train step is a custom one as well as the model.