Categories
Misc

Tensorflow Probability: BatchNormalization Bijector Wrong Result with "prob" method

I am trying to implement a normalizing flow according to the RealNVP model for density estimation. First, I am trying to make it work on the “moons” toy dataset.

The model produces the expected result when not using the BatchNormalization bijector. However, when adding the BatchNormalization bijector to the model, the methods prob and log_prob return unexpected results.

Following is a code snippet setting up the model:

“`python layers = 6 dimensions = 2 hidden_units = [512, 512] bijectors = []

base_dist = tfd.Normal(loc=0.0, scale=1.0) # specify base distribution

for i in range(layers): # Adding the BatchNormalization bijector corrupts the results bijectors.append(tfb.BatchNormalization()) bijectors.append(RealNVP(input_shape=dimensions, n_hidden=hidden_units)) bijectors.append(tfp.bijectors.Permute([1, 0]))

bijector = tfb.Chain(bijectors=list(reversed(bijectors))[:-1], name=’chain_of_real_nvp’)

flow = tfd.TransformedDistribution( distribution=tfd.Sample(base_dist, sample_shape=[dimensions]), bijector=bijector ) “`

When to BatchNormalization bijector is omitted both sampling and evaluating the probability return expected results:

Heatmap of probabilities and samples without BN

However, when the BatchNormalization bijector is added, sampling is as expected but evaluating the probability seems wrong:

Heatmap of probabilities and samples with BN

Because I am interested in density estimation the prob method is crucial. The full code can be found in the following jupyter notebook: https://github.com/mmsbrggr/normalizing-flows/blob/master/moons_training_rnvp.ipynb

I know that the BatchNormalization bijector behaves differently during training and inference. Could the problem be that the BN bijector is still in training mode? If so how can I move the flow to inference mode?

submitted by /u/marcelmoosbrugger
[visit reddit] [comments]

Leave a Reply

Your email address will not be published. Required fields are marked *