You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

372 lines
11KB

  1. import os
  2. import sys
  3. import py
  4. import tempfile
  5. try:
  6. from io import StringIO
  7. except ImportError:
  8. from StringIO import StringIO
  9. if sys.version_info < (3,0):
  10. class TextIO(StringIO):
  11. def write(self, data):
  12. if not isinstance(data, unicode):
  13. data = unicode(data, getattr(self, '_encoding', 'UTF-8'), 'replace')
  14. return StringIO.write(self, data)
  15. else:
  16. TextIO = StringIO
  17. try:
  18. from io import BytesIO
  19. except ImportError:
  20. class BytesIO(StringIO):
  21. def write(self, data):
  22. if isinstance(data, unicode):
  23. raise TypeError("not a byte value: %r" %(data,))
  24. return StringIO.write(self, data)
  25. patchsysdict = {0: 'stdin', 1: 'stdout', 2: 'stderr'}
  26. class FDCapture:
  27. """ Capture IO to/from a given os-level filedescriptor. """
  28. def __init__(self, targetfd, tmpfile=None, now=True, patchsys=False):
  29. """ save targetfd descriptor, and open a new
  30. temporary file there. If no tmpfile is
  31. specified a tempfile.Tempfile() will be opened
  32. in text mode.
  33. """
  34. self.targetfd = targetfd
  35. if tmpfile is None and targetfd != 0:
  36. f = tempfile.TemporaryFile('wb+')
  37. tmpfile = dupfile(f, encoding="UTF-8")
  38. f.close()
  39. self.tmpfile = tmpfile
  40. self._savefd = os.dup(self.targetfd)
  41. if patchsys:
  42. self._oldsys = getattr(sys, patchsysdict[targetfd])
  43. if now:
  44. self.start()
  45. def start(self):
  46. try:
  47. os.fstat(self._savefd)
  48. except OSError:
  49. raise ValueError("saved filedescriptor not valid, "
  50. "did you call start() twice?")
  51. if self.targetfd == 0 and not self.tmpfile:
  52. fd = os.open(devnullpath, os.O_RDONLY)
  53. os.dup2(fd, 0)
  54. os.close(fd)
  55. if hasattr(self, '_oldsys'):
  56. setattr(sys, patchsysdict[self.targetfd], DontReadFromInput())
  57. else:
  58. os.dup2(self.tmpfile.fileno(), self.targetfd)
  59. if hasattr(self, '_oldsys'):
  60. setattr(sys, patchsysdict[self.targetfd], self.tmpfile)
  61. def done(self):
  62. """ unpatch and clean up, returns the self.tmpfile (file object)
  63. """
  64. os.dup2(self._savefd, self.targetfd)
  65. os.close(self._savefd)
  66. if self.targetfd != 0:
  67. self.tmpfile.seek(0)
  68. if hasattr(self, '_oldsys'):
  69. setattr(sys, patchsysdict[self.targetfd], self._oldsys)
  70. return self.tmpfile
  71. def writeorg(self, data):
  72. """ write a string to the original file descriptor
  73. """
  74. tempfp = tempfile.TemporaryFile()
  75. try:
  76. os.dup2(self._savefd, tempfp.fileno())
  77. tempfp.write(data)
  78. finally:
  79. tempfp.close()
  80. def dupfile(f, mode=None, buffering=0, raising=False, encoding=None):
  81. """ return a new open file object that's a duplicate of f
  82. mode is duplicated if not given, 'buffering' controls
  83. buffer size (defaulting to no buffering) and 'raising'
  84. defines whether an exception is raised when an incompatible
  85. file object is passed in (if raising is False, the file
  86. object itself will be returned)
  87. """
  88. try:
  89. fd = f.fileno()
  90. mode = mode or f.mode
  91. except AttributeError:
  92. if raising:
  93. raise
  94. return f
  95. newfd = os.dup(fd)
  96. if sys.version_info >= (3,0):
  97. if encoding is not None:
  98. mode = mode.replace("b", "")
  99. buffering = True
  100. return os.fdopen(newfd, mode, buffering, encoding, closefd=True)
  101. else:
  102. f = os.fdopen(newfd, mode, buffering)
  103. if encoding is not None:
  104. return EncodedFile(f, encoding)
  105. return f
  106. class EncodedFile(object):
  107. def __init__(self, _stream, encoding):
  108. self._stream = _stream
  109. self.encoding = encoding
  110. def write(self, obj):
  111. if isinstance(obj, unicode):
  112. obj = obj.encode(self.encoding)
  113. elif isinstance(obj, str):
  114. pass
  115. else:
  116. obj = str(obj)
  117. self._stream.write(obj)
  118. def writelines(self, linelist):
  119. data = ''.join(linelist)
  120. self.write(data)
  121. def __getattr__(self, name):
  122. return getattr(self._stream, name)
  123. class Capture(object):
  124. def call(cls, func, *args, **kwargs):
  125. """ return a (res, out, err) tuple where
  126. out and err represent the output/error output
  127. during function execution.
  128. call the given function with args/kwargs
  129. and capture output/error during its execution.
  130. """
  131. so = cls()
  132. try:
  133. res = func(*args, **kwargs)
  134. finally:
  135. out, err = so.reset()
  136. return res, out, err
  137. call = classmethod(call)
  138. def reset(self):
  139. """ reset sys.stdout/stderr and return captured output as strings. """
  140. if hasattr(self, '_reset'):
  141. raise ValueError("was already reset")
  142. self._reset = True
  143. outfile, errfile = self.done(save=False)
  144. out, err = "", ""
  145. if outfile and not outfile.closed:
  146. out = outfile.read()
  147. outfile.close()
  148. if errfile and errfile != outfile and not errfile.closed:
  149. err = errfile.read()
  150. errfile.close()
  151. return out, err
  152. def suspend(self):
  153. """ return current snapshot captures, memorize tempfiles. """
  154. outerr = self.readouterr()
  155. outfile, errfile = self.done()
  156. return outerr
  157. class StdCaptureFD(Capture):
  158. """ This class allows to capture writes to FD1 and FD2
  159. and may connect a NULL file to FD0 (and prevent
  160. reads from sys.stdin). If any of the 0,1,2 file descriptors
  161. is invalid it will not be captured.
  162. """
  163. def __init__(self, out=True, err=True, mixed=False,
  164. in_=True, patchsys=True, now=True):
  165. self._options = {
  166. "out": out,
  167. "err": err,
  168. "mixed": mixed,
  169. "in_": in_,
  170. "patchsys": patchsys,
  171. "now": now,
  172. }
  173. self._save()
  174. if now:
  175. self.startall()
  176. def _save(self):
  177. in_ = self._options['in_']
  178. out = self._options['out']
  179. err = self._options['err']
  180. mixed = self._options['mixed']
  181. patchsys = self._options['patchsys']
  182. if in_:
  183. try:
  184. self.in_ = FDCapture(0, tmpfile=None, now=False,
  185. patchsys=patchsys)
  186. except OSError:
  187. pass
  188. if out:
  189. tmpfile = None
  190. if hasattr(out, 'write'):
  191. tmpfile = out
  192. try:
  193. self.out = FDCapture(1, tmpfile=tmpfile,
  194. now=False, patchsys=patchsys)
  195. self._options['out'] = self.out.tmpfile
  196. except OSError:
  197. pass
  198. if err:
  199. if out and mixed:
  200. tmpfile = self.out.tmpfile
  201. elif hasattr(err, 'write'):
  202. tmpfile = err
  203. else:
  204. tmpfile = None
  205. try:
  206. self.err = FDCapture(2, tmpfile=tmpfile,
  207. now=False, patchsys=patchsys)
  208. self._options['err'] = self.err.tmpfile
  209. except OSError:
  210. pass
  211. def startall(self):
  212. if hasattr(self, 'in_'):
  213. self.in_.start()
  214. if hasattr(self, 'out'):
  215. self.out.start()
  216. if hasattr(self, 'err'):
  217. self.err.start()
  218. def resume(self):
  219. """ resume capturing with original temp files. """
  220. self.startall()
  221. def done(self, save=True):
  222. """ return (outfile, errfile) and stop capturing. """
  223. outfile = errfile = None
  224. if hasattr(self, 'out') and not self.out.tmpfile.closed:
  225. outfile = self.out.done()
  226. if hasattr(self, 'err') and not self.err.tmpfile.closed:
  227. errfile = self.err.done()
  228. if hasattr(self, 'in_'):
  229. tmpfile = self.in_.done()
  230. if save:
  231. self._save()
  232. return outfile, errfile
  233. def readouterr(self):
  234. """ return snapshot value of stdout/stderr capturings. """
  235. if hasattr(self, "out"):
  236. out = self._readsnapshot(self.out.tmpfile)
  237. else:
  238. out = ""
  239. if hasattr(self, "err"):
  240. err = self._readsnapshot(self.err.tmpfile)
  241. else:
  242. err = ""
  243. return out, err
  244. def _readsnapshot(self, f):
  245. f.seek(0)
  246. res = f.read()
  247. enc = getattr(f, "encoding", None)
  248. if enc:
  249. res = py.builtin._totext(res, enc, "replace")
  250. f.truncate(0)
  251. f.seek(0)
  252. return res
  253. class StdCapture(Capture):
  254. """ This class allows to capture writes to sys.stdout|stderr "in-memory"
  255. and will raise errors on tries to read from sys.stdin. It only
  256. modifies sys.stdout|stderr|stdin attributes and does not
  257. touch underlying File Descriptors (use StdCaptureFD for that).
  258. """
  259. def __init__(self, out=True, err=True, in_=True, mixed=False, now=True):
  260. self._oldout = sys.stdout
  261. self._olderr = sys.stderr
  262. self._oldin = sys.stdin
  263. if out and not hasattr(out, 'file'):
  264. out = TextIO()
  265. self.out = out
  266. if err:
  267. if mixed:
  268. err = out
  269. elif not hasattr(err, 'write'):
  270. err = TextIO()
  271. self.err = err
  272. self.in_ = in_
  273. if now:
  274. self.startall()
  275. def startall(self):
  276. if self.out:
  277. sys.stdout = self.out
  278. if self.err:
  279. sys.stderr = self.err
  280. if self.in_:
  281. sys.stdin = self.in_ = DontReadFromInput()
  282. def done(self, save=True):
  283. """ return (outfile, errfile) and stop capturing. """
  284. outfile = errfile = None
  285. if self.out and not self.out.closed:
  286. sys.stdout = self._oldout
  287. outfile = self.out
  288. outfile.seek(0)
  289. if self.err and not self.err.closed:
  290. sys.stderr = self._olderr
  291. errfile = self.err
  292. errfile.seek(0)
  293. if self.in_:
  294. sys.stdin = self._oldin
  295. return outfile, errfile
  296. def resume(self):
  297. """ resume capturing with original temp files. """
  298. self.startall()
  299. def readouterr(self):
  300. """ return snapshot value of stdout/stderr capturings. """
  301. out = err = ""
  302. if self.out:
  303. out = self.out.getvalue()
  304. self.out.truncate(0)
  305. self.out.seek(0)
  306. if self.err:
  307. err = self.err.getvalue()
  308. self.err.truncate(0)
  309. self.err.seek(0)
  310. return out, err
  311. class DontReadFromInput:
  312. """Temporary stub class. Ideally when stdin is accessed, the
  313. capturing should be turned off, with possibly all data captured
  314. so far sent to the screen. This should be configurable, though,
  315. because in automated test runs it is better to crash than
  316. hang indefinitely.
  317. """
  318. def read(self, *args):
  319. raise IOError("reading from stdin while output is captured")
  320. readline = read
  321. readlines = read
  322. __iter__ = read
  323. def fileno(self):
  324. raise ValueError("redirected Stdin is pseudofile, has no fileno()")
  325. def isatty(self):
  326. return False
  327. def close(self):
  328. pass
  329. try:
  330. devnullpath = os.devnull
  331. except AttributeError:
  332. if os.name == 'nt':
  333. devnullpath = 'NUL'
  334. else:
  335. devnullpath = '/dev/null'