Skip to content
Snippets Groups Projects
Commit 0c1c7e00 authored by Marc Abramowitz's avatar Marc Abramowitz
Browse files

Merge pull request #14 from esben/consolidate

Consolidate outstanding forks
parents 17d845e4 4a029131
No related branches found
No related tags found
No related merge requests found
...@@ -3,8 +3,6 @@ import sys ...@@ -3,8 +3,6 @@ import sys
from .adapters import UnixAdapter from .adapters import UnixAdapter
__all__ = ['monkeypatch', 'Session']
DEFAULT_SCHEME = 'http+unix://' DEFAULT_SCHEME = 'http+unix://'
...@@ -18,10 +16,17 @@ class monkeypatch(object): ...@@ -18,10 +16,17 @@ class monkeypatch(object):
def __init__(self, url_scheme=DEFAULT_SCHEME): def __init__(self, url_scheme=DEFAULT_SCHEME):
self.session = Session() self.session = Session()
requests = self._get_global_requests_module() requests = self._get_global_requests_module()
self.orig_requests_get = requests.get
requests.get = self.session.get # Methods to replace
self.orig_requests_request = requests.request self.methods = ('request', 'get', 'head', 'post',
requests.request = self.session.request 'patch', 'put', 'delete', 'options')
# Store the original methods
self.orig_methods = dict(
(m, requests.__dict__[m]) for m in self.methods)
# Monkey patch
g = globals()
for m in self.methods:
requests.__dict__[m] = g[m]
def _get_global_requests_module(self): def _get_global_requests_module(self):
return sys.modules['requests'] return sys.modules['requests']
...@@ -31,5 +36,42 @@ class monkeypatch(object): ...@@ -31,5 +36,42 @@ class monkeypatch(object):
def __exit__(self, *args): def __exit__(self, *args):
requests = self._get_global_requests_module() requests = self._get_global_requests_module()
requests.get = self.orig_requests_get for m in self.methods:
requests.request = self.orig_requests_request requests.__dict__[m] = self.orig_methods[m]
# These are the same methods defined for the global requests object
def request(method, url, **kwargs):
session = Session()
return session.request(method=method, url=url, **kwargs)
def get(url, **kwargs):
kwargs.setdefault('allow_redirects', True)
return request('get', url, **kwargs)
def head(url, **kwargs):
kwargs.setdefault('allow_redirects', False)
return request('head', url, **kwargs)
def post(url, data=None, json=None, **kwargs):
return request('post', url, data=data, json=json, **kwargs)
def patch(url, data=None, **kwargs):
return request('patch', url, data=data, **kwargs)
def put(url, data=None, **kwargs):
return request('put', url, data=data, **kwargs)
def delete(url, **kwargs):
return request('delete', url, **kwargs)
def options(url, **kwargs):
kwargs.setdefault('allow_redirects', True)
return request('options', url, **kwargs)
...@@ -13,6 +13,7 @@ except ImportError: ...@@ -13,6 +13,7 @@ except ImportError:
# The following was adapted from some code from docker-py # The following was adapted from some code from docker-py
# https://github.com/docker/docker-py/blob/master/docker/unixconn/unixconn.py # https://github.com/docker/docker-py/blob/master/docker/unixconn/unixconn.py
class UnixHTTPConnection(HTTPConnection): class UnixHTTPConnection(HTTPConnection):
def __init__(self, unix_socket_url, timeout=60): def __init__(self, unix_socket_url, timeout=60):
"""Create an HTTP connection to a unix domain socket """Create an HTTP connection to a unix domain socket
...@@ -33,6 +34,7 @@ class UnixHTTPConnection(HTTPConnection): ...@@ -33,6 +34,7 @@ class UnixHTTPConnection(HTTPConnection):
class UnixHTTPConnectionPool(HTTPConnectionPool): class UnixHTTPConnectionPool(HTTPConnectionPool):
def __init__(self, socket_path, timeout=60): def __init__(self, socket_path, timeout=60):
HTTPConnectionPool.__init__(self, 'localhost', timeout=timeout) HTTPConnectionPool.__init__(self, 'localhost', timeout=timeout)
self.socket_path = socket_path self.socket_path = socket_path
...@@ -43,12 +45,16 @@ class UnixHTTPConnectionPool(HTTPConnectionPool): ...@@ -43,12 +45,16 @@ class UnixHTTPConnectionPool(HTTPConnectionPool):
class UnixAdapter(HTTPAdapter): class UnixAdapter(HTTPAdapter):
def __init__(self, timeout=60): def __init__(self, timeout=60):
super(UnixAdapter, self).__init__() super(UnixAdapter, self).__init__()
self.timeout = timeout self.timeout = timeout
def get_connection(self, socket_path, proxies=None): def get_connection(self, socket_path, proxies=None):
if proxies: proxies = proxies or {}
proxy = proxies.get(urlparse(socket_path.lower()).scheme)
if proxy:
raise ValueError('%s does not support specifying proxies' raise ValueError('%s does not support specifying proxies'
% self.__class__.__name__) % self.__class__.__name__)
return UnixHTTPConnectionPool(socket_path, self.timeout) return UnixHTTPConnectionPool(socket_path, self.timeout)
...@@ -20,19 +20,25 @@ def test_unix_domain_adapter_ok(): ...@@ -20,19 +20,25 @@ def test_unix_domain_adapter_ok():
session = requests_unixsocket.Session('http+unix://') session = requests_unixsocket.Session('http+unix://')
urlencoded_usock = requests.compat.quote_plus(usock_thread.usock) urlencoded_usock = requests.compat.quote_plus(usock_thread.usock)
url = 'http+unix://%s/path/to/page' % urlencoded_usock url = 'http+unix://%s/path/to/page' % urlencoded_usock
logger.debug('Calling session.get(%r) ...', url)
r = session.get(url) for method in ['get', 'post', 'head', 'patch', 'put', 'delete',
logger.debug( 'options']:
'Received response: %r with text: %r and headers: %r', logger.debug('Calling session.%s(%r) ...', method, url)
r, r.text, r.headers) r = getattr(session, method)(url)
assert r.status_code == 200 logger.debug(
assert r.headers['server'] == 'waitress' 'Received response: %r with text: %r and headers: %r',
assert r.headers['X-Transport'] == 'unix domain socket' r, r.text, r.headers)
assert r.headers['X-Requested-Path'] == '/path/to/page' assert r.status_code == 200
assert r.headers['X-Socket-Path'] == usock_thread.usock assert r.headers['server'] == 'waitress'
assert isinstance(r.connection, requests_unixsocket.UnixAdapter) assert r.headers['X-Transport'] == 'unix domain socket'
assert r.url == url assert r.headers['X-Requested-Path'] == '/path/to/page'
assert r.text == 'Hello world!' assert r.headers['X-Socket-Path'] == usock_thread.usock
assert isinstance(r.connection, requests_unixsocket.UnixAdapter)
assert r.url == url
if method == 'head':
assert r.text == ''
else:
assert r.text == 'Hello world!'
def test_unix_domain_adapter_url_with_query_params(): def test_unix_domain_adapter_url_with_query_params():
...@@ -41,37 +47,47 @@ def test_unix_domain_adapter_url_with_query_params(): ...@@ -41,37 +47,47 @@ def test_unix_domain_adapter_url_with_query_params():
urlencoded_usock = requests.compat.quote_plus(usock_thread.usock) urlencoded_usock = requests.compat.quote_plus(usock_thread.usock)
url = ('http+unix://%s' url = ('http+unix://%s'
'/containers/nginx/logs?timestamp=true' % urlencoded_usock) '/containers/nginx/logs?timestamp=true' % urlencoded_usock)
logger.debug('Calling session.get(%r) ...', url)
r = session.get(url) for method in ['get', 'post', 'head', 'patch', 'put', 'delete',
logger.debug( 'options']:
'Received response: %r with text: %r and headers: %r', logger.debug('Calling session.%s(%r) ...', method, url)
r, r.text, r.headers) r = getattr(session, method)(url)
assert r.status_code == 200 logger.debug(
assert r.headers['server'] == 'waitress' 'Received response: %r with text: %r and headers: %r',
assert r.headers['X-Transport'] == 'unix domain socket' r, r.text, r.headers)
assert r.headers['X-Requested-Path'] == '/containers/nginx/logs' assert r.status_code == 200
assert r.headers['X-Requested-Query-String'] == 'timestamp=true' assert r.headers['server'] == 'waitress'
assert r.headers['X-Socket-Path'] == usock_thread.usock assert r.headers['X-Transport'] == 'unix domain socket'
assert isinstance(r.connection, requests_unixsocket.UnixAdapter) assert r.headers['X-Requested-Path'] == '/containers/nginx/logs'
assert r.url == url assert r.headers['X-Requested-Query-String'] == 'timestamp=true'
assert r.text == 'Hello world!' assert r.headers['X-Socket-Path'] == usock_thread.usock
assert isinstance(r.connection, requests_unixsocket.UnixAdapter)
assert r.url == url
if method == 'head':
assert r.text == ''
else:
assert r.text == 'Hello world!'
def test_unix_domain_adapter_connection_error(): def test_unix_domain_adapter_connection_error():
session = requests_unixsocket.Session('http+unix://') session = requests_unixsocket.Session('http+unix://')
with pytest.raises(requests.ConnectionError): for method in ['get', 'post', 'head', 'patch', 'put', 'delete', 'options']:
session.get('http+unix://socket_does_not_exist/path/to/page') with pytest.raises(requests.ConnectionError):
getattr(session, method)(
'http+unix://socket_does_not_exist/path/to/page')
def test_unix_domain_adapter_connection_proxies_error(): def test_unix_domain_adapter_connection_proxies_error():
session = requests_unixsocket.Session('http+unix://') session = requests_unixsocket.Session('http+unix://')
with pytest.raises(ValueError) as excinfo: for method in ['get', 'post', 'head', 'patch', 'put', 'delete', 'options']:
session.get('http+unix://socket_does_not_exist/path/to/page', with pytest.raises(ValueError) as excinfo:
proxies={"http": "http://10.10.1.10:1080"}) getattr(session, method)(
assert ('UnixAdapter does not support specifying proxies' 'http+unix://socket_does_not_exist/path/to/page',
in str(excinfo.value)) proxies={"http+unix": "http://10.10.1.10:1080"})
assert ('UnixAdapter does not support specifying proxies'
in str(excinfo.value))
def test_unix_domain_adapter_monkeypatch(): def test_unix_domain_adapter_monkeypatch():
...@@ -79,19 +95,27 @@ def test_unix_domain_adapter_monkeypatch(): ...@@ -79,19 +95,27 @@ def test_unix_domain_adapter_monkeypatch():
with requests_unixsocket.monkeypatch('http+unix://'): with requests_unixsocket.monkeypatch('http+unix://'):
urlencoded_usock = requests.compat.quote_plus(usock_thread.usock) urlencoded_usock = requests.compat.quote_plus(usock_thread.usock)
url = 'http+unix://%s/path/to/page' % urlencoded_usock url = 'http+unix://%s/path/to/page' % urlencoded_usock
logger.debug('Calling requests.get(%r) ...', url)
r = requests.get(url)
logger.debug(
'Received response: %r with text: %r and headers: %r',
r, r.text, r.headers)
assert r.status_code == 200
assert r.headers['server'] == 'waitress'
assert r.headers['X-Transport'] == 'unix domain socket'
assert r.headers['X-Requested-Path'] == '/path/to/page'
assert r.headers['X-Socket-Path'] == usock_thread.usock
assert isinstance(r.connection, requests_unixsocket.UnixAdapter)
assert r.url == url
assert r.text == 'Hello world!'
for method in ['get', 'post', 'head', 'patch', 'put', 'delete',
'options']:
logger.debug('Calling session.%s(%r) ...', method, url)
r = getattr(requests, method)(url)
logger.debug(
'Received response: %r with text: %r and headers: %r',
r, r.text, r.headers)
assert r.status_code == 200
assert r.headers['server'] == 'waitress'
assert r.headers['X-Transport'] == 'unix domain socket'
assert r.headers['X-Requested-Path'] == '/path/to/page'
assert r.headers['X-Socket-Path'] == usock_thread.usock
assert isinstance(r.connection,
requests_unixsocket.UnixAdapter)
assert r.url == url
if method == 'head':
assert r.text == ''
else:
assert r.text == 'Hello world!'
for method in ['get', 'post', 'head', 'patch', 'put', 'delete', 'options']:
with pytest.raises(requests.exceptions.InvalidSchema): with pytest.raises(requests.exceptions.InvalidSchema):
requests.get(url) getattr(requests, method)(url)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment