You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

621 lines
16KB

  1. """Imported from the recipes section of the itertools documentation.
  2. All functions taken from the recipes section of the itertools library docs
  3. [1]_.
  4. Some backward-compatible usability improvements have been made.
  5. .. [1] http://docs.python.org/library/itertools.html#recipes
  6. """
  7. import warnings
  8. from collections import deque
  9. from itertools import (
  10. chain,
  11. combinations,
  12. count,
  13. cycle,
  14. groupby,
  15. islice,
  16. repeat,
  17. starmap,
  18. tee,
  19. zip_longest,
  20. )
  21. import operator
  22. from random import randrange, sample, choice
  23. __all__ = [
  24. 'all_equal',
  25. 'consume',
  26. 'convolve',
  27. 'dotproduct',
  28. 'first_true',
  29. 'flatten',
  30. 'grouper',
  31. 'iter_except',
  32. 'ncycles',
  33. 'nth',
  34. 'nth_combination',
  35. 'padnone',
  36. 'pad_none',
  37. 'pairwise',
  38. 'partition',
  39. 'powerset',
  40. 'prepend',
  41. 'quantify',
  42. 'random_combination_with_replacement',
  43. 'random_combination',
  44. 'random_permutation',
  45. 'random_product',
  46. 'repeatfunc',
  47. 'roundrobin',
  48. 'tabulate',
  49. 'tail',
  50. 'take',
  51. 'unique_everseen',
  52. 'unique_justseen',
  53. ]
  54. def take(n, iterable):
  55. """Return first *n* items of the iterable as a list.
  56. >>> take(3, range(10))
  57. [0, 1, 2]
  58. If there are fewer than *n* items in the iterable, all of them are
  59. returned.
  60. >>> take(10, range(3))
  61. [0, 1, 2]
  62. """
  63. return list(islice(iterable, n))
  64. def tabulate(function, start=0):
  65. """Return an iterator over the results of ``func(start)``,
  66. ``func(start + 1)``, ``func(start + 2)``...
  67. *func* should be a function that accepts one integer argument.
  68. If *start* is not specified it defaults to 0. It will be incremented each
  69. time the iterator is advanced.
  70. >>> square = lambda x: x ** 2
  71. >>> iterator = tabulate(square, -3)
  72. >>> take(4, iterator)
  73. [9, 4, 1, 0]
  74. """
  75. return map(function, count(start))
  76. def tail(n, iterable):
  77. """Return an iterator over the last *n* items of *iterable*.
  78. >>> t = tail(3, 'ABCDEFG')
  79. >>> list(t)
  80. ['E', 'F', 'G']
  81. """
  82. return iter(deque(iterable, maxlen=n))
  83. def consume(iterator, n=None):
  84. """Advance *iterable* by *n* steps. If *n* is ``None``, consume it
  85. entirely.
  86. Efficiently exhausts an iterator without returning values. Defaults to
  87. consuming the whole iterator, but an optional second argument may be
  88. provided to limit consumption.
  89. >>> i = (x for x in range(10))
  90. >>> next(i)
  91. 0
  92. >>> consume(i, 3)
  93. >>> next(i)
  94. 4
  95. >>> consume(i)
  96. >>> next(i)
  97. Traceback (most recent call last):
  98. File "<stdin>", line 1, in <module>
  99. StopIteration
  100. If the iterator has fewer items remaining than the provided limit, the
  101. whole iterator will be consumed.
  102. >>> i = (x for x in range(3))
  103. >>> consume(i, 5)
  104. >>> next(i)
  105. Traceback (most recent call last):
  106. File "<stdin>", line 1, in <module>
  107. StopIteration
  108. """
  109. # Use functions that consume iterators at C speed.
  110. if n is None:
  111. # feed the entire iterator into a zero-length deque
  112. deque(iterator, maxlen=0)
  113. else:
  114. # advance to the empty slice starting at position n
  115. next(islice(iterator, n, n), None)
  116. def nth(iterable, n, default=None):
  117. """Returns the nth item or a default value.
  118. >>> l = range(10)
  119. >>> nth(l, 3)
  120. 3
  121. >>> nth(l, 20, "zebra")
  122. 'zebra'
  123. """
  124. return next(islice(iterable, n, None), default)
  125. def all_equal(iterable):
  126. """
  127. Returns ``True`` if all the elements are equal to each other.
  128. >>> all_equal('aaaa')
  129. True
  130. >>> all_equal('aaab')
  131. False
  132. """
  133. g = groupby(iterable)
  134. return next(g, True) and not next(g, False)
  135. def quantify(iterable, pred=bool):
  136. """Return the how many times the predicate is true.
  137. >>> quantify([True, False, True])
  138. 2
  139. """
  140. return sum(map(pred, iterable))
  141. def pad_none(iterable):
  142. """Returns the sequence of elements and then returns ``None`` indefinitely.
  143. >>> take(5, pad_none(range(3)))
  144. [0, 1, 2, None, None]
  145. Useful for emulating the behavior of the built-in :func:`map` function.
  146. See also :func:`padded`.
  147. """
  148. return chain(iterable, repeat(None))
  149. padnone = pad_none
  150. def ncycles(iterable, n):
  151. """Returns the sequence elements *n* times
  152. >>> list(ncycles(["a", "b"], 3))
  153. ['a', 'b', 'a', 'b', 'a', 'b']
  154. """
  155. return chain.from_iterable(repeat(tuple(iterable), n))
  156. def dotproduct(vec1, vec2):
  157. """Returns the dot product of the two iterables.
  158. >>> dotproduct([10, 10], [20, 20])
  159. 400
  160. """
  161. return sum(map(operator.mul, vec1, vec2))
  162. def flatten(listOfLists):
  163. """Return an iterator flattening one level of nesting in a list of lists.
  164. >>> list(flatten([[0, 1], [2, 3]]))
  165. [0, 1, 2, 3]
  166. See also :func:`collapse`, which can flatten multiple levels of nesting.
  167. """
  168. return chain.from_iterable(listOfLists)
  169. def repeatfunc(func, times=None, *args):
  170. """Call *func* with *args* repeatedly, returning an iterable over the
  171. results.
  172. If *times* is specified, the iterable will terminate after that many
  173. repetitions:
  174. >>> from operator import add
  175. >>> times = 4
  176. >>> args = 3, 5
  177. >>> list(repeatfunc(add, times, *args))
  178. [8, 8, 8, 8]
  179. If *times* is ``None`` the iterable will not terminate:
  180. >>> from random import randrange
  181. >>> times = None
  182. >>> args = 1, 11
  183. >>> take(6, repeatfunc(randrange, times, *args)) # doctest:+SKIP
  184. [2, 4, 8, 1, 8, 4]
  185. """
  186. if times is None:
  187. return starmap(func, repeat(args))
  188. return starmap(func, repeat(args, times))
  189. def _pairwise(iterable):
  190. """Returns an iterator of paired items, overlapping, from the original
  191. >>> take(4, pairwise(count()))
  192. [(0, 1), (1, 2), (2, 3), (3, 4)]
  193. On Python 3.10 and above, this is an alias for :func:`itertools.pairwise`.
  194. """
  195. a, b = tee(iterable)
  196. next(b, None)
  197. yield from zip(a, b)
  198. try:
  199. from itertools import pairwise as itertools_pairwise
  200. except ImportError:
  201. pairwise = _pairwise
  202. else:
  203. def pairwise(iterable):
  204. yield from itertools_pairwise(iterable)
  205. pairwise.__doc__ = _pairwise.__doc__
  206. def grouper(iterable, n, fillvalue=None):
  207. """Collect data into fixed-length chunks or blocks.
  208. >>> list(grouper('ABCDEFG', 3, 'x'))
  209. [('A', 'B', 'C'), ('D', 'E', 'F'), ('G', 'x', 'x')]
  210. """
  211. if isinstance(iterable, int):
  212. warnings.warn(
  213. "grouper expects iterable as first parameter", DeprecationWarning
  214. )
  215. n, iterable = iterable, n
  216. args = [iter(iterable)] * n
  217. return zip_longest(fillvalue=fillvalue, *args)
  218. def roundrobin(*iterables):
  219. """Yields an item from each iterable, alternating between them.
  220. >>> list(roundrobin('ABC', 'D', 'EF'))
  221. ['A', 'D', 'E', 'B', 'F', 'C']
  222. This function produces the same output as :func:`interleave_longest`, but
  223. may perform better for some inputs (in particular when the number of
  224. iterables is small).
  225. """
  226. # Recipe credited to George Sakkis
  227. pending = len(iterables)
  228. nexts = cycle(iter(it).__next__ for it in iterables)
  229. while pending:
  230. try:
  231. for next in nexts:
  232. yield next()
  233. except StopIteration:
  234. pending -= 1
  235. nexts = cycle(islice(nexts, pending))
  236. def partition(pred, iterable):
  237. """
  238. Returns a 2-tuple of iterables derived from the input iterable.
  239. The first yields the items that have ``pred(item) == False``.
  240. The second yields the items that have ``pred(item) == True``.
  241. >>> is_odd = lambda x: x % 2 != 0
  242. >>> iterable = range(10)
  243. >>> even_items, odd_items = partition(is_odd, iterable)
  244. >>> list(even_items), list(odd_items)
  245. ([0, 2, 4, 6, 8], [1, 3, 5, 7, 9])
  246. If *pred* is None, :func:`bool` is used.
  247. >>> iterable = [0, 1, False, True, '', ' ']
  248. >>> false_items, true_items = partition(None, iterable)
  249. >>> list(false_items), list(true_items)
  250. ([0, False, ''], [1, True, ' '])
  251. """
  252. if pred is None:
  253. pred = bool
  254. evaluations = ((pred(x), x) for x in iterable)
  255. t1, t2 = tee(evaluations)
  256. return (
  257. (x for (cond, x) in t1 if not cond),
  258. (x for (cond, x) in t2 if cond),
  259. )
  260. def powerset(iterable):
  261. """Yields all possible subsets of the iterable.
  262. >>> list(powerset([1, 2, 3]))
  263. [(), (1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)]
  264. :func:`powerset` will operate on iterables that aren't :class:`set`
  265. instances, so repeated elements in the input will produce repeated elements
  266. in the output. Use :func:`unique_everseen` on the input to avoid generating
  267. duplicates:
  268. >>> seq = [1, 1, 0]
  269. >>> list(powerset(seq))
  270. [(), (1,), (1,), (0,), (1, 1), (1, 0), (1, 0), (1, 1, 0)]
  271. >>> from more_itertools import unique_everseen
  272. >>> list(powerset(unique_everseen(seq)))
  273. [(), (1,), (0,), (1, 0)]
  274. """
  275. s = list(iterable)
  276. return chain.from_iterable(combinations(s, r) for r in range(len(s) + 1))
  277. def unique_everseen(iterable, key=None):
  278. """
  279. Yield unique elements, preserving order.
  280. >>> list(unique_everseen('AAAABBBCCDAABBB'))
  281. ['A', 'B', 'C', 'D']
  282. >>> list(unique_everseen('ABBCcAD', str.lower))
  283. ['A', 'B', 'C', 'D']
  284. Sequences with a mix of hashable and unhashable items can be used.
  285. The function will be slower (i.e., `O(n^2)`) for unhashable items.
  286. Remember that ``list`` objects are unhashable - you can use the *key*
  287. parameter to transform the list to a tuple (which is hashable) to
  288. avoid a slowdown.
  289. >>> iterable = ([1, 2], [2, 3], [1, 2])
  290. >>> list(unique_everseen(iterable)) # Slow
  291. [[1, 2], [2, 3]]
  292. >>> list(unique_everseen(iterable, key=tuple)) # Faster
  293. [[1, 2], [2, 3]]
  294. Similary, you may want to convert unhashable ``set`` objects with
  295. ``key=frozenset``. For ``dict`` objects,
  296. ``key=lambda x: frozenset(x.items())`` can be used.
  297. """
  298. seenset = set()
  299. seenset_add = seenset.add
  300. seenlist = []
  301. seenlist_add = seenlist.append
  302. use_key = key is not None
  303. for element in iterable:
  304. k = key(element) if use_key else element
  305. try:
  306. if k not in seenset:
  307. seenset_add(k)
  308. yield element
  309. except TypeError:
  310. if k not in seenlist:
  311. seenlist_add(k)
  312. yield element
  313. def unique_justseen(iterable, key=None):
  314. """Yields elements in order, ignoring serial duplicates
  315. >>> list(unique_justseen('AAAABBBCCDAABBB'))
  316. ['A', 'B', 'C', 'D', 'A', 'B']
  317. >>> list(unique_justseen('ABBCcAD', str.lower))
  318. ['A', 'B', 'C', 'A', 'D']
  319. """
  320. return map(next, map(operator.itemgetter(1), groupby(iterable, key)))
  321. def iter_except(func, exception, first=None):
  322. """Yields results from a function repeatedly until an exception is raised.
  323. Converts a call-until-exception interface to an iterator interface.
  324. Like ``iter(func, sentinel)``, but uses an exception instead of a sentinel
  325. to end the loop.
  326. >>> l = [0, 1, 2]
  327. >>> list(iter_except(l.pop, IndexError))
  328. [2, 1, 0]
  329. """
  330. try:
  331. if first is not None:
  332. yield first()
  333. while 1:
  334. yield func()
  335. except exception:
  336. pass
  337. def first_true(iterable, default=None, pred=None):
  338. """
  339. Returns the first true value in the iterable.
  340. If no true value is found, returns *default*
  341. If *pred* is not None, returns the first item for which
  342. ``pred(item) == True`` .
  343. >>> first_true(range(10))
  344. 1
  345. >>> first_true(range(10), pred=lambda x: x > 5)
  346. 6
  347. >>> first_true(range(10), default='missing', pred=lambda x: x > 9)
  348. 'missing'
  349. """
  350. return next(filter(pred, iterable), default)
  351. def random_product(*args, repeat=1):
  352. """Draw an item at random from each of the input iterables.
  353. >>> random_product('abc', range(4), 'XYZ') # doctest:+SKIP
  354. ('c', 3, 'Z')
  355. If *repeat* is provided as a keyword argument, that many items will be
  356. drawn from each iterable.
  357. >>> random_product('abcd', range(4), repeat=2) # doctest:+SKIP
  358. ('a', 2, 'd', 3)
  359. This equivalent to taking a random selection from
  360. ``itertools.product(*args, **kwarg)``.
  361. """
  362. pools = [tuple(pool) for pool in args] * repeat
  363. return tuple(choice(pool) for pool in pools)
  364. def random_permutation(iterable, r=None):
  365. """Return a random *r* length permutation of the elements in *iterable*.
  366. If *r* is not specified or is ``None``, then *r* defaults to the length of
  367. *iterable*.
  368. >>> random_permutation(range(5)) # doctest:+SKIP
  369. (3, 4, 0, 1, 2)
  370. This equivalent to taking a random selection from
  371. ``itertools.permutations(iterable, r)``.
  372. """
  373. pool = tuple(iterable)
  374. r = len(pool) if r is None else r
  375. return tuple(sample(pool, r))
  376. def random_combination(iterable, r):
  377. """Return a random *r* length subsequence of the elements in *iterable*.
  378. >>> random_combination(range(5), 3) # doctest:+SKIP
  379. (2, 3, 4)
  380. This equivalent to taking a random selection from
  381. ``itertools.combinations(iterable, r)``.
  382. """
  383. pool = tuple(iterable)
  384. n = len(pool)
  385. indices = sorted(sample(range(n), r))
  386. return tuple(pool[i] for i in indices)
  387. def random_combination_with_replacement(iterable, r):
  388. """Return a random *r* length subsequence of elements in *iterable*,
  389. allowing individual elements to be repeated.
  390. >>> random_combination_with_replacement(range(3), 5) # doctest:+SKIP
  391. (0, 0, 1, 2, 2)
  392. This equivalent to taking a random selection from
  393. ``itertools.combinations_with_replacement(iterable, r)``.
  394. """
  395. pool = tuple(iterable)
  396. n = len(pool)
  397. indices = sorted(randrange(n) for i in range(r))
  398. return tuple(pool[i] for i in indices)
  399. def nth_combination(iterable, r, index):
  400. """Equivalent to ``list(combinations(iterable, r))[index]``.
  401. The subsequences of *iterable* that are of length *r* can be ordered
  402. lexicographically. :func:`nth_combination` computes the subsequence at
  403. sort position *index* directly, without computing the previous
  404. subsequences.
  405. >>> nth_combination(range(5), 3, 5)
  406. (0, 3, 4)
  407. ``ValueError`` will be raised If *r* is negative or greater than the length
  408. of *iterable*.
  409. ``IndexError`` will be raised if the given *index* is invalid.
  410. """
  411. pool = tuple(iterable)
  412. n = len(pool)
  413. if (r < 0) or (r > n):
  414. raise ValueError
  415. c = 1
  416. k = min(r, n - r)
  417. for i in range(1, k + 1):
  418. c = c * (n - k + i) // i
  419. if index < 0:
  420. index += c
  421. if (index < 0) or (index >= c):
  422. raise IndexError
  423. result = []
  424. while r:
  425. c, n, r = c * r // n, n - 1, r - 1
  426. while index >= c:
  427. index -= c
  428. c, n = c * (n - r) // n, n - 1
  429. result.append(pool[-1 - n])
  430. return tuple(result)
  431. def prepend(value, iterator):
  432. """Yield *value*, followed by the elements in *iterator*.
  433. >>> value = '0'
  434. >>> iterator = ['1', '2', '3']
  435. >>> list(prepend(value, iterator))
  436. ['0', '1', '2', '3']
  437. To prepend multiple values, see :func:`itertools.chain`
  438. or :func:`value_chain`.
  439. """
  440. return chain([value], iterator)
  441. def convolve(signal, kernel):
  442. """Convolve the iterable *signal* with the iterable *kernel*.
  443. >>> signal = (1, 2, 3, 4, 5)
  444. >>> kernel = [3, 2, 1]
  445. >>> list(convolve(signal, kernel))
  446. [3, 8, 14, 20, 26, 14, 5]
  447. Note: the input arguments are not interchangeable, as the *kernel*
  448. is immediately consumed and stored.
  449. """
  450. kernel = tuple(kernel)[::-1]
  451. n = len(kernel)
  452. window = deque([0], maxlen=n) * n
  453. for x in chain(signal, repeat(0, n - 1)):
  454. window.append(x)
  455. yield sum(map(operator.mul, kernel, window))