Categories
Misc

Implementation of Elman Neural Network via subclassing Tensorflow’s Model class

Hello

I’m pretty new to Tensorflow world and have implemented some models using existing layers like convolution, dense and LSTM. Now I wanted to implement Elman Neural Network (wiki page) in tensorflow 2, and because it seems there isn’t any existing implementation in Tensorflow, so I decided to implement myself. Although there are couple of implementations in the web for Elman neural network but non of them have implemented via subclassing Tensorflow’s Model class. I thought it would be lot easier to implement via subclassing Model class. Following is my implementation. The problem is when I call .fit on a train dataset batch_size variable in call method is None, and I couldn’t determine how to solve this problem. Any Ideas?

class ElmanNeuralNetwork(tf.keras.Model): def __init__(self, input_dim, hidden_units, output_dim, n_classes): super(ElmanNeuralNetwork, self).__init__() self.hidden_units = hidden_units self.U_h = self.add_weight(shape=(self.hidden_units, self.hidden_units), initializer='random_normal', trainable=True) self.W_h = self.add_weight(shape=(input_dim, self.hidden_units), initializer='random_normal', trainable=True) self.b_h = self.add_weight(shape=(self.hidden_units,), initializer='random_normal', trainable=True) self.W_y = self.add_weight(shape=(self.hidden_units, output_dim), initializer='random_normal', trainable=True) self.b_y = self.add_weight(shape=(output_dim,), initializer='random_normal', trainable=True) self.softmax_layer = tf.keras.layers.Dense(n_classes, activation='softmax') def call(self, x): batch_size, n_step, n_feature = x.shape h = tf.zeros((batch_size,self.hidden_units)) for _ in range(n_step): h = tf.keras.activations.tanh(tf.matmul(self.W_h, x) + tf.matmul(self.U_h, h) + self.b_h) y = tf.keras.activations.tanh(tf.matmul(self.W_y, h) + self.b_y) preds = self.softmax_layer(y) return preds 

submitted by /u/engmrgh
[visit reddit] [comments]

Leave a Reply

Your email address will not be published. Required fields are marked *