Hi eveyone, i am a novice with tensorflow and i am trying to build a simple model with 1 custom layer with 1 trainable parameter M and a skip connection between the input and the final custom layer. The custom layer should calculate
$$M + ln(x) + 0.5* [inputpreviouslayer]*x^2 $$
where x is the input of the network (hence the skip connection). In other words i want that the neural network learn M and the inputpreviouslayer.
i tried with:
class SimpleLayer(tf.keras.layers.Layer):
def __init__(self):
”’Initializes the instance attributes”’
super(SimpleDense, self).__init__()
def build(self, input_shape):
”’Create the state of the layer (weights)”’
q_init = tf.zeros_initializer()
self.M = tf.Variable(name=”Nuisance”,initial_value=q_init(shape=(1), dtype=’float32′),trainable=True)
def call(self, inputs):
”’Defines the computation from inputs to outputs”’
tf.math.log
return (self.M + tf.math.log( inputs[0], name=None ) +0.5 (1- inputs[1]*inputs[0] ))
and
from keras.layers import Input, concatenate
from keras.models import Model
from tensorflow.keras.layers import Dropout
from tensorflow.keras.layers import Dense
from tensorflow.keras.optimizers import Adam
from tensorflow.keras import regularizers
def prova():
inputs = Input(shape=(1,))
hidden = Dense(200,activation=”relu”)(inputs)
hidden = Dropout(0.1)(hidden, training=True)
hidden = Dense(300,activation=”relu”)(hidden)
hidden = Dropout(0.1)(hidden, training=True)
hidden = Dense(200,activation=”relu”)(hidden)
hidden = Dropout(0.1)(hidden, training=True)
deceleration = Dense(1)(hidden)
hidden = concatenate([inputs,deceleration])
params_mc = SimpleLayer(hidden)
testmodel = Model(inputs=inputs, outputs=params_mc)
return testmodel
and
nn = prova()
nn.compile(Adam(learning_rate=0.02), loss=”mse”)
history = nn.fit(x, y, epochs=15000 , verbose=0,batch_size=1048)
but when i try to run it i get
> __init__() takes 1 positional argument but 2 were given
can anybody tell me how to correctly modify the custom layer? how can i solve this issue? thanks
submitted by /u/ilrazziatore
[visit reddit] [comments]