diff --git a/src/pylast/__init__.py b/src/pylast/__init__.py index 02db07c..1341de0 100644 --- a/src/pylast/__init__.py +++ b/src/pylast/__init__.py @@ -23,6 +23,7 @@ import collections import hashlib import html.entities import logging +import os import shelve import ssl import tempfile @@ -421,7 +422,8 @@ class _Network: """ if not file_path: - file_path = tempfile.mktemp(prefix="pylast_tmp_") + self.cache_backend = _ShelfCacheBackend.create_shelf() + return self.cache_backend = _ShelfCacheBackend(file_path) @@ -782,8 +784,11 @@ class LibreFMNetwork(_Network): class _ShelfCacheBackend: """Used as a backend for caching cacheable requests.""" - def __init__(self, file_path=None): - self.shelf = shelve.open(file_path) + def __init__(self, file_path=None, flag=None): + if flag is not None: + self.shelf = shelve.open(file_path, flag=flag) + else: + self.shelf = shelve.open(file_path) self.cache_keys = set(self.shelf.keys()) def __contains__(self, key): @@ -799,6 +804,12 @@ class _ShelfCacheBackend: self.cache_keys.add(key) self.shelf[key] = xml_string + @classmethod + def create_shelf(cls): + file_descriptor, file_path = tempfile.mkstemp(prefix="pylast_tmp_") + os.close(file_descriptor) + return cls(file_path=file_path, flag="n") + class _Request: """Representing an abstract web service operation.""" @@ -916,7 +927,7 @@ class _Request: headers=headers, ) except Exception as e: - raise NetworkError(self.network, e) + raise NetworkError(self.network, e) from e else: conn = HTTPSConnection(context=SSL_CONTEXT, host=host_name) @@ -924,7 +935,7 @@ class _Request: try: conn.request(method="POST", url=host_subdir, body=data, headers=headers) except Exception as e: - raise NetworkError(self.network, e) + raise NetworkError(self.network, e) from e try: response = conn.getresponse() @@ -937,7 +948,7 @@ class _Request: ) response_text = _unicode(response.read()) except Exception as e: - raise MalformedResponseError(self.network, e) + raise MalformedResponseError(self.network, e) from e try: self._check_response_for_errors(response_text) @@ -961,7 +972,7 @@ class _Request: try: doc = minidom.parseString(_string(response).replace("opensearch:", "")) except Exception as e: - raise MalformedResponseError(self.network, e) + raise MalformedResponseError(self.network, e) from e e = doc.getElementsByTagName("lfm")[0] # logger.debug(doc.toprettyxml()) @@ -1042,9 +1053,6 @@ class SessionKeyGenerator: if url in self.web_auth_tokens.keys(): token = self.web_auth_tokens[url] - else: - # This will raise a WSError if token is blank or unauthorized - token = token request = _Request(self.network, "auth.getSession", {"token": token}) @@ -1397,7 +1405,13 @@ class _Taggable(_BaseObject): return seq -class WSError(Exception): +class PyLastError(Exception): + """Generic exception raised by PyLast""" + + pass + + +class WSError(PyLastError): """Exception related to the Network web service""" def __init__(self, network, status, details): @@ -1441,7 +1455,7 @@ class WSError(Exception): return self.status -class MalformedResponseError(Exception): +class MalformedResponseError(PyLastError): """Exception conveying a malformed response from the music network.""" def __init__(self, network, underlying_error): @@ -1454,7 +1468,7 @@ class MalformedResponseError(Exception): ) -class NetworkError(Exception): +class NetworkError(PyLastError): """Exception conveying a problem in sending a request to Last.fm""" def __init__(self, network, underlying_error): @@ -2778,7 +2792,7 @@ def _collect_nodes(limit, sender, method_name, cacheable, params=None, stream=Fa break # success except Exception as e: if tries >= 3: - raise e + raise PyLastError() from e # Wait and try again time.sleep(1) tries += 1 @@ -2795,7 +2809,7 @@ def _collect_nodes(limit, sender, method_name, cacheable, params=None, stream=Fa main.getAttribute("totalPages") or main.getAttribute("totalpages") ) else: - raise Exception("No total pages attribute") + raise PyLastError("No total pages attribute") for node in main.childNodes: if not node.nodeType == xml.dom.Node.TEXT_NODE and ( @@ -2910,9 +2924,6 @@ def _number(string): def _unescape_htmlentity(string): - - # string = _unicode(string) - mapping = html.entities.name2codepoint for key in mapping: string = string.replace("&%s;" % key, chr(mapping[key]))