diff options
Diffstat (limited to 'sugar_network/toolkit/http.py')
-rw-r--r-- | sugar_network/toolkit/http.py | 185 |
1 files changed, 119 insertions, 66 deletions
diff --git a/sugar_network/toolkit/http.py b/sugar_network/toolkit/http.py index 2780c1b..0a25c57 100644 --- a/sugar_network/toolkit/http.py +++ b/sugar_network/toolkit/http.py @@ -13,13 +13,15 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see <http://www.gnu.org/licenses/>. +import os import sys import json import types +import hashlib import logging -from os.path import join, dirname +from os.path import join, dirname, exists, expanduser, abspath -from sugar_network import client, toolkit +from sugar_network import toolkit from sugar_network.toolkit import enforce @@ -67,7 +69,6 @@ class BadRequest(Status): class Unauthorized(Status): status = '401 Unauthorized' - headers = {'WWW-Authenticate': 'Sugar'} status_code = 401 @@ -104,16 +105,15 @@ class GatewayTimeout(Status): class Connection(object): _Session = None - _SSLError = None _ConnectionError = None - def __init__(self, api_url='', creds=None, trust_env=True, max_retries=0): + def __init__(self, api_url='', auth=None, max_retries=0, **session_args): self.api_url = api_url - self._get_profile = None - self._session = None - self._creds = creds - self._trust_env = trust_env + self.auth = auth self._max_retries = max_retries + self._session_args = session_args + self._session = None + self._nonce = None def __repr__(self): return '<Connection api_url=%s>' % self.api_url @@ -199,35 +199,24 @@ class Connection(object): path = [''] if not isinstance(path, basestring): path = '/'.join([i.strip('/') for i in [self.api_url] + path]) - if isinstance(params, basestring): - path += '?' + params - params = None - a_try = 0 + try_ = 0 while True: - a_try += 1 + try_ += 1 try: reply = self._session.request(method, path, data=data, headers=headers, params=params, **kwargs) - except Connection._SSLError: - _logger.warning('Use --no-check-certificate to avoid checks') - raise except Connection._ConnectionError, error: raise ConnectionError, error, sys.exc_info()[2] - if reply.status_code != 200: - if reply.status_code == 401: - enforce(method not in ('PUT', 'POST') or - not hasattr(data, 'read'), - 'Cannot resend data after authentication') - enforce(self._get_profile is not None, - 'Operation is not available in anonymous mode') - _logger.info('User is not registered on the server, ' - 'registering') - self.post(['user'], self._get_profile()) - a_try = 0 - continue - if allowed and reply.status_code in allowed: - break + + if reply.status_code == Unauthorized.status_code: + enforce(self.auth is not None, Unauthorized, 'No credentials') + self._authenticate(reply.headers.get('www-authenticate')) + try_ = 0 + elif reply.status_code == 200 or \ + allowed and reply.status_code in allowed: + break + else: content = reply.content try: error = json.loads(content)['error'] @@ -236,13 +225,12 @@ class Connection(object): # was not sent by the application level server code, i.e., # something happaned on low level, like connection abort. # If so, try to resend request. - if a_try <= self._max_retries and method == 'GET': + if try_ <= self._max_retries and method in ('GET', 'HEAD'): continue error = content or reply.headers.get('x-sn-error') or \ 'No error message provided' cls = _FORWARD_STATUSES.get(reply.status_code, RuntimeError) - raise cls, error, sys.exc_info()[2] - break + raise cls(error) return reply @@ -264,14 +252,12 @@ class Connection(object): else: request.content = request.content_stream.read() headers['content-length'] = str(len(request.content)) - for env_key, key, value in ( - ('HTTP_IF_MODIFIED_SINCE', 'if-modified-since', None), - ('HTTP_ACCEPT_LANGUAGE', 'accept-language', ','.join( - client.accept_language.value or toolkit.default_langs())), - ('HTTP_ACCEPT_ENCODING', 'accept-encoding', None), + for env_key, key in ( + ('HTTP_IF_MODIFIED_SINCE', 'if-modified-since'), + ('HTTP_ACCEPT_LANGUAGE', 'accept-language'), + ('HTTP_ACCEPT_ENCODING', 'accept-encoding'), ): - if value is None: - value = request.environ.get(env_key) + value = request.environ.get(env_key) if value is not None: headers[key] = value @@ -316,25 +302,100 @@ class Connection(object): def _init(self): if Connection._Session is None: - sys.path.insert(0, - join(dirname(__file__), '..', 'lib', 'requests')) - from requests import Session - from requests.exceptions import SSLError - from requests.exceptions import ConnectionError as _ConnectionError + sys_path = join(dirname(dirname(__file__)), 'lib', 'requests') + sys.path.insert(0, sys_path) + from requests import Session, exceptions Connection._Session = Session - Connection._SSLError = SSLError - Connection._ConnectionError = _ConnectionError + Connection._ConnectionError = exceptions.ConnectionError self._session = Connection._Session() + self._session.headers['accept-language'] = \ + ','.join(toolkit.default_langs()) + for arg, value in self._session_args.items(): + setattr(self._session, arg, value) self._session.stream = True - self._session.trust_env = self._trust_env - if client.no_check_certificate.value: - self._session.verify = False - if self._creds: - uid, keyfile, self._get_profile = self._creds - self._session.headers['X-SN-login'] = uid - self._session.headers['X-SN-signature'] = _sign(keyfile, uid) - self._session.headers['accept-language'] = toolkit.default_lang() + + def _authenticate(self, challenge): + from urllib2 import parse_http_list, parse_keqv_list + + nonce = None + if challenge: + challenge = challenge.split(' ', 1)[-1] + nonce = parse_keqv_list(parse_http_list(challenge)).get('nonce') + + if self._nonce and nonce == self._nonce: + enforce(self.auth.profile(), Unauthorized, 'Bad credentials') + _logger.info('Register on the server') + self.post(['user'], self.auth.profile()) + + self._session.headers['authorization'] = self.auth(nonce) + self._nonce = nonce + + +class SugarAuth(object): + + def __init__(self, key_path, profile=None): + self._key_path = abspath(expanduser(key_path)) + self._profile = profile or {'color': '#000000,#000000'} + self._key = None + self._pubkey = None + self._login = None + + @property + def pubkey(self): + if self._pubkey is None: + self.ensure_key() + from M2Crypto.BIO import MemoryBuffer + buf = MemoryBuffer() + self._key.save_pub_key_bio(buf) + self._pubkey = buf.getvalue() + return self._pubkey + + @property + def login(self): + if self._login is None: + self._login = str(hashlib.sha1(self.pubkey).hexdigest()) + return self._login + + def profile(self): + if 'name' not in self._profile: + self._profile['name'] = self.login + self._profile['pubkey'] = self.pubkey + return self._profile + + def __call__(self, nonce): + self.ensure_key() + data = hashlib.sha1('%s:%s' % (self.login, nonce)).digest() + signature = self._key.sign(data).encode('hex') + return 'Sugar username="%s",nonce="%s",signature="%s"' % \ + (self.login, nonce, signature) + + def ensure_key(self): + from M2Crypto import RSA + from base64 import b64encode + + if exists(self._key_path): + self._key = RSA.load_key(self._key_path) + return + + key_dir = dirname(self._key_path) + if not exists(key_dir): + os.makedirs(key_dir) + os.chmod(key_dir, 0700) + + _logger.info('Generate RSA private key at %r', self._key_path) + self._key = RSA.gen_key(2048, 65537, lambda *args: None) + self._key.save_key(self._key_path, cipher=None) + os.chmod(self._key_path, 0600) + + pub_key_path = self._key_path + '.pub' + with file(pub_key_path, 'w') as f: + f.write('ssh-rsa %s %s@%s' % ( + b64encode('\x00\x00\x00\x07ssh-rsa%s%s' % self._key.pub()), + self.login, + os.uname()[1], + )) + _logger.info('Saved RSA public key at %r', pub_key_path) class _Subscription(object): @@ -355,14 +416,14 @@ class _Subscription(object): return self._handshake(ping=True)._fp.fp.fileno() def pull(self): - for a_try in (1, 0): + for try_ in (1, 0): stream = self._handshake() try: line = toolkit.readline(stream) enforce(line, 'Subscription aborted') break except Exception: - if a_try == 0: + if try_ == 0: raise toolkit.exception('Failed to read from %r subscription, ' 'will resubscribe', self._client.api_url) @@ -398,14 +459,6 @@ def _parse_event(line): _logger.exception('Failed to parse %r event', line) -def _sign(key_path, data): - import hashlib - from M2Crypto import DSA - key = DSA.load_key(key_path) - # pylint: disable-msg=E1121 - return key.sign_asn1(hashlib.sha1(data).digest()).encode('hex') - - _FORWARD_STATUSES = { BadRequest.status_code: BadRequest, Forbidden.status_code: Forbidden, |