#!/usr/bin/env python import os import sys import shutil import time import threading import traceback import gobject import gtk import dbus import dbus.glib # seems to automatically integrate dbus with GTK main loop PROGNAME = "soas-assimilator" try : import hashlib sha1 = hashlib.sha1 except ImportError : import sha sha1 = sha.new _checksum_buffer_size = 1024*1024 def checksum(f, max_size) : checksum = sha1() buf = f.read(min(_checksum_buffer_size, max_size)) while buf : checksum.update(buf) max_size -= len(buf) if not max_size : break buf = f.read(min(_checksum_buffer_size, max_size)) return checksum.hexdigest() _sizeUnits = " KMGTP" def formatSize(size) : if not size : return "0" unitIdx = 0 while ((size % 1024) == 0) : size //= 1024 unitIdx += 1 return "%d%sB" % (size, _sizeUnits[unitIdx]) class tColoredListWindow (gtk.Window) : def __init__(self, title) : gtk.Window.__init__(self) self.set_title(title) self.initWidgets() def initWidgets(self) : self.vbox = gtk.VBox(False, 10) self.add(self.vbox) self._scrolledList = gtk.ScrolledWindow() self._scrolledList.set_policy(gtk.POLICY_AUTOMATIC, gtk.POLICY_AUTOMATIC) self.vbox.pack_start(self._scrolledList) self.listModel = gtk.ListStore(*[_type for (name, _type, r) in self.listColumns]) self._listView = gtk.TreeView(self.listModel) self._listView.get_selection().set_mode(gtk.SELECTION_NONE) self._scrolledList.add(self._listView) self._addListColumns() self.maximize() self.show_all() def _addListColumns(self) : for (idx, (name, _type, renderer)) in enumerate(self.listColumns) : if not renderer : continue r = renderer() r.set_property('background-set', True) if (_type == bool) : col = gtk.TreeViewColumn(name, r, active=idx, background=self.colorColIdx) else : col = gtk.TreeViewColumn(name, r, text=idx, background=self.colorColIdx) col.set_resizable(True) col.set_sort_column_id(idx) self._listView.append_column(col) def findRow(self, col, content) : for idx in range(len(self.listModel)) : if (self.listModel[idx][col] == content) : return idx class tStatusListWindow (tColoredListWindow) : listColumns = [ ('device', str, gtk.CellRendererText), ('status', str, gtk.CellRendererText), ('image', str, gtk.CellRendererText), ('model', str, gtk.CellRendererText), ('info', str, gtk.CellRendererText), ('color', str, None), ] _deviceColIdx = 0 colorColIdx = len(listColumns)-1 _colorMap = { 'waiting': '#D0D0D0', 'reading': '#FFFF00', 'writing': '#FFFF00', 'verifying': '#6060FF', 'success': '#00FF00', 'error': '#FF0000', } def __init__(self, title) : tColoredListWindow.__init__(self, "%s: %s" % (PROGNAME, title)) def setStatus(self, device, info) : row = self.findRow(device) if row is None : row = self.listModel.append() info = dict(info) info['color'] = self._colorMap[info['status']] self.listModel[row] = [info.get(name, '') for (name, t, r) in self.listColumns] def findRow(self, device) : return tColoredListWindow.findRow(self, self._deviceColIdx, device) class tReadWindow (tStatusListWindow) : def __init__(self) : tStatusListWindow.__init__(self, "Read USB sticks") def initWidgets(self) : tStatusListWindow.initWidgets(self) self._hButtonBox = gtk.HButtonBox() self.writeButton = gtk.Button("Continue to write mode") self.writeButton.modify_bg(gtk.STATE_NORMAL, gtk.gdk.color_parse("#F00000")) self.writeButton.modify_bg(gtk.STATE_PRELIGHT, gtk.gdk.color_parse("#FF0000")) self.writeButton.show() self._hButtonBox.pack_start(self.writeButton) self._hButtonBox.show() self.vbox.pack_end(self._hButtonBox, expand=False) class tWriteWindow (tStatusListWindow) : def __init__(self) : tStatusListWindow.__init__(self, "Write USB sticks") class ReadThread (threading.Thread) : _bufferSize = 1024*1024 def __init__(self, dev_path, model, dev_size, image_prefix, status_cb) : threading.Thread.__init__(self) self.dev_path = dev_path self.model = model self.dev_size = dev_size self.image_prefix = image_prefix self.image_path = image_prefix + ".img" self.checksum_path = image_prefix + ".sha1" self.status_cb = status_cb def run(self) : start_time = time.time() self.status_cb(device=self.dev_path, status='reading', image=self.image_path, image_size=self.dev_size, model=self.model, device_size=self.dev_size) try : imageF, devF = file(self.image_path, "wb"), file(self.dev_path, "rb") try : # copy checksum = self._copy(imageF, devF) imageF.flush() finally : imageF.close(), devF.close() file(self.checksum_path, "w").write("%s\n" % (checksum,)) except : self.status_cb(device=self.dev_path, status='error', error=sys.exc_info()[1]) else : self.status_cb(device=self.dev_path, status='success', read_time=(time.time() - start_time)) def _copy(self, imageF, devF) : checksum = sha1() buf = devF.read(self._bufferSize) while buf : checksum.update(buf) imageF.write(buf) buf = devF.read(self._bufferSize) return checksum.hexdigest() class WriteThread (threading.Thread) : def __init__(self, image_path, dev_path, model, dev_size, image_size, checksum, status_cb) : threading.Thread.__init__(self) self.image_path = image_path self.dev_path = dev_path self.model = model self.image_size = image_size self.dev_size = dev_size self.checksum = checksum self.status_cb = status_cb def run(self) : start_time = time.time() self.status_cb(device=self.dev_path, status='writing', image=self.image_path, image_size=self.image_size, model=self.model, device_size=self.dev_size) try : imageF, devF = file(self.image_path, "rb"), file(self.dev_path, "wb+") try : # copy shutil.copyfileobj(imageF, devF) devF.flush() devF.seek(0) cur_time = time.time() self.status_cb(device=self.dev_path, status="verifying", write_time=(cur_time - start_time)) start_time = cur_time # verify dev_checksum = checksum(devF, self.image_size) if (dev_checksum != self.checksum) : return self.status_cb(device=self.dev_path, status='error', error="verification failed: %r (image) != %r (device)" % (self.checksum, dev_checksum)) finally : imageF.close(), devF.close() except : self.status_cb(device=self.dev_path, status='error', error=sys.exc_info()[1]) else : self.status_cb(device=self.dev_path, status='success', verify_time=(time.time() - start_time)) class WriteQThread (threading.Thread) : def __init__(self, q, cond) : threading.Thread.__init__(self) self._q = q self._cond = cond def run(self) : while True : self._cond.acquire() try : if self._q : args = self._q.pop(0) self._cond.release() thread = WriteThread(*args) thread.start() thread.join() self._cond.acquire() else : self._cond.wait() finally : self._cond.release() class tDeviceListener (object) : def __init__(self, image_dir, status_cb, use_parallel) : self.mode = 'read' self.image_dir = image_dir self.status_cb = status_cb self._images_dict = {} self._images_flat = [] self._images_lastupdate = 0 self._use_write_q = not use_parallel self._write_q = [] self._write_q_cond = threading.Condition() self.bus = dbus.Bus(dbus.Bus.TYPE_SYSTEM) self.bus.add_signal_receiver(self.hal_device_added, "DeviceAdded", "org.freedesktop.Hal.Manager", "org.freedesktop.Hal", "/org/freedesktop/Hal/Manager") self._start_write_q() def setMode(self, mode) : if mode not in ['read', 'write'] : raise ValueError("Unknown mode: %r" % (mode,)) self.mode = mode def hal_device_added(self, udi) : try : proxy = self.bus.get_object("org.freedesktop.Hal", udi) dev = dbus.Interface(proxy, "org.freedesktop.Hal.Device") props = dev.GetAllProperties() if (props.get("info.category") != "storage") \ or (props.get("storage.removable") != 1) \ or (props.get("storage.drive_type") != "disk") : # not a USB stick return # for debugging # for (k,v) in props.items() : # print "%30s %s" % (k,v) if self.mode == 'write' : self.assimilate(str(props['storage.model']), str(props['block.device']), int(props['storage.removable.media_size']), str(props.get('storage.serial', 'n_a'))) elif self.mode == 'read' : self.readImage(str(props['storage.model']), str(props['block.device']), int(props['storage.removable.media_size']), str(props['storage.serial'])) except : # don't break if we have trouble with a single device, but show the error traceback.print_exc() def assimilate(self, model, dev_path, dev_size, serial_nr) : model = "%s (serial# %s)" % (model, serial_nr) try : image_size, (image_path, image_checksum) = self._find_image(dev_size) except ValueError : if not self._images_flat : msg = "No images available." else : msg = "No matching image found (device size %d, smallest image size %d)" % (dev_size, self._images_flat[0][0]) return self.status_cb(dev_path, model=model, device_size=dev_size, status="error", error=msg) self._addToWriteQ(image_path, dev_path, model, dev_size, image_size, image_checksum) def _addToWriteQ(self, image_path, dev_path, model, dev_size, image_size, image_checksum) : if not self._use_write_q : thread = WriteThread(image_path, dev_path, model, dev_size, image_size, image_checksum, self.status_cb) thread.start() return self._write_q_cond.acquire() try : self.status_cb(device=dev_path, status='waiting', image=image_path, image_size=image_size, model=model, device_size=dev_size) self._write_q.append((image_path, dev_path, model, dev_size, image_size, image_checksum, self.status_cb)) self._write_q_cond.notify() finally : self._write_q_cond.release() def readImage(self, model, dev_path, dev_size, serial_nr) : image_prefix = "dump-%s-%s" % (serial_nr.replace('/', '_').replace(':', '_'), formatSize(dev_size)) model = "%s (serial# %s)" % (model, serial_nr) thread = ReadThread(dev_path, model, dev_size, os.path.join(self.image_dir, image_prefix), self.status_cb) thread.start() def _find_image(self, size) : """ Try to find image to put on current stick. Will return the largest one still fitting on the stick. """ self._cache_images() try : return [(isize, image) for (isize, image) in self._images_flat if (isize <= size)][-1] except IndexError : # no matching image found raise ValueError("No matching image found") _images_cachetime = 10 # rescan images every 10 seconds def _cache_images(self) : curTime = time.time() if (curTime - self._images_lastupdate) < self._images_cachetime : return image_names = [os.path.join(self.image_dir, fname) for fname in os.listdir(self.image_dir) if fname.endswith(".img")] for name in image_names : checksum_fname = name[:-3]+"sha1" if (name in self._images_dict) or (not os.path.exists(checksum_fname)) : continue self._images_dict[os.stat(name).st_size] = (name,file(checksum_fname).read().strip()) self._images_flat = sorted(self._images_dict.items()) self._images_lastupdate = curTime def _start_write_q(self) : if self._use_write_q : thread = WriteQThread(self._write_q, self._write_q_cond) thread.start() class tApp (object) : def __init__(self) : self._imageDir = self._chooseImageDir() self._devStatus = {} parallel = self._askParallel() self._deviceListener = tDeviceListener(self._imageDir, self._updateStatus, parallel) if self._askReadImage() : self._win = tReadWindow() self._win.writeButton.connect("clicked", self._showWriteWindow) self._destroyCbHandle = self._win.connect('destroy', lambda *w: self._close()) else : self._win = None self._showWriteWindow() def _showWriteWindow(self, *args) : if self._win : # close read window, but prevent closing the application self._win.handler_disconnect(self._destroyCbHandle) self._win.destroy() self._win = tWriteWindow() self._destroyCbHandle = self._win.connect('destroy', lambda *w: self._close()) self._deviceListener.setMode('write') def _updateStatus(self, device, status, **kwArgs) : devStatus = self._devStatus.setdefault(device, {'device': device}) devStatus.update(kwArgs, status=status) info = {'device': device, 'status': status, 'model': "%s (%s)" % (devStatus['model'], formatSize(devStatus['device_size'])), 'image': "%s (%s)" % (devStatus.get('image', ''), formatSize(devStatus.get('image_size', 0))), 'info': ''} if (status in ['verifying', 'success']) : if (self._deviceListener.mode == "read") : info['info'] = "reading took %ds, %.3fMB/s" % (devStatus['read_time'], float(devStatus['image_size'])/1024/1024/devStatus['read_time']) else : info['info'] = "writing took %ds, %.3fMB/s" % (devStatus['write_time'], float(devStatus['image_size'])/1024/1024/devStatus['write_time']) if (status == 'success') and (self._deviceListener.mode == "write") : info['info'] += ", verification took %ds, %.3fMB/s" % (devStatus['verify_time'], float(devStatus['image_size'])/1024/1024/devStatus['verify_time']) elif (status == 'error') : info['info'] = devStatus['error'] self._win.setStatus(device, info) def _chooseImageDir(self) : """ Query image directory from user. Will exit program in case user cancels. """ imageChooserDialog = gtk.FileChooserDialog(title="Choose directory containing image files", action=gtk.FILE_CHOOSER_ACTION_SELECT_FOLDER, buttons=(gtk.STOCK_CANCEL, gtk.RESPONSE_CANCEL, gtk.STOCK_OPEN, gtk.RESPONSE_OK)) imageChooserDialog.set_default_response(gtk.RESPONSE_OK) if imageChooserDialog.run() != gtk.RESPONSE_OK : sys.exit(10) dir = imageChooserDialog.get_filename() imageChooserDialog.destroy() return dir def _askReadImage(self) : """ Ask the user whether to create an image of a USB stick. """ askDialog = gtk.MessageDialog(type=gtk.MESSAGE_QUESTION, buttons=gtk.BUTTONS_NONE) askDialog.set_markup("\n".join([x.strip() for x in """ Once the write mode is active, it will overwrite ANY USB stick you plug in. You now (and only now) have the chance to read an existing USB stick and save an image of it to the directory you just chose. After choosing "Create image first", any stick you plug in will be read out automatically. What do you want to do? """.split("\n")])) continueButton = askDialog.add_button("Continue to write mode", gtk.RESPONSE_CANCEL) continueButton.modify_bg(gtk.STATE_NORMAL, gtk.gdk.color_parse("#F00000")) continueButton.modify_bg(gtk.STATE_PRELIGHT, gtk.gdk.color_parse("#FF0000")) askDialog.add_button("Create image first", gtk.RESPONSE_OK) askDialog.set_title("Read images first?") res = askDialog.run() askDialog.destroy() return (res == gtk.RESPONSE_OK) def _askParallel(self) : """ Ask the user whether to restrict writing to a single device at a time or write in parallel. Returns True for the latter, False otherwise. """ askDialog = gtk.MessageDialog(type=gtk.MESSAGE_QUESTION, buttons=gtk.BUTTONS_NONE, message_format="\n".join([x.strip() for x in """ If you are using a slow (i.e. 12Mbps / FullSpeed) hub or your system is having trouble with USB, it might make sense to write access a single device at a time. Usually parallel mode is much faster and thus recommended, though. """.split("\n")])) askDialog.add_buttons("Parallel mode", gtk.RESPONSE_OK, "Single device mode", gtk.RESPONSE_CANCEL) askDialog.set_title("Read images first?") res = askDialog.run() askDialog.destroy() return (res == gtk.RESPONSE_OK) def run(self) : gtk.main() def _close(self) : gtk.main_quit() def printSyntax(myName) : print "Syntax: %s" % (myName,) return 100 def main(myName, args) : if args : return printSyntax(myName) gobject.threads_init() tApp().run() sys.exit(main(sys.argv[0], sys.argv[1:]))