Merge pull request #356 from kvanzuijlen/feature/unsafe_tempfile

This commit is contained in:
Hugo van Kemenade 2021-03-14 17:55:53 +02:00 committed by GitHub
commit 585da81d56
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -23,6 +23,7 @@ import collections
import hashlib
import html.entities
import logging
import os
import shelve
import ssl
import tempfile
@ -421,7 +422,8 @@ class _Network:
"""
if not file_path:
file_path = tempfile.mktemp(prefix="pylast_tmp_")
self.cache_backend = _ShelfCacheBackend.create_shelf()
return
self.cache_backend = _ShelfCacheBackend(file_path)
@ -782,8 +784,11 @@ class LibreFMNetwork(_Network):
class _ShelfCacheBackend:
"""Used as a backend for caching cacheable requests."""
def __init__(self, file_path=None):
self.shelf = shelve.open(file_path)
def __init__(self, file_path=None, flag=None):
if flag is not None:
self.shelf = shelve.open(file_path, flag=flag)
else:
self.shelf = shelve.open(file_path)
self.cache_keys = set(self.shelf.keys())
def __contains__(self, key):
@ -799,6 +804,12 @@ class _ShelfCacheBackend:
self.cache_keys.add(key)
self.shelf[key] = xml_string
@classmethod
def create_shelf(cls):
file_descriptor, file_path = tempfile.mkstemp(prefix="pylast_tmp_")
os.close(file_descriptor)
return cls(file_path=file_path, flag="n")
class _Request:
"""Representing an abstract web service operation."""
@ -916,7 +927,7 @@ class _Request:
headers=headers,
)
except Exception as e:
raise NetworkError(self.network, e)
raise NetworkError(self.network, e) from e
else:
conn = HTTPSConnection(context=SSL_CONTEXT, host=host_name)
@ -924,7 +935,7 @@ class _Request:
try:
conn.request(method="POST", url=host_subdir, body=data, headers=headers)
except Exception as e:
raise NetworkError(self.network, e)
raise NetworkError(self.network, e) from e
try:
response = conn.getresponse()
@ -937,7 +948,7 @@ class _Request:
)
response_text = _unicode(response.read())
except Exception as e:
raise MalformedResponseError(self.network, e)
raise MalformedResponseError(self.network, e) from e
try:
self._check_response_for_errors(response_text)
@ -961,7 +972,7 @@ class _Request:
try:
doc = minidom.parseString(_string(response).replace("opensearch:", ""))
except Exception as e:
raise MalformedResponseError(self.network, e)
raise MalformedResponseError(self.network, e) from e
e = doc.getElementsByTagName("lfm")[0]
# logger.debug(doc.toprettyxml())
@ -1042,9 +1053,6 @@ class SessionKeyGenerator:
if url in self.web_auth_tokens.keys():
token = self.web_auth_tokens[url]
else:
# This will raise a WSError if token is blank or unauthorized
token = token
request = _Request(self.network, "auth.getSession", {"token": token})
@ -1397,7 +1405,13 @@ class _Taggable(_BaseObject):
return seq
class WSError(Exception):
class PyLastError(Exception):
"""Generic exception raised by PyLast"""
pass
class WSError(PyLastError):
"""Exception related to the Network web service"""
def __init__(self, network, status, details):
@ -1441,7 +1455,7 @@ class WSError(Exception):
return self.status
class MalformedResponseError(Exception):
class MalformedResponseError(PyLastError):
"""Exception conveying a malformed response from the music network."""
def __init__(self, network, underlying_error):
@ -1454,7 +1468,7 @@ class MalformedResponseError(Exception):
)
class NetworkError(Exception):
class NetworkError(PyLastError):
"""Exception conveying a problem in sending a request to Last.fm"""
def __init__(self, network, underlying_error):
@ -2778,7 +2792,7 @@ def _collect_nodes(limit, sender, method_name, cacheable, params=None, stream=Fa
break # success
except Exception as e:
if tries >= 3:
raise e
raise PyLastError() from e
# Wait and try again
time.sleep(1)
tries += 1
@ -2795,7 +2809,7 @@ def _collect_nodes(limit, sender, method_name, cacheable, params=None, stream=Fa
main.getAttribute("totalPages") or main.getAttribute("totalpages")
)
else:
raise Exception("No total pages attribute")
raise PyLastError("No total pages attribute")
for node in main.childNodes:
if not node.nodeType == xml.dom.Node.TEXT_NODE and (
@ -2910,9 +2924,6 @@ def _number(string):
def _unescape_htmlentity(string):
# string = _unicode(string)
mapping = html.entities.name2codepoint
for key in mapping:
string = string.replace("&%s;" % key, chr(mapping[key]))