diff options
author | Justin Gallardo <jirwin@suzy.(none)> | 2006-12-04 19:12:24 (GMT) |
---|---|---|
committer | Justin Gallardo <jirwin@suzy.(none)> | 2006-12-04 19:12:24 (GMT) |
commit | b9f9ef0fe9e36cf6e5de59700154b16f2dae15cd (patch) | |
tree | 3d5403ec73e993a78c5e92f8b14a5b86e8b6ae60 /sugar/p2p | |
parent | f5ae0662482de14f9d3812ddc4aba9be61024887 (diff) |
Changed all tabs to 4 spaces for python style
Diffstat (limited to 'sugar/p2p')
-rw-r--r-- | sugar/p2p/MostlyReliablePipe.py | 2546 | ||||
-rw-r--r-- | sugar/p2p/NotificationListener.py | 30 | ||||
-rw-r--r-- | sugar/p2p/Notifier.py | 14 | ||||
-rw-r--r-- | sugar/p2p/Stream.py | 242 | ||||
-rw-r--r-- | sugar/p2p/network.py | 574 |
5 files changed, 1703 insertions, 1703 deletions
diff --git a/sugar/p2p/MostlyReliablePipe.py b/sugar/p2p/MostlyReliablePipe.py index 4218181..604eada 100644 --- a/sugar/p2p/MostlyReliablePipe.py +++ b/sugar/p2p/MostlyReliablePipe.py @@ -32,952 +32,952 @@ import gobject def _stringify_sha(sha_hash): - print_sha = "" - for char in sha_hash: - print_sha = print_sha + binascii.b2a_hex(char) - return print_sha + print_sha = "" + for char in sha_hash: + print_sha = print_sha + binascii.b2a_hex(char) + return print_sha def _sha_data(data): - sha_hash = sha.new() - sha_hash.update(data) - return sha_hash.digest() + sha_hash = sha.new() + sha_hash.update(data) + return sha_hash.digest() _UDP_DATAGRAM_SIZE = 512 class SegmentBase(object): - _MAGIC = 0xbaea4304 - - # 4: magic (0xbaea4304) - # 1: type - # 2: segment number - # 2: total segments - # 2: message sequence number - #20: total data sha1 - _HEADER_TEMPLATE = "! IbHHH20s" - _HEADER_LEN = struct.calcsize(_HEADER_TEMPLATE) - _MTU = _UDP_DATAGRAM_SIZE - _HEADER_LEN - - # Message segment packet types - _SEGMENT_TYPE_DATA = 0 - _SEGMENT_TYPE_RETRANSMIT = 1 - _SEGMENT_TYPE_ACK = 2 - - def magic(): - return SegmentBase._MAGIC - magic = staticmethod(magic) - - def header_template(): - return SegmentBase._HEADER_TEMPLATE - header_template = staticmethod(header_template) - - def type_data(): - return SegmentBase._SEGMENT_TYPE_DATA - type_data = staticmethod(type_data) - - def type_retransmit(): - return SegmentBase._SEGMENT_TYPE_RETRANSMIT - type_retransmit = staticmethod(type_retransmit) - - def type_ack(): - return SegmentBase._SEGMENT_TYPE_ACK - type_ack = staticmethod(type_ack) - - def header_len(): - """Return the header size of SegmentBase packets.""" - return SegmentBase._HEADER_LEN - header_len = staticmethod(header_len) - - def mtu(): - """Return the SegmentBase packet MTU.""" - return SegmentBase._MTU - mtu = staticmethod(mtu) - - def __init__(self, segno, total_segs, msg_seq_num, master_sha): - self._type = None - self._transmits = 0 - self._last_transmit = 0 - self._data = None - self._data_len = 0 - self.userdata = None - self._stime = time.time() - self._addr = None - - # Sanity checks on the message attributes - if not segno or not isinstance(segno, int): - raise ValueError("Segment number must be in integer.") - if segno < 1 or segno > 65535: - raise ValueError("Segment number must be between 1 and 65535 inclusive.") - if not total_segs or not isinstance(total_segs, int): - raise ValueError("Message segment total must be an integer.") - if total_segs < 1 or total_segs > 65535: - raise ValueError("Message must have between 1 and 65535 segments inclusive.") - if segno > total_segs: - raise ValueError("Segment number cannot be larger than message segment total.") - if not msg_seq_num or not isinstance(msg_seq_num, int): - raise ValueError("Message sequnce number must be an integer.") - if msg_seq_num < 1 or msg_seq_num > 65535: - raise ValueError("Message sequence number must be between 1 and 65535 inclusive.") - if not master_sha or not isinstance(master_sha, str) or len(master_sha) != 20: - raise ValueError("Message SHA1 checksum invalid.") - - self._segno = segno - self._total_segs = total_segs - self._msg_seq_num = msg_seq_num - self._master_sha = master_sha - - def _validate_address(addr): - if not addr or not isinstance(addr, tuple): - raise ValueError("Address must be a tuple.") - if len(addr) != 2 or not isinstance(addr[0], str) or not isinstance(addr[1], int): - raise ValueError("Address format was invalid.") - if addr[1] < 1 or addr[1] > 65535: - raise ValueError("Address port was invalid.") - _validate_address = staticmethod(_validate_address) - - def new_from_data(addr, data): - """Static constructor for creation from a packed data stream.""" - SegmentBase._validate_address(addr) - - # Verify minimum length - if not data: - raise ValueError("Segment data is invalid.") - data_len = len(data) - if data_len < SegmentBase.header_len() + 1: - raise ValueError("Segment is less then minimum required length") - if data_len > _UDP_DATAGRAM_SIZE: - raise ValueError("Segment data is larger than allowed.") - stream = StringIO.StringIO(data) - - # Determine and verify the length of included data - stream.seek(0, 2) - data_len = stream.tell() - SegmentBase._HEADER_LEN - stream.seek(0) - - if data_len < 1: - raise ValueError("Segment must have some data.") - if data_len > SegmentBase._MTU: - raise ValueError("Data length must not be larger than the MTU (%s)." % SegmentBase._MTU) - - # Read the first header attributes - (magic, seg_type, segno, total_segs, msg_seq_num, master_sha) = struct.unpack(SegmentBase._HEADER_TEMPLATE, - stream.read(SegmentBase._HEADER_LEN)) - - # Sanity checks on the message attributes - if magic != SegmentBase._MAGIC: - raise ValueError("Segment does not have the correct magic.") - - # if the segment is the only one in the message, validate the data - if segno == 1 and total_segs == 1: - data_sha = _sha_data(stream.read(data_len)) - if data_sha != master_sha: - raise ValueError("Single segment message SHA checksums didn't match.") - stream.seek(SegmentBase._HEADER_LEN) - - if seg_type == SegmentBase._SEGMENT_TYPE_DATA: - segment = DataSegment(segno, total_segs, msg_seq_num, master_sha) - elif seg_type == SegmentBase._SEGMENT_TYPE_RETRANSMIT: - segment = RetransmitSegment(segno, total_segs, msg_seq_num, master_sha) - elif seg_type == SegmentBase._SEGMENT_TYPE_ACK: - segment = AckSegment(segno, total_segs, msg_seq_num, master_sha) - else: - raise ValueError("Segment has invalid type.") - - # Segment specific data interpretation - segment._addr = addr - segment._unpack_data(stream, data_len) - - return segment - new_from_data = staticmethod(new_from_data) - - def stime(self): - return self._stime - - def address(self): - return self._addr - - def segment_number(self): - return self._segno - - def total_segments(self): - return self._total_segs - - def message_sequence_number(self): - return self._msg_seq_num - - def data(self): - return self._data - - def master_sha(self): - return self._master_sha - - def segment_type(self): - return self._type - - def packetize(self): - """Return a correctly formatted message that can be immediately sent.""" - header = struct.pack(self._HEADER_TEMPLATE, self._MAGIC, self._type, - self._segno, self._total_segs, self._msg_seq_num, self._master_sha) - return header + self._data - - def transmits(self): - return self._transmits - - def inc_transmits(self): - self._transmits = self._transmits + 1 - self._last_transmit = time.time() - - def last_transmit(self): - return self._last_transmit + _MAGIC = 0xbaea4304 + + # 4: magic (0xbaea4304) + # 1: type + # 2: segment number + # 2: total segments + # 2: message sequence number + #20: total data sha1 + _HEADER_TEMPLATE = "! IbHHH20s" + _HEADER_LEN = struct.calcsize(_HEADER_TEMPLATE) + _MTU = _UDP_DATAGRAM_SIZE - _HEADER_LEN + + # Message segment packet types + _SEGMENT_TYPE_DATA = 0 + _SEGMENT_TYPE_RETRANSMIT = 1 + _SEGMENT_TYPE_ACK = 2 + + def magic(): + return SegmentBase._MAGIC + magic = staticmethod(magic) + + def header_template(): + return SegmentBase._HEADER_TEMPLATE + header_template = staticmethod(header_template) + + def type_data(): + return SegmentBase._SEGMENT_TYPE_DATA + type_data = staticmethod(type_data) + + def type_retransmit(): + return SegmentBase._SEGMENT_TYPE_RETRANSMIT + type_retransmit = staticmethod(type_retransmit) + + def type_ack(): + return SegmentBase._SEGMENT_TYPE_ACK + type_ack = staticmethod(type_ack) + + def header_len(): + """Return the header size of SegmentBase packets.""" + return SegmentBase._HEADER_LEN + header_len = staticmethod(header_len) + + def mtu(): + """Return the SegmentBase packet MTU.""" + return SegmentBase._MTU + mtu = staticmethod(mtu) + + def __init__(self, segno, total_segs, msg_seq_num, master_sha): + self._type = None + self._transmits = 0 + self._last_transmit = 0 + self._data = None + self._data_len = 0 + self.userdata = None + self._stime = time.time() + self._addr = None + + # Sanity checks on the message attributes + if not segno or not isinstance(segno, int): + raise ValueError("Segment number must be in integer.") + if segno < 1 or segno > 65535: + raise ValueError("Segment number must be between 1 and 65535 inclusive.") + if not total_segs or not isinstance(total_segs, int): + raise ValueError("Message segment total must be an integer.") + if total_segs < 1 or total_segs > 65535: + raise ValueError("Message must have between 1 and 65535 segments inclusive.") + if segno > total_segs: + raise ValueError("Segment number cannot be larger than message segment total.") + if not msg_seq_num or not isinstance(msg_seq_num, int): + raise ValueError("Message sequnce number must be an integer.") + if msg_seq_num < 1 or msg_seq_num > 65535: + raise ValueError("Message sequence number must be between 1 and 65535 inclusive.") + if not master_sha or not isinstance(master_sha, str) or len(master_sha) != 20: + raise ValueError("Message SHA1 checksum invalid.") + + self._segno = segno + self._total_segs = total_segs + self._msg_seq_num = msg_seq_num + self._master_sha = master_sha + + def _validate_address(addr): + if not addr or not isinstance(addr, tuple): + raise ValueError("Address must be a tuple.") + if len(addr) != 2 or not isinstance(addr[0], str) or not isinstance(addr[1], int): + raise ValueError("Address format was invalid.") + if addr[1] < 1 or addr[1] > 65535: + raise ValueError("Address port was invalid.") + _validate_address = staticmethod(_validate_address) + + def new_from_data(addr, data): + """Static constructor for creation from a packed data stream.""" + SegmentBase._validate_address(addr) + + # Verify minimum length + if not data: + raise ValueError("Segment data is invalid.") + data_len = len(data) + if data_len < SegmentBase.header_len() + 1: + raise ValueError("Segment is less then minimum required length") + if data_len > _UDP_DATAGRAM_SIZE: + raise ValueError("Segment data is larger than allowed.") + stream = StringIO.StringIO(data) + + # Determine and verify the length of included data + stream.seek(0, 2) + data_len = stream.tell() - SegmentBase._HEADER_LEN + stream.seek(0) + + if data_len < 1: + raise ValueError("Segment must have some data.") + if data_len > SegmentBase._MTU: + raise ValueError("Data length must not be larger than the MTU (%s)." % SegmentBase._MTU) + + # Read the first header attributes + (magic, seg_type, segno, total_segs, msg_seq_num, master_sha) = struct.unpack(SegmentBase._HEADER_TEMPLATE, + stream.read(SegmentBase._HEADER_LEN)) + + # Sanity checks on the message attributes + if magic != SegmentBase._MAGIC: + raise ValueError("Segment does not have the correct magic.") + + # if the segment is the only one in the message, validate the data + if segno == 1 and total_segs == 1: + data_sha = _sha_data(stream.read(data_len)) + if data_sha != master_sha: + raise ValueError("Single segment message SHA checksums didn't match.") + stream.seek(SegmentBase._HEADER_LEN) + + if seg_type == SegmentBase._SEGMENT_TYPE_DATA: + segment = DataSegment(segno, total_segs, msg_seq_num, master_sha) + elif seg_type == SegmentBase._SEGMENT_TYPE_RETRANSMIT: + segment = RetransmitSegment(segno, total_segs, msg_seq_num, master_sha) + elif seg_type == SegmentBase._SEGMENT_TYPE_ACK: + segment = AckSegment(segno, total_segs, msg_seq_num, master_sha) + else: + raise ValueError("Segment has invalid type.") + + # Segment specific data interpretation + segment._addr = addr + segment._unpack_data(stream, data_len) + + return segment + new_from_data = staticmethod(new_from_data) + + def stime(self): + return self._stime + + def address(self): + return self._addr + + def segment_number(self): + return self._segno + + def total_segments(self): + return self._total_segs + + def message_sequence_number(self): + return self._msg_seq_num + + def data(self): + return self._data + + def master_sha(self): + return self._master_sha + + def segment_type(self): + return self._type + + def packetize(self): + """Return a correctly formatted message that can be immediately sent.""" + header = struct.pack(self._HEADER_TEMPLATE, self._MAGIC, self._type, + self._segno, self._total_segs, self._msg_seq_num, self._master_sha) + return header + self._data + + def transmits(self): + return self._transmits + + def inc_transmits(self): + self._transmits = self._transmits + 1 + self._last_transmit = time.time() + + def last_transmit(self): + return self._last_transmit class DataSegment(SegmentBase): - """A message segment that encapsulates random data.""" - - def __init__(self, segno, total_segs, msg_seq_num, master_sha): - SegmentBase.__init__(self, segno, total_segs, msg_seq_num, master_sha) - self._type = SegmentBase._SEGMENT_TYPE_DATA - - def _get_template_for_len(length): - return "! %ds" % length - _get_template_for_len = staticmethod(_get_template_for_len) - - def _unpack_data(self, stream, data_len): - """Unpack the data stream, called by constructor.""" - self._data_len = data_len - template = DataSegment._get_template_for_len(self._data_len) - self._data = struct.unpack(template, stream.read(self._data_len))[0] - - def new_from_parts(segno, total_segs, msg_seq_num, master_sha, data): - """Construct a new message segment from individual attributes.""" - if not data: - raise ValueError("Must have valid data.") - segment = DataSegment(segno, total_segs, msg_seq_num, master_sha) - segment._data_len = len(data) - template = DataSegment._get_template_for_len(segment._data_len) - segment._data = struct.pack(template, data) - return segment - new_from_parts = staticmethod(new_from_parts) + """A message segment that encapsulates random data.""" + + def __init__(self, segno, total_segs, msg_seq_num, master_sha): + SegmentBase.__init__(self, segno, total_segs, msg_seq_num, master_sha) + self._type = SegmentBase._SEGMENT_TYPE_DATA + + def _get_template_for_len(length): + return "! %ds" % length + _get_template_for_len = staticmethod(_get_template_for_len) + + def _unpack_data(self, stream, data_len): + """Unpack the data stream, called by constructor.""" + self._data_len = data_len + template = DataSegment._get_template_for_len(self._data_len) + self._data = struct.unpack(template, stream.read(self._data_len))[0] + + def new_from_parts(segno, total_segs, msg_seq_num, master_sha, data): + """Construct a new message segment from individual attributes.""" + if not data: + raise ValueError("Must have valid data.") + segment = DataSegment(segno, total_segs, msg_seq_num, master_sha) + segment._data_len = len(data) + template = DataSegment._get_template_for_len(segment._data_len) + segment._data = struct.pack(template, data) + return segment + new_from_parts = staticmethod(new_from_parts) class RetransmitSegment(SegmentBase): - """A message segment that encapsulates a retransmission request.""" - - # Retransmission data format: - # 2: message sequence number - # 20: total data sha1 - # 2: segment number - _RT_DATA_TEMPLATE = "! H20sH" - _RT_DATA_LEN = struct.calcsize(_RT_DATA_TEMPLATE) - - def data_template(): - return RetransmitSegment._RT_DATA_TEMPLATE - data_template = staticmethod(data_template) - - def __init__(self, segno, total_segs, msg_seq_num, master_sha): - """Should not be called directly.""" - if segno != 1 or total_segs != 1: - raise ValueError("Retransmission request messages must have only one segment.") - - SegmentBase.__init__(self, segno, total_segs, msg_seq_num, master_sha) - self._type = SegmentBase._SEGMENT_TYPE_RETRANSMIT - - def _verify_data(rt_msg_seq_num, rt_master_sha, rt_segment_number): - # Sanity checks on the message attributes - if not rt_segment_number or not isinstance(rt_segment_number, int): - raise ValueError("RT Segment number must be in integer.") - if rt_segment_number < 1 or rt_segment_number > 65535: - raise ValueError("RT Segment number must be between 1 and 65535 inclusive.") - if not rt_msg_seq_num or not isinstance(rt_msg_seq_num, int): - raise ValueError("RT Message sequnce number must be an integer.") - if rt_msg_seq_num < 1 or rt_msg_seq_num > 65535: - raise ValueError("RT Message sequence number must be between 1 and 65535 inclusive.") - if not rt_master_sha or not isinstance(rt_master_sha, str) or len(rt_master_sha) != 20: - raise ValueError("RT Message SHA1 checksum invalid.") - _verify_data = staticmethod(_verify_data) - - def _make_rtms_data(rt_msg_seq_num, rt_master_sha, rt_segment_number): - """Pack retransmission request payload.""" - data = struct.pack(RetransmitSegment._RT_DATA_TEMPLATE, rt_msg_seq_num, - rt_master_sha, rt_segment_number) - return (data, _sha_data(data)) - _make_rtms_data = staticmethod(_make_rtms_data) - - def new_from_parts(addr, msg_seq_num, rt_msg_seq_num, rt_master_sha, rt_segment_number): - """Static constructor for creation from individual attributes.""" - - RetransmitSegment._verify_data(rt_msg_seq_num, rt_master_sha, rt_segment_number) - (data, data_sha) = RetransmitSegment._make_rtms_data(rt_msg_seq_num, - rt_master_sha, rt_segment_number) - segment = RetransmitSegment(1, 1, msg_seq_num, data_sha) - segment._data_len = RetransmitSegment._RT_DATA_LEN - segment._data = data - SegmentBase._validate_address(addr) - segment._addr = addr - - segment._rt_msg_seq_num = rt_msg_seq_num - segment._rt_master_sha = rt_master_sha - segment._rt_segment_number = rt_segment_number - return segment - new_from_parts = staticmethod(new_from_parts) - - def _unpack_data(self, stream, data_len): - if data_len != self._RT_DATA_LEN: - raise ValueError("Retransmission request data had invalid length.") - data = stream.read(data_len) - (rt_msg_seq_num, rt_master_sha, rt_seg_no) = struct.unpack(self._RT_DATA_TEMPLATE, data) - RetransmitSegment._verify_data(rt_msg_seq_num, rt_master_sha, rt_seg_no) - - self._data = data - self._data_len = data_len - self._rt_msg_seq_num = rt_msg_seq_num - self._rt_master_sha = rt_master_sha - self._rt_segment_number = rt_seg_no - - def rt_msg_seq_num(self): - return self._rt_msg_seq_num - - def rt_master_sha(self): - return self._rt_master_sha - - def rt_segment_number(self): - return self._rt_segment_number + """A message segment that encapsulates a retransmission request.""" + + # Retransmission data format: + # 2: message sequence number + # 20: total data sha1 + # 2: segment number + _RT_DATA_TEMPLATE = "! H20sH" + _RT_DATA_LEN = struct.calcsize(_RT_DATA_TEMPLATE) + + def data_template(): + return RetransmitSegment._RT_DATA_TEMPLATE + data_template = staticmethod(data_template) + + def __init__(self, segno, total_segs, msg_seq_num, master_sha): + """Should not be called directly.""" + if segno != 1 or total_segs != 1: + raise ValueError("Retransmission request messages must have only one segment.") + + SegmentBase.__init__(self, segno, total_segs, msg_seq_num, master_sha) + self._type = SegmentBase._SEGMENT_TYPE_RETRANSMIT + + def _verify_data(rt_msg_seq_num, rt_master_sha, rt_segment_number): + # Sanity checks on the message attributes + if not rt_segment_number or not isinstance(rt_segment_number, int): + raise ValueError("RT Segment number must be in integer.") + if rt_segment_number < 1 or rt_segment_number > 65535: + raise ValueError("RT Segment number must be between 1 and 65535 inclusive.") + if not rt_msg_seq_num or not isinstance(rt_msg_seq_num, int): + raise ValueError("RT Message sequnce number must be an integer.") + if rt_msg_seq_num < 1 or rt_msg_seq_num > 65535: + raise ValueError("RT Message sequence number must be between 1 and 65535 inclusive.") + if not rt_master_sha or not isinstance(rt_master_sha, str) or len(rt_master_sha) != 20: + raise ValueError("RT Message SHA1 checksum invalid.") + _verify_data = staticmethod(_verify_data) + + def _make_rtms_data(rt_msg_seq_num, rt_master_sha, rt_segment_number): + """Pack retransmission request payload.""" + data = struct.pack(RetransmitSegment._RT_DATA_TEMPLATE, rt_msg_seq_num, + rt_master_sha, rt_segment_number) + return (data, _sha_data(data)) + _make_rtms_data = staticmethod(_make_rtms_data) + + def new_from_parts(addr, msg_seq_num, rt_msg_seq_num, rt_master_sha, rt_segment_number): + """Static constructor for creation from individual attributes.""" + + RetransmitSegment._verify_data(rt_msg_seq_num, rt_master_sha, rt_segment_number) + (data, data_sha) = RetransmitSegment._make_rtms_data(rt_msg_seq_num, + rt_master_sha, rt_segment_number) + segment = RetransmitSegment(1, 1, msg_seq_num, data_sha) + segment._data_len = RetransmitSegment._RT_DATA_LEN + segment._data = data + SegmentBase._validate_address(addr) + segment._addr = addr + + segment._rt_msg_seq_num = rt_msg_seq_num + segment._rt_master_sha = rt_master_sha + segment._rt_segment_number = rt_segment_number + return segment + new_from_parts = staticmethod(new_from_parts) + + def _unpack_data(self, stream, data_len): + if data_len != self._RT_DATA_LEN: + raise ValueError("Retransmission request data had invalid length.") + data = stream.read(data_len) + (rt_msg_seq_num, rt_master_sha, rt_seg_no) = struct.unpack(self._RT_DATA_TEMPLATE, data) + RetransmitSegment._verify_data(rt_msg_seq_num, rt_master_sha, rt_seg_no) + + self._data = data + self._data_len = data_len + self._rt_msg_seq_num = rt_msg_seq_num + self._rt_master_sha = rt_master_sha + self._rt_segment_number = rt_seg_no + + def rt_msg_seq_num(self): + return self._rt_msg_seq_num + + def rt_master_sha(self): + return self._rt_master_sha + + def rt_segment_number(self): + return self._rt_segment_number class AckSegment(SegmentBase): - """A message segment that encapsulates a message acknowledgement.""" - - # Ack data format: - # 2: acked message sequence number - # 20: acked message total data sha1 - # 4: acked message source IP address - _ACK_DATA_TEMPLATE = "! H20s4s" - _ACK_DATA_LEN = struct.calcsize(_ACK_DATA_TEMPLATE) - - def data_template(): - return AckSegment._ACK_DATA_TEMPLATE - data_template = staticmethod(data_template) - - def __init__(self, segno, total_segs, msg_seq_num, master_sha): - """Should not be called directly.""" - if segno != 1 or total_segs != 1: - raise ValueError("Acknowledgement messages must have only one segment.") - - SegmentBase.__init__(self, segno, total_segs, msg_seq_num, master_sha) - self._type = SegmentBase._SEGMENT_TYPE_ACK - - def _verify_data(ack_msg_seq_num, ack_master_sha, ack_addr): - # Sanity checks on the message attributes - if not ack_msg_seq_num or not isinstance(ack_msg_seq_num, int): - raise ValueError("Ack message sequnce number must be an integer.") - if ack_msg_seq_num < 1 or ack_msg_seq_num > 65535: - raise ValueError("Ack message sequence number must be between 1 and 65535 inclusive.") - if not ack_master_sha or not isinstance(ack_master_sha, str) or len(ack_master_sha) != 20: - raise ValueError("Ack message SHA1 checksum invalid.") - if not isinstance(ack_addr, str): - raise ValueError("Ack message invalid address type.") - try: - foo = socket.inet_aton(ack_addr) - except socket.error: - raise ValueError("Ack message invalid address.") - _verify_data = staticmethod(_verify_data) - - def _make_ack_data(ack_msg_seq_num, ack_master_sha, ack_addr): - """Pack an ack payload.""" - addr_data = socket.inet_aton(ack_addr) - data = struct.pack(AckSegment._ACK_DATA_TEMPLATE, ack_msg_seq_num, - ack_master_sha, addr_data) - return (data, _sha_data(data)) - _make_ack_data = staticmethod(_make_ack_data) - - def new_from_parts(addr, msg_seq_num, ack_msg_seq_num, ack_master_sha, ack_addr): - """Static constructor for creation from individual attributes.""" - - AckSegment._verify_data(ack_msg_seq_num, ack_master_sha, ack_addr) - (data, data_sha) = AckSegment._make_ack_data(ack_msg_seq_num, - ack_master_sha, ack_addr) - segment = AckSegment(1, 1, msg_seq_num, data_sha) - segment._data_len = AckSegment._ACK_DATA_LEN - segment._data = data - SegmentBase._validate_address(addr) - segment._addr = addr - - segment._ack_msg_seq_num = ack_msg_seq_num - segment._ack_master_sha = ack_master_sha - segment._ack_addr = ack_addr - return segment - new_from_parts = staticmethod(new_from_parts) - - def _unpack_data(self, stream, data_len): - if data_len != self._ACK_DATA_LEN: - raise ValueError("Ack segment data had invalid length.") - data = stream.read(data_len) - (ack_msg_seq_num, ack_master_sha, ack_addr_data) = struct.unpack(self._ACK_DATA_TEMPLATE, data) - try: - ack_addr = socket.inet_ntoa(ack_addr_data) - except socket.error: - raise ValueError("Ack segment data had invalid address.") - AckSegment._verify_data(ack_msg_seq_num, ack_master_sha, ack_addr) - - self._data = data - self._data_len = data_len - self._ack_msg_seq_num = ack_msg_seq_num - self._ack_master_sha = ack_master_sha - self._ack_addr = ack_addr - - def ack_msg_seq_num(self): - return self._ack_msg_seq_num - - def ack_master_sha(self): - return self._ack_master_sha - - def ack_addr(self): - return self._ack_addr + """A message segment that encapsulates a message acknowledgement.""" + + # Ack data format: + # 2: acked message sequence number + # 20: acked message total data sha1 + # 4: acked message source IP address + _ACK_DATA_TEMPLATE = "! H20s4s" + _ACK_DATA_LEN = struct.calcsize(_ACK_DATA_TEMPLATE) + + def data_template(): + return AckSegment._ACK_DATA_TEMPLATE + data_template = staticmethod(data_template) + + def __init__(self, segno, total_segs, msg_seq_num, master_sha): + """Should not be called directly.""" + if segno != 1 or total_segs != 1: + raise ValueError("Acknowledgement messages must have only one segment.") + + SegmentBase.__init__(self, segno, total_segs, msg_seq_num, master_sha) + self._type = SegmentBase._SEGMENT_TYPE_ACK + + def _verify_data(ack_msg_seq_num, ack_master_sha, ack_addr): + # Sanity checks on the message attributes + if not ack_msg_seq_num or not isinstance(ack_msg_seq_num, int): + raise ValueError("Ack message sequnce number must be an integer.") + if ack_msg_seq_num < 1 or ack_msg_seq_num > 65535: + raise ValueError("Ack message sequence number must be between 1 and 65535 inclusive.") + if not ack_master_sha or not isinstance(ack_master_sha, str) or len(ack_master_sha) != 20: + raise ValueError("Ack message SHA1 checksum invalid.") + if not isinstance(ack_addr, str): + raise ValueError("Ack message invalid address type.") + try: + foo = socket.inet_aton(ack_addr) + except socket.error: + raise ValueError("Ack message invalid address.") + _verify_data = staticmethod(_verify_data) + + def _make_ack_data(ack_msg_seq_num, ack_master_sha, ack_addr): + """Pack an ack payload.""" + addr_data = socket.inet_aton(ack_addr) + data = struct.pack(AckSegment._ACK_DATA_TEMPLATE, ack_msg_seq_num, + ack_master_sha, addr_data) + return (data, _sha_data(data)) + _make_ack_data = staticmethod(_make_ack_data) + + def new_from_parts(addr, msg_seq_num, ack_msg_seq_num, ack_master_sha, ack_addr): + """Static constructor for creation from individual attributes.""" + + AckSegment._verify_data(ack_msg_seq_num, ack_master_sha, ack_addr) + (data, data_sha) = AckSegment._make_ack_data(ack_msg_seq_num, + ack_master_sha, ack_addr) + segment = AckSegment(1, 1, msg_seq_num, data_sha) + segment._data_len = AckSegment._ACK_DATA_LEN + segment._data = data + SegmentBase._validate_address(addr) + segment._addr = addr + + segment._ack_msg_seq_num = ack_msg_seq_num + segment._ack_master_sha = ack_master_sha + segment._ack_addr = ack_addr + return segment + new_from_parts = staticmethod(new_from_parts) + + def _unpack_data(self, stream, data_len): + if data_len != self._ACK_DATA_LEN: + raise ValueError("Ack segment data had invalid length.") + data = stream.read(data_len) + (ack_msg_seq_num, ack_master_sha, ack_addr_data) = struct.unpack(self._ACK_DATA_TEMPLATE, data) + try: + ack_addr = socket.inet_ntoa(ack_addr_data) + except socket.error: + raise ValueError("Ack segment data had invalid address.") + AckSegment._verify_data(ack_msg_seq_num, ack_master_sha, ack_addr) + + self._data = data + self._data_len = data_len + self._ack_msg_seq_num = ack_msg_seq_num + self._ack_master_sha = ack_master_sha + self._ack_addr = ack_addr + + def ack_msg_seq_num(self): + return self._ack_msg_seq_num + + def ack_master_sha(self): + return self._ack_master_sha + + def ack_addr(self): + return self._ack_addr class Message(object): - """Tracks an entire message object, which is composed of a number - of individual segments.""" - def __init__(self, src_addr, msg_seq_num, msg_sha, total_segments): - self._rt_target = 0 - self._next_rt_time = 0 - self._last_incoming_time = 0 - self._segments = {} - self._complete = False - self._dispatched_time = 0 - self._data = None - self._data_sha = None - self._src_addr = src_addr - self._msg_seq_num = msg_seq_num - self._msg_sha = msg_sha - self._total_segments = total_segments - self._rt_tries = {} - for i in range(1, self._total_segments + 1): - self._rt_tries[i] = 0 - - def __del__(self): - self.clear() - - def sha(self): - return self._msg_sha - - def source_address(self): - return self._src_addr - - def clear(self): - for key in self._segments.keys()[:]: - del self._segments[key] - del self._rt_tries[key] - self._segments = {} - self._rt_tries = {} - - def has_segment(self, segno): - return self._segments.has_key(segno) - - def first_missing(self): - for i in range(1, self._total_segments + 1): - if not self._segments.has_key(i): - return i - return 0 - - _DEF_RT_REQUEST_INTERVAL = 0.09 # 70ms (in seconds) - def update_rt_wait(self, now): - """now argument should be in seconds.""" - wait = self._DEF_RT_REQUEST_INTERVAL - if self._last_incoming_time > now - 0.02: - msg_completeness = float(len(self._segments)) / float(self._total_segments) - wait = wait + (self._DEF_RT_REQUEST_INTERVAL * (1.0 - msg_completeness)) - self._next_rt_time = now + wait - - def add_segment(self, segment): - if self.complete(): - return - segno = segment.segment_number() - if self._segments.has_key(segno): - return - self._segments[segno] = segment - self._rt_tries[segno] = 0 - now = time.time() - self._last_incoming_time = now - - num_segs = len(self._segments) - if num_segs == self._total_segments: - self._complete = True - self._next_rt_time = 0 - self._data = '' - for seg in self._segments.values(): - self._data = self._data + seg.data() - self._data_sha = _sha_data(self._data) - elif segno == num_segs or num_segs == 1: - # If we're not missing segments, push back retransmit request - self.update_rt_wait(now) - - def get_retransmit_message(self, msg_seq_num, segno): - if segno < 1 or segno > self._total_segments: - return None - seg = RetransmitSegment.new_from_parts(self._src_addr, msg_seq_num, - self._msg_seq_num, self._msg_sha, segno) - self._rt_tries[segno] = self._rt_tries[segno] + 1 - self.update_rt_wait(time.time()) - return seg - - def complete(self): - return self._complete - - def dispatch_time(self): - return self._dispatch_time - - def set_dispatch_time(self): - self._dispatch_time = time.time() - - def data(self): - return (self._data, self._data_sha) - - def last_incoming_time(self): - return self._last_incoming_time - - def next_rt_time(self): - return self._next_rt_time - - def rt_tries(self, segno): - if self._rt_tries.has_key(segno): - return self._rt_tries[segno] - return 0 + """Tracks an entire message object, which is composed of a number + of individual segments.""" + def __init__(self, src_addr, msg_seq_num, msg_sha, total_segments): + self._rt_target = 0 + self._next_rt_time = 0 + self._last_incoming_time = 0 + self._segments = {} + self._complete = False + self._dispatched_time = 0 + self._data = None + self._data_sha = None + self._src_addr = src_addr + self._msg_seq_num = msg_seq_num + self._msg_sha = msg_sha + self._total_segments = total_segments + self._rt_tries = {} + for i in range(1, self._total_segments + 1): + self._rt_tries[i] = 0 + + def __del__(self): + self.clear() + + def sha(self): + return self._msg_sha + + def source_address(self): + return self._src_addr + + def clear(self): + for key in self._segments.keys()[:]: + del self._segments[key] + del self._rt_tries[key] + self._segments = {} + self._rt_tries = {} + + def has_segment(self, segno): + return self._segments.has_key(segno) + + def first_missing(self): + for i in range(1, self._total_segments + 1): + if not self._segments.has_key(i): + return i + return 0 + + _DEF_RT_REQUEST_INTERVAL = 0.09 # 70ms (in seconds) + def update_rt_wait(self, now): + """now argument should be in seconds.""" + wait = self._DEF_RT_REQUEST_INTERVAL + if self._last_incoming_time > now - 0.02: + msg_completeness = float(len(self._segments)) / float(self._total_segments) + wait = wait + (self._DEF_RT_REQUEST_INTERVAL * (1.0 - msg_completeness)) + self._next_rt_time = now + wait + + def add_segment(self, segment): + if self.complete(): + return + segno = segment.segment_number() + if self._segments.has_key(segno): + return + self._segments[segno] = segment + self._rt_tries[segno] = 0 + now = time.time() + self._last_incoming_time = now + + num_segs = len(self._segments) + if num_segs == self._total_segments: + self._complete = True + self._next_rt_time = 0 + self._data = '' + for seg in self._segments.values(): + self._data = self._data + seg.data() + self._data_sha = _sha_data(self._data) + elif segno == num_segs or num_segs == 1: + # If we're not missing segments, push back retransmit request + self.update_rt_wait(now) + + def get_retransmit_message(self, msg_seq_num, segno): + if segno < 1 or segno > self._total_segments: + return None + seg = RetransmitSegment.new_from_parts(self._src_addr, msg_seq_num, + self._msg_seq_num, self._msg_sha, segno) + self._rt_tries[segno] = self._rt_tries[segno] + 1 + self.update_rt_wait(time.time()) + return seg + + def complete(self): + return self._complete + + def dispatch_time(self): + return self._dispatch_time + + def set_dispatch_time(self): + self._dispatch_time = time.time() + + def data(self): + return (self._data, self._data_sha) + + def last_incoming_time(self): + return self._last_incoming_time + + def next_rt_time(self): + return self._next_rt_time + + def rt_tries(self, segno): + if self._rt_tries.has_key(segno): + return self._rt_tries[segno] + return 0 def _get_local_interfaces(): - import array - import struct - import fcntl - import socket + import array + import struct + import fcntl + import socket - max_possible = 4 - bytes = max_possible * 32 - SIOCGIFCONF = 0x8912 - names = array.array('B', '\0' * bytes) + max_possible = 4 + bytes = max_possible * 32 + SIOCGIFCONF = 0x8912 + names = array.array('B', '\0' * bytes) - sockfd = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - ifreq = struct.pack('iL', bytes, names.buffer_info()[0]) - result = fcntl.ioctl(sockfd.fileno(), SIOCGIFCONF, ifreq) - sockfd.close() + sockfd = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + ifreq = struct.pack('iL', bytes, names.buffer_info()[0]) + result = fcntl.ioctl(sockfd.fileno(), SIOCGIFCONF, ifreq) + sockfd.close() - outbytes = struct.unpack('iL', result)[0] - namestr = names.tostring() + outbytes = struct.unpack('iL', result)[0] + namestr = names.tostring() - return [namestr[i:i+32].split('\0', 1)[0] for i in range(0, outbytes, 32)] + return [namestr[i:i+32].split('\0', 1)[0] for i in range(0, outbytes, 32)] def _get_local_ip_addresses(): - """Call Linux specific bits to retrieve our own IP address.""" - import socket - import sys - import fcntl - import struct - - intfs = _get_local_interfaces() - - ips = [] - SIOCGIFADDR = 0x8915 - sockfd = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - for intf in intfs: - if intf == "lo": - continue - try: - ifreq = (intf + '\0'*32)[:32] - result = fcntl.ioctl(sockfd.fileno(), SIOCGIFADDR, ifreq) - addr = socket.inet_ntoa(result[20:24]) - ips.append(addr) - except IOError, exc: - print "Error getting IP address: %s" % exc - sockfd.close() - return ips + """Call Linux specific bits to retrieve our own IP address.""" + import socket + import sys + import fcntl + import struct + + intfs = _get_local_interfaces() + + ips = [] + SIOCGIFADDR = 0x8915 + sockfd = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + for intf in intfs: + if intf == "lo": + continue + try: + ifreq = (intf + '\0'*32)[:32] + result = fcntl.ioctl(sockfd.fileno(), SIOCGIFADDR, ifreq) + addr = socket.inet_ntoa(result[20:24]) + ips.append(addr) + except IOError, exc: + print "Error getting IP address: %s" % exc + sockfd.close() + return ips class MostlyReliablePipe(object): - """Implement Mostly-Reliable UDP. We don't actually care about guaranteeing - delivery or receipt, just a better effort than no effort at all.""" - - _UDP_MSG_SIZE = SegmentBase.mtu() + SegmentBase.header_len() - _SEGMENT_TTL = 120 # 2 minutes - - def __init__(self, local_addr, remote_addr, port, data_cb, user_data=None): - self._local_addr = local_addr - self._remote_addr = remote_addr - self._port = port - self._data_cb = data_cb - self._user_data = user_data - self._started = False - self._send_worker = 0 - self._seq_counter = 0 - self._drop_prob = 0 - self._rt_check_worker_id = 0 - - self._outgoing = [] - self._sent = {} - - self._incoming = {} # (message sha, # of segments) -> [segment1, segment2, ...] - self._dispatched = {} - self._acks = {} # (message sequence #, master sha, source addr) -> received timestamp - self._ack_check_worker_id = 0 - - self._local_ips = _get_local_ip_addresses() - - self._setup_listener() - self._setup_sender() - - def __del__(self): - if self._send_worker > 0: - gobject.source_remove(self._send_worker) - self._send_worker = 0 - if self._rt_check_worker_id > 0: - gobject.source_remove(self._rt_check_worker_id) - self._rt_check_worker_id = 0 - if self._ack_check_worker_id > 0: - gobject.source_remove(self._ack_check_worker_id) - self._ack_check_worker_id = 0 - - def _setup_sender(self): - """Setup the send socket for multicast.""" - self._send_sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - # Make the socket multicast-aware, and set TTL. - self._send_sock.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, 20) # Change TTL (=20) to suit - - def _setup_listener(self): - """Set up the listener socket for multicast traffic.""" - # Listener socket - self._listen_sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - - # Set some options to make it multicast-friendly - self._listen_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - self._listen_sock.setsockopt(socket.SOL_IP, socket.IP_MULTICAST_TTL, 20) - self._listen_sock.setsockopt(socket.SOL_IP, socket.IP_MULTICAST_LOOP, 1) - - def start(self): - """Let the listener socket start listening for network data.""" - # Set some more multicast options - self._listen_sock.bind((self._local_addr, self._port)) - self._listen_sock.settimeout(2) + """Implement Mostly-Reliable UDP. We don't actually care about guaranteeing + delivery or receipt, just a better effort than no effort at all.""" + + _UDP_MSG_SIZE = SegmentBase.mtu() + SegmentBase.header_len() + _SEGMENT_TTL = 120 # 2 minutes + + def __init__(self, local_addr, remote_addr, port, data_cb, user_data=None): + self._local_addr = local_addr + self._remote_addr = remote_addr + self._port = port + self._data_cb = data_cb + self._user_data = user_data + self._started = False + self._send_worker = 0 + self._seq_counter = 0 + self._drop_prob = 0 + self._rt_check_worker_id = 0 + + self._outgoing = [] + self._sent = {} + + self._incoming = {} # (message sha, # of segments) -> [segment1, segment2, ...] + self._dispatched = {} + self._acks = {} # (message sequence #, master sha, source addr) -> received timestamp + self._ack_check_worker_id = 0 + + self._local_ips = _get_local_ip_addresses() + + self._setup_listener() + self._setup_sender() + + def __del__(self): + if self._send_worker > 0: + gobject.source_remove(self._send_worker) + self._send_worker = 0 + if self._rt_check_worker_id > 0: + gobject.source_remove(self._rt_check_worker_id) + self._rt_check_worker_id = 0 + if self._ack_check_worker_id > 0: + gobject.source_remove(self._ack_check_worker_id) + self._ack_check_worker_id = 0 + + def _setup_sender(self): + """Setup the send socket for multicast.""" + self._send_sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + # Make the socket multicast-aware, and set TTL. + self._send_sock.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, 20) # Change TTL (=20) to suit + + def _setup_listener(self): + """Set up the listener socket for multicast traffic.""" + # Listener socket + self._listen_sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + + # Set some options to make it multicast-friendly + self._listen_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + self._listen_sock.setsockopt(socket.SOL_IP, socket.IP_MULTICAST_TTL, 20) + self._listen_sock.setsockopt(socket.SOL_IP, socket.IP_MULTICAST_LOOP, 1) + + def start(self): + """Let the listener socket start listening for network data.""" + # Set some more multicast options + self._listen_sock.bind((self._local_addr, self._port)) + self._listen_sock.settimeout(2) # Disable for now to try to fix "cannot assign requested address" errors -# intf = socket.gethostbyname(socket.gethostname()) -# self._listen_sock.setsockopt(socket.SOL_IP, socket.IP_MULTICAST_IF, -# socket.inet_aton(intf) + socket.inet_aton('0.0.0.0')) - self._listen_sock.setsockopt(socket.SOL_IP, socket.IP_ADD_MEMBERSHIP, - socket.inet_aton(self._remote_addr) + socket.inet_aton('0.0.0.0')) - - # Watch the listener socket for data - gobject.io_add_watch(self._listen_sock, gobject.IO_IN, self._handle_incoming_data) - gobject.timeout_add(self._SEGMENT_TTL * 1000, self._segment_ttl_worker) - self._rt_check_worker_id = gobject.timeout_add(50, self._retransmit_check_worker) - self._ack_check_worker_id = gobject.timeout_add(50, self._ack_check_worker) - - self._started = True - - def _segment_ttl_worker(self): - """Cull already-sent message segments that are past their TTL.""" - now = time.time() - for key in self._sent.keys()[:]: - segment = self._sent[key] - if segment.stime() < now - self._SEGMENT_TTL: - if segment.userdata: - gobject.source_remove(segment.userdata) - del self._sent[key] - - # Cull incomplete incoming segment chains that haven't gotten any data - # for a long time either - for msg_key in self._incoming.keys()[:]: - message = self._incoming[msg_key] - if message.last_incoming_time() < now - self._SEGMENT_TTL: - del self._incoming[msg_key] - - # Remove already received and dispatched messages after a while - for msg_key in self._dispatched.keys()[:]: - message = self._dispatched[msg_key] - if message.dispatch_time() < now - (self._SEGMENT_TTL*2): - del self._dispatched[msg_key] - - # Remove received acks after a while - for ack_key in self._acks.keys()[:]: - ack_time = self._acks[ack_key] - if ack_time < now - (self._SEGMENT_TTL*2): - del self._acks[ack_key] - - return True - - _MAX_SEGMENT_RETRIES = 10 - def _retransmit_request(self, message): - """Returns true if the message has exceeded it's retry limit.""" - first_missing = message.first_missing() - if first_missing > 0: - num_retries = message.rt_tries(first_missing) - if num_retries > self._MAX_SEGMENT_RETRIES: - return True - msg_seq = self._next_msg_seq() - seg = message.get_retransmit_message(msg_seq, first_missing) - if seg: - print "(MRP): Requesting retransmit of %d by %s" % (first_missing, message.source_address()) - self._outgoing.append(seg) - self._schedule_send_worker() - return False - - def _retransmit_check_worker(self): - """Periodically check for and send retransmit requests for message - segments that got lost.""" - try: - now = time.time() - for key in self._incoming.keys()[:]: - message = self._incoming[key] - if message.complete(): - continue - next_rt = message.next_rt_time() - if next_rt == 0 or next_rt > now: - continue - if self._retransmit_request(message): - # Kill the message, too many retries - print "(MRP): Dropped message %s, exceeded retries." % _stringify_sha(message.sha()) - self._dispatched[key] = message - message.set_dispatch_time() - del self._incoming[key] - except KeyboardInterrupt: - return False - return True - - def _process_incoming_data(self, segment): - """Handle a new message segment. First checks if there is only one - segment to the message, and if the checksum from the header matches - that computed from the data, dispatches it. Otherwise, it adds the - new segment to the list of other segments for that message, and - checks to see if the message is complete. If all segments are present, - the message is reassembled and dispatched.""" - - msg_sha = segment.master_sha() - nsegs = segment.total_segments() - addr = segment.address() - segno = segment.segment_number() - - msg_seq_num = segment.message_sequence_number() - msg_key = (addr[0], msg_seq_num, msg_sha, nsegs) - - if self._dispatched.has_key(msg_key): - # We already dispatched this message, this segment is useless - return - # First segment in the message - if not self._incoming.has_key(msg_key): - self._incoming[msg_key] = Message((addr[0], self._port), msg_seq_num, msg_sha, nsegs) - # Acknowledge the message if it didn't come from us - if addr[0] not in self._local_ips: - ack_key = (msg_seq_num, msg_sha, addr[0]) - if not self._acks.has_key(ack_key): - self._send_ack_for_message(msg_seq_num, msg_sha, addr[0]) - - message = self._incoming[msg_key] - # Look for a dupe, and if so, drop the new segment - if message.has_segment(segno): - return - message.add_segment(segment) - - # Dispatch the message if all segments are present and the sha is correct - if message.complete(): - (msg_data, complete_data_sha) = message.data() - if msg_sha == complete_data_sha: - self._data_cb(addr, msg_data, self._user_data) - self._dispatched[msg_key] = message - message.set_dispatch_time() - del self._incoming[msg_key] - return - - def _segment_retransmit_cb(self, key, segment): - """Add a segment ot the outgoing queue and schedule its transmission.""" - del self._sent[key] - self._outgoing.append(segment) - self._schedule_send_worker() - return False - - def _schedule_segment_retransmit(self, key, segment, when, now): - """Schedule retransmission of a segment if one is not already scheduled.""" - if segment.userdata: - # Already scheduled for retransmit - return - - if when <= now: - # Immediate retransmission - self._segment_retransmit_cb(key, segment) - else: - # convert time to milliseconds - timeout = int((when - now) * 1000) - segment.userdata = gobject.timeout_add(timeout, self._segment_retransmit_cb, - key, segment) - - _STD_RETRANSMIT_INTERVAL = 0.05 # 50ms (in seconds) - def _process_retransmit_request(self, segment): - """Validate and process a retransmission request.""" - key = (segment.rt_msg_seq_num(), segment.rt_master_sha(), segment.rt_segment_number()) - if not self._sent.has_key(key): - # Either we don't know about the segment, or it was already culled - return - - # Calculate next retransmission time and schedule packet for retransmit - segment = self._sent[key] - # only retransmit segments every 150ms or more - now = time.time() - next_transmit = max(now, segment.last_transmit() + self._STD_RETRANSMIT_INTERVAL) - self._schedule_segment_retransmit(key, segment, next_transmit, now) - - def _ack_check_worker(self): - """Periodically check for messages that haven't received an ack - yet, and retransmit them.""" - try: - now = time.time() - for key in self._sent.keys()[:]: - segment = self._sent[key] - # We only care about retransmitting the first segment - # of a message, since if other machines don't have the - # rest of the segments, they'll issue retransmit requests - if segment.segment_number() != 1: - continue - if segment.last_transmit() > now - 0.150: # 150ms - # Was just retransmitted recently, wait longer - # before retransmitting it - continue - ack_key = None - for ip in self._local_ips: - ack_key = (segment.message_sequence_number(), segment.master_sha(), ip) - if self._acks.has_key(ack_key): - break - ack_key = None - # If the segment already has been acked, don't send it - # again unless somebody explicitly requests a retransmit - if ack_key is not None: - continue - - del self._sent[key] - self._outgoing.append(segment) - self._schedule_send_worker() - except KeyboardInterrupt: - return False - return True - - def _send_ack_for_message(self, ack_msg_seq_num, ack_msg_sha, ack_addr): - """Send an ack segment for a message.""" - msg_seq_num = self._next_msg_seq() - full_remote_addr = (self._remote_addr, self._port) - ack = AckSegment.new_from_parts(full_remote_addr, msg_seq_num, - ack_msg_seq_num, ack_msg_sha, ack_addr) - self._outgoing.append(ack) - self._schedule_send_worker() - self._process_incoming_ack(ack) - - def _process_incoming_ack(self, segment): - """Save the ack so that we don't send an ack when we start getting the segments - the ack was acknowledging.""" - # If the ack is supposed to be for a message we sent, only accept it if - # we actually sent the message to which it refers - ack_addr = segment.ack_addr() - ack_master_sha = segment.ack_master_sha() - ack_msg_seq_num = segment.ack_msg_seq_num() - if ack_addr in self._local_ips: - sent_key = (ack_msg_seq_num, ack_master_sha, 1) - if not self._sent.has_key(sent_key): - return - ack_key = (ack_msg_seq_num, ack_master_sha, ack_addr) - if not self._acks.has_key(ack_key): - self._acks[ack_key] = time.time() - - def set_drop_probability(self, prob=4): - """Debugging function to randomly drop incoming packets. - The prob argument should be an integer between 1 and 10 to drop, - or 0 to drop none. Higher numbers drop more packets.""" - if not isinstance(prob, int): - raise ValueError("Drop probability must be an integer.") - if prob < 1 or prob > 10: - raise ValueError("Drop probability must be between 1 and 10 inclusive.") - self._drop_prob = prob - - def _handle_incoming_data(self, source, condition): - """Handle incoming network data by making a message segment out of it - sending it off to the processing function.""" - if not (condition & gobject.IO_IN): - return True - msg = {} - data, addr = source.recvfrom(self._UDP_MSG_SIZE) - - should_drop = False - p = random.random() * 10.0 - if self._drop_prob > 0 and p <= self._drop_prob: - should_drop = True - - try: - segment = SegmentBase.new_from_data(addr, data) - if should_drop: - print "(MRP): Dropped segment %d." % segment.segment_number() - else: - stype = segment.segment_type() - if stype == SegmentBase.type_data(): - self._process_incoming_data(segment) - elif stype == SegmentBase.type_retransmit(): - self._process_retransmit_request(segment) - elif stype == SegmentBase.type_ack(): - self._process_incoming_ack(segment) - except ValueError, exc: - print "(MRP): Bad segment: %s" % exc - return True - - def _next_msg_seq(self): - self._seq_counter = self._seq_counter + 1 - if self._seq_counter > 65535: - self._seq_counter = 1 - return self._seq_counter - - def send(self, data): - """Break data up into chunks and queue for later transmission.""" - if not self._started: - raise Exception("Can't send anything until started!") - - msg_seq = self._next_msg_seq() - - # Pack the data into network byte order - template = "! %ds" % len(str(data)) - data = struct.pack(template, str(data)) - master_sha = _sha_data(data) - - # Split up the data into segments - left = length = len(data) - mtu = SegmentBase.mtu() - nmessages = length / mtu - if length % mtu > 0: - nmessages = nmessages + 1 - seg_num = 1 - while left > 0: - seg = DataSegment.new_from_parts(seg_num, nmessages, - msg_seq, master_sha, data[:mtu]) - self._outgoing.append(seg) - seg_num = seg_num + 1 - data = data[mtu:] - left = left - mtu - self._schedule_send_worker() - - def _schedule_send_worker(self): - if len(self._outgoing) > 0 and self._send_worker == 0: - self._send_worker = gobject.timeout_add(50, self._send_worker_cb) - - def _send_worker_cb(self): - """Send all queued segments that have yet to be transmitted.""" - self._send_worker = 0 - nsent = 0 - for segment in self._outgoing: - packet = segment.packetize() - segment.inc_transmits() - addr = (self._remote_addr, self._port) - if segment.address(): - addr = segment.address() - self._send_sock.sendto(packet, addr) - if segment.userdata: - gobject.source_remove(segment.userdata) - segment.userdata = None # Retransmission GSource - key = (segment.message_sequence_number(), segment.master_sha(), segment.segment_number()) - self._sent[key] = segment - nsent = nsent + 1 - if nsent > 10: - break - self._outgoing = self._outgoing[nsent:] - if len(self._outgoing): - self._schedule_send_worker() - return False +# intf = socket.gethostbyname(socket.gethostname()) +# self._listen_sock.setsockopt(socket.SOL_IP, socket.IP_MULTICAST_IF, +# socket.inet_aton(intf) + socket.inet_aton('0.0.0.0')) + self._listen_sock.setsockopt(socket.SOL_IP, socket.IP_ADD_MEMBERSHIP, + socket.inet_aton(self._remote_addr) + socket.inet_aton('0.0.0.0')) + + # Watch the listener socket for data + gobject.io_add_watch(self._listen_sock, gobject.IO_IN, self._handle_incoming_data) + gobject.timeout_add(self._SEGMENT_TTL * 1000, self._segment_ttl_worker) + self._rt_check_worker_id = gobject.timeout_add(50, self._retransmit_check_worker) + self._ack_check_worker_id = gobject.timeout_add(50, self._ack_check_worker) + + self._started = True + + def _segment_ttl_worker(self): + """Cull already-sent message segments that are past their TTL.""" + now = time.time() + for key in self._sent.keys()[:]: + segment = self._sent[key] + if segment.stime() < now - self._SEGMENT_TTL: + if segment.userdata: + gobject.source_remove(segment.userdata) + del self._sent[key] + + # Cull incomplete incoming segment chains that haven't gotten any data + # for a long time either + for msg_key in self._incoming.keys()[:]: + message = self._incoming[msg_key] + if message.last_incoming_time() < now - self._SEGMENT_TTL: + del self._incoming[msg_key] + + # Remove already received and dispatched messages after a while + for msg_key in self._dispatched.keys()[:]: + message = self._dispatched[msg_key] + if message.dispatch_time() < now - (self._SEGMENT_TTL*2): + del self._dispatched[msg_key] + + # Remove received acks after a while + for ack_key in self._acks.keys()[:]: + ack_time = self._acks[ack_key] + if ack_time < now - (self._SEGMENT_TTL*2): + del self._acks[ack_key] + + return True + + _MAX_SEGMENT_RETRIES = 10 + def _retransmit_request(self, message): + """Returns true if the message has exceeded it's retry limit.""" + first_missing = message.first_missing() + if first_missing > 0: + num_retries = message.rt_tries(first_missing) + if num_retries > self._MAX_SEGMENT_RETRIES: + return True + msg_seq = self._next_msg_seq() + seg = message.get_retransmit_message(msg_seq, first_missing) + if seg: + print "(MRP): Requesting retransmit of %d by %s" % (first_missing, message.source_address()) + self._outgoing.append(seg) + self._schedule_send_worker() + return False + + def _retransmit_check_worker(self): + """Periodically check for and send retransmit requests for message + segments that got lost.""" + try: + now = time.time() + for key in self._incoming.keys()[:]: + message = self._incoming[key] + if message.complete(): + continue + next_rt = message.next_rt_time() + if next_rt == 0 or next_rt > now: + continue + if self._retransmit_request(message): + # Kill the message, too many retries + print "(MRP): Dropped message %s, exceeded retries." % _stringify_sha(message.sha()) + self._dispatched[key] = message + message.set_dispatch_time() + del self._incoming[key] + except KeyboardInterrupt: + return False + return True + + def _process_incoming_data(self, segment): + """Handle a new message segment. First checks if there is only one + segment to the message, and if the checksum from the header matches + that computed from the data, dispatches it. Otherwise, it adds the + new segment to the list of other segments for that message, and + checks to see if the message is complete. If all segments are present, + the message is reassembled and dispatched.""" + + msg_sha = segment.master_sha() + nsegs = segment.total_segments() + addr = segment.address() + segno = segment.segment_number() + + msg_seq_num = segment.message_sequence_number() + msg_key = (addr[0], msg_seq_num, msg_sha, nsegs) + + if self._dispatched.has_key(msg_key): + # We already dispatched this message, this segment is useless + return + # First segment in the message + if not self._incoming.has_key(msg_key): + self._incoming[msg_key] = Message((addr[0], self._port), msg_seq_num, msg_sha, nsegs) + # Acknowledge the message if it didn't come from us + if addr[0] not in self._local_ips: + ack_key = (msg_seq_num, msg_sha, addr[0]) + if not self._acks.has_key(ack_key): + self._send_ack_for_message(msg_seq_num, msg_sha, addr[0]) + + message = self._incoming[msg_key] + # Look for a dupe, and if so, drop the new segment + if message.has_segment(segno): + return + message.add_segment(segment) + + # Dispatch the message if all segments are present and the sha is correct + if message.complete(): + (msg_data, complete_data_sha) = message.data() + if msg_sha == complete_data_sha: + self._data_cb(addr, msg_data, self._user_data) + self._dispatched[msg_key] = message + message.set_dispatch_time() + del self._incoming[msg_key] + return + + def _segment_retransmit_cb(self, key, segment): + """Add a segment ot the outgoing queue and schedule its transmission.""" + del self._sent[key] + self._outgoing.append(segment) + self._schedule_send_worker() + return False + + def _schedule_segment_retransmit(self, key, segment, when, now): + """Schedule retransmission of a segment if one is not already scheduled.""" + if segment.userdata: + # Already scheduled for retransmit + return + + if when <= now: + # Immediate retransmission + self._segment_retransmit_cb(key, segment) + else: + # convert time to milliseconds + timeout = int((when - now) * 1000) + segment.userdata = gobject.timeout_add(timeout, self._segment_retransmit_cb, + key, segment) + + _STD_RETRANSMIT_INTERVAL = 0.05 # 50ms (in seconds) + def _process_retransmit_request(self, segment): + """Validate and process a retransmission request.""" + key = (segment.rt_msg_seq_num(), segment.rt_master_sha(), segment.rt_segment_number()) + if not self._sent.has_key(key): + # Either we don't know about the segment, or it was already culled + return + + # Calculate next retransmission time and schedule packet for retransmit + segment = self._sent[key] + # only retransmit segments every 150ms or more + now = time.time() + next_transmit = max(now, segment.last_transmit() + self._STD_RETRANSMIT_INTERVAL) + self._schedule_segment_retransmit(key, segment, next_transmit, now) + + def _ack_check_worker(self): + """Periodically check for messages that haven't received an ack + yet, and retransmit them.""" + try: + now = time.time() + for key in self._sent.keys()[:]: + segment = self._sent[key] + # We only care about retransmitting the first segment + # of a message, since if other machines don't have the + # rest of the segments, they'll issue retransmit requests + if segment.segment_number() != 1: + continue + if segment.last_transmit() > now - 0.150: # 150ms + # Was just retransmitted recently, wait longer + # before retransmitting it + continue + ack_key = None + for ip in self._local_ips: + ack_key = (segment.message_sequence_number(), segment.master_sha(), ip) + if self._acks.has_key(ack_key): + break + ack_key = None + # If the segment already has been acked, don't send it + # again unless somebody explicitly requests a retransmit + if ack_key is not None: + continue + + del self._sent[key] + self._outgoing.append(segment) + self._schedule_send_worker() + except KeyboardInterrupt: + return False + return True + + def _send_ack_for_message(self, ack_msg_seq_num, ack_msg_sha, ack_addr): + """Send an ack segment for a message.""" + msg_seq_num = self._next_msg_seq() + full_remote_addr = (self._remote_addr, self._port) + ack = AckSegment.new_from_parts(full_remote_addr, msg_seq_num, + ack_msg_seq_num, ack_msg_sha, ack_addr) + self._outgoing.append(ack) + self._schedule_send_worker() + self._process_incoming_ack(ack) + + def _process_incoming_ack(self, segment): + """Save the ack so that we don't send an ack when we start getting the segments + the ack was acknowledging.""" + # If the ack is supposed to be for a message we sent, only accept it if + # we actually sent the message to which it refers + ack_addr = segment.ack_addr() + ack_master_sha = segment.ack_master_sha() + ack_msg_seq_num = segment.ack_msg_seq_num() + if ack_addr in self._local_ips: + sent_key = (ack_msg_seq_num, ack_master_sha, 1) + if not self._sent.has_key(sent_key): + return + ack_key = (ack_msg_seq_num, ack_master_sha, ack_addr) + if not self._acks.has_key(ack_key): + self._acks[ack_key] = time.time() + + def set_drop_probability(self, prob=4): + """Debugging function to randomly drop incoming packets. + The prob argument should be an integer between 1 and 10 to drop, + or 0 to drop none. Higher numbers drop more packets.""" + if not isinstance(prob, int): + raise ValueError("Drop probability must be an integer.") + if prob < 1 or prob > 10: + raise ValueError("Drop probability must be between 1 and 10 inclusive.") + self._drop_prob = prob + + def _handle_incoming_data(self, source, condition): + """Handle incoming network data by making a message segment out of it + sending it off to the processing function.""" + if not (condition & gobject.IO_IN): + return True + msg = {} + data, addr = source.recvfrom(self._UDP_MSG_SIZE) + + should_drop = False + p = random.random() * 10.0 + if self._drop_prob > 0 and p <= self._drop_prob: + should_drop = True + + try: + segment = SegmentBase.new_from_data(addr, data) + if should_drop: + print "(MRP): Dropped segment %d." % segment.segment_number() + else: + stype = segment.segment_type() + if stype == SegmentBase.type_data(): + self._process_incoming_data(segment) + elif stype == SegmentBase.type_retransmit(): + self._process_retransmit_request(segment) + elif stype == SegmentBase.type_ack(): + self._process_incoming_ack(segment) + except ValueError, exc: + print "(MRP): Bad segment: %s" % exc + return True + + def _next_msg_seq(self): + self._seq_counter = self._seq_counter + 1 + if self._seq_counter > 65535: + self._seq_counter = 1 + return self._seq_counter + + def send(self, data): + """Break data up into chunks and queue for later transmission.""" + if not self._started: + raise Exception("Can't send anything until started!") + + msg_seq = self._next_msg_seq() + + # Pack the data into network byte order + template = "! %ds" % len(str(data)) + data = struct.pack(template, str(data)) + master_sha = _sha_data(data) + + # Split up the data into segments + left = length = len(data) + mtu = SegmentBase.mtu() + nmessages = length / mtu + if length % mtu > 0: + nmessages = nmessages + 1 + seg_num = 1 + while left > 0: + seg = DataSegment.new_from_parts(seg_num, nmessages, + msg_seq, master_sha, data[:mtu]) + self._outgoing.append(seg) + seg_num = seg_num + 1 + data = data[mtu:] + left = left - mtu + self._schedule_send_worker() + + def _schedule_send_worker(self): + if len(self._outgoing) > 0 and self._send_worker == 0: + self._send_worker = gobject.timeout_add(50, self._send_worker_cb) + + def _send_worker_cb(self): + """Send all queued segments that have yet to be transmitted.""" + self._send_worker = 0 + nsent = 0 + for segment in self._outgoing: + packet = segment.packetize() + segment.inc_transmits() + addr = (self._remote_addr, self._port) + if segment.address(): + addr = segment.address() + self._send_sock.sendto(packet, addr) + if segment.userdata: + gobject.source_remove(segment.userdata) + segment.userdata = None # Retransmission GSource + key = (segment.message_sequence_number(), segment.master_sha(), segment.segment_number()) + self._sent[key] = segment + nsent = nsent + 1 + if nsent > 10: + break + self._outgoing = self._outgoing[nsent:] + if len(self._outgoing): + self._schedule_send_worker() + return False ################################################################# @@ -988,348 +988,348 @@ import unittest class SegmentBaseTestCase(unittest.TestCase): - _DEF_SEGNO = 1 - _DEF_TOT_SEGS = 5 - _DEF_MSG_SEQ_NUM = 4556 - _DEF_MASTER_SHA = "12345678901234567890" - _DEF_SEG_TYPE = 0 + _DEF_SEGNO = 1 + _DEF_TOT_SEGS = 5 + _DEF_MSG_SEQ_NUM = 4556 + _DEF_MASTER_SHA = "12345678901234567890" + _DEF_SEG_TYPE = 0 - _DEF_ADDRESS = ('123.3.2.1', 3333) - _SEG_MAGIC = 0xbaea4304 + _DEF_ADDRESS = ('123.3.2.1', 3333) + _SEG_MAGIC = 0xbaea4304 class SegmentBaseInitTestCase(SegmentBaseTestCase): - def _test_init_fail(self, segno, total_segs, msg_seq_num, master_sha, fail_msg): - try: - seg = SegmentBase(segno, total_segs, msg_seq_num, master_sha) - except ValueError, exc: - pass - else: - self.fail("expected a ValueError for %s." % fail_msg) - - def testSegmentBase(self): - assert SegmentBase.magic() == self._SEG_MAGIC, "Segment magic wasn't correct!" - assert SegmentBase.header_len() > 0, "header size was not greater than zero." - assert SegmentBase.mtu() > 0, "MTU was not greater than zero." - assert SegmentBase.mtu() + SegmentBase.header_len() == _UDP_DATAGRAM_SIZE, "MTU + header size didn't equal expected %d." % _UDP_DATAGRAM_SIZE - - def testGoodInit(self): - seg = SegmentBase(self._DEF_SEGNO, self._DEF_TOT_SEGS, self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA) - assert seg.stime() < time.time(), "segment start time is less than now!" - assert not seg.address(), "Segment address was not None after init." - assert seg.segment_number() == self._DEF_SEGNO, "Segment number wasn't correct after init." - assert seg.total_segments() == self._DEF_TOT_SEGS, "Total segments wasn't correct after init." - assert seg.message_sequence_number() == self._DEF_MSG_SEQ_NUM, "Message sequence number wasn't correct after init." - assert seg.master_sha() == self._DEF_MASTER_SHA, "Message master SHA wasn't correct after init." - assert seg.segment_type() == None, "Segment type was not None after init." - assert seg.transmits() == 0, "Segment transmits was not 0 after init." - assert seg.last_transmit() == 0, "Segment last transmit was not 0 after init." - assert seg.data() == None, "Segment data was not None after init." - - def testSegmentNumber(self): - self._test_init_fail(0, self._DEF_TOT_SEGS, self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA, "invalid segment number") - self._test_init_fail(65536, self._DEF_TOT_SEGS, self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA, "invalid segment number") - self._test_init_fail(None, self._DEF_TOT_SEGS, self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA, "invalid segment number") - self._test_init_fail("", self._DEF_TOT_SEGS, self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA, "invalid segment number") - - def testTotalMessageSegmentNumber(self): - self._test_init_fail(self._DEF_SEGNO, 0, self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA, "invalid total segments") - self._test_init_fail(self._DEF_SEGNO, 65536, self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA, "invalid total segments") - self._test_init_fail(self._DEF_SEGNO, None, self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA, "invalid total segments") - self._test_init_fail(self._DEF_SEGNO, "", self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA, "invalid total segments") - - def testMessageSequenceNumber(self): - self._test_init_fail(self._DEF_SEGNO, self._DEF_TOT_SEGS, 0, self._DEF_MASTER_SHA, "invalid message sequence number") - self._test_init_fail(self._DEF_SEGNO, self._DEF_TOT_SEGS, 65536, self._DEF_MASTER_SHA, "invalid message sequence number") - self._test_init_fail(self._DEF_SEGNO, self._DEF_TOT_SEGS, None, self._DEF_MASTER_SHA, "invalid message sequence number") - self._test_init_fail(self._DEF_SEGNO, self._DEF_TOT_SEGS, "", self._DEF_MASTER_SHA, "invalid message sequence number") - - def testMasterSHA(self): - self._test_init_fail(self._DEF_SEGNO, self._DEF_TOT_SEGS, self._DEF_MSG_SEQ_NUM, "1" * 19, "invalid SHA1 data hash") - self._test_init_fail(self._DEF_SEGNO, self._DEF_TOT_SEGS, self._DEF_MSG_SEQ_NUM, "1" * 21, "invalid SHA1 data hash") - self._test_init_fail(self._DEF_SEGNO, self._DEF_TOT_SEGS, self._DEF_MSG_SEQ_NUM, None, "invalid SHA1 data hash") - self._test_init_fail(self._DEF_SEGNO, self._DEF_TOT_SEGS, self._DEF_MSG_SEQ_NUM, 1234, "invalid SHA1 data hash") - - def _testNewFromDataFail(self, addr, data, fail_msg): - try: - seg = SegmentBase.new_from_data(addr, data) - except ValueError, exc: - pass - else: - self.fail("expected a ValueError about %s." % fail_msg) - - def testNewFromDataAddress(self): - self._testNewFromDataFail(None, None, "bad address") - self._testNewFromDataFail('', None, "bad address") - self._testNewFromDataFail((''), None, "bad address") - self._testNewFromDataFail((1), None, "bad address") - self._testNewFromDataFail(('', ''), None, "bad address") - self._testNewFromDataFail((1, 3333), None, "bad address") - self._testNewFromDataFail(('', 0), None, "bad address") - self._testNewFromDataFail(('', 65536), None, "bad address") - - def testNewFromDataData(self): - """Only test generic new_from_data() bits, not type-specific ones.""" - self._testNewFromDataFail(self._DEF_ADDRESS, None, "invalid data") - - really_short_data = "111" - self._testNewFromDataFail(self._DEF_ADDRESS, really_short_data, "data too short") - - only_header_data = "1" * SegmentBase.header_len() - self._testNewFromDataFail(self._DEF_ADDRESS, only_header_data, "data too short") - - too_much_data = "1" * (_UDP_DATAGRAM_SIZE + 1) - self._testNewFromDataFail(self._DEF_ADDRESS, too_much_data, "too much data") - - header_template = SegmentBase.header_template() - bad_magic_data = struct.pack(header_template, 0x12345678, self._DEF_SEG_TYPE, - self._DEF_SEGNO, self._DEF_TOT_SEGS, self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA) - self._testNewFromDataFail(self._DEF_ADDRESS, bad_magic_data, "invalid magic") - - bad_type_data = struct.pack(header_template, self._SEG_MAGIC, -1, self._DEF_SEGNO, - self._DEF_TOT_SEGS, self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA) - self._testNewFromDataFail(self._DEF_ADDRESS, bad_type_data, "invalid segment type") - - # Test master_sha that doesn't match data's SHA - header = struct.pack(header_template, self._SEG_MAGIC, self._DEF_SEG_TYPE, 1, 1, - self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA) - data = struct.pack("! 15s", "7" * 15) - self._testNewFromDataFail(self._DEF_ADDRESS, header + data, "single-segment message SHA mismatch") - - def addToSuite(suite): - suite.addTest(SegmentBaseInitTestCase("testGoodInit")) - suite.addTest(SegmentBaseInitTestCase("testSegmentNumber")) - suite.addTest(SegmentBaseInitTestCase("testTotalMessageSegmentNumber")) - suite.addTest(SegmentBaseInitTestCase("testMessageSequenceNumber")) - suite.addTest(SegmentBaseInitTestCase("testMasterSHA")) - suite.addTest(SegmentBaseInitTestCase("testNewFromDataAddress")) - suite.addTest(SegmentBaseInitTestCase("testNewFromDataData")) - addToSuite = staticmethod(addToSuite) + def _test_init_fail(self, segno, total_segs, msg_seq_num, master_sha, fail_msg): + try: + seg = SegmentBase(segno, total_segs, msg_seq_num, master_sha) + except ValueError, exc: + pass + else: + self.fail("expected a ValueError for %s." % fail_msg) + + def testSegmentBase(self): + assert SegmentBase.magic() == self._SEG_MAGIC, "Segment magic wasn't correct!" + assert SegmentBase.header_len() > 0, "header size was not greater than zero." + assert SegmentBase.mtu() > 0, "MTU was not greater than zero." + assert SegmentBase.mtu() + SegmentBase.header_len() == _UDP_DATAGRAM_SIZE, "MTU + header size didn't equal expected %d." % _UDP_DATAGRAM_SIZE + + def testGoodInit(self): + seg = SegmentBase(self._DEF_SEGNO, self._DEF_TOT_SEGS, self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA) + assert seg.stime() < time.time(), "segment start time is less than now!" + assert not seg.address(), "Segment address was not None after init." + assert seg.segment_number() == self._DEF_SEGNO, "Segment number wasn't correct after init." + assert seg.total_segments() == self._DEF_TOT_SEGS, "Total segments wasn't correct after init." + assert seg.message_sequence_number() == self._DEF_MSG_SEQ_NUM, "Message sequence number wasn't correct after init." + assert seg.master_sha() == self._DEF_MASTER_SHA, "Message master SHA wasn't correct after init." + assert seg.segment_type() == None, "Segment type was not None after init." + assert seg.transmits() == 0, "Segment transmits was not 0 after init." + assert seg.last_transmit() == 0, "Segment last transmit was not 0 after init." + assert seg.data() == None, "Segment data was not None after init." + + def testSegmentNumber(self): + self._test_init_fail(0, self._DEF_TOT_SEGS, self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA, "invalid segment number") + self._test_init_fail(65536, self._DEF_TOT_SEGS, self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA, "invalid segment number") + self._test_init_fail(None, self._DEF_TOT_SEGS, self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA, "invalid segment number") + self._test_init_fail("", self._DEF_TOT_SEGS, self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA, "invalid segment number") + + def testTotalMessageSegmentNumber(self): + self._test_init_fail(self._DEF_SEGNO, 0, self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA, "invalid total segments") + self._test_init_fail(self._DEF_SEGNO, 65536, self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA, "invalid total segments") + self._test_init_fail(self._DEF_SEGNO, None, self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA, "invalid total segments") + self._test_init_fail(self._DEF_SEGNO, "", self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA, "invalid total segments") + + def testMessageSequenceNumber(self): + self._test_init_fail(self._DEF_SEGNO, self._DEF_TOT_SEGS, 0, self._DEF_MASTER_SHA, "invalid message sequence number") + self._test_init_fail(self._DEF_SEGNO, self._DEF_TOT_SEGS, 65536, self._DEF_MASTER_SHA, "invalid message sequence number") + self._test_init_fail(self._DEF_SEGNO, self._DEF_TOT_SEGS, None, self._DEF_MASTER_SHA, "invalid message sequence number") + self._test_init_fail(self._DEF_SEGNO, self._DEF_TOT_SEGS, "", self._DEF_MASTER_SHA, "invalid message sequence number") + + def testMasterSHA(self): + self._test_init_fail(self._DEF_SEGNO, self._DEF_TOT_SEGS, self._DEF_MSG_SEQ_NUM, "1" * 19, "invalid SHA1 data hash") + self._test_init_fail(self._DEF_SEGNO, self._DEF_TOT_SEGS, self._DEF_MSG_SEQ_NUM, "1" * 21, "invalid SHA1 data hash") + self._test_init_fail(self._DEF_SEGNO, self._DEF_TOT_SEGS, self._DEF_MSG_SEQ_NUM, None, "invalid SHA1 data hash") + self._test_init_fail(self._DEF_SEGNO, self._DEF_TOT_SEGS, self._DEF_MSG_SEQ_NUM, 1234, "invalid SHA1 data hash") + + def _testNewFromDataFail(self, addr, data, fail_msg): + try: + seg = SegmentBase.new_from_data(addr, data) + except ValueError, exc: + pass + else: + self.fail("expected a ValueError about %s." % fail_msg) + + def testNewFromDataAddress(self): + self._testNewFromDataFail(None, None, "bad address") + self._testNewFromDataFail('', None, "bad address") + self._testNewFromDataFail((''), None, "bad address") + self._testNewFromDataFail((1), None, "bad address") + self._testNewFromDataFail(('', ''), None, "bad address") + self._testNewFromDataFail((1, 3333), None, "bad address") + self._testNewFromDataFail(('', 0), None, "bad address") + self._testNewFromDataFail(('', 65536), None, "bad address") + + def testNewFromDataData(self): + """Only test generic new_from_data() bits, not type-specific ones.""" + self._testNewFromDataFail(self._DEF_ADDRESS, None, "invalid data") + + really_short_data = "111" + self._testNewFromDataFail(self._DEF_ADDRESS, really_short_data, "data too short") + + only_header_data = "1" * SegmentBase.header_len() + self._testNewFromDataFail(self._DEF_ADDRESS, only_header_data, "data too short") + + too_much_data = "1" * (_UDP_DATAGRAM_SIZE + 1) + self._testNewFromDataFail(self._DEF_ADDRESS, too_much_data, "too much data") + + header_template = SegmentBase.header_template() + bad_magic_data = struct.pack(header_template, 0x12345678, self._DEF_SEG_TYPE, + self._DEF_SEGNO, self._DEF_TOT_SEGS, self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA) + self._testNewFromDataFail(self._DEF_ADDRESS, bad_magic_data, "invalid magic") + + bad_type_data = struct.pack(header_template, self._SEG_MAGIC, -1, self._DEF_SEGNO, + self._DEF_TOT_SEGS, self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA) + self._testNewFromDataFail(self._DEF_ADDRESS, bad_type_data, "invalid segment type") + + # Test master_sha that doesn't match data's SHA + header = struct.pack(header_template, self._SEG_MAGIC, self._DEF_SEG_TYPE, 1, 1, + self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA) + data = struct.pack("! 15s", "7" * 15) + self._testNewFromDataFail(self._DEF_ADDRESS, header + data, "single-segment message SHA mismatch") + + def addToSuite(suite): + suite.addTest(SegmentBaseInitTestCase("testGoodInit")) + suite.addTest(SegmentBaseInitTestCase("testSegmentNumber")) + suite.addTest(SegmentBaseInitTestCase("testTotalMessageSegmentNumber")) + suite.addTest(SegmentBaseInitTestCase("testMessageSequenceNumber")) + suite.addTest(SegmentBaseInitTestCase("testMasterSHA")) + suite.addTest(SegmentBaseInitTestCase("testNewFromDataAddress")) + suite.addTest(SegmentBaseInitTestCase("testNewFromDataData")) + addToSuite = staticmethod(addToSuite) class DataSegmentTestCase(SegmentBaseTestCase): - """Test DataSegment class specific initialization and stuff.""" - - def testInit(self): - seg = DataSegment(self._DEF_SEGNO, self._DEF_TOT_SEGS, self._DEF_MSG_SEQ_NUM, - self._DEF_MASTER_SHA) - assert seg.segment_type() == SegmentBase.type_data(), "Segment wasn't a data segment." - - def testNewFromParts(self): - try: - seg = DataSegment.new_from_parts(self._DEF_SEGNO, self._DEF_TOT_SEGS, - self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA, None) - except ValueError, exc: - pass - else: - self.fail("Expected ValueError about invalid data.") - - # Ensure message data is same as we stuff in after object is instantiated - payload = "How are you today?" - seg = DataSegment.new_from_parts(self._DEF_SEGNO, self._DEF_TOT_SEGS, self._DEF_MSG_SEQ_NUM, - self._DEF_MASTER_SHA, payload) - assert seg.data() == payload, "Data after segment creation didn't match expected." - - def testNewFromData(self): - """Test DataSegment's new_from_data() functionality.""" - - # Make sure something valid actually works - header_template = SegmentBase.header_template() - payload_str = "How are you today?" - payload = struct.pack("! %ds" % len(payload_str), payload_str) - payload_sha = _sha_data(payload) - header = struct.pack(header_template, self._SEG_MAGIC, SegmentBase.type_data(), self._DEF_SEGNO, - self._DEF_TOT_SEGS, self._DEF_MSG_SEQ_NUM, payload_sha) - seg = SegmentBase.new_from_data(self._DEF_ADDRESS, header + payload) - - assert seg.address() == self._DEF_ADDRESS, "Segment address did not match expected." - assert seg.segment_type() == SegmentBase.type_data(), "Segment type did not match expected." - assert seg.segment_number() == self._DEF_SEGNO, "Segment number did not match expected." - assert seg.total_segments() == self._DEF_TOT_SEGS, "Total segments did not match expected." - assert seg.message_sequence_number() == self._DEF_MSG_SEQ_NUM, "Message sequence number did not match expected." - assert seg.master_sha() == payload_sha, "Message master SHA did not match expected." - assert seg.data() == payload, "Segment data did not match expected payload." - - def addToSuite(suite): - suite.addTest(DataSegmentTestCase("testInit")) - suite.addTest(DataSegmentTestCase("testNewFromParts")) - suite.addTest(DataSegmentTestCase("testNewFromData")) - addToSuite = staticmethod(addToSuite) + """Test DataSegment class specific initialization and stuff.""" + + def testInit(self): + seg = DataSegment(self._DEF_SEGNO, self._DEF_TOT_SEGS, self._DEF_MSG_SEQ_NUM, + self._DEF_MASTER_SHA) + assert seg.segment_type() == SegmentBase.type_data(), "Segment wasn't a data segment." + + def testNewFromParts(self): + try: + seg = DataSegment.new_from_parts(self._DEF_SEGNO, self._DEF_TOT_SEGS, + self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA, None) + except ValueError, exc: + pass + else: + self.fail("Expected ValueError about invalid data.") + + # Ensure message data is same as we stuff in after object is instantiated + payload = "How are you today?" + seg = DataSegment.new_from_parts(self._DEF_SEGNO, self._DEF_TOT_SEGS, self._DEF_MSG_SEQ_NUM, + self._DEF_MASTER_SHA, payload) + assert seg.data() == payload, "Data after segment creation didn't match expected." + + def testNewFromData(self): + """Test DataSegment's new_from_data() functionality.""" + + # Make sure something valid actually works + header_template = SegmentBase.header_template() + payload_str = "How are you today?" + payload = struct.pack("! %ds" % len(payload_str), payload_str) + payload_sha = _sha_data(payload) + header = struct.pack(header_template, self._SEG_MAGIC, SegmentBase.type_data(), self._DEF_SEGNO, + self._DEF_TOT_SEGS, self._DEF_MSG_SEQ_NUM, payload_sha) + seg = SegmentBase.new_from_data(self._DEF_ADDRESS, header + payload) + + assert seg.address() == self._DEF_ADDRESS, "Segment address did not match expected." + assert seg.segment_type() == SegmentBase.type_data(), "Segment type did not match expected." + assert seg.segment_number() == self._DEF_SEGNO, "Segment number did not match expected." + assert seg.total_segments() == self._DEF_TOT_SEGS, "Total segments did not match expected." + assert seg.message_sequence_number() == self._DEF_MSG_SEQ_NUM, "Message sequence number did not match expected." + assert seg.master_sha() == payload_sha, "Message master SHA did not match expected." + assert seg.data() == payload, "Segment data did not match expected payload." + + def addToSuite(suite): + suite.addTest(DataSegmentTestCase("testInit")) + suite.addTest(DataSegmentTestCase("testNewFromParts")) + suite.addTest(DataSegmentTestCase("testNewFromData")) + addToSuite = staticmethod(addToSuite) class RetransmitSegmentTestCase(SegmentBaseTestCase): - """Test RetransmitSegment class specific initialization and stuff.""" - - def _test_init_fail(self, segno, total_segs, msg_seq_num, master_sha, fail_msg): - try: - seg = RetransmitSegment(segno, total_segs, msg_seq_num, master_sha) - except ValueError, exc: - pass - else: - self.fail("expected a ValueError for %s." % fail_msg) - - def testInit(self): - self._test_init_fail(0, 1, self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA, "invalid segment number") - self._test_init_fail(2, 1, self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA, "invalid segment number") - self._test_init_fail(1, 0, self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA, "invalid number of total segments") - self._test_init_fail(1, 2, self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA, "invalid number of total segments") - - # Something that's supposed to work - seg = RetransmitSegment(1, 1, self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA) - assert seg.segment_type() == SegmentBase.type_retransmit(), "Segment wasn't a retransmit segment." - - def _test_new_from_parts_fail(self, msg_seq_num, rt_msg_seq_num, rt_master_sha, rt_segment_number, fail_msg): - try: - seg = RetransmitSegment.new_from_parts(self._DEF_ADDRESS, msg_seq_num, rt_msg_seq_num, - rt_master_sha, rt_segment_number) - except ValueError, exc: - pass - else: - self.fail("expected a ValueError for %s." % fail_msg) - - def testNewFromParts(self): - """Test RetransmitSegment's new_from_parts() functionality.""" - self._test_new_from_parts_fail(0, self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA, - self._DEF_SEGNO, "invalid message sequence number") - self._test_new_from_parts_fail(65536, self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA, - self._DEF_SEGNO, "invalid message sequence number") - self._test_new_from_parts_fail(None, self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA, - self._DEF_SEGNO, "invalid message sequence number") - self._test_new_from_parts_fail("", self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA, - self._DEF_SEGNO, "invalid message sequence number") - - self._test_new_from_parts_fail(self._DEF_MSG_SEQ_NUM, 0, self._DEF_MASTER_SHA, - self._DEF_SEGNO, "invalid retransmit message sequence number") - self._test_new_from_parts_fail(self._DEF_MSG_SEQ_NUM, 65536, self._DEF_MASTER_SHA, - self._DEF_SEGNO, "invalid retransmit message sequence number") - self._test_new_from_parts_fail(self._DEF_MSG_SEQ_NUM, None, self._DEF_MASTER_SHA, - self._DEF_SEGNO, "invalid retransmit message sequence number") - self._test_new_from_parts_fail(self._DEF_MSG_SEQ_NUM, "", self._DEF_MASTER_SHA, - self._DEF_SEGNO, "invalid retransmit message sequence number") - - self._test_new_from_parts_fail(self._DEF_MSG_SEQ_NUM, self._DEF_MSG_SEQ_NUM, "1" * 19, - self._DEF_SEGNO, "invalid retransmit message master SHA") - self._test_new_from_parts_fail(self._DEF_MSG_SEQ_NUM, self._DEF_MSG_SEQ_NUM, "1" * 21, - self._DEF_SEGNO, "invalid retransmit message master SHA") - self._test_new_from_parts_fail(self._DEF_MSG_SEQ_NUM, self._DEF_MSG_SEQ_NUM, None, - self._DEF_SEGNO, "invalid retransmit message master SHA") - self._test_new_from_parts_fail(self._DEF_MSG_SEQ_NUM, self._DEF_MSG_SEQ_NUM, 1234, - self._DEF_SEGNO, "invalid retransmit message master SHA") - - self._test_new_from_parts_fail(self._DEF_MSG_SEQ_NUM, self._DEF_MSG_SEQ_NUM, - self._DEF_MASTER_SHA, 0, "invalid retransmit message segment number") - self._test_new_from_parts_fail(self._DEF_MSG_SEQ_NUM, self._DEF_MSG_SEQ_NUM, - self._DEF_MASTER_SHA, 65536, "invalid retransmit message segment number") - self._test_new_from_parts_fail(self._DEF_MSG_SEQ_NUM, self._DEF_MSG_SEQ_NUM, - self._DEF_MASTER_SHA, None, "invalid retransmit message segment number") - self._test_new_from_parts_fail(self._DEF_MSG_SEQ_NUM, self._DEF_MSG_SEQ_NUM, - self._DEF_MASTER_SHA, "", "invalid retransmit message segment number") - - # Ensure message data is same as we stuff in after object is instantiated - seg = RetransmitSegment.new_from_parts(self._DEF_ADDRESS, self._DEF_MSG_SEQ_NUM, self._DEF_MSG_SEQ_NUM, - self._DEF_MASTER_SHA, self._DEF_SEGNO) - assert seg.rt_msg_seq_num() == self._DEF_MSG_SEQ_NUM, "RT message sequence number after segment creation didn't match expected." - assert seg.rt_master_sha() == self._DEF_MASTER_SHA, "RT master SHA after segment creation didn't match expected." - assert seg.rt_segment_number() == self._DEF_SEGNO, "RT segment number after segment creation didn't match expected." - - def _new_from_data(self, rt_msg_seq_num, rt_master_sha, rt_segment_number): - payload = struct.pack(RetransmitSegment.data_template(), rt_msg_seq_num, rt_master_sha, rt_segment_number) - payload_sha = _sha_data(payload) - header_template = SegmentBase.header_template() - header = struct.pack(header_template, self._SEG_MAGIC, SegmentBase.type_retransmit(), 1, 1, - self._DEF_MSG_SEQ_NUM, payload_sha) - return header + payload - - def _test_new_from_data_fail(self, rt_msg_seq_num, rt_master_sha, rt_segment_number, fail_msg): - try: - packet = self._new_from_data(rt_msg_seq_num, rt_master_sha, rt_segment_number) - seg = SegmentBase.new_from_data(self._DEF_ADDRESS, packet) - except ValueError, exc: - pass - else: - self.fail("Expected a ValueError about %s." % fail_msg) - - def testNewFromData(self): - """Test DataSegment's new_from_data() functionality.""" - self._test_new_from_data_fail(0, self._DEF_MASTER_SHA, self._DEF_SEGNO, "invalid RT message sequence number") - self._test_new_from_data_fail(65536, self._DEF_MASTER_SHA, self._DEF_SEGNO, "invalid RT message sequence number") - - self._test_new_from_data_fail(self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA, 0, "invalid RT segment number") - self._test_new_from_data_fail(self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA, 65536, "invalid RT segment number") - - # Ensure something that should work - packet = self._new_from_data(self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA, self._DEF_SEGNO) - seg = SegmentBase.new_from_data(self._DEF_ADDRESS, packet) - assert seg.segment_type() == SegmentBase.type_retransmit(), "Segment wasn't expected type." - assert seg.rt_msg_seq_num() == self._DEF_MSG_SEQ_NUM, "Segment RT message sequence number didn't match expected." - assert seg.rt_master_sha() == self._DEF_MASTER_SHA, "Segment RT master SHA didn't match expected." - assert seg.rt_segment_number() == self._DEF_SEGNO, "Segment RT segment number didn't match expected." - - def testPartsToData(self): - seg = RetransmitSegment.new_from_parts(self._DEF_ADDRESS, self._DEF_MSG_SEQ_NUM, self._DEF_MSG_SEQ_NUM, - self._DEF_MASTER_SHA, self._DEF_SEGNO) - new_seg = SegmentBase.new_from_data(self._DEF_ADDRESS, seg.packetize()) - assert new_seg.rt_msg_seq_num() == self._DEF_MSG_SEQ_NUM, "Segment RT message sequence number didn't match expected." - assert new_seg.rt_master_sha() == self._DEF_MASTER_SHA, "Segment RT master SHA didn't match expected." - assert new_seg.rt_segment_number() == self._DEF_SEGNO, "Segment RT segment number didn't match expected." - - def addToSuite(suite): - suite.addTest(RetransmitSegmentTestCase("testInit")) - suite.addTest(RetransmitSegmentTestCase("testNewFromParts")) - suite.addTest(RetransmitSegmentTestCase("testNewFromData")) - suite.addTest(RetransmitSegmentTestCase("testPartsToData")) - addToSuite = staticmethod(addToSuite) + """Test RetransmitSegment class specific initialization and stuff.""" + + def _test_init_fail(self, segno, total_segs, msg_seq_num, master_sha, fail_msg): + try: + seg = RetransmitSegment(segno, total_segs, msg_seq_num, master_sha) + except ValueError, exc: + pass + else: + self.fail("expected a ValueError for %s." % fail_msg) + + def testInit(self): + self._test_init_fail(0, 1, self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA, "invalid segment number") + self._test_init_fail(2, 1, self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA, "invalid segment number") + self._test_init_fail(1, 0, self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA, "invalid number of total segments") + self._test_init_fail(1, 2, self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA, "invalid number of total segments") + + # Something that's supposed to work + seg = RetransmitSegment(1, 1, self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA) + assert seg.segment_type() == SegmentBase.type_retransmit(), "Segment wasn't a retransmit segment." + + def _test_new_from_parts_fail(self, msg_seq_num, rt_msg_seq_num, rt_master_sha, rt_segment_number, fail_msg): + try: + seg = RetransmitSegment.new_from_parts(self._DEF_ADDRESS, msg_seq_num, rt_msg_seq_num, + rt_master_sha, rt_segment_number) + except ValueError, exc: + pass + else: + self.fail("expected a ValueError for %s." % fail_msg) + + def testNewFromParts(self): + """Test RetransmitSegment's new_from_parts() functionality.""" + self._test_new_from_parts_fail(0, self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA, + self._DEF_SEGNO, "invalid message sequence number") + self._test_new_from_parts_fail(65536, self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA, + self._DEF_SEGNO, "invalid message sequence number") + self._test_new_from_parts_fail(None, self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA, + self._DEF_SEGNO, "invalid message sequence number") + self._test_new_from_parts_fail("", self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA, + self._DEF_SEGNO, "invalid message sequence number") + + self._test_new_from_parts_fail(self._DEF_MSG_SEQ_NUM, 0, self._DEF_MASTER_SHA, + self._DEF_SEGNO, "invalid retransmit message sequence number") + self._test_new_from_parts_fail(self._DEF_MSG_SEQ_NUM, 65536, self._DEF_MASTER_SHA, + self._DEF_SEGNO, "invalid retransmit message sequence number") + self._test_new_from_parts_fail(self._DEF_MSG_SEQ_NUM, None, self._DEF_MASTER_SHA, + self._DEF_SEGNO, "invalid retransmit message sequence number") + self._test_new_from_parts_fail(self._DEF_MSG_SEQ_NUM, "", self._DEF_MASTER_SHA, + self._DEF_SEGNO, "invalid retransmit message sequence number") + + self._test_new_from_parts_fail(self._DEF_MSG_SEQ_NUM, self._DEF_MSG_SEQ_NUM, "1" * 19, + self._DEF_SEGNO, "invalid retransmit message master SHA") + self._test_new_from_parts_fail(self._DEF_MSG_SEQ_NUM, self._DEF_MSG_SEQ_NUM, "1" * 21, + self._DEF_SEGNO, "invalid retransmit message master SHA") + self._test_new_from_parts_fail(self._DEF_MSG_SEQ_NUM, self._DEF_MSG_SEQ_NUM, None, + self._DEF_SEGNO, "invalid retransmit message master SHA") + self._test_new_from_parts_fail(self._DEF_MSG_SEQ_NUM, self._DEF_MSG_SEQ_NUM, 1234, + self._DEF_SEGNO, "invalid retransmit message master SHA") + + self._test_new_from_parts_fail(self._DEF_MSG_SEQ_NUM, self._DEF_MSG_SEQ_NUM, + self._DEF_MASTER_SHA, 0, "invalid retransmit message segment number") + self._test_new_from_parts_fail(self._DEF_MSG_SEQ_NUM, self._DEF_MSG_SEQ_NUM, + self._DEF_MASTER_SHA, 65536, "invalid retransmit message segment number") + self._test_new_from_parts_fail(self._DEF_MSG_SEQ_NUM, self._DEF_MSG_SEQ_NUM, + self._DEF_MASTER_SHA, None, "invalid retransmit message segment number") + self._test_new_from_parts_fail(self._DEF_MSG_SEQ_NUM, self._DEF_MSG_SEQ_NUM, + self._DEF_MASTER_SHA, "", "invalid retransmit message segment number") + + # Ensure message data is same as we stuff in after object is instantiated + seg = RetransmitSegment.new_from_parts(self._DEF_ADDRESS, self._DEF_MSG_SEQ_NUM, self._DEF_MSG_SEQ_NUM, + self._DEF_MASTER_SHA, self._DEF_SEGNO) + assert seg.rt_msg_seq_num() == self._DEF_MSG_SEQ_NUM, "RT message sequence number after segment creation didn't match expected." + assert seg.rt_master_sha() == self._DEF_MASTER_SHA, "RT master SHA after segment creation didn't match expected." + assert seg.rt_segment_number() == self._DEF_SEGNO, "RT segment number after segment creation didn't match expected." + + def _new_from_data(self, rt_msg_seq_num, rt_master_sha, rt_segment_number): + payload = struct.pack(RetransmitSegment.data_template(), rt_msg_seq_num, rt_master_sha, rt_segment_number) + payload_sha = _sha_data(payload) + header_template = SegmentBase.header_template() + header = struct.pack(header_template, self._SEG_MAGIC, SegmentBase.type_retransmit(), 1, 1, + self._DEF_MSG_SEQ_NUM, payload_sha) + return header + payload + + def _test_new_from_data_fail(self, rt_msg_seq_num, rt_master_sha, rt_segment_number, fail_msg): + try: + packet = self._new_from_data(rt_msg_seq_num, rt_master_sha, rt_segment_number) + seg = SegmentBase.new_from_data(self._DEF_ADDRESS, packet) + except ValueError, exc: + pass + else: + self.fail("Expected a ValueError about %s." % fail_msg) + + def testNewFromData(self): + """Test DataSegment's new_from_data() functionality.""" + self._test_new_from_data_fail(0, self._DEF_MASTER_SHA, self._DEF_SEGNO, "invalid RT message sequence number") + self._test_new_from_data_fail(65536, self._DEF_MASTER_SHA, self._DEF_SEGNO, "invalid RT message sequence number") + + self._test_new_from_data_fail(self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA, 0, "invalid RT segment number") + self._test_new_from_data_fail(self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA, 65536, "invalid RT segment number") + + # Ensure something that should work + packet = self._new_from_data(self._DEF_MSG_SEQ_NUM, self._DEF_MASTER_SHA, self._DEF_SEGNO) + seg = SegmentBase.new_from_data(self._DEF_ADDRESS, packet) + assert seg.segment_type() == SegmentBase.type_retransmit(), "Segment wasn't expected type." + assert seg.rt_msg_seq_num() == self._DEF_MSG_SEQ_NUM, "Segment RT message sequence number didn't match expected." + assert seg.rt_master_sha() == self._DEF_MASTER_SHA, "Segment RT master SHA didn't match expected." + assert seg.rt_segment_number() == self._DEF_SEGNO, "Segment RT segment number didn't match expected." + + def testPartsToData(self): + seg = RetransmitSegment.new_from_parts(self._DEF_ADDRESS, self._DEF_MSG_SEQ_NUM, self._DEF_MSG_SEQ_NUM, + self._DEF_MASTER_SHA, self._DEF_SEGNO) + new_seg = SegmentBase.new_from_data(self._DEF_ADDRESS, seg.packetize()) + assert new_seg.rt_msg_seq_num() == self._DEF_MSG_SEQ_NUM, "Segment RT message sequence number didn't match expected." + assert new_seg.rt_master_sha() == self._DEF_MASTER_SHA, "Segment RT master SHA didn't match expected." + assert new_seg.rt_segment_number() == self._DEF_SEGNO, "Segment RT segment number didn't match expected." + + def addToSuite(suite): + suite.addTest(RetransmitSegmentTestCase("testInit")) + suite.addTest(RetransmitSegmentTestCase("testNewFromParts")) + suite.addTest(RetransmitSegmentTestCase("testNewFromData")) + suite.addTest(RetransmitSegmentTestCase("testPartsToData")) + addToSuite = staticmethod(addToSuite) class SHAUtilsTestCase(unittest.TestCase): - def testSHA(self): - data = "235jklqt3hjwasdv879wfe89723rqjh32tr3hwaejksdvd89udsv89dsgiougjktqjhk23tjht23hjt3qhjewagthjasgdgsd" - data_sha = _sha_data(data) - assert len(data_sha) == 20, "SHA wasn't correct size." - known_sha = "\xee\x9e\xb9\x1d\xe8\x96\x75\xcb\x12\xf1\x25\x22\x0f\x76\xf7\xf3\xc8\x4e\xbf\xcd" - assert data_sha == known_sha, "SHA didn't match known SHA." + def testSHA(self): + data = "235jklqt3hjwasdv879wfe89723rqjh32tr3hwaejksdvd89udsv89dsgiougjktqjhk23tjht23hjt3qhjewagthjasgdgsd" + data_sha = _sha_data(data) + assert len(data_sha) == 20, "SHA wasn't correct size." + known_sha = "\xee\x9e\xb9\x1d\xe8\x96\x75\xcb\x12\xf1\x25\x22\x0f\x76\xf7\xf3\xc8\x4e\xbf\xcd" + assert data_sha == known_sha, "SHA didn't match known SHA." - def testStringifySHA(self): - data = "jlkwjlkaegdjlksgdjklsdgajklganjtwn23n325n23tjwgeajkga nafDA fwqnjlqtjkl23tjk2365jlk235jkl2356jlktjkltewjlktewjklewtjklaggsda" - data_known_sha = "9650c23db78092a0ffda4577c87ebf36d25c868e" - assert _stringify_sha(_sha_data(data)) == data_known_sha, "SHA stringify didn't return correct SHA." - # Do it twice for kicks - assert _stringify_sha(_sha_data(data)) == data_known_sha, "SHA stringify didn't return correct SHA." + def testStringifySHA(self): + data = "jlkwjlkaegdjlksgdjklsdgajklganjtwn23n325n23tjwgeajkga nafDA fwqnjlqtjkl23tjk2365jlk235jkl2356jlktjkltewjlktewjklewtjklaggsda" + data_known_sha = "9650c23db78092a0ffda4577c87ebf36d25c868e" + assert _stringify_sha(_sha_data(data)) == data_known_sha, "SHA stringify didn't return correct SHA." + # Do it twice for kicks + assert _stringify_sha(_sha_data(data)) == data_known_sha, "SHA stringify didn't return correct SHA." - def addToSuite(suite): - suite.addTest(SHAUtilsTestCase("testSHA")) - suite.addTest(SHAUtilsTestCase("testStringifySHA")) - addToSuite = staticmethod(addToSuite) + def addToSuite(suite): + suite.addTest(SHAUtilsTestCase("testSHA")) + suite.addTest(SHAUtilsTestCase("testStringifySHA")) + addToSuite = staticmethod(addToSuite) def unit_test(): - suite = unittest.TestSuite() - SegmentBaseInitTestCase.addToSuite(suite) - DataSegmentTestCase.addToSuite(suite) - RetransmitSegmentTestCase.addToSuite(suite) - SHAUtilsTestCase.addToSuite(suite) + suite = unittest.TestSuite() + SegmentBaseInitTestCase.addToSuite(suite) + DataSegmentTestCase.addToSuite(suite) + RetransmitSegmentTestCase.addToSuite(suite) + SHAUtilsTestCase.addToSuite(suite) - runner = unittest.TextTestRunner() - runner.run(suite) + runner = unittest.TextTestRunner() + runner.run(suite) def got_data(addr, data, user_data=None): - print "Got data from %s, writing to %s." % (addr, user_data) - fl = open(user_data, "w+") - fl.write(data) - fl.close() + print "Got data from %s, writing to %s." % (addr, user_data) + fl = open(user_data, "w+") + fl.write(data) + fl.close() def simple_test(): - import sys - pipe = MostlyReliablePipe('', '224.0.0.222', 2293, got_data, sys.argv[2]) -# pipe.set_drop_probability(4) - pipe.start() - fl = open(sys.argv[1], "r") - data = fl.read() - fl.close() - msg = """The said Eliza, John, and Georgiana were now clustered round their mama in the drawing-room: + import sys + pipe = MostlyReliablePipe('', '224.0.0.222', 2293, got_data, sys.argv[2]) +# pipe.set_drop_probability(4) + pipe.start() + fl = open(sys.argv[1], "r") + data = fl.read() + fl.close() + msg = """The said Eliza, John, and Georgiana were now clustered round their mama in the drawing-room: she lay reclined on a sofa by the fireside, and with her darlings about her (for the time neither quarrelling nor crying) looked perfectly happy. Me, she had dispensed from joining the group; saying, 'She regretted to be under the necessity of keeping me at a distance; but that until she heard from @@ -1337,58 +1337,58 @@ Bessie, and could discover by her own observation, that I was endeavouring in go a more sociable and childlike disposition, a more attractive and sprightly manner -- something lighter, franker, more natural, as it were -- she really must exclude me from privileges intended only for contented, happy, little children.'""" - pipe.send(data) - try: - gtk.main() - except KeyboardInterrupt: - print 'Ctrl+C pressed, exiting...' + pipe.send(data) + try: + gtk.main() + except KeyboardInterrupt: + print 'Ctrl+C pressed, exiting...' def net_test_got_data(addr, data, user_data=None): - # Don't report data if we are a sender - if user_data: - return - print "%s (%s)" % (data, addr) + # Don't report data if we are a sender + if user_data: + return + print "%s (%s)" % (data, addr) idstamp = 0 def transmit_data(pipe): - global idstamp - msg = "Message #%d" % idstamp - print "Sending '%s'" % msg - pipe.send(msg) - idstamp = idstamp + 1 - return True + global idstamp + msg = "Message #%d" % idstamp + print "Sending '%s'" % msg + pipe.send(msg) + idstamp = idstamp + 1 + return True def network_test(): - import sys, os - send = False - if len(sys.argv) != 2: - print "Need one arg, either 'send' or 'recv'" - os._exit(1) - if sys.argv[1] == "send": - send = True - elif sys.argv[1] == "recv": - send = False - else: - print "Arg should be either 'send' or 'recv'" - os._exit(1) - - pipe = MostlyReliablePipe('', '224.0.0.222', 2293, net_test_got_data, send) - pipe.start() - if send: - gobject.timeout_add(1000, transmit_data, pipe) - try: - gtk.main() - except KeyboardInterrupt: - print 'Ctrl+C pressed, exiting...' + import sys, os + send = False + if len(sys.argv) != 2: + print "Need one arg, either 'send' or 'recv'" + os._exit(1) + if sys.argv[1] == "send": + send = True + elif sys.argv[1] == "recv": + send = False + else: + print "Arg should be either 'send' or 'recv'" + os._exit(1) + + pipe = MostlyReliablePipe('', '224.0.0.222', 2293, net_test_got_data, send) + pipe.start() + if send: + gobject.timeout_add(1000, transmit_data, pipe) + try: + gtk.main() + except KeyboardInterrupt: + print 'Ctrl+C pressed, exiting...' def main(): -# unit_test() -# simple_test() - network_test() +# unit_test() +# simple_test() + network_test() if __name__ == "__main__": - main() + main() diff --git a/sugar/p2p/NotificationListener.py b/sugar/p2p/NotificationListener.py index f68bbb2..42668ad 100644 --- a/sugar/p2p/NotificationListener.py +++ b/sugar/p2p/NotificationListener.py @@ -21,18 +21,18 @@ from sugar.p2p.Notifier import Notifier from sugar.p2p import network class NotificationListener: - def __init__(self, service): - logging.debug('Start notification listener. Service %s, address %s, port %s' % (service.get_type(), service.get_address(), service.get_port())) - server = network.GroupServer(service.get_address(), - service.get_port(), - self._recv_multicast) - server.start() - - self._listeners = [] - - def add_listener(self, listener): - self._listeners.append(listener) - - def _recv_multicast(self, msg): - for listener in self._listeners: - listener(msg) + def __init__(self, service): + logging.debug('Start notification listener. Service %s, address %s, port %s' % (service.get_type(), service.get_address(), service.get_port())) + server = network.GroupServer(service.get_address(), + service.get_port(), + self._recv_multicast) + server.start() + + self._listeners = [] + + def add_listener(self, listener): + self._listeners.append(listener) + + def _recv_multicast(self, msg): + for listener in self._listeners: + listener(msg) diff --git a/sugar/p2p/Notifier.py b/sugar/p2p/Notifier.py index f216fda..69d0af6 100644 --- a/sugar/p2p/Notifier.py +++ b/sugar/p2p/Notifier.py @@ -18,10 +18,10 @@ from sugar.p2p import network class Notifier: - def __init__(self, service): - address = service.get_address() - port = service.get_port() - self._client = network.GroupClient(address, port) - - def notify(self, msg): - self._client.send_msg(msg) + def __init__(self, service): + address = service.get_address() + port = service.get_port() + self._client = network.GroupClient(address, port) + + def notify(self, msg): + self._client.send_msg(msg) diff --git a/sugar/p2p/Stream.py b/sugar/p2p/Stream.py index edb4d1b..b3239b3 100644 --- a/sugar/p2p/Stream.py +++ b/sugar/p2p/Stream.py @@ -26,135 +26,135 @@ from MostlyReliablePipe import MostlyReliablePipe from sugar.presence import Service def is_multicast_address(address): - """Simple numerical check for whether an IP4 address - is in the range for multicast addresses or not.""" - if not address: - return False - if address[3] != '.': - return False - first = int(float(address[:3])) - if first >= 224 and first <= 239: - return True - return False + """Simple numerical check for whether an IP4 address + is in the range for multicast addresses or not.""" + if not address: + return False + if address[3] != '.': + return False + first = int(float(address[:3])) + if first >= 224 and first <= 239: + return True + return False class Stream(object): - def __init__(self, service): - if not service.get_port(): - raise ValueError("service must have an address.") - self._service = service - self._reader_port = self._service.get_port() - self._writer_port = self._reader_port - self._address = self._service.get_address() - self._callback = None - - def new_from_service(service, start_reader=True): - if is_multicast_address(service.get_address()): - return MulticastStream(service) - else: - return UnicastStream(service, start_reader) - new_from_service = staticmethod(new_from_service) - - def set_data_listener(self, callback): - self._callback = callback - - def _recv(self, address, data): - if self._callback: - self._callback(address, data) + def __init__(self, service): + if not service.get_port(): + raise ValueError("service must have an address.") + self._service = service + self._reader_port = self._service.get_port() + self._writer_port = self._reader_port + self._address = self._service.get_address() + self._callback = None + + def new_from_service(service, start_reader=True): + if is_multicast_address(service.get_address()): + return MulticastStream(service) + else: + return UnicastStream(service, start_reader) + new_from_service = staticmethod(new_from_service) + + def set_data_listener(self, callback): + self._callback = callback + + def _recv(self, address, data): + if self._callback: + self._callback(address, data) class UnicastStreamWriter(object): - def __init__(self, stream, service): - # set up the writer - self._service = service - if not service.get_address(): - raise ValueError("service must have a valid address.") - self._address = self._service.get_address() - self._port = self._service.get_port() - self._xmlrpc_addr = "http://%s:%d" % (self._address, self._port) - self._writer = network.GlibServerProxy(self._xmlrpc_addr) - - def write(self, xmlrpc_data): - """Write some data to the default endpoint of this pipe on the remote server.""" - try: - self._writer.message(None, None, xmlrpc_data) - return True - except (socket.error, xmlrpclib.Fault, xmlrpclib.ProtocolError): - traceback.print_exc() - return False - - def custom_request(self, method_name, request_cb, user_data, *args): - """Call a custom XML-RPC method on the remote server.""" - try: - method = getattr(self._writer, method_name) - method(request_cb, user_data, *args) - return True - except (socket.error, xmlrpclib.Fault, xmlrpclib.ProtocolError): - traceback.print_exc() - return False + def __init__(self, stream, service): + # set up the writer + self._service = service + if not service.get_address(): + raise ValueError("service must have a valid address.") + self._address = self._service.get_address() + self._port = self._service.get_port() + self._xmlrpc_addr = "http://%s:%d" % (self._address, self._port) + self._writer = network.GlibServerProxy(self._xmlrpc_addr) + + def write(self, xmlrpc_data): + """Write some data to the default endpoint of this pipe on the remote server.""" + try: + self._writer.message(None, None, xmlrpc_data) + return True + except (socket.error, xmlrpclib.Fault, xmlrpclib.ProtocolError): + traceback.print_exc() + return False + + def custom_request(self, method_name, request_cb, user_data, *args): + """Call a custom XML-RPC method on the remote server.""" + try: + method = getattr(self._writer, method_name) + method(request_cb, user_data, *args) + return True + except (socket.error, xmlrpclib.Fault, xmlrpclib.ProtocolError): + traceback.print_exc() + return False class UnicastStream(Stream): - def __init__(self, service, start_reader=True): - """Initializes the stream. If the 'start_reader' argument is True, - the stream will initialize and start a new stream reader, if it - is False, no reader will be created and the caller must call the - start_reader() method to start the stream reader and be able to - receive any data from the stream.""" - Stream.__init__(self, service) - if start_reader: - self.start_reader() - - def start_reader(self): - """Start the stream's reader, which for UnicastStream objects is - and XMLRPC server. If there's a port conflict with some other - service, the reader will try to find another port to use instead. - Returns the port number used for the reader.""" - # Set up the reader - self._reader = network.GlibXMLRPCServer(("", self._reader_port)) - self._reader.register_function(self._message, "message") - - def _message(self, message): - """Called by the XMLRPC server when network data arrives.""" - address = network.get_authinfo() - self._recv(address, message) - return True - - def register_reader_handler(self, handler, name): - """Register a custom message handler with the reader. This call - adds a custom XMLRPC method call with the name 'name' to the reader's - XMLRPC server, which then calls the 'handler' argument back when - a method call for it arrives over the network.""" - if name == "message": - raise ValueError("Handler name 'message' is a reserved handler.") - self._reader.register_function(handler, name) - - def new_writer(self, service): - """Return a new stream writer object.""" - return UnicastStreamWriter(self, service) + def __init__(self, service, start_reader=True): + """Initializes the stream. If the 'start_reader' argument is True, + the stream will initialize and start a new stream reader, if it + is False, no reader will be created and the caller must call the + start_reader() method to start the stream reader and be able to + receive any data from the stream.""" + Stream.__init__(self, service) + if start_reader: + self.start_reader() + + def start_reader(self): + """Start the stream's reader, which for UnicastStream objects is + and XMLRPC server. If there's a port conflict with some other + service, the reader will try to find another port to use instead. + Returns the port number used for the reader.""" + # Set up the reader + self._reader = network.GlibXMLRPCServer(("", self._reader_port)) + self._reader.register_function(self._message, "message") + + def _message(self, message): + """Called by the XMLRPC server when network data arrives.""" + address = network.get_authinfo() + self._recv(address, message) + return True + + def register_reader_handler(self, handler, name): + """Register a custom message handler with the reader. This call + adds a custom XMLRPC method call with the name 'name' to the reader's + XMLRPC server, which then calls the 'handler' argument back when + a method call for it arrives over the network.""" + if name == "message": + raise ValueError("Handler name 'message' is a reserved handler.") + self._reader.register_function(handler, name) + + def new_writer(self, service): + """Return a new stream writer object.""" + return UnicastStreamWriter(self, service) class MulticastStream(Stream): - def __init__(self, service): - Stream.__init__(self, service) - self._service = service - self._internal_start_reader() - - def start_reader(self): - return self._reader_port - - def _internal_start_reader(self): - logging.debug('Start multicast stream, address %s, port %d' % (self._address, self._reader_port)) - if not self._service.get_address(): - raise ValueError("service must have a valid address.") - self._pipe = MostlyReliablePipe('', self._address, self._reader_port, - self._recv_data_cb) - self._pipe.start() - - def write(self, data): - self._pipe.send(data) - - def _recv_data_cb(self, address, data, user_data=None): - self._recv(address[0], data) - - def new_writer(self, service=None): - return self + def __init__(self, service): + Stream.__init__(self, service) + self._service = service + self._internal_start_reader() + + def start_reader(self): + return self._reader_port + + def _internal_start_reader(self): + logging.debug('Start multicast stream, address %s, port %d' % (self._address, self._reader_port)) + if not self._service.get_address(): + raise ValueError("service must have a valid address.") + self._pipe = MostlyReliablePipe('', self._address, self._reader_port, + self._recv_data_cb) + self._pipe.start() + + def write(self, data): + self._pipe.send(data) + + def _recv_data_cb(self, address, data, user_data=None): + self._recv(address[0], data) + + def new_writer(self, service=None): + return self diff --git a/sugar/p2p/network.py b/sugar/p2p/network.py index 6718669..e5b4e4b 100644 --- a/sugar/p2p/network.py +++ b/sugar/p2p/network.py @@ -35,347 +35,347 @@ RESULT_SUCCESS = 1 __authinfos = {} def _add_authinfo(authinfo): - __authinfos[threading.currentThread()] = authinfo + __authinfos[threading.currentThread()] = authinfo def get_authinfo(): - return __authinfos.get(threading.currentThread()) + return __authinfos.get(threading.currentThread()) def _del_authinfo(): - del __authinfos[threading.currentThread()] + del __authinfos[threading.currentThread()] class GlibTCPServer(SocketServer.TCPServer): - """GlibTCPServer + """GlibTCPServer - Integrate socket accept into glib mainloop. - """ + Integrate socket accept into glib mainloop. + """ - allow_reuse_address = True - request_queue_size = 20 + allow_reuse_address = True + request_queue_size = 20 - def __init__(self, server_address, RequestHandlerClass): - SocketServer.TCPServer.__init__(self, server_address, RequestHandlerClass) - self.socket.setblocking(0) # Set nonblocking + def __init__(self, server_address, RequestHandlerClass): + SocketServer.TCPServer.__init__(self, server_address, RequestHandlerClass) + self.socket.setblocking(0) # Set nonblocking - # Watch the listener socket for data - gobject.io_add_watch(self.socket, gobject.IO_IN, self._handle_accept) + # Watch the listener socket for data + gobject.io_add_watch(self.socket, gobject.IO_IN, self._handle_accept) - def _handle_accept(self, source, condition): - """Process incoming data on the server's socket by doing an accept() - via handle_request().""" - if not (condition & gobject.IO_IN): - return True - self.handle_request() - return True + def _handle_accept(self, source, condition): + """Process incoming data on the server's socket by doing an accept() + via handle_request().""" + if not (condition & gobject.IO_IN): + return True + self.handle_request() + return True class GlibXMLRPCRequestHandler(SimpleXMLRPCServer.SimpleXMLRPCRequestHandler): - """ GlibXMLRPCRequestHandler - - The stock SimpleXMLRPCRequestHandler and server don't allow any way to pass - the client's address and/or SSL certificate into the function that actually - _processes_ the request. So we have to store it in a thread-indexed dict. - """ - - def do_POST(self): - _add_authinfo(self.client_address) - try: - SimpleXMLRPCServer.SimpleXMLRPCRequestHandler.do_POST(self) - except socket.timeout: - pass - except socket.error, e: - print "Error (%s): socket error - '%s'" % (self.client_address, e) - except: - print "Error while processing POST:" - traceback.print_exc() - _del_authinfo() + """ GlibXMLRPCRequestHandler + + The stock SimpleXMLRPCRequestHandler and server don't allow any way to pass + the client's address and/or SSL certificate into the function that actually + _processes_ the request. So we have to store it in a thread-indexed dict. + """ + + def do_POST(self): + _add_authinfo(self.client_address) + try: + SimpleXMLRPCServer.SimpleXMLRPCRequestHandler.do_POST(self) + except socket.timeout: + pass + except socket.error, e: + print "Error (%s): socket error - '%s'" % (self.client_address, e) + except: + print "Error while processing POST:" + traceback.print_exc() + _del_authinfo() class GlibXMLRPCServer(GlibTCPServer, SimpleXMLRPCServer.SimpleXMLRPCDispatcher): - """GlibXMLRPCServer - - Use nonblocking sockets and handle the accept via glib rather than - blocking on accept(). - """ - - def __init__(self, addr, requestHandler=GlibXMLRPCRequestHandler, logRequests=0): - self.logRequests = logRequests - SimpleXMLRPCServer.SimpleXMLRPCDispatcher.__init__(self) - GlibTCPServer.__init__(self, addr, requestHandler) - - def _marshaled_dispatch(self, data, dispatch_method = None): - """Dispatches an XML-RPC method from marshalled (XML) data. - - XML-RPC methods are dispatched from the marshalled (XML) data - using the _dispatch method and the result is returned as - marshalled data. For backwards compatibility, a dispatch - function can be provided as an argument (see comment in - SimpleXMLRPCRequestHandler.do_POST) but overriding the - existing method through subclassing is the prefered means - of changing method dispatch behavior. - """ - - params, method = xmlrpclib.loads(data) - - # generate response - try: - if dispatch_method is not None: - response = dispatch_method(method, params) - else: - response = self._dispatch(method, params) - # wrap response in a singleton tuple - response = (response,) - response = xmlrpclib.dumps(response, methodresponse=1) - except xmlrpclib.Fault, fault: - response = xmlrpclib.dumps(fault) - except: - print "Exception while processing request:" - traceback.print_exc() - - # report exception back to server - response = xmlrpclib.dumps( - xmlrpclib.Fault(1, "%s:%s" % (sys.exc_type, sys.exc_value)) - ) - - return response + """GlibXMLRPCServer + + Use nonblocking sockets and handle the accept via glib rather than + blocking on accept(). + """ + + def __init__(self, addr, requestHandler=GlibXMLRPCRequestHandler, logRequests=0): + self.logRequests = logRequests + SimpleXMLRPCServer.SimpleXMLRPCDispatcher.__init__(self) + GlibTCPServer.__init__(self, addr, requestHandler) + + def _marshaled_dispatch(self, data, dispatch_method = None): + """Dispatches an XML-RPC method from marshalled (XML) data. + + XML-RPC methods are dispatched from the marshalled (XML) data + using the _dispatch method and the result is returned as + marshalled data. For backwards compatibility, a dispatch + function can be provided as an argument (see comment in + SimpleXMLRPCRequestHandler.do_POST) but overriding the + existing method through subclassing is the prefered means + of changing method dispatch behavior. + """ + + params, method = xmlrpclib.loads(data) + + # generate response + try: + if dispatch_method is not None: + response = dispatch_method(method, params) + else: + response = self._dispatch(method, params) + # wrap response in a singleton tuple + response = (response,) + response = xmlrpclib.dumps(response, methodresponse=1) + except xmlrpclib.Fault, fault: + response = xmlrpclib.dumps(fault) + except: + print "Exception while processing request:" + traceback.print_exc() + + # report exception back to server + response = xmlrpclib.dumps( + xmlrpclib.Fault(1, "%s:%s" % (sys.exc_type, sys.exc_value)) + ) + + return response class GlibHTTP(httplib.HTTP): - """Subclass HTTP so we can return it's connection class' socket.""" - def connect(self, host=None, port=None): - httplib.HTTP.connect(self, host, port) - self._conn.sock.setblocking(0) - def get_sock(self): - return self._conn.sock + """Subclass HTTP so we can return it's connection class' socket.""" + def connect(self, host=None, port=None): + httplib.HTTP.connect(self, host, port) + self._conn.sock.setblocking(0) + def get_sock(self): + return self._conn.sock class GlibXMLRPCTransport(xmlrpclib.Transport): - """Integrate the request with the glib mainloop rather than blocking.""" - ## - # Connect to server. - # - # @param host Target host. - # @return A connection handle. - - def __init__(self): - pass - - def make_connection(self, host): - """Use our own connection object so we can get its socket.""" - # create a HTTP connection object from a host descriptor - host, extra_headers, x509 = self.get_host_info(host) - return GlibHTTP(host) - - ## - # Send a complete request, and parse the response. - # - # @param host Target host. - # @param handler Target PRC handler. - # @param request_body XML-RPC request body. - # @param verbose Debugging flag. - # @return Parsed response. - - def start_request(self, host, handler, request_body, verbose=0, request_cb=None, user_data=None): - """Do the first half of the request by sending data to the remote - server. The bottom half bits get run when the remote server's response - actually comes back.""" - # issue XML-RPC request - - h = self.make_connection(host) - if verbose: - h.set_debuglevel(1) - - self.send_request(h, handler, request_body) - self.send_host(h, host) - self.send_user_agent(h) - self.send_content(h, request_body) - - # Schedule a GIOWatch so we don't block waiting for the response - gobject.io_add_watch(h.get_sock(), gobject.IO_IN, self._finish_request, - h, host, handler, verbose, request_cb, user_data) - - def _finish_request(self, source, condition, h, host, handler, verbose, request_cb, user_data): - """Parse and return response when the remote server actually returns it.""" - if not (condition & gobject.IO_IN): - return True - - try: - errcode, errmsg, headers = h.getreply() - except socket.error, err: - if err[0] != 104: - raise socket.error(err) - else: - gobject.idle_add(request_cb, RESULT_FAILED, None, user_data) - return False - - if errcode != 200: - raise xmlrpclib.ProtocolError(host + handler, errcode, errmsg, headers) - self.verbose = verbose - response = self._parse_response(h.getfile(), h.get_sock()) - if request_cb: - if len(response) == 1: - response = response[0] - gobject.idle_add(request_cb, RESULT_SUCCESS, response, user_data) - return False + """Integrate the request with the glib mainloop rather than blocking.""" + ## + # Connect to server. + # + # @param host Target host. + # @return A connection handle. + + def __init__(self): + pass + + def make_connection(self, host): + """Use our own connection object so we can get its socket.""" + # create a HTTP connection object from a host descriptor + host, extra_headers, x509 = self.get_host_info(host) + return GlibHTTP(host) + + ## + # Send a complete request, and parse the response. + # + # @param host Target host. + # @param handler Target PRC handler. + # @param request_body XML-RPC request body. + # @param verbose Debugging flag. + # @return Parsed response. + + def start_request(self, host, handler, request_body, verbose=0, request_cb=None, user_data=None): + """Do the first half of the request by sending data to the remote + server. The bottom half bits get run when the remote server's response + actually comes back.""" + # issue XML-RPC request + + h = self.make_connection(host) + if verbose: + h.set_debuglevel(1) + + self.send_request(h, handler, request_body) + self.send_host(h, host) + self.send_user_agent(h) + self.send_content(h, request_body) + + # Schedule a GIOWatch so we don't block waiting for the response + gobject.io_add_watch(h.get_sock(), gobject.IO_IN, self._finish_request, + h, host, handler, verbose, request_cb, user_data) + + def _finish_request(self, source, condition, h, host, handler, verbose, request_cb, user_data): + """Parse and return response when the remote server actually returns it.""" + if not (condition & gobject.IO_IN): + return True + + try: + errcode, errmsg, headers = h.getreply() + except socket.error, err: + if err[0] != 104: + raise socket.error(err) + else: + gobject.idle_add(request_cb, RESULT_FAILED, None, user_data) + return False + + if errcode != 200: + raise xmlrpclib.ProtocolError(host + handler, errcode, errmsg, headers) + self.verbose = verbose + response = self._parse_response(h.getfile(), h.get_sock()) + if request_cb: + if len(response) == 1: + response = response[0] + gobject.idle_add(request_cb, RESULT_SUCCESS, response, user_data) + return False class _Method: - """Right, so python people thought it would be funny to make this - class private to xmlrpclib.py...""" - # some magic to bind an XML-RPC method to an RPC server. - # supports "nested" methods (e.g. examples.getStateName) - def __init__(self, send, name): - self.__send = send - self.__name = name - def __getattr__(self, name): - return _Method(self.__send, "%s.%s" % (self.__name, name)) - def __call__(self, request_cb, user_data, *args): - return self.__send(self.__name, request_cb, user_data, args) + """Right, so python people thought it would be funny to make this + class private to xmlrpclib.py...""" + # some magic to bind an XML-RPC method to an RPC server. + # supports "nested" methods (e.g. examples.getStateName) + def __init__(self, send, name): + self.__send = send + self.__name = name + def __getattr__(self, name): + return _Method(self.__send, "%s.%s" % (self.__name, name)) + def __call__(self, request_cb, user_data, *args): + return self.__send(self.__name, request_cb, user_data, args) class GlibServerProxy(xmlrpclib.ServerProxy): - """Subclass xmlrpclib.ServerProxy so we can run the XML-RPC request - in two parts, integrated with the glib mainloop, such that we don't - block anywhere. - - Using this object is somewhat special; it requires more arguments to each - XML-RPC request call than the normal xmlrpclib.ServerProxy object: - - client = GlibServerProxy("http://127.0.0.1:8888") - user_data = "bar" - xmlrpc_arg1 = "test" - xmlrpc_arg2 = "foo" - client.test(xmlrpc_test_cb, user_data, xmlrpc_arg1, xmlrpc_arg2) - - Here, 'xmlrpc_test_cb' is the callback function, which has the following - signature: - - def xmlrpc_test_cb(result_status, response, user_data=None): - ... - """ - def __init__(self, uri, encoding=None, verbose=0, allow_none=0): - self._transport = GlibXMLRPCTransport() - self._encoding = encoding - self._verbose = verbose - self._allow_none = allow_none - xmlrpclib.ServerProxy.__init__(self, uri, self._transport, encoding, verbose, allow_none) - - # get the url - import urllib - urltype, uri = urllib.splittype(uri) - if urltype not in ("http", "https"): - raise IOError, "unsupported XML-RPC protocol" - self._host, self._handler = urllib.splithost(uri) - if not self._handler: - self._handler = "/RPC2" - - def __request(self, methodname, request_cb, user_data, params): - """Call the method on the remote server. We just start the request here - and the transport itself takes care of scheduling the response callback - when the remote server returns the response. We don't want to block anywhere.""" - - request = xmlrpclib.dumps(params, methodname, encoding=self._encoding, - allow_none=self._allow_none) - - try: - response = self._transport.start_request( - self._host, - self._handler, - request, - verbose=self._verbose, - request_cb=request_cb, - user_data=user_data - ) - except socket.error, exc: - gobject.idle_add(request_cb, RESULT_FAILED, None, user_data) - - def __getattr__(self, name): - # magic method dispatcher - return _Method(self.__request, name) + """Subclass xmlrpclib.ServerProxy so we can run the XML-RPC request + in two parts, integrated with the glib mainloop, such that we don't + block anywhere. + + Using this object is somewhat special; it requires more arguments to each + XML-RPC request call than the normal xmlrpclib.ServerProxy object: + + client = GlibServerProxy("http://127.0.0.1:8888") + user_data = "bar" + xmlrpc_arg1 = "test" + xmlrpc_arg2 = "foo" + client.test(xmlrpc_test_cb, user_data, xmlrpc_arg1, xmlrpc_arg2) + + Here, 'xmlrpc_test_cb' is the callback function, which has the following + signature: + + def xmlrpc_test_cb(result_status, response, user_data=None): + ... + """ + def __init__(self, uri, encoding=None, verbose=0, allow_none=0): + self._transport = GlibXMLRPCTransport() + self._encoding = encoding + self._verbose = verbose + self._allow_none = allow_none + xmlrpclib.ServerProxy.__init__(self, uri, self._transport, encoding, verbose, allow_none) + + # get the url + import urllib + urltype, uri = urllib.splittype(uri) + if urltype not in ("http", "https"): + raise IOError, "unsupported XML-RPC protocol" + self._host, self._handler = urllib.splithost(uri) + if not self._handler: + self._handler = "/RPC2" + + def __request(self, methodname, request_cb, user_data, params): + """Call the method on the remote server. We just start the request here + and the transport itself takes care of scheduling the response callback + when the remote server returns the response. We don't want to block anywhere.""" + + request = xmlrpclib.dumps(params, methodname, encoding=self._encoding, + allow_none=self._allow_none) + + try: + response = self._transport.start_request( + self._host, + self._handler, + request, + verbose=self._verbose, + request_cb=request_cb, + user_data=user_data + ) + except socket.error, exc: + gobject.idle_add(request_cb, RESULT_FAILED, None, user_data) + + def __getattr__(self, name): + # magic method dispatcher + return _Method(self.__request, name) class GroupServer(object): - _MAX_MSG_SIZE = 500 + _MAX_MSG_SIZE = 500 - def __init__(self, address, port, data_cb): - self._address = address - self._port = port - self._data_cb = data_cb + def __init__(self, address, port, data_cb): + self._address = address + self._port = port + self._data_cb = data_cb - self._setup_listener() + self._setup_listener() - def _setup_listener(self): - # Listener socket - self._listen_sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + def _setup_listener(self): + # Listener socket + self._listen_sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - # Set some options to make it multicast-friendly - self._listen_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - self._listen_sock.setsockopt(socket.SOL_IP, socket.IP_MULTICAST_TTL, 20) - self._listen_sock.setsockopt(socket.SOL_IP, socket.IP_MULTICAST_LOOP, 1) + # Set some options to make it multicast-friendly + self._listen_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + self._listen_sock.setsockopt(socket.SOL_IP, socket.IP_MULTICAST_TTL, 20) + self._listen_sock.setsockopt(socket.SOL_IP, socket.IP_MULTICAST_LOOP, 1) - def start(self): - # Set some more multicast options - self._listen_sock.bind(('', self._port)) - self._listen_sock.settimeout(2) - intf = socket.gethostbyname(socket.gethostname()) - self._listen_sock.setsockopt(socket.SOL_IP, socket.IP_MULTICAST_IF, socket.inet_aton(intf) + socket.inet_aton('0.0.0.0')) - self._listen_sock.setsockopt(socket.SOL_IP, socket.IP_ADD_MEMBERSHIP, socket.inet_aton(self._address) + socket.inet_aton('0.0.0.0')) + def start(self): + # Set some more multicast options + self._listen_sock.bind(('', self._port)) + self._listen_sock.settimeout(2) + intf = socket.gethostbyname(socket.gethostname()) + self._listen_sock.setsockopt(socket.SOL_IP, socket.IP_MULTICAST_IF, socket.inet_aton(intf) + socket.inet_aton('0.0.0.0')) + self._listen_sock.setsockopt(socket.SOL_IP, socket.IP_ADD_MEMBERSHIP, socket.inet_aton(self._address) + socket.inet_aton('0.0.0.0')) - # Watch the listener socket for data - gobject.io_add_watch(self._listen_sock, gobject.IO_IN, self._handle_incoming_data) + # Watch the listener socket for data + gobject.io_add_watch(self._listen_sock, gobject.IO_IN, self._handle_incoming_data) - def _handle_incoming_data(self, source, condition): - if not (condition & gobject.IO_IN): - return True - msg = {} - msg['data'], (msg['addr'], msg['port']) = source.recvfrom(self._MAX_MSG_SIZE) - if self._data_cb: - self._data_cb(msg) - return True + def _handle_incoming_data(self, source, condition): + if not (condition & gobject.IO_IN): + return True + msg = {} + msg['data'], (msg['addr'], msg['port']) = source.recvfrom(self._MAX_MSG_SIZE) + if self._data_cb: + self._data_cb(msg) + return True class GroupClient(object): - _MAX_MSG_SIZE = 500 + _MAX_MSG_SIZE = 500 - def __init__(self, address, port): - self._address = address - self._port = port + def __init__(self, address, port): + self._address = address + self._port = port - self._send_sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - # Make the socket multicast-aware, and set TTL. - self._send_sock.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, 20) # Change TTL (=20) to suit + self._send_sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + # Make the socket multicast-aware, and set TTL. + self._send_sock.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, 20) # Change TTL (=20) to suit - def send_msg(self, data): - self._send_sock.sendto(data, (self._address, self._port)) + def send_msg(self, data): + self._send_sock.sendto(data, (self._address, self._port)) class Test(object): - def test(self, arg1): - print "Request got %s" % arg1 - return "success" + def test(self, arg1): + print "Request got %s" % arg1 + return "success" def xmlrpc_test_cb(response, user_data=None): - print "Response was %s, user_data was %s" % (response, user_data) - import gtk - gtk.main_quit() + print "Response was %s, user_data was %s" % (response, user_data) + import gtk + gtk.main_quit() def xmlrpc_test(): - client = GlibServerProxy("http://127.0.0.1:8888") - client.test(xmlrpc_test_cb, "bar", "test data") + client = GlibServerProxy("http://127.0.0.1:8888") + client.test(xmlrpc_test_cb, "bar", "test data") def main(): - import gtk - server = GlibXMLRPCServer(("", 8888)) - inst = Test() - server.register_instance(inst) - - gobject.idle_add(xmlrpc_test) - - try: - gtk.main() - except KeyboardInterrupt: - print 'Ctrl+C pressed, exiting...' - print "Done." + import gtk + server = GlibXMLRPCServer(("", 8888)) + inst = Test() + server.register_instance(inst) + + gobject.idle_add(xmlrpc_test) + + try: + gtk.main() + except KeyboardInterrupt: + print 'Ctrl+C pressed, exiting...' + print "Done." if __name__ == "__main__": - main() + main() |