Categories
Misc

ValueError: The first argument to `Layer.call` must always be passed

I have two classes and trying to build Tensorflow ranking model

when i run model.fit(cached_train, epochs=3), i get error

ValueError: The first argument to `Layer.call` must always be passed

class ProdRankingModel(tf.keras.Model):

def __init__(self):

super().__init__()

embedding_dimension = 32

self.user_embeddings = tf.keras.Sequential([

tf.keras.layers.StringLookup(

vocabulary=unique_user_ids, mask_token=None),

tf.keras.layers.Embedding(len(unique_user_ids) + 1, embedding_dimension)])

self.prod_embeddings = tf.keras.Sequential([

tf.keras.layers.StringLookup(

vocabulary=unique_items, mask_token=None),

tf.keras.layers.Embedding(len(unique_items) + 1, embedding_dimension)

])

# Compute predictions.

self.ratings = tf.keras.Sequential([

tf.keras.layers.Dense(256, activation=”relu”),

tf.keras.layers.Dense(64, activation=”relu”),

tf.keras.layers.Dense(1)

])

def call(self, inputs):

user_id, products = inputs

user_embedding = self.user_embeddings(user_id)

product_embedding = self.prod_embeddings(products)

return self.ratings(tf.concat([user_embedding, product_embedding], axis=1))

class ProductModel(tfrs.models.Model):

def __init__(self):

super().__init__()

self.prodranking_model: tf.keras.Model = ProdRankingModel()

self.task: tf.keras.layers.Layer = tfrs.tasks.Ranking(

loss = tf.keras.losses.MeanSquaredError(),

metrics=[tf.keras.metrics.RootMeanSquaredError()]

)

def call(self, features: Dict[str, tf.Tensor]) -> tf.Tensor:

return self.prodranking_model(

(features[“user_id”], features[“prod_name”]))

def compute_loss(self, features: Dict[Text, tf.Tensor], training=False) -> tf.Tensor:

labels = features.pop(“prod_count”)

rating_predictions = self(features)

# The task computes the loss and the metrics.

return self.task(labels=labels, predictions=rating_predictions)

submitted by /u/TimelyAbbreviations1
[visit reddit] [comments]

Leave a Reply

Your email address will not be published. Required fields are marked *