# Get device configs

from tensorflow.python.client import device_lib

def get_available_devices(cpu: bool = True, gpu: bool = True):
    local_device_protos = device_lib.list_local_devices()
    devices = []
    if cpu:
        devices = [x.name for x in local_device_protos if x.device_type == 'CPU']
    if gpu:
        devices += [x.name for x in local_device_protos if x.device_type == 'GPU']
    return devices

# Check CUDA Installation and GPU Availability
print(tf.config.list_physical_devices('GPU'))



# Fetch row indices
x = tf.random.normal([3,2])
x = tf.convert_to_tensor(x)

indices = tf.convert_to_tensor([0,1,0])
one_hot_indices = tf.expand_dims(indices, 1)
range = tf.expand_dims(tf.range(tf.shape(indices)[0]), 1)
ind = tf.concat([range, one_hot_indices], 1)
tf.gather_nd(x, ind)