Categories
Misc

Set TensorFlow2 tensor to zero based on condition

I have a tensor in TensorFlow 2.5 where the goal is to set the smallest sum of magnitude based on p% of ‘slices’ to zeros. The sample code to achieve this is as follows:

# create example tensor- input_shape = (1, 4, 4, 6) y = tf.random.normal(input_shape) 

The goal is to find and remove the smallest (say) p = 20% of the values from the last index ‘6’ based on the sum of their magnitudes. floor(20% of 6) = 1, i.e., after removing 20% of smallest magnitude summation of values, 5 of the 4×4 ‘slices’ should have non-zero values while the 6th slice should be equal to zero.

 # Sum magnitude values- filter_sum = tf.math.reduce_sum(input_tensor = tf.math.abs(y), axis = [1, 2], keepdims = False) filter_sum.shape # TensorShape([1, 6]) indices = tf.argsort(filter_sum) indices.shape # TensorShape([1, 6]) indices.numpy() # array([[0, 1, 4, 2, 5, 3]]) # 20% of 6 values = 1 math.floor(0.2 * 6) # 1 

In this example, the smallest 4×4 ‘slice’ in ‘y’ is the first slice as computed by ‘indices’. How do I proceed?

Thanks

submitted by /u/grid_world
[visit reddit] [comments]

Leave a Reply

Your email address will not be published. Required fields are marked *