Web   ·   Wiki   ·   Activities   ·   Blog   ·   Lists   ·   Chat   ·   Meeting   ·   Bugs   ·   Git   ·   Translate   ·   Archive   ·   People   ·   Donate
summaryrefslogtreecommitdiffstats
path: root/websdk/mercurial/wireproto.py
diff options
context:
space:
mode:
Diffstat (limited to 'websdk/mercurial/wireproto.py')
-rw-r--r--websdk/mercurial/wireproto.py607
1 files changed, 607 insertions, 0 deletions
diff --git a/websdk/mercurial/wireproto.py b/websdk/mercurial/wireproto.py
new file mode 100644
index 0000000..d189004
--- /dev/null
+++ b/websdk/mercurial/wireproto.py
@@ -0,0 +1,607 @@
+# wireproto.py - generic wire protocol support functions
+#
+# Copyright 2005-2010 Matt Mackall <mpm@selenic.com>
+#
+# This software may be used and distributed according to the terms of the
+# GNU General Public License version 2 or any later version.
+
+import urllib, tempfile, os, sys
+from i18n import _
+from node import bin, hex
+import changegroup as changegroupmod
+import repo, error, encoding, util, store
+
+# abstract batching support
+
+class future(object):
+ '''placeholder for a value to be set later'''
+ def set(self, value):
+ if util.safehasattr(self, 'value'):
+ raise error.RepoError("future is already set")
+ self.value = value
+
+class batcher(object):
+ '''base class for batches of commands submittable in a single request
+
+ All methods invoked on instances of this class are simply queued and return a
+ a future for the result. Once you call submit(), all the queued calls are
+ performed and the results set in their respective futures.
+ '''
+ def __init__(self):
+ self.calls = []
+ def __getattr__(self, name):
+ def call(*args, **opts):
+ resref = future()
+ self.calls.append((name, args, opts, resref,))
+ return resref
+ return call
+ def submit(self):
+ pass
+
+class localbatch(batcher):
+ '''performs the queued calls directly'''
+ def __init__(self, local):
+ batcher.__init__(self)
+ self.local = local
+ def submit(self):
+ for name, args, opts, resref in self.calls:
+ resref.set(getattr(self.local, name)(*args, **opts))
+
+class remotebatch(batcher):
+ '''batches the queued calls; uses as few roundtrips as possible'''
+ def __init__(self, remote):
+ '''remote must support _submitbatch(encbatch) and _submitone(op, encargs)'''
+ batcher.__init__(self)
+ self.remote = remote
+ def submit(self):
+ req, rsp = [], []
+ for name, args, opts, resref in self.calls:
+ mtd = getattr(self.remote, name)
+ batchablefn = getattr(mtd, 'batchable', None)
+ if batchablefn is not None:
+ batchable = batchablefn(mtd.im_self, *args, **opts)
+ encargsorres, encresref = batchable.next()
+ if encresref:
+ req.append((name, encargsorres,))
+ rsp.append((batchable, encresref, resref,))
+ else:
+ resref.set(encargsorres)
+ else:
+ if req:
+ self._submitreq(req, rsp)
+ req, rsp = [], []
+ resref.set(mtd(*args, **opts))
+ if req:
+ self._submitreq(req, rsp)
+ def _submitreq(self, req, rsp):
+ encresults = self.remote._submitbatch(req)
+ for encres, r in zip(encresults, rsp):
+ batchable, encresref, resref = r
+ encresref.set(encres)
+ resref.set(batchable.next())
+
+def batchable(f):
+ '''annotation for batchable methods
+
+ Such methods must implement a coroutine as follows:
+
+ @batchable
+ def sample(self, one, two=None):
+ # Handle locally computable results first:
+ if not one:
+ yield "a local result", None
+ # Build list of encoded arguments suitable for your wire protocol:
+ encargs = [('one', encode(one),), ('two', encode(two),)]
+ # Create future for injection of encoded result:
+ encresref = future()
+ # Return encoded arguments and future:
+ yield encargs, encresref
+ # Assuming the future to be filled with the result from the batched request
+ # now. Decode it:
+ yield decode(encresref.value)
+
+ The decorator returns a function which wraps this coroutine as a plain method,
+ but adds the original method as an attribute called "batchable", which is
+ used by remotebatch to split the call into separate encoding and decoding
+ phases.
+ '''
+ def plain(*args, **opts):
+ batchable = f(*args, **opts)
+ encargsorres, encresref = batchable.next()
+ if not encresref:
+ return encargsorres # a local result in this case
+ self = args[0]
+ encresref.set(self._submitone(f.func_name, encargsorres))
+ return batchable.next()
+ setattr(plain, 'batchable', f)
+ return plain
+
+# list of nodes encoding / decoding
+
+def decodelist(l, sep=' '):
+ if l:
+ return map(bin, l.split(sep))
+ return []
+
+def encodelist(l, sep=' '):
+ return sep.join(map(hex, l))
+
+# batched call argument encoding
+
+def escapearg(plain):
+ return (plain
+ .replace(':', '::')
+ .replace(',', ':,')
+ .replace(';', ':;')
+ .replace('=', ':='))
+
+def unescapearg(escaped):
+ return (escaped
+ .replace(':=', '=')
+ .replace(':;', ';')
+ .replace(':,', ',')
+ .replace('::', ':'))
+
+# client side
+
+def todict(**args):
+ return args
+
+class wirerepository(repo.repository):
+
+ def batch(self):
+ return remotebatch(self)
+ def _submitbatch(self, req):
+ cmds = []
+ for op, argsdict in req:
+ args = ','.join('%s=%s' % p for p in argsdict.iteritems())
+ cmds.append('%s %s' % (op, args))
+ rsp = self._call("batch", cmds=';'.join(cmds))
+ return rsp.split(';')
+ def _submitone(self, op, args):
+ return self._call(op, **args)
+
+ @batchable
+ def lookup(self, key):
+ self.requirecap('lookup', _('look up remote revision'))
+ f = future()
+ yield todict(key=encoding.fromlocal(key)), f
+ d = f.value
+ success, data = d[:-1].split(" ", 1)
+ if int(success):
+ yield bin(data)
+ self._abort(error.RepoError(data))
+
+ @batchable
+ def heads(self):
+ f = future()
+ yield {}, f
+ d = f.value
+ try:
+ yield decodelist(d[:-1])
+ except ValueError:
+ self._abort(error.ResponseError(_("unexpected response:"), d))
+
+ @batchable
+ def known(self, nodes):
+ f = future()
+ yield todict(nodes=encodelist(nodes)), f
+ d = f.value
+ try:
+ yield [bool(int(f)) for f in d]
+ except ValueError:
+ self._abort(error.ResponseError(_("unexpected response:"), d))
+
+ @batchable
+ def branchmap(self):
+ f = future()
+ yield {}, f
+ d = f.value
+ try:
+ branchmap = {}
+ for branchpart in d.splitlines():
+ branchname, branchheads = branchpart.split(' ', 1)
+ branchname = encoding.tolocal(urllib.unquote(branchname))
+ branchheads = decodelist(branchheads)
+ branchmap[branchname] = branchheads
+ yield branchmap
+ except TypeError:
+ self._abort(error.ResponseError(_("unexpected response:"), d))
+
+ def branches(self, nodes):
+ n = encodelist(nodes)
+ d = self._call("branches", nodes=n)
+ try:
+ br = [tuple(decodelist(b)) for b in d.splitlines()]
+ return br
+ except ValueError:
+ self._abort(error.ResponseError(_("unexpected response:"), d))
+
+ def between(self, pairs):
+ batch = 8 # avoid giant requests
+ r = []
+ for i in xrange(0, len(pairs), batch):
+ n = " ".join([encodelist(p, '-') for p in pairs[i:i + batch]])
+ d = self._call("between", pairs=n)
+ try:
+ r.extend(l and decodelist(l) or [] for l in d.splitlines())
+ except ValueError:
+ self._abort(error.ResponseError(_("unexpected response:"), d))
+ return r
+
+ @batchable
+ def pushkey(self, namespace, key, old, new):
+ if not self.capable('pushkey'):
+ yield False, None
+ f = future()
+ yield todict(namespace=encoding.fromlocal(namespace),
+ key=encoding.fromlocal(key),
+ old=encoding.fromlocal(old),
+ new=encoding.fromlocal(new)), f
+ d = f.value
+ try:
+ d = bool(int(d))
+ except ValueError:
+ raise error.ResponseError(
+ _('push failed (unexpected response):'), d)
+ yield d
+
+ @batchable
+ def listkeys(self, namespace):
+ if not self.capable('pushkey'):
+ yield {}, None
+ f = future()
+ yield todict(namespace=encoding.fromlocal(namespace)), f
+ d = f.value
+ r = {}
+ for l in d.splitlines():
+ k, v = l.split('\t')
+ r[encoding.tolocal(k)] = encoding.tolocal(v)
+ yield r
+
+ def stream_out(self):
+ return self._callstream('stream_out')
+
+ def changegroup(self, nodes, kind):
+ n = encodelist(nodes)
+ f = self._callstream("changegroup", roots=n)
+ return changegroupmod.unbundle10(self._decompress(f), 'UN')
+
+ def changegroupsubset(self, bases, heads, kind):
+ self.requirecap('changegroupsubset', _('look up remote changes'))
+ bases = encodelist(bases)
+ heads = encodelist(heads)
+ f = self._callstream("changegroupsubset",
+ bases=bases, heads=heads)
+ return changegroupmod.unbundle10(self._decompress(f), 'UN')
+
+ def getbundle(self, source, heads=None, common=None):
+ self.requirecap('getbundle', _('look up remote changes'))
+ opts = {}
+ if heads is not None:
+ opts['heads'] = encodelist(heads)
+ if common is not None:
+ opts['common'] = encodelist(common)
+ f = self._callstream("getbundle", **opts)
+ return changegroupmod.unbundle10(self._decompress(f), 'UN')
+
+ def unbundle(self, cg, heads, source):
+ '''Send cg (a readable file-like object representing the
+ changegroup to push, typically a chunkbuffer object) to the
+ remote server as a bundle. Return an integer indicating the
+ result of the push (see localrepository.addchangegroup()).'''
+
+ if heads != ['force'] and self.capable('unbundlehash'):
+ heads = encodelist(['hashed',
+ util.sha1(''.join(sorted(heads))).digest()])
+ else:
+ heads = encodelist(heads)
+
+ ret, output = self._callpush("unbundle", cg, heads=heads)
+ if ret == "":
+ raise error.ResponseError(
+ _('push failed:'), output)
+ try:
+ ret = int(ret)
+ except ValueError:
+ raise error.ResponseError(
+ _('push failed (unexpected response):'), ret)
+
+ for l in output.splitlines(True):
+ self.ui.status(_('remote: '), l)
+ return ret
+
+ def debugwireargs(self, one, two, three=None, four=None, five=None):
+ # don't pass optional arguments left at their default value
+ opts = {}
+ if three is not None:
+ opts['three'] = three
+ if four is not None:
+ opts['four'] = four
+ return self._call('debugwireargs', one=one, two=two, **opts)
+
+# server side
+
+class streamres(object):
+ def __init__(self, gen):
+ self.gen = gen
+
+class pushres(object):
+ def __init__(self, res):
+ self.res = res
+
+class pusherr(object):
+ def __init__(self, res):
+ self.res = res
+
+class ooberror(object):
+ def __init__(self, message):
+ self.message = message
+
+def dispatch(repo, proto, command):
+ func, spec = commands[command]
+ args = proto.getargs(spec)
+ return func(repo, proto, *args)
+
+def options(cmd, keys, others):
+ opts = {}
+ for k in keys:
+ if k in others:
+ opts[k] = others[k]
+ del others[k]
+ if others:
+ sys.stderr.write("abort: %s got unexpected arguments %s\n"
+ % (cmd, ",".join(others)))
+ return opts
+
+def batch(repo, proto, cmds, others):
+ res = []
+ for pair in cmds.split(';'):
+ op, args = pair.split(' ', 1)
+ vals = {}
+ for a in args.split(','):
+ if a:
+ n, v = a.split('=')
+ vals[n] = unescapearg(v)
+ func, spec = commands[op]
+ if spec:
+ keys = spec.split()
+ data = {}
+ for k in keys:
+ if k == '*':
+ star = {}
+ for key in vals.keys():
+ if key not in keys:
+ star[key] = vals[key]
+ data['*'] = star
+ else:
+ data[k] = vals[k]
+ result = func(repo, proto, *[data[k] for k in keys])
+ else:
+ result = func(repo, proto)
+ if isinstance(result, ooberror):
+ return result
+ res.append(escapearg(result))
+ return ';'.join(res)
+
+def between(repo, proto, pairs):
+ pairs = [decodelist(p, '-') for p in pairs.split(" ")]
+ r = []
+ for b in repo.between(pairs):
+ r.append(encodelist(b) + "\n")
+ return "".join(r)
+
+def branchmap(repo, proto):
+ branchmap = repo.branchmap()
+ heads = []
+ for branch, nodes in branchmap.iteritems():
+ branchname = urllib.quote(encoding.fromlocal(branch))
+ branchnodes = encodelist(nodes)
+ heads.append('%s %s' % (branchname, branchnodes))
+ return '\n'.join(heads)
+
+def branches(repo, proto, nodes):
+ nodes = decodelist(nodes)
+ r = []
+ for b in repo.branches(nodes):
+ r.append(encodelist(b) + "\n")
+ return "".join(r)
+
+def capabilities(repo, proto):
+ caps = ('lookup changegroupsubset branchmap pushkey known getbundle '
+ 'unbundlehash batch').split()
+ if _allowstream(repo.ui):
+ requiredformats = repo.requirements & repo.supportedformats
+ # if our local revlogs are just revlogv1, add 'stream' cap
+ if not requiredformats - set(('revlogv1',)):
+ caps.append('stream')
+ # otherwise, add 'streamreqs' detailing our local revlog format
+ else:
+ caps.append('streamreqs=%s' % ','.join(requiredformats))
+ caps.append('unbundle=%s' % ','.join(changegroupmod.bundlepriority))
+ caps.append('httpheader=1024')
+ return ' '.join(caps)
+
+def changegroup(repo, proto, roots):
+ nodes = decodelist(roots)
+ cg = repo.changegroup(nodes, 'serve')
+ return streamres(proto.groupchunks(cg))
+
+def changegroupsubset(repo, proto, bases, heads):
+ bases = decodelist(bases)
+ heads = decodelist(heads)
+ cg = repo.changegroupsubset(bases, heads, 'serve')
+ return streamres(proto.groupchunks(cg))
+
+def debugwireargs(repo, proto, one, two, others):
+ # only accept optional args from the known set
+ opts = options('debugwireargs', ['three', 'four'], others)
+ return repo.debugwireargs(one, two, **opts)
+
+def getbundle(repo, proto, others):
+ opts = options('getbundle', ['heads', 'common'], others)
+ for k, v in opts.iteritems():
+ opts[k] = decodelist(v)
+ cg = repo.getbundle('serve', **opts)
+ return streamres(proto.groupchunks(cg))
+
+def heads(repo, proto):
+ h = repo.heads()
+ return encodelist(h) + "\n"
+
+def hello(repo, proto):
+ '''the hello command returns a set of lines describing various
+ interesting things about the server, in an RFC822-like format.
+ Currently the only one defined is "capabilities", which
+ consists of a line in the form:
+
+ capabilities: space separated list of tokens
+ '''
+ return "capabilities: %s\n" % (capabilities(repo, proto))
+
+def listkeys(repo, proto, namespace):
+ d = repo.listkeys(encoding.tolocal(namespace)).items()
+ t = '\n'.join(['%s\t%s' % (encoding.fromlocal(k), encoding.fromlocal(v))
+ for k, v in d])
+ return t
+
+def lookup(repo, proto, key):
+ try:
+ r = hex(repo.lookup(encoding.tolocal(key)))
+ success = 1
+ except Exception, inst:
+ r = str(inst)
+ success = 0
+ return "%s %s\n" % (success, r)
+
+def known(repo, proto, nodes, others):
+ return ''.join(b and "1" or "0" for b in repo.known(decodelist(nodes)))
+
+def pushkey(repo, proto, namespace, key, old, new):
+ # compatibility with pre-1.8 clients which were accidentally
+ # sending raw binary nodes rather than utf-8-encoded hex
+ if len(new) == 20 and new.encode('string-escape') != new:
+ # looks like it could be a binary node
+ try:
+ new.decode('utf-8')
+ new = encoding.tolocal(new) # but cleanly decodes as UTF-8
+ except UnicodeDecodeError:
+ pass # binary, leave unmodified
+ else:
+ new = encoding.tolocal(new) # normal path
+
+ r = repo.pushkey(encoding.tolocal(namespace), encoding.tolocal(key),
+ encoding.tolocal(old), new)
+ return '%s\n' % int(r)
+
+def _allowstream(ui):
+ return ui.configbool('server', 'uncompressed', True, untrusted=True)
+
+def stream(repo, proto):
+ '''If the server supports streaming clone, it advertises the "stream"
+ capability with a value representing the version and flags of the repo
+ it is serving. Client checks to see if it understands the format.
+
+ The format is simple: the server writes out a line with the amount
+ of files, then the total amount of bytes to be transfered (separated
+ by a space). Then, for each file, the server first writes the filename
+ and filesize (separated by the null character), then the file contents.
+ '''
+
+ if not _allowstream(repo.ui):
+ return '1\n'
+
+ entries = []
+ total_bytes = 0
+ try:
+ # get consistent snapshot of repo, lock during scan
+ lock = repo.lock()
+ try:
+ repo.ui.debug('scanning\n')
+ for name, ename, size in repo.store.walk():
+ entries.append((name, size))
+ total_bytes += size
+ finally:
+ lock.release()
+ except error.LockError:
+ return '2\n' # error: 2
+
+ def streamer(repo, entries, total):
+ '''stream out all metadata files in repository.'''
+ yield '0\n' # success
+ repo.ui.debug('%d files, %d bytes to transfer\n' %
+ (len(entries), total_bytes))
+ yield '%d %d\n' % (len(entries), total_bytes)
+ for name, size in entries:
+ repo.ui.debug('sending %s (%d bytes)\n' % (name, size))
+ # partially encode name over the wire for backwards compat
+ yield '%s\0%d\n' % (store.encodedir(name), size)
+ for chunk in util.filechunkiter(repo.sopener(name), limit=size):
+ yield chunk
+
+ return streamres(streamer(repo, entries, total_bytes))
+
+def unbundle(repo, proto, heads):
+ their_heads = decodelist(heads)
+
+ def check_heads():
+ heads = repo.heads()
+ heads_hash = util.sha1(''.join(sorted(heads))).digest()
+ return (their_heads == ['force'] or their_heads == heads or
+ their_heads == ['hashed', heads_hash])
+
+ proto.redirect()
+
+ # fail early if possible
+ if not check_heads():
+ return pusherr('unsynced changes')
+
+ # write bundle data to temporary file because it can be big
+ fd, tempname = tempfile.mkstemp(prefix='hg-unbundle-')
+ fp = os.fdopen(fd, 'wb+')
+ r = 0
+ try:
+ proto.getfile(fp)
+ lock = repo.lock()
+ try:
+ if not check_heads():
+ # someone else committed/pushed/unbundled while we
+ # were transferring data
+ return pusherr('unsynced changes')
+
+ # push can proceed
+ fp.seek(0)
+ gen = changegroupmod.readbundle(fp, None)
+
+ try:
+ r = repo.addchangegroup(gen, 'serve', proto._client(),
+ lock=lock)
+ except util.Abort, inst:
+ sys.stderr.write("abort: %s\n" % inst)
+ finally:
+ lock.release()
+ return pushres(r)
+
+ finally:
+ fp.close()
+ os.unlink(tempname)
+
+commands = {
+ 'batch': (batch, 'cmds *'),
+ 'between': (between, 'pairs'),
+ 'branchmap': (branchmap, ''),
+ 'branches': (branches, 'nodes'),
+ 'capabilities': (capabilities, ''),
+ 'changegroup': (changegroup, 'roots'),
+ 'changegroupsubset': (changegroupsubset, 'bases heads'),
+ 'debugwireargs': (debugwireargs, 'one two *'),
+ 'getbundle': (getbundle, '*'),
+ 'heads': (heads, ''),
+ 'hello': (hello, ''),
+ 'known': (known, 'nodes *'),
+ 'listkeys': (listkeys, 'namespace'),
+ 'lookup': (lookup, 'key'),
+ 'pushkey': (pushkey, 'namespace key old new'),
+ 'stream_out': (stream, ''),
+ 'unbundle': (unbundle, 'heads'),
+}