Select an item from a list of object of any type when using tensorflow 2.x

Submitted 3 years, 7 months ago
Ticket #134
Views 210
Language/Framework Python
Priority Medium
Status Closed

Given a list of instances of class A[A() for _ in range(5)], I want to randomly select one of them (see the following code for an example)

class A:
    def __init__(self, a):
        self.a = a
    def __call__(self):
        return self.a
def f():
    a_list = [A(i) for i in range(5)]
    a = a_list[random.randint(0, 5)]()
    return a

f()

Is there is a way to decorate f with  `@tf.function` without changing what f does and without calling all items in `a_list`?

Note that directly decorating f with `@tf.function` without any other changing to the above code is infeasible as it will always return the same result. Also, I know that this can be achieved by calling all elements in `a_list` first and then index them using `tf.gather_nd`. But this will incur a large amount of overhead if calling an object of type A involves a deep neural network.

Submitted on Sep 10, 20
add a comment

1 Answer

Verified

@tf.function
def f2():
    a_list = [A(i) for i in range(5)]
    idx = tf.cast(tf.random.uniform(shape=[], maxval=4), tf.int32)
    return tf.switch_case(idx, a_list)

For a speed comparison I made the call method of A expensive matrix algebra. Then consider an alternate function which invokes every function:

@tf.function
def f3():
    a_list = [A(i) for i in range(40)]
    results = [a() for a in a_list]
    return results

Running f2 with 40 elements: 0.42643 seconds

Running f3 with 40 elements: 14.9153 seconds

So that looks to be right about exactly the expected 40x speedup for only choosing one branch.

Submitted 3 years, 6 months ago


Latest Blogs