diff options
Diffstat (limited to 'websdk/mercurial/wireproto.py')
-rw-r--r-- | websdk/mercurial/wireproto.py | 607 |
1 files changed, 607 insertions, 0 deletions
diff --git a/websdk/mercurial/wireproto.py b/websdk/mercurial/wireproto.py new file mode 100644 index 0000000..d189004 --- /dev/null +++ b/websdk/mercurial/wireproto.py @@ -0,0 +1,607 @@ +# wireproto.py - generic wire protocol support functions +# +# Copyright 2005-2010 Matt Mackall <mpm@selenic.com> +# +# This software may be used and distributed according to the terms of the +# GNU General Public License version 2 or any later version. + +import urllib, tempfile, os, sys +from i18n import _ +from node import bin, hex +import changegroup as changegroupmod +import repo, error, encoding, util, store + +# abstract batching support + +class future(object): + '''placeholder for a value to be set later''' + def set(self, value): + if util.safehasattr(self, 'value'): + raise error.RepoError("future is already set") + self.value = value + +class batcher(object): + '''base class for batches of commands submittable in a single request + + All methods invoked on instances of this class are simply queued and return a + a future for the result. Once you call submit(), all the queued calls are + performed and the results set in their respective futures. + ''' + def __init__(self): + self.calls = [] + def __getattr__(self, name): + def call(*args, **opts): + resref = future() + self.calls.append((name, args, opts, resref,)) + return resref + return call + def submit(self): + pass + +class localbatch(batcher): + '''performs the queued calls directly''' + def __init__(self, local): + batcher.__init__(self) + self.local = local + def submit(self): + for name, args, opts, resref in self.calls: + resref.set(getattr(self.local, name)(*args, **opts)) + +class remotebatch(batcher): + '''batches the queued calls; uses as few roundtrips as possible''' + def __init__(self, remote): + '''remote must support _submitbatch(encbatch) and _submitone(op, encargs)''' + batcher.__init__(self) + self.remote = remote + def submit(self): + req, rsp = [], [] + for name, args, opts, resref in self.calls: + mtd = getattr(self.remote, name) + batchablefn = getattr(mtd, 'batchable', None) + if batchablefn is not None: + batchable = batchablefn(mtd.im_self, *args, **opts) + encargsorres, encresref = batchable.next() + if encresref: + req.append((name, encargsorres,)) + rsp.append((batchable, encresref, resref,)) + else: + resref.set(encargsorres) + else: + if req: + self._submitreq(req, rsp) + req, rsp = [], [] + resref.set(mtd(*args, **opts)) + if req: + self._submitreq(req, rsp) + def _submitreq(self, req, rsp): + encresults = self.remote._submitbatch(req) + for encres, r in zip(encresults, rsp): + batchable, encresref, resref = r + encresref.set(encres) + resref.set(batchable.next()) + +def batchable(f): + '''annotation for batchable methods + + Such methods must implement a coroutine as follows: + + @batchable + def sample(self, one, two=None): + # Handle locally computable results first: + if not one: + yield "a local result", None + # Build list of encoded arguments suitable for your wire protocol: + encargs = [('one', encode(one),), ('two', encode(two),)] + # Create future for injection of encoded result: + encresref = future() + # Return encoded arguments and future: + yield encargs, encresref + # Assuming the future to be filled with the result from the batched request + # now. Decode it: + yield decode(encresref.value) + + The decorator returns a function which wraps this coroutine as a plain method, + but adds the original method as an attribute called "batchable", which is + used by remotebatch to split the call into separate encoding and decoding + phases. + ''' + def plain(*args, **opts): + batchable = f(*args, **opts) + encargsorres, encresref = batchable.next() + if not encresref: + return encargsorres # a local result in this case + self = args[0] + encresref.set(self._submitone(f.func_name, encargsorres)) + return batchable.next() + setattr(plain, 'batchable', f) + return plain + +# list of nodes encoding / decoding + +def decodelist(l, sep=' '): + if l: + return map(bin, l.split(sep)) + return [] + +def encodelist(l, sep=' '): + return sep.join(map(hex, l)) + +# batched call argument encoding + +def escapearg(plain): + return (plain + .replace(':', '::') + .replace(',', ':,') + .replace(';', ':;') + .replace('=', ':=')) + +def unescapearg(escaped): + return (escaped + .replace(':=', '=') + .replace(':;', ';') + .replace(':,', ',') + .replace('::', ':')) + +# client side + +def todict(**args): + return args + +class wirerepository(repo.repository): + + def batch(self): + return remotebatch(self) + def _submitbatch(self, req): + cmds = [] + for op, argsdict in req: + args = ','.join('%s=%s' % p for p in argsdict.iteritems()) + cmds.append('%s %s' % (op, args)) + rsp = self._call("batch", cmds=';'.join(cmds)) + return rsp.split(';') + def _submitone(self, op, args): + return self._call(op, **args) + + @batchable + def lookup(self, key): + self.requirecap('lookup', _('look up remote revision')) + f = future() + yield todict(key=encoding.fromlocal(key)), f + d = f.value + success, data = d[:-1].split(" ", 1) + if int(success): + yield bin(data) + self._abort(error.RepoError(data)) + + @batchable + def heads(self): + f = future() + yield {}, f + d = f.value + try: + yield decodelist(d[:-1]) + except ValueError: + self._abort(error.ResponseError(_("unexpected response:"), d)) + + @batchable + def known(self, nodes): + f = future() + yield todict(nodes=encodelist(nodes)), f + d = f.value + try: + yield [bool(int(f)) for f in d] + except ValueError: + self._abort(error.ResponseError(_("unexpected response:"), d)) + + @batchable + def branchmap(self): + f = future() + yield {}, f + d = f.value + try: + branchmap = {} + for branchpart in d.splitlines(): + branchname, branchheads = branchpart.split(' ', 1) + branchname = encoding.tolocal(urllib.unquote(branchname)) + branchheads = decodelist(branchheads) + branchmap[branchname] = branchheads + yield branchmap + except TypeError: + self._abort(error.ResponseError(_("unexpected response:"), d)) + + def branches(self, nodes): + n = encodelist(nodes) + d = self._call("branches", nodes=n) + try: + br = [tuple(decodelist(b)) for b in d.splitlines()] + return br + except ValueError: + self._abort(error.ResponseError(_("unexpected response:"), d)) + + def between(self, pairs): + batch = 8 # avoid giant requests + r = [] + for i in xrange(0, len(pairs), batch): + n = " ".join([encodelist(p, '-') for p in pairs[i:i + batch]]) + d = self._call("between", pairs=n) + try: + r.extend(l and decodelist(l) or [] for l in d.splitlines()) + except ValueError: + self._abort(error.ResponseError(_("unexpected response:"), d)) + return r + + @batchable + def pushkey(self, namespace, key, old, new): + if not self.capable('pushkey'): + yield False, None + f = future() + yield todict(namespace=encoding.fromlocal(namespace), + key=encoding.fromlocal(key), + old=encoding.fromlocal(old), + new=encoding.fromlocal(new)), f + d = f.value + try: + d = bool(int(d)) + except ValueError: + raise error.ResponseError( + _('push failed (unexpected response):'), d) + yield d + + @batchable + def listkeys(self, namespace): + if not self.capable('pushkey'): + yield {}, None + f = future() + yield todict(namespace=encoding.fromlocal(namespace)), f + d = f.value + r = {} + for l in d.splitlines(): + k, v = l.split('\t') + r[encoding.tolocal(k)] = encoding.tolocal(v) + yield r + + def stream_out(self): + return self._callstream('stream_out') + + def changegroup(self, nodes, kind): + n = encodelist(nodes) + f = self._callstream("changegroup", roots=n) + return changegroupmod.unbundle10(self._decompress(f), 'UN') + + def changegroupsubset(self, bases, heads, kind): + self.requirecap('changegroupsubset', _('look up remote changes')) + bases = encodelist(bases) + heads = encodelist(heads) + f = self._callstream("changegroupsubset", + bases=bases, heads=heads) + return changegroupmod.unbundle10(self._decompress(f), 'UN') + + def getbundle(self, source, heads=None, common=None): + self.requirecap('getbundle', _('look up remote changes')) + opts = {} + if heads is not None: + opts['heads'] = encodelist(heads) + if common is not None: + opts['common'] = encodelist(common) + f = self._callstream("getbundle", **opts) + return changegroupmod.unbundle10(self._decompress(f), 'UN') + + def unbundle(self, cg, heads, source): + '''Send cg (a readable file-like object representing the + changegroup to push, typically a chunkbuffer object) to the + remote server as a bundle. Return an integer indicating the + result of the push (see localrepository.addchangegroup()).''' + + if heads != ['force'] and self.capable('unbundlehash'): + heads = encodelist(['hashed', + util.sha1(''.join(sorted(heads))).digest()]) + else: + heads = encodelist(heads) + + ret, output = self._callpush("unbundle", cg, heads=heads) + if ret == "": + raise error.ResponseError( + _('push failed:'), output) + try: + ret = int(ret) + except ValueError: + raise error.ResponseError( + _('push failed (unexpected response):'), ret) + + for l in output.splitlines(True): + self.ui.status(_('remote: '), l) + return ret + + def debugwireargs(self, one, two, three=None, four=None, five=None): + # don't pass optional arguments left at their default value + opts = {} + if three is not None: + opts['three'] = three + if four is not None: + opts['four'] = four + return self._call('debugwireargs', one=one, two=two, **opts) + +# server side + +class streamres(object): + def __init__(self, gen): + self.gen = gen + +class pushres(object): + def __init__(self, res): + self.res = res + +class pusherr(object): + def __init__(self, res): + self.res = res + +class ooberror(object): + def __init__(self, message): + self.message = message + +def dispatch(repo, proto, command): + func, spec = commands[command] + args = proto.getargs(spec) + return func(repo, proto, *args) + +def options(cmd, keys, others): + opts = {} + for k in keys: + if k in others: + opts[k] = others[k] + del others[k] + if others: + sys.stderr.write("abort: %s got unexpected arguments %s\n" + % (cmd, ",".join(others))) + return opts + +def batch(repo, proto, cmds, others): + res = [] + for pair in cmds.split(';'): + op, args = pair.split(' ', 1) + vals = {} + for a in args.split(','): + if a: + n, v = a.split('=') + vals[n] = unescapearg(v) + func, spec = commands[op] + if spec: + keys = spec.split() + data = {} + for k in keys: + if k == '*': + star = {} + for key in vals.keys(): + if key not in keys: + star[key] = vals[key] + data['*'] = star + else: + data[k] = vals[k] + result = func(repo, proto, *[data[k] for k in keys]) + else: + result = func(repo, proto) + if isinstance(result, ooberror): + return result + res.append(escapearg(result)) + return ';'.join(res) + +def between(repo, proto, pairs): + pairs = [decodelist(p, '-') for p in pairs.split(" ")] + r = [] + for b in repo.between(pairs): + r.append(encodelist(b) + "\n") + return "".join(r) + +def branchmap(repo, proto): + branchmap = repo.branchmap() + heads = [] + for branch, nodes in branchmap.iteritems(): + branchname = urllib.quote(encoding.fromlocal(branch)) + branchnodes = encodelist(nodes) + heads.append('%s %s' % (branchname, branchnodes)) + return '\n'.join(heads) + +def branches(repo, proto, nodes): + nodes = decodelist(nodes) + r = [] + for b in repo.branches(nodes): + r.append(encodelist(b) + "\n") + return "".join(r) + +def capabilities(repo, proto): + caps = ('lookup changegroupsubset branchmap pushkey known getbundle ' + 'unbundlehash batch').split() + if _allowstream(repo.ui): + requiredformats = repo.requirements & repo.supportedformats + # if our local revlogs are just revlogv1, add 'stream' cap + if not requiredformats - set(('revlogv1',)): + caps.append('stream') + # otherwise, add 'streamreqs' detailing our local revlog format + else: + caps.append('streamreqs=%s' % ','.join(requiredformats)) + caps.append('unbundle=%s' % ','.join(changegroupmod.bundlepriority)) + caps.append('httpheader=1024') + return ' '.join(caps) + +def changegroup(repo, proto, roots): + nodes = decodelist(roots) + cg = repo.changegroup(nodes, 'serve') + return streamres(proto.groupchunks(cg)) + +def changegroupsubset(repo, proto, bases, heads): + bases = decodelist(bases) + heads = decodelist(heads) + cg = repo.changegroupsubset(bases, heads, 'serve') + return streamres(proto.groupchunks(cg)) + +def debugwireargs(repo, proto, one, two, others): + # only accept optional args from the known set + opts = options('debugwireargs', ['three', 'four'], others) + return repo.debugwireargs(one, two, **opts) + +def getbundle(repo, proto, others): + opts = options('getbundle', ['heads', 'common'], others) + for k, v in opts.iteritems(): + opts[k] = decodelist(v) + cg = repo.getbundle('serve', **opts) + return streamres(proto.groupchunks(cg)) + +def heads(repo, proto): + h = repo.heads() + return encodelist(h) + "\n" + +def hello(repo, proto): + '''the hello command returns a set of lines describing various + interesting things about the server, in an RFC822-like format. + Currently the only one defined is "capabilities", which + consists of a line in the form: + + capabilities: space separated list of tokens + ''' + return "capabilities: %s\n" % (capabilities(repo, proto)) + +def listkeys(repo, proto, namespace): + d = repo.listkeys(encoding.tolocal(namespace)).items() + t = '\n'.join(['%s\t%s' % (encoding.fromlocal(k), encoding.fromlocal(v)) + for k, v in d]) + return t + +def lookup(repo, proto, key): + try: + r = hex(repo.lookup(encoding.tolocal(key))) + success = 1 + except Exception, inst: + r = str(inst) + success = 0 + return "%s %s\n" % (success, r) + +def known(repo, proto, nodes, others): + return ''.join(b and "1" or "0" for b in repo.known(decodelist(nodes))) + +def pushkey(repo, proto, namespace, key, old, new): + # compatibility with pre-1.8 clients which were accidentally + # sending raw binary nodes rather than utf-8-encoded hex + if len(new) == 20 and new.encode('string-escape') != new: + # looks like it could be a binary node + try: + new.decode('utf-8') + new = encoding.tolocal(new) # but cleanly decodes as UTF-8 + except UnicodeDecodeError: + pass # binary, leave unmodified + else: + new = encoding.tolocal(new) # normal path + + r = repo.pushkey(encoding.tolocal(namespace), encoding.tolocal(key), + encoding.tolocal(old), new) + return '%s\n' % int(r) + +def _allowstream(ui): + return ui.configbool('server', 'uncompressed', True, untrusted=True) + +def stream(repo, proto): + '''If the server supports streaming clone, it advertises the "stream" + capability with a value representing the version and flags of the repo + it is serving. Client checks to see if it understands the format. + + The format is simple: the server writes out a line with the amount + of files, then the total amount of bytes to be transfered (separated + by a space). Then, for each file, the server first writes the filename + and filesize (separated by the null character), then the file contents. + ''' + + if not _allowstream(repo.ui): + return '1\n' + + entries = [] + total_bytes = 0 + try: + # get consistent snapshot of repo, lock during scan + lock = repo.lock() + try: + repo.ui.debug('scanning\n') + for name, ename, size in repo.store.walk(): + entries.append((name, size)) + total_bytes += size + finally: + lock.release() + except error.LockError: + return '2\n' # error: 2 + + def streamer(repo, entries, total): + '''stream out all metadata files in repository.''' + yield '0\n' # success + repo.ui.debug('%d files, %d bytes to transfer\n' % + (len(entries), total_bytes)) + yield '%d %d\n' % (len(entries), total_bytes) + for name, size in entries: + repo.ui.debug('sending %s (%d bytes)\n' % (name, size)) + # partially encode name over the wire for backwards compat + yield '%s\0%d\n' % (store.encodedir(name), size) + for chunk in util.filechunkiter(repo.sopener(name), limit=size): + yield chunk + + return streamres(streamer(repo, entries, total_bytes)) + +def unbundle(repo, proto, heads): + their_heads = decodelist(heads) + + def check_heads(): + heads = repo.heads() + heads_hash = util.sha1(''.join(sorted(heads))).digest() + return (their_heads == ['force'] or their_heads == heads or + their_heads == ['hashed', heads_hash]) + + proto.redirect() + + # fail early if possible + if not check_heads(): + return pusherr('unsynced changes') + + # write bundle data to temporary file because it can be big + fd, tempname = tempfile.mkstemp(prefix='hg-unbundle-') + fp = os.fdopen(fd, 'wb+') + r = 0 + try: + proto.getfile(fp) + lock = repo.lock() + try: + if not check_heads(): + # someone else committed/pushed/unbundled while we + # were transferring data + return pusherr('unsynced changes') + + # push can proceed + fp.seek(0) + gen = changegroupmod.readbundle(fp, None) + + try: + r = repo.addchangegroup(gen, 'serve', proto._client(), + lock=lock) + except util.Abort, inst: + sys.stderr.write("abort: %s\n" % inst) + finally: + lock.release() + return pushres(r) + + finally: + fp.close() + os.unlink(tempname) + +commands = { + 'batch': (batch, 'cmds *'), + 'between': (between, 'pairs'), + 'branchmap': (branchmap, ''), + 'branches': (branches, 'nodes'), + 'capabilities': (capabilities, ''), + 'changegroup': (changegroup, 'roots'), + 'changegroupsubset': (changegroupsubset, 'bases heads'), + 'debugwireargs': (debugwireargs, 'one two *'), + 'getbundle': (getbundle, '*'), + 'heads': (heads, ''), + 'hello': (hello, ''), + 'known': (known, 'nodes *'), + 'listkeys': (listkeys, 'namespace'), + 'lookup': (lookup, 'key'), + 'pushkey': (pushkey, 'namespace key old new'), + 'stream_out': (stream, ''), + 'unbundle': (unbundle, 'heads'), +} |