Refactor helper functions into a utils package
This commit is contained in:
parent
ece37c4659
commit
620323eab0
|
@ -22,17 +22,13 @@ from __future__ import annotations
|
|||
|
||||
import collections
|
||||
import hashlib
|
||||
import html.entities
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import shelve
|
||||
import ssl
|
||||
import tempfile
|
||||
import time
|
||||
import xml.dom
|
||||
from urllib.parse import quote_plus
|
||||
from xml.dom import Node, minidom
|
||||
from xml.dom import minidom
|
||||
|
||||
import httpx
|
||||
|
||||
|
@ -43,12 +39,29 @@ except ImportError:
|
|||
# Python 3.7 and lower
|
||||
import importlib_metadata # type: ignore
|
||||
|
||||
from .utils import (
|
||||
_collect_nodes,
|
||||
_number,
|
||||
_parse_response,
|
||||
_string_output,
|
||||
_unescape_htmlentity,
|
||||
_unicode,
|
||||
_url_safe,
|
||||
cleanup_nodes,
|
||||
md5,
|
||||
)
|
||||
|
||||
__author__ = "Amr Hassan, hugovk, Mice Pápai"
|
||||
__copyright__ = "Copyright (C) 2008-2010 Amr Hassan, 2013-2021 hugovk, 2017 Mice Pápai"
|
||||
__license__ = "apache2"
|
||||
__email__ = "amr.hassan@gmail.com"
|
||||
__version__ = importlib_metadata.version(__name__)
|
||||
|
||||
__all__ = [
|
||||
# Utils
|
||||
cleanup_nodes,
|
||||
md5,
|
||||
]
|
||||
|
||||
# 1 : This error does not exist
|
||||
STATUS_INVALID_SERVICE = 2
|
||||
|
@ -938,7 +951,7 @@ class _Request:
|
|||
client.close()
|
||||
return response_text
|
||||
|
||||
def execute(self, cacheable: bool = False) -> xml.dom.minidom.Document:
|
||||
def execute(self, cacheable: bool = False) -> minidom.Document:
|
||||
"""Returns the XML DOM response of the POST Request from the server"""
|
||||
|
||||
if self.network.is_caching_enabled() and cacheable:
|
||||
|
@ -1089,13 +1102,6 @@ Image = collections.namedtuple(
|
|||
)
|
||||
|
||||
|
||||
def _string_output(func):
|
||||
def r(*args):
|
||||
return str(func(*args))
|
||||
|
||||
return r
|
||||
|
||||
|
||||
class _BaseObject:
|
||||
"""An abstract webservices object."""
|
||||
|
||||
|
@ -2720,89 +2726,6 @@ class TrackSearch(_Search):
|
|||
return seq
|
||||
|
||||
|
||||
def md5(text):
|
||||
"""Returns the md5 hash of a string."""
|
||||
|
||||
h = hashlib.md5()
|
||||
h.update(_unicode(text).encode("utf-8"))
|
||||
|
||||
return h.hexdigest()
|
||||
|
||||
|
||||
def _unicode(text):
|
||||
if isinstance(text, bytes):
|
||||
return str(text, "utf-8")
|
||||
else:
|
||||
return str(text)
|
||||
|
||||
|
||||
def cleanup_nodes(doc):
|
||||
"""
|
||||
Remove text nodes containing only whitespace
|
||||
"""
|
||||
for node in doc.documentElement.childNodes:
|
||||
if node.nodeType == Node.TEXT_NODE and node.nodeValue.isspace():
|
||||
doc.documentElement.removeChild(node)
|
||||
return doc
|
||||
|
||||
|
||||
def _collect_nodes(
|
||||
limit, sender, method_name, cacheable, params=None, stream: bool = False
|
||||
):
|
||||
"""
|
||||
Returns a sequence of dom.Node objects about as close to limit as possible
|
||||
"""
|
||||
if not params:
|
||||
params = sender._get_params()
|
||||
|
||||
def _stream_collect_nodes():
|
||||
node_count = 0
|
||||
page = 1
|
||||
end_of_pages = False
|
||||
|
||||
while not end_of_pages and (not limit or (limit and node_count < limit)):
|
||||
params["page"] = str(page)
|
||||
|
||||
tries = 1
|
||||
while True:
|
||||
try:
|
||||
doc = sender._request(method_name, cacheable, params)
|
||||
break # success
|
||||
except Exception as e:
|
||||
if tries >= 3:
|
||||
raise PyLastError() from e
|
||||
# Wait and try again
|
||||
time.sleep(1)
|
||||
tries += 1
|
||||
|
||||
doc = cleanup_nodes(doc)
|
||||
|
||||
# break if there are no child nodes
|
||||
if not doc.documentElement.childNodes:
|
||||
break
|
||||
main = doc.documentElement.childNodes[0]
|
||||
|
||||
if main.hasAttribute("totalPages") or main.hasAttribute("totalpages"):
|
||||
total_pages = _number(
|
||||
main.getAttribute("totalPages") or main.getAttribute("totalpages")
|
||||
)
|
||||
else:
|
||||
raise PyLastError("No total pages attribute")
|
||||
|
||||
for node in main.childNodes:
|
||||
if not node.nodeType == xml.dom.Node.TEXT_NODE and (
|
||||
not limit or (node_count < limit)
|
||||
):
|
||||
node_count += 1
|
||||
yield node
|
||||
|
||||
end_of_pages = page >= total_pages
|
||||
|
||||
page += 1
|
||||
|
||||
return _stream_collect_nodes() if stream else list(_stream_collect_nodes())
|
||||
|
||||
|
||||
def _extract(node, name, index: int = 0):
|
||||
"""Extracts a value from the xml string"""
|
||||
|
||||
|
@ -2878,51 +2801,3 @@ def _extract_tracks(doc, network):
|
|||
artist = _extract(node, "name", 1)
|
||||
seq.append(Track(artist, name, network))
|
||||
return seq
|
||||
|
||||
|
||||
def _url_safe(text):
|
||||
"""Does all kinds of tricks on a text to make it safe to use in a URL."""
|
||||
|
||||
return quote_plus(quote_plus(str(text))).lower()
|
||||
|
||||
|
||||
def _number(string):
|
||||
"""
|
||||
Extracts an int from a string.
|
||||
Returns a 0 if None or an empty string was passed.
|
||||
"""
|
||||
|
||||
if not string:
|
||||
return 0
|
||||
else:
|
||||
try:
|
||||
return int(string)
|
||||
except ValueError:
|
||||
return float(string)
|
||||
|
||||
|
||||
def _unescape_htmlentity(string):
|
||||
mapping = html.entities.name2codepoint
|
||||
for key in mapping:
|
||||
string = string.replace(f"&{key};", chr(mapping[key]))
|
||||
|
||||
return string
|
||||
|
||||
|
||||
def _parse_response(response: str) -> xml.dom.minidom.Document:
|
||||
response = str(response).replace("opensearch:", "")
|
||||
try:
|
||||
doc = minidom.parseString(response)
|
||||
except xml.parsers.expat.ExpatError:
|
||||
# Try again. For performance, we only remove when needed in rare cases.
|
||||
doc = minidom.parseString(_remove_invalid_xml_chars(response))
|
||||
return doc
|
||||
|
||||
|
||||
def _remove_invalid_xml_chars(string: str) -> str:
|
||||
return re.sub(
|
||||
r"[^\u0009\u000A\u000D\u0020-\uD7FF\uE000-\uFFFD\u10000-\u10FFF]+", "", string
|
||||
)
|
||||
|
||||
|
||||
# End of file
|
||||
|
|
159
src/pylast/utils.py
Normal file
159
src/pylast/utils.py
Normal file
|
@ -0,0 +1,159 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import html
|
||||
import re
|
||||
import time
|
||||
import warnings
|
||||
import xml
|
||||
from urllib.parse import quote_plus
|
||||
from xml.dom import Node, minidom
|
||||
|
||||
import pylast
|
||||
|
||||
|
||||
def cleanup_nodes(doc: minidom.Document) -> minidom.Document:
|
||||
"""
|
||||
cleanup_nodes is deprecated and will be removed in pylast 6.0
|
||||
"""
|
||||
warnings.warn(
|
||||
"cleanup_nodes is deprecated and will be removed in pylast 6.0",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return _cleanup_nodes(doc)
|
||||
|
||||
|
||||
def md5(text: str) -> str:
|
||||
"""Returns the md5 hash of a string."""
|
||||
|
||||
h = hashlib.md5()
|
||||
h.update(_unicode(text).encode("utf-8"))
|
||||
|
||||
return h.hexdigest()
|
||||
|
||||
|
||||
def _collect_nodes(
|
||||
limit, sender, method_name, cacheable, params=None, stream: bool = False
|
||||
):
|
||||
"""
|
||||
Returns a sequence of dom.Node objects about as close to limit as possible
|
||||
"""
|
||||
if not params:
|
||||
params = sender._get_params()
|
||||
|
||||
def _stream_collect_nodes():
|
||||
node_count = 0
|
||||
page = 1
|
||||
end_of_pages = False
|
||||
|
||||
while not end_of_pages and (not limit or (limit and node_count < limit)):
|
||||
params["page"] = str(page)
|
||||
|
||||
tries = 1
|
||||
while True:
|
||||
try:
|
||||
doc = sender._request(method_name, cacheable, params)
|
||||
break # success
|
||||
except Exception as e:
|
||||
if tries >= 3:
|
||||
raise pylast.PyLastError() from e
|
||||
# Wait and try again
|
||||
time.sleep(1)
|
||||
tries += 1
|
||||
|
||||
doc = _cleanup_nodes(doc)
|
||||
|
||||
# break if there are no child nodes
|
||||
if not doc.documentElement.childNodes:
|
||||
break
|
||||
main = doc.documentElement.childNodes[0]
|
||||
|
||||
if main.hasAttribute("totalPages") or main.hasAttribute("totalpages"):
|
||||
total_pages = _number(
|
||||
main.getAttribute("totalPages") or main.getAttribute("totalpages")
|
||||
)
|
||||
else:
|
||||
raise pylast.PyLastError("No total pages attribute")
|
||||
|
||||
for node in main.childNodes:
|
||||
if not node.nodeType == xml.dom.Node.TEXT_NODE and (
|
||||
not limit or (node_count < limit)
|
||||
):
|
||||
node_count += 1
|
||||
yield node
|
||||
|
||||
end_of_pages = page >= total_pages
|
||||
|
||||
page += 1
|
||||
|
||||
return _stream_collect_nodes() if stream else list(_stream_collect_nodes())
|
||||
|
||||
|
||||
def _cleanup_nodes(doc: minidom.Document) -> minidom.Document:
|
||||
"""
|
||||
Remove text nodes containing only whitespace
|
||||
"""
|
||||
for node in doc.documentElement.childNodes:
|
||||
if node.nodeType == Node.TEXT_NODE and node.nodeValue.isspace():
|
||||
doc.documentElement.removeChild(node)
|
||||
return doc
|
||||
|
||||
|
||||
def _number(string: str | None) -> float:
|
||||
"""
|
||||
Extracts an int from a string.
|
||||
Returns a 0 if None or an empty string was passed.
|
||||
"""
|
||||
|
||||
if not string:
|
||||
return 0
|
||||
else:
|
||||
try:
|
||||
return int(string)
|
||||
except ValueError:
|
||||
return float(string)
|
||||
|
||||
|
||||
def _parse_response(response: str) -> xml.dom.minidom.Document:
|
||||
response = str(response).replace("opensearch:", "")
|
||||
try:
|
||||
doc = minidom.parseString(response)
|
||||
except xml.parsers.expat.ExpatError:
|
||||
# Try again. For performance, we only remove when needed in rare cases.
|
||||
doc = minidom.parseString(_remove_invalid_xml_chars(response))
|
||||
return doc
|
||||
|
||||
|
||||
def _remove_invalid_xml_chars(string: str) -> str:
|
||||
return re.sub(
|
||||
r"[^\u0009\u000A\u000D\u0020-\uD7FF\uE000-\uFFFD\u10000-\u10FFF]+", "", string
|
||||
)
|
||||
|
||||
|
||||
def _string_output(func):
|
||||
def r(*args):
|
||||
return str(func(*args))
|
||||
|
||||
return r
|
||||
|
||||
|
||||
def _unescape_htmlentity(string: str) -> str:
|
||||
mapping = html.entities.name2codepoint
|
||||
for key in mapping:
|
||||
string = string.replace(f"&{key};", chr(mapping[key]))
|
||||
|
||||
return string
|
||||
|
||||
|
||||
def _unicode(text: bytes | str) -> str:
|
||||
if isinstance(text, bytes):
|
||||
return str(text, "utf-8")
|
||||
else:
|
||||
return str(text)
|
||||
|
||||
|
||||
def _url_safe(text: str) -> str:
|
||||
"""Does all kinds of tricks on a text to make it safe to use in a URL."""
|
||||
|
||||
return quote_plus(quote_plus(str(text))).lower()
|
|
@ -45,7 +45,7 @@ def test_cast_and_hash(obj) -> None:
|
|||
],
|
||||
)
|
||||
def test__remove_invalid_xml_chars(test_input: str, expected: str) -> None:
|
||||
assert pylast._remove_invalid_xml_chars(test_input) == expected
|
||||
assert pylast.utils._remove_invalid_xml_chars(test_input) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
|
Loading…
Reference in a new issue