# Copyright (C) 2012 Aleksey Lim
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see .
import os
import cgi
import json
import time
import types
import logging
from email.utils import parsedate, formatdate
from urlparse import parse_qsl, urlsplit
from bisect import bisect_left
from os.path import join, isfile
import active_document as ad
from sugar_network import static
from sugar_network.resources.volume import Request
from active_toolkit.sockets import BUFFER_SIZE
from active_toolkit import coroutine, util, enforce
_logger = logging.getLogger('router')
class HTTPStatus(Exception):
status = None
headers = None
result = None
class BadRequest(HTTPStatus):
status = '400 Bad Request'
class Unauthorized(HTTPStatus):
status = '401 Unauthorized'
headers = {'WWW-Authenticate': 'Sugar'}
def route(method, path):
path = path.strip('/').split('/')
# Only top level paths for now
enforce(len(path) == 1)
def decorate(func):
func.route = (method, path[0])
return func
return decorate
class Router(object):
def __init__(self, commands):
self.commands = commands
self._authenticated = set()
self._valid_origins = set()
self._invalid_origins = set()
self._host = None
self._routes = {}
cls = self.__class__
while cls is not None:
for name in dir(cls):
attr = getattr(self, name)
if hasattr(attr, 'route'):
self._routes[attr.route] = attr
# pylint: disable-msg=E1101
cls = cls.__base__
if 'SSH_ASKPASS' in os.environ:
# Otherwise ssh-keygen will popup auth dialogs on registeration
del os.environ['SSH_ASKPASS']
def authenticate(self, request):
user = request.environ.get('HTTP_SUGAR_USER')
if user is None:
return None
if user not in self._authenticated and \
(request.path != ['user'] or request['method'] != 'POST'):
_logger.debug('Logging %r user', user)
request = Request(method='GET', cmd='exists',
document='user', guid=user)
enforce(self.commands.call(request, ad.Response()), Unauthorized,
'Principal user does not exist')
self._authenticated.add(user)
return user
def call(self, request, response):
if 'HTTP_ORIGIN' in request.environ:
enforce(self._assert_origin(request.environ), ad.Forbidden,
'Cross-site is not allowed for %r origin',
request.environ['HTTP_ORIGIN'])
response['Access-Control-Allow-Origin'] = \
request.environ['HTTP_ORIGIN']
if request['method'] == 'OPTIONS':
# TODO Process OPTIONS request per url?
if request.environ['HTTP_ORIGIN']:
response['Access-Control-Allow-Methods'] = \
request.environ['HTTP_ACCESS_CONTROL_REQUEST_METHOD']
response['Access-Control-Allow-Headers'] = \
request.environ['HTTP_ACCESS_CONTROL_REQUEST_HEADERS']
else:
response['Allow'] = 'GET, POST, PUT, DELETE'
response.content_length = 0
return None
request.principal = self.authenticate(request)
if request.path[:1] == ['static']:
static_path = join(static.PATH, *request.path[1:])
enforce(isfile(static_path), 'No such file')
mtime = os.stat(static_path).st_mtime
if request.if_modified_since and \
mtime <= request.if_modified_since:
raise ad.NotModified()
response.last_modified = mtime
result = file(static_path)
else:
rout = None
if request.path:
rout = self._routes.get((request['method'], request.path[0]))
if rout:
result = rout(request, response)
else:
result = self.commands.call(request, response)
if hasattr(result, 'read'):
# pylint: disable-msg=E1103
if hasattr(result, 'fileno'):
response.content_length = os.fstat(result.fileno()).st_size
elif hasattr(result, 'seek'):
result.seek(0, 2)
response.content_length = result.tell()
result.seek(0)
result = _stream_reader(result)
return result
def __call__(self, environ, start_response):
request = _Request(environ)
request_repr = str(request) if _logger.level <= logging.DEBUG else None
response = _Response()
js_callback = None
if 'callback' in request:
js_callback = request.pop('callback')
result = None
try:
result = self.call(request, response)
except ad.Redirect, error:
response.status = '303 See Other'
response['Location'] = error.location
response.content_type = None
except ad.NotModified:
response.status = '304 Not Modified'
response.content_type = None
except Exception, error:
util.exception('Error while processing %r request', request.url)
if isinstance(error, ad.NotFound):
response.status = '404 Not Found'
elif isinstance(error, ad.Forbidden):
response.status = '403 Forbidden'
elif isinstance(error, HTTPStatus):
response.status = error.status
response.update(error.headers or {})
result = error.result
else:
response.status = '500 Internal Server Error'
if result is None:
result = {'error': str(error),
'request': request.url,
}
response.content_type = 'application/json'
result_streamed = isinstance(result, types.GeneratorType)
if js_callback:
if result_streamed:
result = ''.join(result)
result_streamed = False
result = '%s(%s);' % (js_callback, json.dumps(result))
response.content_length = len(result)
elif not result_streamed and \
response.content_type == 'application/json':
result = json.dumps(result)
response.content_length = len(result)
_logger.debug('Called %s: response=%r result=%r streamed=%r',
request_repr, response, result, result_streamed)
start_response(response.status, response.items())
if result_streamed:
for i in result:
yield i
elif result is not None:
yield result
def _assert_origin(self, environ):
origin = environ['HTTP_ORIGIN']
if origin in self._valid_origins:
return True
if origin in self._invalid_origins:
return False
valid = True
if origin == 'null' or origin.startswith('file://'):
# True all time for local apps
pass
else:
if self._host is None:
http_host = environ['HTTP_HOST'].split(':', 1)[0]
self._host = coroutine.gethostbyname(http_host)
ip = coroutine.gethostbyname(urlsplit(origin).hostname)
valid = (self._host == ip)
if valid:
_logger.info('Allow cross-site for %r origin', origin)
self._valid_origins.add(origin)
else:
_logger.info('Disallow cross-site for %r origin', origin)
self._invalid_origins.add(origin)
return valid
class _Request(Request):
environ = None
url = None
path = None
principal = None
def __init__(self, environ=None):
Request.__init__(self)
if not environ:
return
self.access_level = ad.ACCESS_REMOTE
self.environ = environ
self.url = '/' + environ['PATH_INFO'].strip('/')
self.path = [i for i in self.url[1:].split('/') if i]
self['method'] = environ['REQUEST_METHOD']
self.content = None
self.content_stream = environ.get('wsgi.input')
self.content_length = 0
self.accept_language = _parse_accept_language(
environ.get('HTTP_ACCEPT_LANGUAGE'))
self.principal = None
query = environ.get('QUERY_STRING') or ''
for attr, value in parse_qsl(query):
param = self.get(attr)
if type(param) is list:
param.append(value)
elif param is not None:
self[str(attr)] = [param, value]
else:
self[str(attr)] = value
if query:
self.url += '?' + query
content_length = environ.get('CONTENT_LENGTH')
if content_length:
self.content_length = int(content_length)
ctype, __ = cgi.parse_header(environ.get('CONTENT_TYPE', ''))
if ctype.lower() == 'application/json':
content = self.read()
if content:
self.content = json.loads(content)
elif ctype.lower() == 'multipart/form-data':
files = cgi.FieldStorage(fp=environ['wsgi.input'],
environ=environ)
enforce(len(files.list) == 1,
'Multipart request should contain only one file')
self.content_stream = files.list[0].file
if_modified_since = environ.get('HTTP_IF_MODIFIED_SINCE')
if if_modified_since:
if_modified_since = parsedate(if_modified_since)
enforce(if_modified_since is not None,
'Failed to parse If-Modified-Since')
self.if_modified_since = time.mktime(if_modified_since)
scope = len(self.path)
enforce(scope >= 0 and scope < 4, BadRequest,
'Incorrect requested path')
if scope == 3:
self['document'], self['guid'], self['prop'] = self.path
elif scope == 2:
self['document'], self['guid'] = self.path
elif scope == 1:
self['document'], = self.path
def clone(self):
request = Request.clone(self)
request.environ = self.environ
request.url = self.url
request.path = self.path
request.principal = self.principal
return request
class _Response(ad.Response):
# pylint: disable-msg=E0202
status = '200 OK'
@property
def content_length(self):
return self.get('Content-Length')
@content_length.setter
def content_length(self, value):
self['Content-Length'] = value
@property
def content_type(self):
return self.get('Content-Type')
@content_type.setter
def content_type(self, value):
if value:
self['Content-Type'] = value
elif 'Content-Type' in self:
del self['Content-Type']
@property
def last_modified(self):
return self.get('Last-Modified')
@last_modified.setter
def last_modified(self, value):
self['Last-Modified'] = formatdate(value, localtime=False, usegmt=True)
def items(self):
for key, value in dict.items(self):
if type(value) in (list, tuple):
for i in value:
yield key, str(i)
else:
yield key, str(value)
def __repr__(self):
args = ['status=%r' % self.status,
] + ['%s=%r' % i for i in self.items()]
return '' % ' '.join(args)
def _parse_accept_language(accept_language):
if not accept_language:
return []
langs = []
qualities = []
for chunk in accept_language.split(','):
lang, params = (chunk.split(';', 1) + [None])[:2]
lang = lang.strip()
if not lang:
continue
quality = 1
if params:
params = params.split('=', 1)
if len(params) > 1 and params[0].strip() == 'q':
quality = float(params[1])
index = bisect_left(qualities, quality)
qualities.insert(index, quality)
langs.insert(len(langs) - index, lang)
return langs
def _stream_reader(stream):
try:
while True:
chunk = stream.read(BUFFER_SIZE)
if not chunk:
break
yield chunk
finally:
stream.close()