# Copyright (C) 2012-2013 Aleksey Lim # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see . import os import sys import json import types import hashlib import logging from os.path import join, dirname, exists, expanduser, abspath from sugar_network import toolkit from sugar_network.toolkit import enforce _REDIRECT_CODES = frozenset([301, 302, 303, 307, 308]) _logger = logging.getLogger('http') class ConnectionError(Exception): pass class Status(Exception): status = None headers = None class StatusPass(Status): pass class NotModified(StatusPass): status = '304 Not Modified' status_code = 304 class Redirect(StatusPass): status = '303 See Other' status_code = 303 def __init__(self, location): StatusPass.__init__(self) self.headers = {'location': location} class BadRequest(Status): status = '400 Bad Request' status_code = 400 class Unauthorized(Status): status = '401 Unauthorized' status_code = 401 class Forbidden(Status): status = '403 Forbidden' status_code = 403 class NotFound(Status): status = '404 Not Found' status_code = 404 class BadGateway(Status): status = '502 Bad Gateway' status_code = 502 class ServiceUnavailable(Status): status = '503 Service Unavailable' status_code = 503 class GatewayTimeout(Status): status = '504 Gateway Timeout' status_code = 504 class _ConnectionError(Status): status = '999 For testing purpose only' status_code = 999 class Connection(object): _Session = None def __init__(self, api_url='', auth=None, max_retries=0, **session_args): self.api_url = api_url self.auth = auth self._max_retries = max_retries self._session_args = session_args self._session = None self._nonce = None def __repr__(self): return '' % self.api_url def __enter__(self): return self def __exit__(self, *args): self.close() def close(self): if self._session is not None: self._session.close() def exists(self, path): reply = self.request('HEAD', path, allowed=[NotFound.status_code]) return reply.status_code != NotFound.status_code def head(self, path_=None, **kwargs): from sugar_network.toolkit.router import Request, Response request = Request(method='HEAD', path=path_, **kwargs) response = Response() self.call(request, response) return response.meta def get(self, path_=None, query_=None, **kwargs): reply = self.request('GET', path_, params=query_ or kwargs) return self._decode_reply(reply) def post(self, path_=None, data_=None, query_=None, **kwargs): reply = self.request('POST', path_, json.dumps(data_), headers={'Content-Type': 'application/json'}, params=query_ or kwargs) return self._decode_reply(reply) def put(self, path_=None, data_=None, query_=None, **kwargs): reply = self.request('PUT', path_, json.dumps(data_), headers={'Content-Type': 'application/json'}, params=query_ or kwargs) return self._decode_reply(reply) def delete(self, path_=None, query_=None, **kwargs): reply = self.request('DELETE', path_, params=query_ or kwargs) return self._decode_reply(reply) def download(self, path, dst=None): reply = self.request('GET', path, allow_redirects=True) content_length = reply.headers.get('Content-Length') if content_length: chunk_size = min(int(content_length), toolkit.BUFFER_SIZE) else: chunk_size = toolkit.BUFFER_SIZE if dst is None: return reply.iter_content(chunk_size=chunk_size) f = file(dst, 'wb') if isinstance(dst, basestring) else dst try: for chunk in reply.iter_content(chunk_size=chunk_size): f.write(chunk) finally: if isinstance(dst, basestring): f.close() def upload(self, path, data, **kwargs): if isinstance(data, basestring): with file(data, 'rb') as f: reply = self.request('POST', path, f, params=kwargs) else: reply = self.request('POST', path, data, params=kwargs) if reply.headers.get('Content-Type') == 'application/json': return json.loads(reply.content) else: return reply.raw def request(self, method, path=None, data=None, headers=None, allowed=None, params=None, **kwargs): if self._session is None: self._init() if not path: path = [''] if not isinstance(path, basestring): path = '/'.join([i.strip('/') for i in [self.api_url] + path]) try_ = 0 while True: try_ += 1 reply = self._session.request(method, path, data=data, headers=headers, params=params, **kwargs) if reply.status_code == Unauthorized.status_code: enforce(self.auth is not None, Unauthorized, 'No credentials') self._authenticate(reply.headers.get('www-authenticate')) try_ = 0 elif reply.status_code == 200 or \ allowed and reply.status_code in allowed: break else: content = reply.content try: error = json.loads(content)['error'] except Exception: # On non-JSONified fail response, assume that the error # was not sent by the application level server code, i.e., # something happaned on low level, like connection abort. # If so, try to resend request. if try_ <= self._max_retries and method in ('GET', 'HEAD'): continue error = content or reply.headers.get('x-sn-error') or \ 'No error message provided' cls = _FORWARD_STATUSES.get(reply.status_code, RuntimeError) \ or ConnectionError raise cls(error) return reply def call(self, request, response=None): if request.content_type == 'application/json': request.content = json.dumps(request.content) headers = {} if request.content is not None: headers['content-type'] = \ request.content_type or 'application/octet-stream' headers['content-length'] = str(len(request.content)) elif request.content_stream is not None: headers['content-type'] = \ request.content_type or 'application/octet-stream' # TODO Avoid reading the full content at once if isinstance(request.content_stream, types.GeneratorType): request.content = ''.join([i for i in request.content_stream]) else: request.content = request.content_stream.read() headers['content-length'] = str(len(request.content)) for env_key, key in ( ('HTTP_IF_MODIFIED_SINCE', 'if-modified-since'), ('HTTP_ACCEPT_LANGUAGE', 'accept-language'), ('HTTP_ACCEPT_ENCODING', 'accept-encoding'), ): value = request.environ.get(env_key) if value is not None: headers[key] = value path = request.path while True: reply = self.request(request.method, path, data=request.content, params=request.query or request, headers=headers, allowed=_REDIRECT_CODES, allow_redirects=False) resend = reply.status_code in _REDIRECT_CODES if response is not None: if 'transfer-encoding' in reply.headers: # `requests` library handles encoding on its own del reply.headers['transfer-encoding'] for key, value in reply.headers.items(): if key.startswith('x-sn-'): response.meta[key[5:]] = json.loads(value) elif not resend: response[key] = value if resend: response.relocations += 1 if not resend: break path = reply.headers['location'] if path.startswith('/'): path = self.api_url + path if request.method != 'HEAD': if reply.headers.get('Content-Type') == 'application/json': return json.loads(reply.content) else: return reply.raw def subscribe(self, **condition): return _Subscription(self, condition) def _decode_reply(self, reply): if reply.headers.get('Content-Type') == 'application/json': return json.loads(reply.content) elif reply.headers.get('Content-Type') == 'text/event-stream': return _pull_events(reply.raw) else: return reply.content def _init(self): if Connection._Session is None: sys_path = join(dirname(dirname(__file__)), 'lib', 'requests') sys.path.insert(0, sys_path) from requests import Session, exceptions Connection._Session = Session # pylint: disable-msg=W0601 global ConnectionError ConnectionError = exceptions.ConnectionError self._session = Connection._Session() self._session.headers['accept-language'] = \ ','.join(toolkit.default_langs()) for arg, value in self._session_args.items(): setattr(self._session, arg, value) self._session.stream = True def _authenticate(self, challenge): from urllib2 import parse_http_list, parse_keqv_list nonce = None if challenge: challenge = challenge.split(' ', 1)[-1] nonce = parse_keqv_list(parse_http_list(challenge)).get('nonce') if self._nonce and nonce == self._nonce: enforce(self.auth.profile(), Unauthorized, 'Bad credentials') _logger.info('Register on the server') self.post(['user'], self.auth.profile()) self._session.headers['authorization'] = self.auth(nonce) self._nonce = nonce class SugarAuth(object): def __init__(self, key_path, profile=None): self._key_path = abspath(expanduser(key_path)) self._profile = profile or {} self._key = None self._pubkey = None self._login = None @property def pubkey(self): if self._pubkey is None: self.ensure_key() from M2Crypto.BIO import MemoryBuffer buf = MemoryBuffer() self._key.save_pub_key_bio(buf) self._pubkey = buf.getvalue() return self._pubkey @property def login(self): if self._login is None: self._login = str(hashlib.sha1(self.pubkey).hexdigest()) return self._login def profile(self): if 'name' not in self._profile: self._profile['name'] = self.login self._profile['pubkey'] = self.pubkey return self._profile def __call__(self, nonce): self.ensure_key() data = hashlib.sha1('%s:%s' % (self.login, nonce)).digest() signature = self._key.sign(data).encode('hex') return 'Sugar username="%s",nonce="%s",signature="%s"' % \ (self.login, nonce, signature) def ensure_key(self): from M2Crypto import RSA from base64 import b64encode if exists(self._key_path): self._key = RSA.load_key(self._key_path) return key_dir = dirname(self._key_path) if not exists(key_dir): os.makedirs(key_dir) os.chmod(key_dir, 0700) _logger.info('Generate RSA private key at %r', self._key_path) self._key = RSA.gen_key(1024, 65537, lambda *args: None) self._key.save_key(self._key_path, cipher=None) os.chmod(self._key_path, 0600) pub_key_path = self._key_path + '.pub' with file(pub_key_path, 'w') as f: f.write('ssh-rsa %s %s@%s' % ( b64encode('\x00\x00\x00\x07ssh-rsa%s%s' % self._key.pub()), self.login, os.uname()[1], )) _logger.info('Saved RSA public key at %r', pub_key_path) class _Subscription(object): def __init__(self, aclient, condition): self._client = aclient self._content = None self._condition = condition def __iter__(self): while True: event = self.pull() if event is not None: yield event def fileno(self): # pylint: disable-msg=W0212 return self._handshake(ping=True)._fp.fp.fileno() def pull(self): for try_ in (1, 0): stream = self._handshake() try: line = toolkit.readline(stream) enforce(line, 'Subscription aborted') break except Exception: if try_ == 0: raise toolkit.exception('Failed to read from %r subscription, ' 'will resubscribe', self._client.api_url) self._content = None return _parse_event(line) def _handshake(self, **params): if self._content is not None: return self._content params.update(self._condition) params['cmd'] = 'subscribe' _logger.debug('Subscribe to %r, %r', self._client.api_url, params) response = self._client.request('GET', params=params) self._content = response.raw return self._content def _pull_events(stream): while True: line = toolkit.readline(stream) if not line: break event = _parse_event(line) if event is not None: yield event def _parse_event(line): if line and line.startswith('data: '): try: return json.loads(line[6:]) except Exception: _logger.exception('Failed to parse %r event', line) _FORWARD_STATUSES = { BadRequest.status_code: BadRequest, Forbidden.status_code: Forbidden, NotFound.status_code: NotFound, BadGateway.status_code: BadGateway, ServiceUnavailable.status_code: ServiceUnavailable, GatewayTimeout.status_code: GatewayTimeout, _ConnectionError.status_code: None, }