utils.py 2.97 KB
Newer Older
1 2 3 4 5 6 7
from twisted.web.http_headers import Headers


def header_value(headers, header_name):
    value = headers[header_name]
    if isinstance(value, list):
        value = value[0]
8 9 10 11
    # decode to urf-8 if value is bytes
    if isinstance(value, bytes):
        value = value.decode("utf-8")
    return value
12 13 14 15 16


def parse_x_forwarded_for(headers,
                          address_header_name='X-Forwarded-For',
                          port_header_name='X-Forwarded-Port',
17 18 19
                          proto_header_name='X-Forwarded-Proto',
                          original_addr=None,
                          original_scheme=None):
20 21 22 23 24 25
    """
    Parses an X-Forwarded-For header and returns a host/port pair as a list.

    @param headers: The twisted-style object containing a request's headers
    @param address_header_name: The name of the expected host header
    @param port_header_name: The name of the expected port header
26 27 28 29
    @param proto_header_name: The name of the expected protocol header
    @param original_addr: A host/port pair that should be returned if the headers are not in the request
    @param original_scheme: A scheme that should be returned if the headers are not in the request
    @return: A tuple containing a list [host (string), port (int)] as the first entry and a proto (string) as the second
30 31
    """
    if not address_header_name:
32
        return (original_addr, original_scheme)
33 34

    if isinstance(headers, Headers):
35
        # Convert twisted-style headers into a dict
36
        headers = dict(headers.getAllRawHeaders())
37 38 39 40 41
        # Lowercase all header keys
        headers = {name.lower(): values for name, values in headers.items()}
    else:
        # Lowercase (and encode to utf-8 where needed) non-twisted header keys
        headers = {name.lower() if isinstance(name, bytes) else name.lower().encode("utf-8"): values for name, values in headers.items()}
42 43

    address_header_name = address_header_name.lower().encode("utf-8")
44 45
    result_addr = original_addr
    result_scheme = original_scheme
46 47 48 49 50 51
    if address_header_name in headers:
        address_value = header_value(headers, address_header_name)

        if ',' in address_value:
            address_value = address_value.split(",")[0].strip()

52
        result_addr = [address_value, 0]
53 54 55 56 57 58 59 60

        if port_header_name:
            # We only want to parse the X-Forwarded-Port header if we also parsed the X-Forwarded-For
            # header to avoid inconsistent results.
            port_header_name = port_header_name.lower().encode("utf-8")
            if port_header_name in headers:
                port_value = header_value(headers, port_header_name)
                try:
61
                    result_addr[1] = int(port_value)
62 63 64
                except ValueError:
                    pass

65 66 67 68 69 70
        if proto_header_name:
            proto_header_name = proto_header_name.lower().encode("utf-8")
            if proto_header_name in headers:
                result_scheme = header_value(headers, proto_header_name)

    return result_addr, result_scheme