| |
@@ -30,6 +30,8 @@
|
| |
have_paramiko = False
|
| |
|
| |
|
| |
+ DEFAULT = object()
|
| |
+
|
| |
class Transport(object):
|
| |
"""Mechanism for communicating with remote hosts
|
| |
|
| |
@@ -44,11 +46,19 @@
|
| |
self._command_index = 0
|
| |
|
| |
def get_file_contents(self, filename, encoding=None):
|
| |
- """Read the named remote file and return the contents as a string"""
|
| |
+ """Read the named remote file and return the contents
|
| |
+
|
| |
+ The string will be decoded using the given encoding;
|
| |
+ if encoding is None (default), it will be returned as a bytestring.
|
| |
+ """
|
| |
raise NotImplementedError('Transport.get_file_contents')
|
| |
|
| |
- def put_file_contents(self, filename, contents):
|
| |
- """Write the given string to the named remote file"""
|
| |
+ def put_file_contents(self, filename, contents, encoding='utf-8'):
|
| |
+ """Write the given string (or bytestring) to the named remote file
|
| |
+
|
| |
+ The contents string will be encoded using the given encoding
|
| |
+ (default: ``'utf-8'``), unless aleady a bytestring.
|
| |
+ """
|
| |
raise NotImplementedError('Transport.put_file_contents')
|
| |
|
| |
def file_exists(self, filename):
|
| |
@@ -59,18 +69,22 @@
|
| |
"""Make the named directory"""
|
| |
raise NotImplementedError('Transport.mkdir')
|
| |
|
| |
- def start_shell(self, argv, log_stdout=True):
|
| |
+ def start_shell(self, argv, log_stdout=True, encoding=None):
|
| |
"""Start a Shell
|
| |
|
| |
:param argv: The command this shell is intended to run (used for
|
| |
logging only)
|
| |
:param log_stdout: If false, the stdout will not be logged (useful when
|
| |
binary output is expected)
|
| |
+ :param encoding: Encoding for the resulting Command's ``stdout_text``
|
| |
+ and ``stderr_text``.
|
| |
|
| |
Given a `shell` from this method, the caller can then use
|
| |
``shell.stdin.write()`` to input any command(s), call ``shell.wait()``
|
| |
to let the command run, and then inspect ``returncode``,
|
| |
``stdout_text`` or ``stderr_text``.
|
| |
+
|
| |
+ Note that ``shell.stdin`` uses bytes I/O.
|
| |
"""
|
| |
raise NotImplementedError('Transport.start_shell')
|
| |
|
| |
@@ -84,7 +98,7 @@
|
| |
|
| |
def get_file(self, remotepath, localpath):
|
| |
"""Copy a file from the remote host to a local file"""
|
| |
- contents = self.get_file_contents(remotepath)
|
| |
+ contents = self.get_file_contents(remotepath, encoding=None)
|
| |
with open(localpath, 'wb') as local_file:
|
| |
local_file.write(contents)
|
| |
|
| |
@@ -92,7 +106,7 @@
|
| |
"""Copy a local file to the remote host"""
|
| |
with open(localpath, 'rb') as local_file:
|
| |
contents = local_file.read()
|
| |
- self.put_file_contents(remotepath, contents)
|
| |
+ self.put_file_contents(remotepath, contents, encoding=None)
|
| |
|
| |
def get_next_command_logger_name(self):
|
| |
self._command_index += 1
|
| |
@@ -111,6 +125,28 @@
|
| |
raise NotImplementedError('Transport.remove_file')
|
| |
|
| |
|
| |
+ class _decoded_output_property(object):
|
| |
+ """Descriptor for on-demand decoding of a Command's output stream
|
| |
+ """
|
| |
+ def __init__(self, name):
|
| |
+ self.name = name
|
| |
+
|
| |
+ def __set_name__(self, cls, name):
|
| |
+ # Sanity check (called only on Python 3.6+).
|
| |
+ # This property expects to handle attributes named '<foo>_text'.
|
| |
+ assert name == self.name + '_text'
|
| |
+
|
| |
+ def __get__(self, instance, cls=None):
|
| |
+ if instance is None:
|
| |
+ return self
|
| |
+ else:
|
| |
+ bytestring = getattr(instance, self.name + '_bytes')
|
| |
+ print(bytestring, instance.encoding)
|
| |
+ decoded = bytestring.decode(instance.encoding)
|
| |
+ setattr(instance, self.name + '_text', decoded)
|
| |
+ return decoded
|
| |
+
|
| |
+
|
| |
class Command(object):
|
| |
"""A Popen-style object representing a remote command
|
| |
|
| |
@@ -122,12 +158,22 @@
|
| |
To make sure reading doesn't stall after one buffer fills up, they are read
|
| |
in parallel using threads.
|
| |
|
| |
- After calling wait(), ``stdout_text`` and ``stderr_text`` attributes will
|
| |
- be strings containing the output, and ``returncode`` will contain the
|
| |
+ After calling wait(), ``stdout_bytes`` and ``stderr_bytes`` attributes will
|
| |
+ be bytestrings containing the output, and ``returncode`` will contain the
|
| |
exit code.
|
| |
+
|
| |
+ The ``stdout_text`` and ``stdout_text`` will be the corresponding output
|
| |
+ decoded using the given ``encoding`` (default: ``'utf-8'``).
|
| |
+ These are decoded on-demand; do not access them if a command
|
| |
+ produces binary output.
|
| |
+
|
| |
+ A Command may be used as a context manager (in the ``with`` statement).
|
| |
+ Exiting the context will automatically call ``wait()``.
|
| |
+ This raises an exception if the exit code is not 0, unless the
|
| |
+ ``raiseonerr`` attribute is set to false before exiting the context.
|
| |
"""
|
| |
def __init__(self, argv, logger_name=None, log_stdout=True,
|
| |
- get_logger=None):
|
| |
+ get_logger=None, encoding='utf-8'):
|
| |
self.returncode = None
|
| |
self.argv = argv
|
| |
self._done = False
|
| |
@@ -140,13 +186,24 @@
|
| |
get_logger = logging.getLogger
|
| |
self.get_logger = get_logger
|
| |
self.log = get_logger(self.logger_name)
|
| |
+ self.encoding = encoding
|
| |
+ self.raiseonerr = True
|
| |
+
|
| |
+ stdout_text = _decoded_output_property('stdout')
|
| |
+ stderr_text = _decoded_output_property('stderr')
|
| |
|
| |
- def wait(self, raiseonerr=True):
|
| |
+ def wait(self, raiseonerr=DEFAULT):
|
| |
"""Wait for the remote process to exit
|
| |
|
| |
- Raises an excption if the exit code is not 0, unless raiseonerr is
|
| |
+ Raises an exception if the exit code is not 0, unless ``raiseonerr`` is
|
| |
true.
|
| |
+
|
| |
+ When ``raiseonerr`` is not specified as argument, the ``raiseonerr``
|
| |
+ attribute is used.
|
| |
"""
|
| |
+ if raiseonerr is DEFAULT:
|
| |
+ raiseonerr = self.raiseonerr
|
| |
+
|
| |
if self._done:
|
| |
return self.returncode
|
| |
|
| |
@@ -168,6 +225,12 @@
|
| |
"""
|
| |
raise NotImplementedError()
|
| |
|
| |
+ def __enter__(self):
|
| |
+ return self
|
| |
+
|
| |
+ def __exit__(self, *exc_info):
|
| |
+ self.wait(raiseonerr=self.raiseonerr)
|
| |
+
|
| |
|
| |
class ParamikoTransport(Transport):
|
| |
"""Transport that uses the Paramiko SSH2 library"""
|
| |
@@ -220,16 +283,18 @@
|
| |
def get_file_contents(self, filename, encoding=None):
|
| |
"""Read the named remote file and return the contents as a string"""
|
| |
self.log.debug('READ %s', filename)
|
| |
- with self.sftp_open(filename) as f:
|
| |
+ with self.sftp_open(filename, 'rb') as f:
|
| |
result = f.read()
|
| |
if encoding:
|
| |
result = result.decode(encoding)
|
| |
return result
|
| |
|
| |
- def put_file_contents(self, filename, contents):
|
| |
+ def put_file_contents(self, filename, contents, encoding=None):
|
| |
"""Write the given string to the named remote file"""
|
| |
self.log.info('WRITE %s', filename)
|
| |
- with self.sftp_open(filename, 'w') as f:
|
| |
+ if encoding and not isinstance(contents, bytes):
|
| |
+ contents = contents.encode(encoding)
|
| |
+ with self.sftp_open(filename, 'wb') as f:
|
| |
f.write(contents)
|
| |
|
| |
def file_exists(self, filename):
|
| |
@@ -248,13 +313,14 @@
|
| |
self.log.info('MKDIR %s', path)
|
| |
self.sftp.mkdir(path)
|
| |
|
| |
- def start_shell(self, argv, log_stdout=True):
|
| |
+ def start_shell(self, argv, log_stdout=True, encoding='utf-8'):
|
| |
logger_name = self.get_next_command_logger_name()
|
| |
ssh = self._transport.open_channel('session')
|
| |
self.log.info('RUN %s', argv)
|
| |
return SSHCommand(ssh, argv, logger_name=logger_name,
|
| |
log_stdout=log_stdout,
|
| |
- get_logger=self.host.config.get_logger)
|
| |
+ get_logger=self.host.config.get_logger,
|
| |
+ encoding=encoding)
|
| |
|
| |
def get_file(self, remotepath, localpath):
|
| |
self.log.debug('GET %s', remotepath)
|
| |
@@ -322,12 +388,14 @@
|
| |
|
| |
return argv
|
| |
|
| |
- def start_shell(self, argv, log_stdout=True):
|
| |
+ def start_shell(self, argv, log_stdout=True, encoding='utf-8'):
|
| |
self.log.info('RUN %s', argv)
|
| |
- command = self._run(['bash'], argv=argv, log_stdout=log_stdout)
|
| |
+ command = self._run(['bash'], argv=argv, log_stdout=log_stdout,
|
| |
+ encoding=encoding)
|
| |
return command
|
| |
|
| |
- def _run(self, command, log_stdout=True, argv=None, collect_output=True):
|
| |
+ def _run(self, command, log_stdout=True, argv=None, collect_output=True,
|
| |
+ encoding='utf-8'):
|
| |
"""Run the given command on the remote host
|
| |
|
| |
:param command: Command to run (appended to the common SSH invocation)
|
| |
@@ -341,7 +409,8 @@
|
| |
ssh = SSHCallWrapper(self.ssh_argv + list(command))
|
| |
return SSHCommand(ssh, argv, logger_name, log_stdout=log_stdout,
|
| |
collect_output=collect_output,
|
| |
- get_logger=self.host.config.get_logger)
|
| |
+ get_logger=self.host.config.get_logger,
|
| |
+ encoding=encoding)
|
| |
|
| |
def file_exists(self, path):
|
| |
self.log.info('STAT %s', path)
|
| |
@@ -355,19 +424,21 @@
|
| |
cmd = self._run(['mkdir', path])
|
| |
cmd.wait()
|
| |
|
| |
- def put_file_contents(self, filename, contents):
|
| |
+ def put_file_contents(self, filename, contents, encoding='utf-8'):
|
| |
self.log.info('PUT %s', filename)
|
| |
+ if encoding and not isinstance(contents, bytes):
|
| |
+ contents = contents.encode(encoding)
|
| |
cmd = self._run(['tee', filename], log_stdout=False)
|
| |
cmd.stdin.write(contents)
|
| |
cmd.wait()
|
| |
- assert cmd.stdout_text == contents
|
| |
+ assert cmd.stdout_bytes == contents
|
| |
|
| |
def get_file_contents(self, filename, encoding=None):
|
| |
self.log.info('GET %s', filename)
|
| |
cmd = self._run(['cat', filename], log_stdout=False)
|
| |
cmd.wait(raiseonerr=False)
|
| |
if cmd.returncode == 0:
|
| |
- result = cmd.stdout_text
|
| |
+ result = cmd.stdout_bytes
|
| |
if encoding:
|
| |
result = result.decode(encoding)
|
| |
return result
|
| |
@@ -432,7 +503,8 @@
|
| |
collect_output=True, encoding='utf-8', get_logger=None):
|
| |
super(SSHCommand, self).__init__(argv, logger_name,
|
| |
log_stdout=log_stdout,
|
| |
- get_logger=get_logger)
|
| |
+ get_logger=get_logger,
|
| |
+ encoding=encoding)
|
| |
self._stdout_lines = []
|
| |
self._stderr_lines = []
|
| |
self.running_threads = set()
|
| |
@@ -443,14 +515,16 @@
|
| |
|
| |
self._ssh.invoke_shell()
|
| |
|
| |
+ self._use_bytes = (encoding is None)
|
| |
+
|
| |
def wrap_file(file, encoding):
|
| |
- if encoding is None or sys.version_info < (3, 0):
|
| |
+ if self._use_bytes:
|
| |
return file
|
| |
else:
|
| |
return io.TextIOWrapper(file, encoding=encoding)
|
| |
- self.stdin = wrap_file(self._ssh.makefile('wb'), 'utf-8')
|
| |
- stdout = wrap_file(self._ssh.makefile('rb'), encoding)
|
| |
- stderr = wrap_file(self._ssh.makefile_stderr('rb'), encoding)
|
| |
+ self.stdin = self._ssh.makefile('wb')
|
| |
+ stdout = self._ssh.makefile('rb')
|
| |
+ stderr = self._ssh.makefile_stderr('rb')
|
| |
|
| |
if collect_output:
|
| |
self._start_pipe_thread(self._stdout_lines, stdout, 'out',
|
| |
@@ -463,8 +537,9 @@
|
| |
while self.running_threads:
|
| |
self.running_threads.pop().join()
|
| |
|
| |
- self.stdout_text = ''.join(self._stdout_lines)
|
| |
- self.stderr_text = ''.join(self._stderr_lines)
|
| |
+ self.stdout_bytes = b''.join(self._stdout_lines)
|
| |
+ self.stderr_bytes = b''.join(self._stderr_lines)
|
| |
+
|
| |
self.returncode = self._ssh.recv_exit_status()
|
| |
self._ssh.close()
|
| |
|
| |
@@ -480,7 +555,8 @@
|
| |
def read_stream():
|
| |
for line in stream:
|
| |
if do_log:
|
| |
- log.debug(line.rstrip('\n'))
|
| |
+ log.debug(line.rstrip(b'\n').decode('utf-8',
|
| |
+ errors='replace'))
|
| |
result_list.append(line)
|
| |
|
| |
thread = threading.Thread(target=read_stream)
|
| |
Hello! You need also change the wait method of a command class as for now it does not use the raiseonerr you've set on the line 96.