Implementation of Elman Neural Network via subclassing Tensorflow’s Model class


I’m pretty new to Tensorflow world and have implemented some models using existing layers like convolution, dense and LSTM. Now I wanted to implement Elman Neural Network (wiki page) in tensorflow 2, and because it seems there isn’t any existing implementation in Tensorflow, so I decided to implement myself. Although there are couple of implementations in the web for Elman neural network but non of them have implemented via subclassing Tensorflow’s Model class. I thought it would be lot easier to implement via subclassing Model class. Following is my implementation. The problem is when I call .fit on a train dataset batch_size variable in call method is None, and I couldn’t determine how to solve this problem. Any Ideas?

class ElmanNeuralNetwork(tf.keras.Model): def __init__(self, input_dim, hidden_units, output_dim, n_classes): super(ElmanNeuralNetwork, self).__init__() self.hidden_units = hidden_units self.U_h = self.add_weight(shape=(self.hidden_units, self.hidden_units), initializer='random_normal', trainable=True) self.W_h = self.add_weight(shape=(input_dim, self.hidden_units), initializer='random_normal', trainable=True) self.b_h = self.add_weight(shape=(self.hidden_units,), initializer='random_normal', trainable=True) self.W_y = self.add_weight(shape=(self.hidden_units, output_dim), initializer='random_normal', trainable=True) self.b_y = self.add_weight(shape=(output_dim,), initializer='random_normal', trainable=True) self.softmax_layer = tf.keras.layers.Dense(n_classes, activation='softmax') def call(self, x): batch_size, n_step, n_feature = x.shape h = tf.zeros((batch_size,self.hidden_units)) for _ in range(n_step): h = tf.keras.activations.tanh(tf.matmul(self.W_h, x) + tf.matmul(self.U_h, h) + self.b_h) y = tf.keras.activations.tanh(tf.matmul(self.W_y, h) + self.b_y) preds = self.softmax_layer(y) return preds 

submitted by /u/engmrgh
[visit reddit] [comments]

Leave a Reply

Your email address will not be published. Required fields are marked *