Passing a comparison operator in Python
October 28, 2013So, earlier today I was refactoring some code to be more general purpose, and said code involved a comparison operator. I was already using the same code in two places, but copy-pasted with different comparison operators switched in. Eww.
I didn't quite know how to pull this off, given that I'm not that well versed on Python's introspection faculties... but I gave it a whack. My first thought was the niave approach, of passing in a string to specify mode...
| ## approach 1 | |
| def myfun(a, b, mode='<'): | |
| valid_modes = ['<', '<=', '>', '>=', '==', '!='] | |
| if mode not in valid_modes: | |
| print 'Error!' | |
| ... # error handling code here | |
| return | |
| if mode == '<': | |
| if a < b: | |
| ... | |
| elif mode == '<=': | |
| if a <= b: | |
| ... | |
| elif mode == '>': | |
| if a > b: | |
| ... | |
| elif mode == '>=': | |
| if a >= b: | |
| ... | |
| elif mode == '==': | |
| if a == b: | |
| ... | |
| elif mode == '!=': | |
| if a != b: | |
| ... | |
| else: | |
| print 'Something went horribly wrong. Blame cosmic rays.' |
Yuck.
Then I realized that functions are pretty much first class in python. Additionally, lambdas are short and sweet. A few minutes later, I was rolling with this instead:
| ## approach 2 | |
| # comparators | |
| cmp_lt = lambda x,y: x < y | |
| cmp_gt = lambda x,y: x > y | |
| cmp_eq = lambda x,y: x == y | |
| cmp_lte = lambda x,y: x <= y | |
| cmp_gte = lambda x,y: x >= y | |
| cmp_neq = lambda x,y: x != y | |
| def myfun(a, b, comparator): | |
| if comparator(a,b): | |
| ... | |
So much nicer. Instead of passing a mode string, now you just pass one of those comparators, ala:
myfun(4,7,comparator=cmp_lt)
Update: A friend pointed out a few other ways of accomplishing this -- one way more pythonic, and one incredibly unsafe.
The pythonic solution relies on the 'operator' module:
| ## python is batteries included! | |
| import operator | |
| def myfun(a, b, comparator): | |
| if comparator(a,b): | |
| ... |
That's much like approach 3, but the call is now:
myfun(4,7,comparator=operator.lt)
The unsafe way is to evaluate it as a string! Here we go:
| ## don't do this, dummy | |
| def myfun(a, b, mode='<'): | |
| if eval(str(a)+mode+str(b)): | |
| ... |
Man, I love Python. We're not starved for options!
Hope this helps someone else!