Categories
Misc

Keras – How to save the VAE from the official example?

There is an official example of a Variational AutoEncoder running on MNIST:

https://github.com/keras-team/keras-io/blob/master/examples/generative/vae.py

I downloaded that code to test it on my machine, and I want to save it. So I simply added at the end of the file:

vae.save('model_keras_example') 

But that does not work it seems:

WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.VAE object at 0x2abb26a278b0>, because it is not built. Traceback (most recent call last): File "/home/drozd/GAN/keras_example_vae.py", line 199, in <module> vae.save('model_keras_example') File "/opt/ebsofts/TensorFlow/2.6.0-foss-2021a-CUDA-11.3.1/lib/python3.9/site-packages/keras/engine/training.py", line 2145, in save save.save_model(self, filepath, overwrite, include_optimizer, save_format, File "/opt/ebsofts/TensorFlow/2.6.0-foss-2021a-CUDA-11.3.1/lib/python3.9/site-packages/keras/saving/save.py", line 149, in save_model saved_model_save.save(model, filepath, overwrite, include_optimizer, File "/opt/ebsofts/TensorFlow/2.6.0-foss-2021a-CUDA-11.3.1/lib/python3.9/site-packages/keras/saving/saved_model/save.py", line 75, in save saving_utils.raise_model_input_error(model) File "/opt/ebsofts/TensorFlow/2.6.0-foss-2021a-CUDA-11.3.1/lib/python3.9/site-packages/keras/saving/saving_utils.py", line 84, in raise_model_input_error raise ValueError( ValueError: Model <__main__.VAE object at 0x2abb26a278b0> cannot be saved because the input shapes have not been set. Usually, input shapes are automatically determined from calling `.fit()` or `.predict()`. To manually set the shapes, call `model.build(input_shape)`. 

I guess I’m not familiar enough with custom models defined as a class. What seems to be the problem here?

I found this: https://stackoverflow.com/questions/69311861/tf2-6-valueerror-model-cannot-be-saved-because-the-input-shapes-have-not-been
which suggests to add a call to compute_output_shape . When I do, it tells me that my custom model needs a call() method but I have no idea how to implement that with a VAE.

Any help would be much appreciated!

Edit : Seems like I can save the encoder and decoder separately:

vae.decoder.save('model_keras_example_decoder') vae.encoder.save('model_keras_example_encoder') 

Then I suppose I can build it back afterwards by reusing the same class…

submitted by /u/Milleuros
[visit reddit] [comments]

Leave a Reply

Your email address will not be published. Required fields are marked *