From d20f0a74a30f01ce4f98b421c6549af72f7889bc Mon Sep 17 00:00:00 2001 From: mprahl Date: Aug 18 2020 17:49:48 +0000 Subject: Add a timeout to xmlrpc calls to prevent the socket waiting forever --- diff --git a/greenwave/consumers/resultsdb.py b/greenwave/consumers/resultsdb.py index daefb98..cab5ff7 100644 --- a/greenwave/consumers/resultsdb.py +++ b/greenwave/consumers/resultsdb.py @@ -17,8 +17,7 @@ from greenwave.subjects.factory import ( create_subject_from_data, UnknownSubjectDataError, ) - -import xmlrpc.client +from greenwave.xmlrpc_server_proxy import get_server_proxy log = logging.getLogger(__name__) @@ -63,7 +62,7 @@ class ResultsDBHandler(Consumer): koji_base_url = self.flask_app.config['KOJI_BASE_URL'] if koji_base_url: - self.koji_proxy = xmlrpc.client.ServerProxy(koji_base_url) + self.koji_proxy = get_server_proxy(koji_base_url) else: self.koji_proxy = None diff --git a/greenwave/product_versions.py b/greenwave/product_versions.py index e65f268..f9b1107 100644 --- a/greenwave/product_versions.py +++ b/greenwave/product_versions.py @@ -5,7 +5,7 @@ Product version guessing for subject identifiers import logging import re - +import socket import xmlrpc.client log = logging.getLogger(__name__) @@ -51,6 +51,8 @@ def _guess_koji_build_product_version( target = koji_proxy.getTaskRequest(koji_task_id)[1] return _guess_product_version(target, koji_build=True) + except socket.error as err: + raise ConnectionError('Could not reach Koji: {}'.format(err)) except xmlrpc.client.Fault: pass diff --git a/greenwave/resources.py b/greenwave/resources.py index 72b786e..60f7bde 100644 --- a/greenwave/resources.py +++ b/greenwave/resources.py @@ -17,6 +17,7 @@ from werkzeug.exceptions import BadGateway, NotFound from greenwave.cache import cached from greenwave.request_session import get_requests_session +from greenwave.xmlrpc_server_proxy import get_server_proxy log = logging.getLogger(__name__) @@ -117,7 +118,7 @@ def retrieve_scm_from_koji(nvr): """ Retrieve cached rev and namespace from koji using the nvr """ koji_url = current_app.config['KOJI_BASE_URL'] try: - proxy = xmlrpc.client.ServerProxy(koji_url) + proxy = get_server_proxy(koji_url) build = proxy.getBuild(nvr) except (xmlrpc.client.ProtocolError, socket.error) as err: raise ConnectionError('Could not reach Koji: {}'.format(err)) diff --git a/greenwave/tests/test_policies.py b/greenwave/tests/test_policies.py index 12c9413..06d4384 100644 --- a/greenwave/tests/test_policies.py +++ b/greenwave/tests/test_policies.py @@ -1295,7 +1295,7 @@ def test_on_demand_policy_match(two_rules): app = create_app('greenwave.config.TestingConfig') with app.app_context(): - with mock.patch('xmlrpc.client.ServerProxy') as koji_server: + with mock.patch('greenwave.resources.get_server_proxy') as koji_server: koji_server_instance = mock.MagicMock() koji_server_instance.getBuild.return_value = {'extra': {'source': None}} koji_server.return_value = koji_server_instance diff --git a/greenwave/tests/test_product_versions.py b/greenwave/tests/test_product_versions.py new file mode 100644 index 0000000..fcf0a69 --- /dev/null +++ b/greenwave/tests/test_product_versions.py @@ -0,0 +1,20 @@ +# SPDX-License-Identifier: GPL-2.0+ + +import socket + +import mock +import pytest + +from greenwave import product_versions + + +@pytest.mark.parametrize('task_id', (None, 3)) +def test_guess_koji_build_product_version_socket_error(task_id): + subject_identifier = 'release-e2e-test-1.0.1685-1.el5' + mock_proxy = mock.Mock() + mock_proxy.getBuild.side_effect = mock_proxy.getTaskRequest.side_effect = ( + socket.timeout('timed out') + ) + expected = 'Could not reach Koji: timed out' + with pytest.raises(ConnectionError, match=expected): + product_versions._guess_koji_build_product_version(subject_identifier, mock_proxy, task_id) diff --git a/greenwave/tests/test_retrieve_gating_yaml.py b/greenwave/tests/test_retrieve_gating_yaml.py index c18528c..9d2141e 100644 --- a/greenwave/tests/test_retrieve_gating_yaml.py +++ b/greenwave/tests/test_retrieve_gating_yaml.py @@ -104,7 +104,7 @@ def test_retrieve_scm_from_koji_build_not_found(): expected_error = '404 Not Found: Failed to find Koji build for "{}" at "{}"'.format( nvr, app.config['KOJI_BASE_URL'] ) - with mock.patch('xmlrpc.client.ServerProxy') as koji_server: + with mock.patch('greenwave.resources.get_server_proxy') as koji_server: proxy = mock.MagicMock() proxy.getBuild.return_value = {} koji_server.return_value = proxy @@ -170,7 +170,7 @@ def test_retrieve_yaml_remote_rule_connection_error(): ) -@mock.patch('greenwave.resources.xmlrpc.client.ServerProxy') +@mock.patch('greenwave.resources.get_server_proxy') def test_retrieve_scm_from_koji_build_socket_error(mock_xmlrpc_client): mock_auth_server = mock_xmlrpc_client.return_value mock_auth_server.getBuild.side_effect = socket.error('Socket is closed') diff --git a/greenwave/tests/test_xmlrpc_server_proxy.py b/greenwave/tests/test_xmlrpc_server_proxy.py new file mode 100644 index 0000000..bf59cb6 --- /dev/null +++ b/greenwave/tests/test_xmlrpc_server_proxy.py @@ -0,0 +1,37 @@ +# SPDX-License-Identifier: GPL-2.0+ + +import mock +import pytest + +from greenwave import xmlrpc_server_proxy + + +@pytest.mark.parametrize( + 'url, expected_transport, timeout, expected_timeout', + ( + ('http://localhost:5000/api', xmlrpc_server_proxy.Transport, 15, 15), + ('https://localhost:5000/api', xmlrpc_server_proxy.SafeTransport, 15, 15), + ('https://localhost:5000/api', xmlrpc_server_proxy.SafeTransport, (3, 12), 12), + ), +) +@mock.patch('greenwave.xmlrpc_server_proxy.Transport') +@mock.patch('greenwave.xmlrpc_server_proxy.SafeTransport') +def test_get_server_proxy_app_context( + mock_safe_transport, + mock_transport, + url, + expected_transport, + timeout, + expected_timeout, + app, +): + with app.app_context(): + app.config['REQUESTS_TIMEOUT'] = timeout + xmlrpc_server_proxy.get_server_proxy(url) + + if expected_transport == xmlrpc_server_proxy.Transport: + mock_transport.__init__.assert_called_once_with(url, expected_timeout) + mock_safe_transport.__init__.assert_not_called() + elif expected_transport == xmlrpc_server_proxy.SafeTransport: + mock_safe_transport.__init__.assert_called_once_with(url, expected_timeout) + mock_transport.__init__.assert_not_called() diff --git a/greenwave/xmlrpc_server_proxy.py b/greenwave/xmlrpc_server_proxy.py new file mode 100644 index 0000000..543e4cc --- /dev/null +++ b/greenwave/xmlrpc_server_proxy.py @@ -0,0 +1,61 @@ +# -*- coding: utf-8 -*- +# SPDX-License-Identifier: GPL-2.0+ +""" +Provides an "xmlrpc.client.ServerProxy" object with a timeout on the socket. +""" +import urllib.parse +import xmlrpc.client + +from flask import current_app, has_app_context + + +def get_server_proxy(uri, timeout=None): + """ + Create an :py:class:`xmlrpc.client.ServerProxy` instance with a socket timeout. + + This is a workaround for https://bugs.python.org/issue14134. + + Args: + uri (str): The connection point on the server in the format of scheme://host/target. + timeout (int): The timeout to set on the transport socket. This defaults to the Flask + configuration `REQUESTS_TIMEOUT` if there is an application context. + + Returns: + xmlrpc.client.ServerProxy: An instance of :py:class:`xmlrpc.client.ServerProxy` with + a socket timeout set. + """ + if timeout is None and has_app_context(): + if isinstance(current_app.config['REQUESTS_TIMEOUT'], tuple): + timeout = current_app.config['REQUESTS_TIMEOUT'][1] + else: + timeout = current_app.config['REQUESTS_TIMEOUT'] + + parsed_uri = urllib.parse.urlparse(uri) + if parsed_uri.scheme == 'https': + transport = SafeTransport(timeout=timeout) + else: + transport = Transport(timeout=timeout) + + return xmlrpc.client.ServerProxy(uri, transport=transport) + + +class Transport(xmlrpc.client.Transport): + def __init__(self, *args, timeout=None, **kwargs): # pragma: no cover + super().__init__(*args, **kwargs) + self._timeout = timeout + + def make_connection(self, *args, **kwargs): # pragma: no cover + connection = super().make_connection(*args, **kwargs) + connection.timeout = self._timeout + return connection + + +class SafeTransport(xmlrpc.client.SafeTransport): + def __init__(self, *args, timeout=None, **kwargs): # pragma: no cover + super().__init__(*args, **kwargs) + self._timeout = timeout + + def make_connection(self, *args, **kwargs): # pragma: no cover + connection = super().make_connection(*args, **kwargs) + connection.timeout = self._timeout + return connection