#11 Fix several issues around quoting, background processes, and encoding
Closed 6 years ago by pviktori. Opened 6 years ago by pviktori.
pviktori/python-pytest-multihost encodings-background  into  master

file modified
@@ -172,6 +172,22 @@ 

  To use YAML files, the PyYAML package is required. Without it only JSON files

  can be used.



+ Encoding and bytes/text

+ -----------------------


+ When writing files or issuing commands, bytestrings are passed through

+ unchanged, and text strings (``unicode`` in Python 2) are encoded using

+ a configurable encoding (``utf-8`` by default).


+ When reading files, bytestrings are returned by default,

+ but an encoding can be given to get a test string.


+ For command output, separate ``stdout_bytes`` and ``stdout_text`` attributes

+ are provided.

+ The latter uses a configurable encoding (``utf-8` by default).






file modified
+46 -19
@@ -25,7 +25,7 @@ 

      See README for an overview of the core classes.


      transport_class = transport.SSHTransport

-     command_prelude = ''

+     command_prelude = b''


      def __init__(self, domain, hostname, role, ip=None,

                   external_hostname=None, username=None, password=None,
@@ -190,9 +190,9 @@ 

          """Shortcut for transport.get_file_contents"""

          return self.transport.get_file_contents(filename, encoding=encoding)


-     def put_file_contents(self, filename, contents):

+     def put_file_contents(self, filename, contents, encoding='utf-8'):

          """Shortcut for transport.put_file_contents"""

-         self.transport.put_file_contents(filename, contents)

+         self.transport.put_file_contents(filename, contents, encoding=encoding)


      def collect_log(self, filename):

          """Call all registered log collectors on the given filename"""
@@ -201,7 +201,7 @@ 


      def run_command(self, argv, set_env=True, stdin_text=None,

                      log_stdout=True, raiseonerr=True,

-                     cwd=None, bg=False):

+                     cwd=None, bg=False, encoding='utf-8'):

          """Run the given command on this host


          Returns a Command instance. The command will have already run in the
@@ -218,45 +218,72 @@ 

          :param raiseonerr: If true, an exception will be raised if the command

                             does not exit with return code 0

          :param cwd: The working directory for the command

-         :param bg: If True, runs command in background

+         :param bg: If True, runs command in background.

+                    In this case, either the result should be used in a ``with``

+                    statement, or ``wait()`` should be called explicitly

+                    when the command is finished.

+         :param encoding: Encoding for the resulting Command instance's

+                          ``stdout_text`` and ``stderr_text``, and for

+                          ``stdin_text``, ``argv``, etc. if they are not

+                          bytestrings already.


-         command = self.transport.start_shell(argv, log_stdout=log_stdout)

+         def encode(string):

+             if not isinstance(string, bytes):

+                 return string.encode(encoding)

+             else:

+                 return string


+         command = self.transport.start_shell(argv, log_stdout=log_stdout,

+                                              encoding=encoding)

          # Set working directory

          if cwd is None:

              cwd = self.test_dir

-         command.stdin.write('cd %s\n' % shell_quote(cwd))

+         command.stdin.write(b'cd %s\n' % shell_quote(encode(cwd)))


          # Set the environment

          if set_env:

-             command.stdin.write('. %s\n' % shell_quote(self.env_sh_path))

+             quoted = shell_quote(encode(self.env_sh_path))

+             command.stdin.write(b'. %s\n' % quoted)


          if self.command_prelude:

-             command.stdin.write(self.command_prelude)

+             command.stdin.write(encode(self.command_prelude))


+         if stdin_text:

+             command.stdin.write(b"echo -e ")

+             command.stdin.write(_echo_quote(encode(stdin_text)))

+             command.stdin.write(b" | ")


          if isinstance(argv, basestring):

              # Run a shell command given as a string

-             command.stdin.write('(')

-             command.stdin.write(argv)

-             command.stdin.write(')')

+             command.stdin.write(b'(')

+             command.stdin.write(encode(argv))

+             command.stdin.write(b')')


              # Run a command given as a popen-style list (no shell expansion)

              for arg in argv:

-                 command.stdin.write(shell_quote(arg))

-                 command.stdin.write(' ')

+                 command.stdin.write(shell_quote(encode(arg)))

+                 command.stdin.write(b' ')


-         command.stdin.write(';exit\n')

-         if stdin_text:

-             command.stdin.write(stdin_text)

+         command.stdin.write(b'\nexit\n')


+         command.raiseonerr = raiseonerr

          if not bg:

-             command.wait(raiseonerr=raiseonerr)

+             command.wait()

Hello! You need also change the wait method of a command class as for now it does not use the raiseonerr you've set on the line 96.

          return command



+ def _echo_quote(bytestring):

+     """Encode a bytestring for use with bash & "echo -e"

+     """

+     bytestring = bytestring.replace(b"\\", br"\\")

+     bytestring = bytestring.replace(b"\0", br"\x00")

+     bytestring = bytestring.replace(b"'", br"'\''")

+     return b"'" + bytestring + b"'"



  class Host(BaseHost):

      """A Unix host"""

-     command_prelude = 'set -e\n'

+     command_prelude = b'set -e\n'



  class WinHost(BaseHost):

file modified
+107 -31
@@ -30,6 +30,8 @@ 

      have_paramiko = False



+ DEFAULT = object()


  class Transport(object):

      """Mechanism for communicating with remote hosts

@@ -44,11 +46,19 @@ 

          self._command_index = 0


      def get_file_contents(self, filename, encoding=None):

-         """Read the named remote file and return the contents as a string"""

+         """Read the named remote file and return the contents


+         The string will be decoded using the given encoding;

+         if encoding is None (default), it will be returned as a bytestring.

+         """

          raise NotImplementedError('Transport.get_file_contents')


-     def put_file_contents(self, filename, contents):

-         """Write the given string to the named remote file"""

+     def put_file_contents(self, filename, contents, encoding='utf-8'):

+         """Write the given string (or bytestring) to the named remote file


+         The contents string will be encoded using the given encoding

+         (default: ``'utf-8'``), unless aleady a bytestring.

+         """

          raise NotImplementedError('Transport.put_file_contents')


      def file_exists(self, filename):
@@ -59,18 +69,22 @@ 

          """Make the named directory"""

          raise NotImplementedError('Transport.mkdir')


-     def start_shell(self, argv, log_stdout=True):

+     def start_shell(self, argv, log_stdout=True, encoding=None):

          """Start a Shell


          :param argv: The command this shell is intended to run (used for

                       logging only)

          :param log_stdout: If false, the stdout will not be logged (useful when

                             binary output is expected)

+         :param encoding: Encoding for the resulting Command's ``stdout_text``

+                          and ``stderr_text``.


          Given a `shell` from this method, the caller can then use

          ``shell.stdin.write()`` to input any command(s), call ``shell.wait()``

          to let the command run, and then inspect ``returncode``,

          ``stdout_text`` or ``stderr_text``.


+         Note that ``shell.stdin`` uses bytes I/O.


          raise NotImplementedError('Transport.start_shell')

@@ -84,7 +98,7 @@ 


      def get_file(self, remotepath, localpath):

          """Copy a file from the remote host to a local file"""

-         contents = self.get_file_contents(remotepath)

+         contents = self.get_file_contents(remotepath, encoding=None)

          with open(localpath, 'wb') as local_file:


@@ -92,7 +106,7 @@ 

          """Copy a local file to the remote host"""

          with open(localpath, 'rb') as local_file:

              contents = local_file.read()

-         self.put_file_contents(remotepath, contents)

+         self.put_file_contents(remotepath, contents, encoding=None)


      def get_next_command_logger_name(self):

          self._command_index += 1
@@ -111,6 +125,28 @@ 

          raise NotImplementedError('Transport.remove_file')



+ class _decoded_output_property(object):

+     """Descriptor for on-demand decoding of a Command's output stream

+     """

+     def __init__(self, name):

+         self.name = name


+     def __set_name__(self, cls, name):

+         # Sanity check (called only on Python 3.6+).

+         # This property expects to handle attributes named '<foo>_text'.

+         assert name == self.name + '_text'


+     def __get__(self, instance, cls=None):

+         if instance is None:

+             return self

+         else:

+             bytestring = getattr(instance, self.name + '_bytes')

+             print(bytestring, instance.encoding)

+             decoded = bytestring.decode(instance.encoding)

+             setattr(instance, self.name + '_text', decoded)

+             return decoded



  class Command(object):

      """A Popen-style object representing a remote command

@@ -122,12 +158,22 @@ 

      To make sure reading doesn't stall after one buffer fills up, they are read

      in parallel using threads.


-     After calling wait(), ``stdout_text`` and ``stderr_text`` attributes will

-     be strings containing the output, and ``returncode`` will contain the

+     After calling wait(), ``stdout_bytes`` and ``stderr_bytes`` attributes will

+     be bytestrings containing the output, and ``returncode`` will contain the

      exit code.


+     The ``stdout_text`` and ``stdout_text`` will be the corresponding output

+     decoded using the given ``encoding`` (default: ``'utf-8'``).

+     These are decoded on-demand; do not access them if a command

+     produces binary output.


+     A Command may be used as a context manager (in the ``with`` statement).

+     Exiting the context will automatically call ``wait()``.

+     This raises an exception if the exit code is not 0, unless the

+     ``raiseonerr`` attribute is set to false before exiting the context.


      def __init__(self, argv, logger_name=None, log_stdout=True,

-                  get_logger=None):

+                  get_logger=None, encoding='utf-8'):

          self.returncode = None

          self.argv = argv

          self._done = False
@@ -140,13 +186,24 @@ 

              get_logger = logging.getLogger

          self.get_logger = get_logger

          self.log = get_logger(self.logger_name)

+         self.encoding = encoding

+         self.raiseonerr = True


+     stdout_text = _decoded_output_property('stdout')

+     stderr_text = _decoded_output_property('stderr')


-     def wait(self, raiseonerr=True):

+     def wait(self, raiseonerr=DEFAULT):

          """Wait for the remote process to exit


-         Raises an excption if the exit code is not 0, unless raiseonerr is

+         Raises an exception if the exit code is not 0, unless ``raiseonerr`` is



+         When ``raiseonerr`` is not specified as argument, the ``raiseonerr``

+         attribute is used.


+         if raiseonerr is DEFAULT:

+             raiseonerr = self.raiseonerr


          if self._done:

              return self.returncode

@@ -168,6 +225,12 @@ 


          raise NotImplementedError()


+     def __enter__(self):

+         return self


+     def __exit__(self, *exc_info):

+         self.wait(raiseonerr=self.raiseonerr)



  class ParamikoTransport(Transport):

      """Transport that uses the Paramiko SSH2 library"""
@@ -220,16 +283,18 @@ 

      def get_file_contents(self, filename, encoding=None):

          """Read the named remote file and return the contents as a string"""

          self.log.debug('READ %s', filename)

-         with self.sftp_open(filename) as f:

+         with self.sftp_open(filename, 'rb') as f:

              result = f.read()

          if encoding:

              result = result.decode(encoding)

          return result


-     def put_file_contents(self, filename, contents):

+     def put_file_contents(self, filename, contents, encoding=None):

          """Write the given string to the named remote file"""

          self.log.info('WRITE %s', filename)

-         with self.sftp_open(filename, 'w') as f:

+         if encoding and not isinstance(contents, bytes):

+             contents = contents.encode(encoding)

+         with self.sftp_open(filename, 'wb') as f:



      def file_exists(self, filename):
@@ -248,13 +313,14 @@ 

          self.log.info('MKDIR %s', path)



-     def start_shell(self, argv, log_stdout=True):

+     def start_shell(self, argv, log_stdout=True, encoding='utf-8'):

          logger_name = self.get_next_command_logger_name()

          ssh = self._transport.open_channel('session')

          self.log.info('RUN %s', argv)

          return SSHCommand(ssh, argv, logger_name=logger_name,


-                           get_logger=self.host.config.get_logger)

+                           get_logger=self.host.config.get_logger,

+                           encoding=encoding)


      def get_file(self, remotepath, localpath):

          self.log.debug('GET %s', remotepath)
@@ -322,12 +388,14 @@ 


          return argv


-     def start_shell(self, argv, log_stdout=True):

+     def start_shell(self, argv, log_stdout=True, encoding='utf-8'):

          self.log.info('RUN %s', argv)

-         command = self._run(['bash'], argv=argv, log_stdout=log_stdout)

+         command = self._run(['bash'], argv=argv, log_stdout=log_stdout,

+                             encoding=encoding)

          return command


-     def _run(self, command, log_stdout=True, argv=None, collect_output=True):

+     def _run(self, command, log_stdout=True, argv=None, collect_output=True,

+              encoding='utf-8'):

          """Run the given command on the remote host


          :param command: Command to run (appended to the common SSH invocation)
@@ -341,7 +409,8 @@ 

          ssh = SSHCallWrapper(self.ssh_argv + list(command))

          return SSHCommand(ssh, argv, logger_name, log_stdout=log_stdout,


-                           get_logger=self.host.config.get_logger)

+                           get_logger=self.host.config.get_logger,

+                           encoding=encoding)


      def file_exists(self, path):

          self.log.info('STAT %s', path)
@@ -355,19 +424,21 @@ 

          cmd = self._run(['mkdir', path])



-     def put_file_contents(self, filename, contents):

+     def put_file_contents(self, filename, contents, encoding='utf-8'):

          self.log.info('PUT %s', filename)

+         if encoding and not isinstance(contents, bytes):

+             contents = contents.encode(encoding)

          cmd = self._run(['tee', filename], log_stdout=False)



-         assert cmd.stdout_text == contents

+         assert cmd.stdout_bytes == contents


      def get_file_contents(self, filename, encoding=None):

          self.log.info('GET %s', filename)

          cmd = self._run(['cat', filename], log_stdout=False)


          if cmd.returncode == 0:

-             result = cmd.stdout_text

+             result = cmd.stdout_bytes

              if encoding:

                  result = result.decode(encoding)

              return result
@@ -432,7 +503,8 @@ 

                   collect_output=True, encoding='utf-8', get_logger=None):

          super(SSHCommand, self).__init__(argv, logger_name,


-                                          get_logger=get_logger)

+                                          get_logger=get_logger,

+                                          encoding=encoding)

          self._stdout_lines = []

          self._stderr_lines = []

          self.running_threads = set()
@@ -443,14 +515,16 @@ 




+         self._use_bytes = (encoding is None)


          def wrap_file(file, encoding):

-             if encoding is None or sys.version_info < (3, 0):

+             if self._use_bytes:

                  return file


                  return io.TextIOWrapper(file, encoding=encoding)

-         self.stdin = wrap_file(self._ssh.makefile('wb'), 'utf-8')

-         stdout = wrap_file(self._ssh.makefile('rb'), encoding)

-         stderr = wrap_file(self._ssh.makefile_stderr('rb'), encoding)

+         self.stdin = self._ssh.makefile('wb')

+         stdout = self._ssh.makefile('rb')

+         stderr = self._ssh.makefile_stderr('rb')


          if collect_output:

              self._start_pipe_thread(self._stdout_lines, stdout, 'out',
@@ -463,8 +537,9 @@ 

          while self.running_threads:



-         self.stdout_text = ''.join(self._stdout_lines)

-         self.stderr_text = ''.join(self._stderr_lines)

+         self.stdout_bytes = b''.join(self._stdout_lines)

+         self.stderr_bytes = b''.join(self._stderr_lines)


          self.returncode = self._ssh.recv_exit_status()


@@ -480,7 +555,8 @@ 

          def read_stream():

              for line in stream:

                  if do_log:

-                     log.debug(line.rstrip('\n'))

+                     log.debug(line.rstrip(b'\n').decode('utf-8',

+                                                         errors='replace'))



          thread = threading.Thread(target=read_stream)

file modified
+3 -3
@@ -15,9 +15,9 @@ 

                           (name, ', '.join(dct)))



- def shell_quote(string):

-     """Quotes a string for the Bash shell"""

-     return "'" + string.replace("'", "'\\''") + "'"

+ def shell_quote(bytestring):

+     """Quotes a bytestring for the Bash shell"""

+     return b"'" + bytestring.replace(b"'", b"'\\''") + b"'"



  class TempDir(object):

@@ -6,6 +6,7 @@ 

  import pytest

  from subprocess import CalledProcessError

  import contextlib

+ import sys

  import os


  import pytest_multihost
@@ -121,6 +122,9 @@ 

  def _first_command(host):

      """If managed command fails, prints a message to help debugging"""


+         # Run dummy command first; this should catch spurious SSH messages.

+         host.run_command(['echo', 'hello', 'world'])

+         # Now, run the actual command


      except (AuthenticationException, CalledProcessError):

          print (
@@ -159,6 +163,38 @@ 

          with pytest.raises(IOError):



+     def test_get_put_file_contents_bytes(self, multihost, tmpdir):

+         host = multihost.host

+         filename = str(tmpdir.join('test-bytes.txt'))

+         testbytes = u'test \0 \N{WHITE SMILING FACE}'.encode('utf-8')

+         with _first_command(host):

+             host.put_file_contents(filename, testbytes, encoding=None)

+         result = host.get_file_contents(filename, encoding=None)

+         assert result == testbytes


+     @pytest.mark.parametrize('encoding', ('utf-8', 'utf-16'))

+     def test_put_file_contents_utf(self, multihost, tmpdir, encoding):

+         host = multihost.host

+         filename = str(tmpdir.join('test-{}.txt'.format(encoding)))

+         teststring = u'test \N{WHITE SMILING FACE}'

+         with _first_command(host):

+             host.put_file_contents(filename, teststring, encoding=encoding)

+         result = host.get_file_contents(filename, encoding=None)

+         assert result == teststring.encode(encoding)

+         with open(filename, 'rb') as f:

+             assert f.read() == teststring.encode(encoding)


+     @pytest.mark.parametrize('encoding', ('utf-8', 'utf-16'))

+     def test_get_file_contents_encoding(self, multihost, tmpdir, encoding):

+         host = multihost.host

+         filename = str(tmpdir.join('test-{}.txt'.format(encoding)))

+         teststring = u'test \N{WHITE SMILING FACE}'

+         with open(filename, 'wb') as f:

+             f.write(teststring.encode(encoding))

+         result = host.get_file_contents(filename, encoding=encoding)

+         assert result == teststring

+         assert type(result) == type(u'')


      def test_rename_file(self, multihost, tmpdir):

          host = multihost.host

          filename = str(tmpdir.join('test.txt'))
@@ -196,6 +232,98 @@ 


          assert not os.path.exists(filename)


+     def test_escaping(self, multihost, tmpdir):

+         host = multihost.host

+         test_file_path = str(tmpdir.join('testfile.txt'))


+         stdin_text = '"test", test, "test", $test, '

+         stdin_text += ''.join(chr(x) for x in range(32, 127))

+         stdin_text += r', \x66\0111\x00, '

+         stdin_text += ''.join('\\' + chr(x) for x in range(32, 127))

+         tee = host.run_command(

+             ["tee", test_file_path],

+             stdin_text=stdin_text,

+             raiseonerr=False,

+         )

+         print(tee.stderr_text)

+         assert tee.stdout_text == stdin_text + '\n'

+         with open(test_file_path, "r") as f:

+             assert f.read() == tee.stdout_text


+     def test_escaping_binary(self, multihost, tmpdir):

+         host = multihost.host

+         test_file_path = str(tmpdir.join('testfile.txt'))


+         stdin_bytes = b'"test", test, "test", $test, '

+         stdin_bytes += bytes(range(0, 256))

+         stdin_bytes += br', \x66\0111\x00'

+         tee = host.run_command(

+             ["tee", test_file_path],

+             stdin_text=stdin_bytes,

+             raiseonerr=False,

+         )

+         assert tee.stdout_bytes == stdin_bytes + b'\n'

+         with open(test_file_path, "rb") as f:

+             assert f.read() == tee.stdout_bytes


+     def test_background_explicit_wait(self, multihost, tmpdir):

+         host = multihost.host


+         pipe_filename = str(tmpdir.join('test.pipe'))


+         with _first_command(host):

+             host.run_command(['mkfifo', pipe_filename])


+         cat = host.run_command(['cat', pipe_filename], bg=True)

+         host.run_command('cat > ' + pipe_filename, stdin_text='expected value')


+         cat.wait()

+         assert cat.stdout_text == 'expected value\n'

+         assert cat.returncode == 0


+     def test_background_context(self, multihost, tmpdir):

+         host = multihost.host


+         pipe_filename = str(tmpdir.join('test.pipe'))


+         with _first_command(host):

+             host.run_command(['mkfifo', pipe_filename])


+         with host.run_command(['cat', pipe_filename], bg=True) as cat:

+             host.run_command('cat > ' + pipe_filename,

+                              stdin_text='expected value')


+         assert cat.stdout_text == 'expected value\n'

+         assert cat.returncode == 0



+     def test_background_raiseonerr_false(self, multihost, tmpdir):

+         host = multihost.host

+         with _first_command(host):

+             false = host.run_command(['false'], raiseonerr=False, bg=True)


+         assert false.returncode != 0



+     def test_background_raiseonerr_with(self, multihost, tmpdir):

+         host = multihost.host

+         with _first_command(host):

+             with pytest.raises(CalledProcessError):

+                 with host.run_command(['false'], raiseonerr=True, bg=True):

+                     pass


+     def test_background_raiseonerr_wait(self, multihost, tmpdir):

+         host = multihost.host

+         with _first_command(host):

+             false = host.run_command(['false'], raiseonerr=True, bg=True)


+             with pytest.raises(CalledProcessError):

+                 false.wait()




+ @pytest.mark.needs_ssh

+ class TestLocalhostBadConnection(object):

      def test_reset(self, multihost):

          host = multihost.host

          with _first_command(host):
@@ -224,11 +352,3 @@ 

          host = multihost_badpassword.host

          with pytest.raises((AuthenticationException, RuntimeError)):

              echo = host.run_command(['echo', 'hello', 'world'])


-     def test_background(self, multihost):

- 	host = multihost.host

- 	run_nc = 'nc -l 12080 > /tmp/filename.out'

- 	cmd = host.run_command(run_nc, bg=True, raiseonerr=False)

- 	send_file = 'nc localhost 12080 < /root/anaconda-ks.cfg'

- 	cmd = host.run_command(send_file)

- 	assert cmd.returncode == 0

file added
@@ -0,0 +1,14 @@ 

+ # Tox (http://tox.testrun.org/) is a tool for running tests

+ # in multiple virtualenvs. This configuration file will run the

+ # test suite on all supported python versions. To use it, "pip install tox"

+ # and then run "tox" from this directory.


+ [tox]

+ envlist = py2,py36

+ minver = 1.8


+ [testenv]

+ deps =

+     pytest

+     paramiko

+ commands = python -m pytest -vv test_pytestmultihost/

I've spent a day on python-pytest-multihost, fixing several issues that piled up in the meanwhile. They're all subtly connected.

  • Set a policy regarding bytes, text, strings and encodings when dealing with file contents -- see the change in README.rst. Update Host and Transport to match.

  • Currently, when stdin was given to a command, no EOF was sent. We assumed the program would only read as much stdout as it needed. This is an issue, especially with background tasks. To fix this, stdin is now piped to the command through echo.
    Also, stdin_text passed to commands is now binary-safe, if given as a bytestring.

  • Command (i.e. what host.run_command returns) is now a context manager, which is useful for background tasks:

    with host.run_command(['cat', pipe_filename], bg=True) as cat:
        host.run_command('cat > ' + pipe_filename, stdin_text='foo')
  • Clarify that wait() should be called for all background tasks (either explicitly, or by with with).

  • Remove the buggy test test_background, instead add tests for the above.

  • Add a Tox configuration file.

Fixes: https://pagure.io/python-pytest-multihost/issue/6
Fixes: https://pagure.io/python-pytest-multihost/issue/7
Fixes: https://pagure.io/python-pytest-multihost/pull-request/9

Hello! You need also change the wait method of a command class as for now it does not use the raiseonerr you've set on the line 96.

1 new commit added

  • Honor instance attribute if raiseonerr is not passed to Command.wait()
6 years ago

Thanks for noticing; that should be fixed now.

reviewed out-of-band by aslaikov

Pull-Request has been closed by pviktori

6 years ago