[networking] Add keep_header_casing extension (#11652)

Authored by: coletdjnz, Grub4K

Co-authored-by: coletdjnz <coletdjnz@protonmail.com>
This commit is contained in:
Simon Sawicki
2025-03-03 00:10:01 +01:00
committed by GitHub
parent 79ec2fdff7
commit 7d18fed8f1
9 changed files with 230 additions and 41 deletions

View File

@@ -296,6 +296,7 @@ class RequestsRH(RequestHandler, InstanceStoreMixin):
extensions.pop('cookiejar', None)
extensions.pop('timeout', None)
extensions.pop('legacy_ssl', None)
extensions.pop('keep_header_casing', None)
def _create_instance(self, cookiejar, legacy_ssl_support=None):
session = RequestsSession()
@@ -312,11 +313,12 @@ class RequestsRH(RequestHandler, InstanceStoreMixin):
session.trust_env = False # no need, we already load proxies from env
return session
def _send(self, request):
headers = self._merge_headers(request.headers)
def _prepare_headers(self, _, headers):
add_accept_encoding_header(headers, SUPPORTED_ENCODINGS)
def _send(self, request):
headers = self._get_headers(request)
max_redirects_exceeded = False
session = self._get_instance(

View File

@@ -379,13 +379,15 @@ class UrllibRH(RequestHandler, InstanceStoreMixin):
opener.addheaders = []
return opener
def _send(self, request):
headers = self._merge_headers(request.headers)
def _prepare_headers(self, _, headers):
add_accept_encoding_header(headers, SUPPORTED_ENCODINGS)
def _send(self, request):
headers = self._get_headers(request)
urllib_req = urllib.request.Request(
url=request.url,
data=request.data,
headers=dict(headers),
headers=headers,
method=request.method,
)

View File

@@ -116,6 +116,7 @@ class WebsocketsRH(WebSocketRequestHandler):
extensions.pop('timeout', None)
extensions.pop('cookiejar', None)
extensions.pop('legacy_ssl', None)
extensions.pop('keep_header_casing', None)
def close(self):
# Remove the logging handler that contains a reference to our logger
@@ -123,15 +124,16 @@ class WebsocketsRH(WebSocketRequestHandler):
for name, handler in self.__logging_handlers.items():
logging.getLogger(name).removeHandler(handler)
def _send(self, request):
timeout = self._calculate_timeout(request)
headers = self._merge_headers(request.headers)
def _prepare_headers(self, request, headers):
if 'cookie' not in headers:
cookiejar = self._get_cookiejar(request)
cookie_header = cookiejar.get_cookie_header(request.url)
if cookie_header:
headers['cookie'] = cookie_header
def _send(self, request):
timeout = self._calculate_timeout(request)
headers = self._get_headers(request)
wsuri = parse_uri(request.url)
create_conn_kwargs = {
'source_address': (self.source_address, 0) if self.source_address else None,

View File

@@ -206,6 +206,7 @@ class RequestHandler(abc.ABC):
- `cookiejar`: Cookiejar to use for this request.
- `timeout`: socket timeout to use for this request.
- `legacy_ssl`: Enable legacy SSL options for this request. See legacy_ssl_support.
- `keep_header_casing`: Keep the casing of headers when sending the request.
To enable these, add extensions.pop('<extension>', None) to _check_extensions
Apart from the url protocol, proxies dict may contain the following keys:
@@ -259,6 +260,23 @@ class RequestHandler(abc.ABC):
def _merge_headers(self, request_headers):
return HTTPHeaderDict(self.headers, request_headers)
def _prepare_headers(self, request: Request, headers: HTTPHeaderDict) -> None: # noqa: B027
"""Additional operations to prepare headers before building. To be extended by subclasses.
@param request: Request object
@param headers: Merged headers to prepare
"""
def _get_headers(self, request: Request) -> dict[str, str]:
"""
Get headers for external use.
Subclasses may define a _prepare_headers method to modify headers after merge but before building.
"""
headers = self._merge_headers(request.headers)
self._prepare_headers(request, headers)
if request.extensions.get('keep_header_casing'):
return headers.sensitive()
return dict(headers)
def _calculate_timeout(self, request):
return float(request.extensions.get('timeout') or self.timeout)
@@ -317,6 +335,7 @@ class RequestHandler(abc.ABC):
assert isinstance(extensions.get('cookiejar'), (YoutubeDLCookieJar, NoneType))
assert isinstance(extensions.get('timeout'), (float, int, NoneType))
assert isinstance(extensions.get('legacy_ssl'), (bool, NoneType))
assert isinstance(extensions.get('keep_header_casing'), (bool, NoneType))
def _validate(self, request):
self._check_url_scheme(request)

View File

@@ -5,11 +5,11 @@ from abc import ABC
from dataclasses import dataclass
from typing import Any
from .common import RequestHandler, register_preference
from .common import RequestHandler, register_preference, Request
from .exceptions import UnsupportedRequest
from ..compat.types import NoneType
from ..utils import classproperty, join_nonempty
from ..utils.networking import std_headers
from ..utils.networking import std_headers, HTTPHeaderDict
@dataclass(order=True, frozen=True)
@@ -123,7 +123,17 @@ class ImpersonateRequestHandler(RequestHandler, ABC):
"""Get the requested target for the request"""
return self._resolve_target(request.extensions.get('impersonate') or self.impersonate)
def _get_impersonate_headers(self, request):
def _prepare_impersonate_headers(self, request: Request, headers: HTTPHeaderDict) -> None: # noqa: B027
"""Additional operations to prepare headers before building. To be extended by subclasses.
@param request: Request object
@param headers: Merged headers to prepare
"""
def _get_impersonate_headers(self, request: Request) -> dict[str, str]:
"""
Get headers for external impersonation use.
Subclasses may define a _prepare_impersonate_headers method to modify headers after merge but before building.
"""
headers = self._merge_headers(request.headers)
if self._get_request_target(request) is not None:
# remove all headers present in std_headers
@@ -131,7 +141,11 @@ class ImpersonateRequestHandler(RequestHandler, ABC):
for k, v in std_headers.items():
if headers.get(k) == v:
headers.pop(k)
return headers
self._prepare_impersonate_headers(request, headers)
if request.extensions.get('keep_header_casing'):
return headers.sensitive()
return dict(headers)
@register_preference(ImpersonateRequestHandler)