Passing a comparison operator in Python
October 28, 2013So, earlier today I was refactoring some code to be more general purpose, and said code involved a comparison operator. I was already using the same code in two places, but copy-pasted with different comparison operators switched in. Eww.
I didn't quite know how to pull this off, given that I'm not that well versed on Python's introspection faculties... but I gave it a whack. My first thought was the niave approach, of passing in a string to specify mode...
## approach 1 | |
def myfun(a, b, mode='<'): | |
valid_modes = ['<', '<=', '>', '>=', '==', '!='] | |
if mode not in valid_modes: | |
print 'Error!' | |
... # error handling code here | |
return | |
if mode == '<': | |
if a < b: | |
... | |
elif mode == '<=': | |
if a <= b: | |
... | |
elif mode == '>': | |
if a > b: | |
... | |
elif mode == '>=': | |
if a >= b: | |
... | |
elif mode == '==': | |
if a == b: | |
... | |
elif mode == '!=': | |
if a != b: | |
... | |
else: | |
print 'Something went horribly wrong. Blame cosmic rays.' |
Yuck.
Then I realized that functions are pretty much first class in python. Additionally, lambdas are short and sweet. A few minutes later, I was rolling with this instead:
## approach 2 | |
# comparators | |
cmp_lt = lambda x,y: x < y | |
cmp_gt = lambda x,y: x > y | |
cmp_eq = lambda x,y: x == y | |
cmp_lte = lambda x,y: x <= y | |
cmp_gte = lambda x,y: x >= y | |
cmp_neq = lambda x,y: x != y | |
def myfun(a, b, comparator): | |
if comparator(a,b): | |
... | |
So much nicer. Instead of passing a mode string, now you just pass one of those comparators, ala:
myfun(4,7,comparator=cmp_lt)
Update: A friend pointed out a few other ways of accomplishing this -- one way more pythonic, and one incredibly unsafe.
The pythonic solution relies on the 'operator' module:
## python is batteries included! | |
import operator | |
def myfun(a, b, comparator): | |
if comparator(a,b): | |
... |
That's much like approach 3, but the call is now:
myfun(4,7,comparator=operator.lt)
The unsafe way is to evaluate it as a string! Here we go:
## don't do this, dummy | |
def myfun(a, b, mode='<'): | |
if eval(str(a)+mode+str(b)): | |
... |
Man, I love Python. We're not starved for options!
Hope this helps someone else!