Compare commits

...

8 commits

Author SHA1 Message Date
Hugo van Kemenade ec2772fc16 Merge branch 'feature/unsafe_tempfile' of https://github.com/kvanzuijlen/pylast into kvanzuijlen-feature/unsafe_tempfile 2021-03-14 17:48:09 +02:00
Koen van Zuijlen ea1f2b42f8 Merge branch 'master' into feature/unsafe_tempfile 2021-01-12 10:19:29 +01:00
Koen van Zuijlen a41f2e0f36 Merge remote-tracking branch 'pylast/master' 2021-01-12 10:18:59 +01:00
Koen van Zuijlen 36b2eeb297 Code improvement 2020-12-30 17:12:32 +01:00
Koen van Zuijlen e9bef6db68 Bugfix for caching between sessions 2020-12-30 17:11:38 +01:00
pre-commit-ci[bot] eca1db8622 [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2020-12-30 14:59:18 +00:00
Koen van Zuijlen 2d2e73c1bc Fixed unsafe tempfile and fixed some basic problems 2020-12-30 15:56:35 +01:00
Koen van Zuijlen 123a00c5e3 Merge remote-tracking branch 'pylast/master' 2020-12-30 14:10:16 +01:00

View file

@ -23,6 +23,7 @@ import collections
import hashlib import hashlib
import html.entities import html.entities
import logging import logging
import os
import shelve import shelve
import ssl import ssl
import tempfile import tempfile
@ -421,7 +422,8 @@ class _Network:
""" """
if not file_path: if not file_path:
file_path = tempfile.mktemp(prefix="pylast_tmp_") self.cache_backend = _ShelfCacheBackend.create_shelf()
return
self.cache_backend = _ShelfCacheBackend(file_path) self.cache_backend = _ShelfCacheBackend(file_path)
@ -782,8 +784,11 @@ class LibreFMNetwork(_Network):
class _ShelfCacheBackend: class _ShelfCacheBackend:
"""Used as a backend for caching cacheable requests.""" """Used as a backend for caching cacheable requests."""
def __init__(self, file_path=None): def __init__(self, file_path=None, flag=None):
self.shelf = shelve.open(file_path) if flag is not None:
self.shelf = shelve.open(file_path, flag=flag)
else:
self.shelf = shelve.open(file_path)
self.cache_keys = set(self.shelf.keys()) self.cache_keys = set(self.shelf.keys())
def __contains__(self, key): def __contains__(self, key):
@ -799,6 +804,12 @@ class _ShelfCacheBackend:
self.cache_keys.add(key) self.cache_keys.add(key)
self.shelf[key] = xml_string self.shelf[key] = xml_string
@classmethod
def create_shelf(cls):
file_descriptor, file_path = tempfile.mkstemp(prefix="pylast_tmp_")
os.close(file_descriptor)
return cls(file_path=file_path, flag="n")
class _Request: class _Request:
"""Representing an abstract web service operation.""" """Representing an abstract web service operation."""
@ -916,7 +927,7 @@ class _Request:
headers=headers, headers=headers,
) )
except Exception as e: except Exception as e:
raise NetworkError(self.network, e) raise NetworkError(self.network, e) from e
else: else:
conn = HTTPSConnection(context=SSL_CONTEXT, host=host_name) conn = HTTPSConnection(context=SSL_CONTEXT, host=host_name)
@ -924,7 +935,7 @@ class _Request:
try: try:
conn.request(method="POST", url=host_subdir, body=data, headers=headers) conn.request(method="POST", url=host_subdir, body=data, headers=headers)
except Exception as e: except Exception as e:
raise NetworkError(self.network, e) raise NetworkError(self.network, e) from e
try: try:
response = conn.getresponse() response = conn.getresponse()
@ -937,7 +948,7 @@ class _Request:
) )
response_text = _unicode(response.read()) response_text = _unicode(response.read())
except Exception as e: except Exception as e:
raise MalformedResponseError(self.network, e) raise MalformedResponseError(self.network, e) from e
try: try:
self._check_response_for_errors(response_text) self._check_response_for_errors(response_text)
@ -961,7 +972,7 @@ class _Request:
try: try:
doc = minidom.parseString(_string(response).replace("opensearch:", "")) doc = minidom.parseString(_string(response).replace("opensearch:", ""))
except Exception as e: except Exception as e:
raise MalformedResponseError(self.network, e) raise MalformedResponseError(self.network, e) from e
e = doc.getElementsByTagName("lfm")[0] e = doc.getElementsByTagName("lfm")[0]
# logger.debug(doc.toprettyxml()) # logger.debug(doc.toprettyxml())
@ -1042,9 +1053,6 @@ class SessionKeyGenerator:
if url in self.web_auth_tokens.keys(): if url in self.web_auth_tokens.keys():
token = self.web_auth_tokens[url] token = self.web_auth_tokens[url]
else:
# This will raise a WSError if token is blank or unauthorized
token = token
request = _Request(self.network, "auth.getSession", {"token": token}) request = _Request(self.network, "auth.getSession", {"token": token})
@ -1397,7 +1405,13 @@ class _Taggable(_BaseObject):
return seq return seq
class WSError(Exception): class PyLastError(Exception):
"""Generic exception raised by PyLast"""
pass
class WSError(PyLastError):
"""Exception related to the Network web service""" """Exception related to the Network web service"""
def __init__(self, network, status, details): def __init__(self, network, status, details):
@ -1441,7 +1455,7 @@ class WSError(Exception):
return self.status return self.status
class MalformedResponseError(Exception): class MalformedResponseError(PyLastError):
"""Exception conveying a malformed response from the music network.""" """Exception conveying a malformed response from the music network."""
def __init__(self, network, underlying_error): def __init__(self, network, underlying_error):
@ -1454,7 +1468,7 @@ class MalformedResponseError(Exception):
) )
class NetworkError(Exception): class NetworkError(PyLastError):
"""Exception conveying a problem in sending a request to Last.fm""" """Exception conveying a problem in sending a request to Last.fm"""
def __init__(self, network, underlying_error): def __init__(self, network, underlying_error):
@ -2778,7 +2792,7 @@ def _collect_nodes(limit, sender, method_name, cacheable, params=None, stream=Fa
break # success break # success
except Exception as e: except Exception as e:
if tries >= 3: if tries >= 3:
raise e raise PyLastError() from e
# Wait and try again # Wait and try again
time.sleep(1) time.sleep(1)
tries += 1 tries += 1
@ -2795,7 +2809,7 @@ def _collect_nodes(limit, sender, method_name, cacheable, params=None, stream=Fa
main.getAttribute("totalPages") or main.getAttribute("totalpages") main.getAttribute("totalPages") or main.getAttribute("totalpages")
) )
else: else:
raise Exception("No total pages attribute") raise PyLastError("No total pages attribute")
for node in main.childNodes: for node in main.childNodes:
if not node.nodeType == xml.dom.Node.TEXT_NODE and ( if not node.nodeType == xml.dom.Node.TEXT_NODE and (
@ -2910,9 +2924,6 @@ def _number(string):
def _unescape_htmlentity(string): def _unescape_htmlentity(string):
# string = _unicode(string)
mapping = html.entities.name2codepoint mapping = html.entities.name2codepoint
for key in mapping: for key in mapping:
string = string.replace("&%s;" % key, chr(mapping[key])) string = string.replace("&%s;" % key, chr(mapping[key]))