Web   ·   Wiki   ·   Activities   ·   Blog   ·   Lists   ·   Chat   ·   Meeting   ·   Bugs   ·   Git   ·   Translate   ·   Archive   ·   People   ·   Donate
summaryrefslogtreecommitdiffstats
path: root/sugar/p2p
diff options
context:
space:
mode:
authorDan Williams <dcbw@redhat.com>2006-06-29 17:30:41 (GMT)
committer Dan Williams <dcbw@redhat.com>2006-06-29 17:30:41 (GMT)
commit29984ace331b041706f454222fac412cae21fd99 (patch)
tree8cab4ff0330574b412bea1b847d0ebd98c4d00ac /sugar/p2p
parent9ef8013a6b5c334292598c4b00e82047fdf441ef (diff)
Add positive acknowledgements to work around 802.11 + multicast unreliabilities
Diffstat (limited to 'sugar/p2p')
-rw-r--r--sugar/p2p/MostlyReliablePipe.py250
1 files changed, 241 insertions, 9 deletions
diff --git a/sugar/p2p/MostlyReliablePipe.py b/sugar/p2p/MostlyReliablePipe.py
index 5b503c2..6523bc0 100644
--- a/sugar/p2p/MostlyReliablePipe.py
+++ b/sugar/p2p/MostlyReliablePipe.py
@@ -44,6 +44,7 @@ class SegmentBase(object):
# Message segment packet types
_SEGMENT_TYPE_DATA = 0
_SEGMENT_TYPE_RETRANSMIT = 1
+ _SEGMENT_TYPE_ACK = 2
def magic():
return SegmentBase._MAGIC
@@ -61,6 +62,10 @@ class SegmentBase(object):
return SegmentBase._SEGMENT_TYPE_RETRANSMIT
type_retransmit = staticmethod(type_retransmit)
+ def type_ack():
+ return SegmentBase._SEGMENT_TYPE_ACK
+ type_ack = staticmethod(type_ack)
+
def header_len():
"""Return the header size of SegmentBase packets."""
return SegmentBase._HEADER_LEN
@@ -156,6 +161,8 @@ class SegmentBase(object):
segment = DataSegment(segno, total_segs, msg_seq_num, master_sha)
elif seg_type == SegmentBase._SEGMENT_TYPE_RETRANSMIT:
segment = RetransmitSegment(segno, total_segs, msg_seq_num, master_sha)
+ elif set_type == SegmentBase._SEGMENT_TYPE_ACK:
+ segment = AckSegment(segno, total_segs, msg_seq_num, master_sha)
else:
raise ValueError("Segment has invalid type.")
@@ -319,6 +326,96 @@ class RetransmitSegment(SegmentBase):
return self._rt_segment_number
+class AckSegment(SegmentBase):
+ """A message segment that encapsulates a message acknowledgement."""
+
+ # Ack data format:
+ # 2: acked message sequence number
+ # 20: acked message total data sha1
+ # 4: acked message source IP address
+ _ACK_DATA_TEMPLATE = "! H20sI"
+ _ACK_DATA_LEN = struct.calcsize(_ACK_DATA_TEMPLATE)
+
+ def data_template():
+ return AckSegment._ACK_DATA_TEMPLATE
+ data_template = staticmethod(data_template)
+
+ def __init__(self, segno, total_segs, msg_seq_num, master_sha):
+ """Should not be called directly."""
+ if segno != 1 or total_segs != 1:
+ raise ValueError("Acknowledgement messages must have only one segment.")
+
+ SegmentBase.__init__(self, segno, total_segs, msg_seq_num, master_sha)
+ self._type = SegmentBase._SEGMENT_TYPE_ACK
+
+ def _verify_data(ack_msg_seq_num, ack_master_sha, ack_addr):
+ # Sanity checks on the message attributes
+ if not ack_msg_seq_num or type(ack_msg_seq_num) != type(1):
+ raise ValueError("Ack message sequnce number must be an integer.")
+ if ack_msg_seq_num < 1 or ack_msg_seq_num > 65535:
+ raise ValueError("Ack message sequence number must be between 1 and 65535 inclusive.")
+ if not ack_master_sha or type(ack_master_sha) != type("") or len(ack_master_sha) != 20:
+ raise ValueError("Ack message SHA1 checksum invalid.")
+ if type(ack_addr) != type(""):
+ raise ValueError("Ack message invalid address.")
+ try:
+ foo = socket.inet_aton(ack_addr)
+ except socket.error:
+ raise ValueError("Ack message invalid address.")
+ _verify_data = staticmethod(_verify_data)
+
+ def _make_ack_data(ack_msg_seq_num, ack_master_sha, ack_addr):
+ """Pack an ack payload."""
+ addr_data = socket.inet_aton(ack_addr)
+ data = struct.pack(AckSegment._ACK_DATA_TEMPLATE, ack_msg_seq_num,
+ ack_master_sha, addr_data)
+ return (data, _sha_data(data))
+ _make_ack_data = staticmethod(_make_ack_data)
+
+ def new_from_parts(addr, msg_seq_num, ack_msg_seq_num, ack_master_sha, ack_addr):
+ """Static constructor for creation from individual attributes."""
+
+ AckSegment._verify_data(ack_msg_seq_num, ack_master_sha, ack_addr)
+ (data, data_sha) = AckSegment._make_ack_data(ack_msg_seq_num,
+ ack_master_sha, ack_addr)
+ segment = AckSegment(1, 1, msg_seq_num, data_sha)
+ segment._data_len = AckSegment._ACK_DATA_LEN
+ segment._data = data
+ SegmentBase._validate_address(addr)
+ segment._addr = addr
+
+ segment._ack_msg_seq_num = ack_msg_seq_num
+ segment._ack_master_sha = ack_master_sha
+ segment._ack_addr = ack_addr
+ return segment
+ new_from_parts = staticmethod(new_from_parts)
+
+ def _unpack_data(self, stream, data_len):
+ if data_len != self._ACK_DATA_LEN:
+ raise ValueError("Ack segment data had invalid length.")
+ data = stream.read(data_len)
+ (ack_msg_seq_num, ack_master_sha, ack_addr_data) = struct.unpack(self._ACK_DATA_TEMPLATE, data)
+ try:
+ ack_addr = socket.inet_ntoa(ack_addr_data)
+ except socket.error:
+ raise ValueError("Ack segment data had invalid address.")
+ AckSegment._verify_data(ack_msg_seq_num, ack_master_sha, ack_addr)
+
+ self._data = data
+ self._data_len = data_len
+ self._ack_msg_seq_num = ack_msg_seq_num
+ self._ack_master_sha = ack_master_sha
+ self._ack_addr = ack_addr
+
+ def ack_msg_seq_num(self):
+ return self._ack_msg_seq_num
+
+ def ack_master_sha(self):
+ return self._ack_master_sha
+
+ def ack_addr(self):
+ return self._ack_addr
+
class Message(object):
"""Tracks an entire message object, which is composed of a number
of individual segments."""
@@ -429,6 +526,53 @@ class Message(object):
return 0
+def _get_local_interfaces():
+ import array
+ import struct
+ import fcntl
+ import socket
+
+ max_possible = 4
+ bytes = max_possible * 32
+ SIOCGIFCONF = 0x8912
+ names = array.array('B', '\0' * bytes)
+
+ sockfd = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
+ ifreq = struct.pack('iL', bytes, names.buffer_info()[0])
+ result = fcntl.ioctl(sockfd.fileno(), SIOCGIFCONF, ifreq)
+ sockfd.close()
+
+ outbytes = struct.unpack('iL', result)[0]
+ namestr = names.tostring()
+
+ return [namestr[i:i+32].split('\0', 1)[0] for i in range(0, outbytes, 32)]
+
+def _get_local_ip_addresses():
+ """Call Linux specific bits to retrieve our own IP address."""
+ import socket
+ import sys
+ import fcntl
+ import struct
+
+ intfs = _get_local_interfaces()
+
+ ips = []
+ SIOCGIFADDR = 0x8915
+ sockfd = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
+ for intf in intfs:
+ if intf == "lo":
+ continue
+ try:
+ ifreq = (intf + '\0'*32)[:32]
+ result = fcntl.ioctl(sockfd.fileno(), SIOCGIFADDR, ifreq)
+ addr = socket.inet_ntoa(result[20:24])
+ ips.append(addr)
+ except IOError, exc:
+ print "Error getting IP address: %s" % exc
+ sockfd.close()
+ return ips
+
+
class MostlyReliablePipe(object):
"""Implement Mostly-Reliable UDP. We don't actually care about guaranteeing
delivery or receipt, just a better effort than no effort at all."""
@@ -446,13 +590,17 @@ class MostlyReliablePipe(object):
self._send_worker = 0
self._seq_counter = 0
self._drop_prob = 0
- self._rt_check_worker = 0
+ self._rt_check_worker_id = 0
self._outgoing = []
self._sent = {}
self._incoming = {} # (message sha, # of segments) -> [segment1, segment2, ...]
self._dispatched = {}
+ self._acks = {} # (message sequence #, master sha, source addr) -> received timestamp
+ self._ack_check_worker_id = 0
+
+ self._local_ips = _get_local_ip_addresses()
self._setup_listener()
self._setup_sender()
@@ -461,9 +609,12 @@ class MostlyReliablePipe(object):
if self._send_worker > 0:
gobject.source_remove(self._send_worker)
self._send_worker = 0
- if self._rt_check_worker > 0:
- gobject.source_remove(self._rt_check_worker)
- self._rt_check_worker = 0
+ if self._rt_check_worker_id > 0:
+ gobject.source_remove(self._rt_check_worker_id)
+ self._rt_check_worker_id = 0
+ if self._ack_check_worker_id > 0:
+ gobject.source_remove(self._ack_check_worker_id)
+ self._ack_check_worker_id = 0
def _setup_sender(self):
"""Setup the send socket for multicast."""
@@ -495,7 +646,8 @@ class MostlyReliablePipe(object):
# Watch the listener socket for data
gobject.io_add_watch(self._listen_sock, gobject.IO_IN, self._handle_incoming_data)
gobject.timeout_add(self._SEGMENT_TTL * 1000, self._segment_ttl_worker)
- gobject.timeout_add(50, self._retransmit_check_worker)
+ self._rt_check_worker_id = gobject.timeout_add(50, self._retransmit_check_worker)
+ self._ack_check_worker_id = gobject.timeout_add(50, self._ack_check_worker)
self._started = True
@@ -516,12 +668,18 @@ class MostlyReliablePipe(object):
if message.last_incoming_time() < now - self._SEGMENT_TTL:
del self._incoming[msg_key]
- # Remove already dispatched messages after a while
+ # Remove already received and dispatched messages after a while
for msg_key in self._dispatched.keys()[:]:
message = self._dispatched[msg_key]
if message.dispatch_time() < now - (self._SEGMENT_TTL*2):
del self._dispatched[msg_key]
+ # Remove received acks after a while
+ for ack_key in self._acks.keys()[:]:
+ ack_time = self._acks[ack_key]
+ if ack_time < now - (self._SEGMENT_TTL*2):
+ del self._acks[ack_key]
+
return True
_MAX_SEGMENT_RETRIES = 10
@@ -541,6 +699,8 @@ class MostlyReliablePipe(object):
return False
def _retransmit_check_worker(self):
+ """Periodically check for and send retransmit requests for message
+ segments that got lost."""
try:
now = time.time()
for key in self._incoming.keys()[:]:
@@ -582,6 +742,12 @@ class MostlyReliablePipe(object):
# First segment in the message
if not self._incoming.has_key(msg_key):
self._incoming[msg_key] = Message((addr[0], self._port), msg_seq_num, msg_sha, nsegs)
+ # Acknowledge the message if it didn't come from us
+ if addr[0] not in self._local_ips:
+ print "Sending ack for msg (%s %s) from %s)" % (msg_seq_num, msg_sha, addr[0])
+ ack_key = (msg_seq_num, msg_sha, addr[0])
+ if not self._acks.has_key(ack_key):
+ self._send_ack_for_message(msg_seq_num, msg_sha, addr[0])
message = self._incoming[msg_key]
# Look for a dupe, and if so, drop the new segment
@@ -636,10 +802,74 @@ class MostlyReliablePipe(object):
next_transmit = max(now, segment.last_transmit() + self._STD_RETRANSMIT_INTERVAL)
self._schedule_segment_retransmit(key, segment, next_transmit, now)
+ def _ack_check_worker(self):
+ """Periodically check for messages that haven't received an ack
+ yet, and retransmit them."""
+ try:
+ now = time.time()
+ for key in self._sent.keys()[:]:
+ segment = self._sent[key]
+ # We only care about retransmitting the first segment
+ # of a message, since if other machines don't have the
+ # rest of the segments, they'll issue retransmit requests
+ if segment.segment_number() != 1:
+ continue
+ if segment.last_transmit() > now - 0.150: # 150ms
+ # Was just retransmitted recently, wait longer
+ # before retransmitting it
+ continue
+ ack_key = None
+ for ip in self._local_ips:
+ ack_key = (segment.message_sequence_number(), segment.master_sha(), ip)
+ if self._acks.has_key(ack_key):
+ break
+ ack_key = None
+ # If the segment already has been acked, don't send it
+ # again unless somebody explicitly requests a retransmit
+ if ack_key is not None:
+ continue
+
+ del self._sent[key]
+ self._outgoing.append(segment)
+ self._schedule_send_worker()
+ except KeyboardInterrupt:
+ return False
+ return True
+
+ def _send_ack_for_message(self, ack_msg_seq_num, ack_msg_sha, ack_addr):
+ """Send an ack segment for a message."""
+ msg_seq_num = self._next_msg_seq()
+ ack = AckSegment.new_from_parts(self._remote_addr, msg_seq_num,
+ ack_msg_seq_num, ack_msg_sha, ack_addr)
+ self._outgoing.append(ack)
+ self._schedule_send_worker()
+ self._process_incoming_ack(ack)
+
+ def _process_incoming_ack(self, segment):
+ """Save the ack so that we don't send an ack when we start getting the segments
+ the ack was acknowledging."""
+ # If the ack is supposed to be for a message we sent, only accept it if
+ # we actually sent the message to which it refers
+ ack_addr = segment.ack_addr()
+ ack_master_sha = segment.ack_master_sha()
+ ack_msg_seq_num = segment.ack_msg_seq_num()
+ if ack_addr in self._local_ips:
+ sent_key = (ack_msg_seq_num, ack_master_sha, 1)
+ if not self._sent.has_key(sent_key):
+ return
+ ack_key = (ack_msg_seq_num, ack_master_sha, ack_addr)
+ if not self._acks.has_key(ack_key):
+ print "Got ack for msg (%s %s) originally from %s" % (ack_msg_seq_num, ack_master_sha, ack_addr)
+ self._acks[ack_key] = time.time()
+
def set_drop_probability(self, prob=4):
"""Debugging function to randomly drop incoming packets.
The prob argument should be an integer between 1 and 10 to drop,
or 0 to drop none. Higher numbers drop more packets."""
+ if type(prob) != type(1):
+ raise ValueError("Drop probability must be an integer.")
+ if prob < 1 or prob > 10:
+ raise ValueError("Drop probability must be between 1 and 10 inclusive.")
self._drop_prob = prob
def _handle_incoming_data(self, source, condition):
@@ -665,6 +895,8 @@ class MostlyReliablePipe(object):
self._process_incoming_data(segment)
elif stype == SegmentBase.type_retransmit():
self._process_retransmit_request(segment)
+ elif stype == SegmentBase.type_ack():
+ self._process_incoming_ack(segment)
except ValueError, exc:
print "(MRP): Bad segment: %s" % exc
return True
@@ -693,12 +925,12 @@ class MostlyReliablePipe(object):
nmessages = length / mtu
if length % mtu > 0:
nmessages = nmessages + 1
- msg_num = 1
+ seg_num = 1
while left > 0:
- seg = DataSegment.new_from_parts(msg_num, nmessages,
+ seg = DataSegment.new_from_parts(seg_num, nmessages,
msg_seq, master_sha, data[:mtu])
self._outgoing.append(seg)
- msg_num = msg_num + 1
+ seg_num = seg_num + 1
data = data[mtu:]
left = left - mtu
self._schedule_send_worker()