michael@0: # Copyright (c) 2012 The Chromium Authors. All rights reserved. michael@0: # Use of this source code is governed by a BSD-style license that can be michael@0: # found in the LICENSE file. michael@0: michael@0: michael@0: import logging michael@0: import multiprocessing michael@0: michael@0: from test_result import TestResults michael@0: michael@0: michael@0: def _ShardedTestRunnable(test): michael@0: """Standalone function needed by multiprocessing.Pool.""" michael@0: log_format = '[' + test.device + '] # %(asctime)-15s: %(message)s' michael@0: if logging.getLogger().handlers: michael@0: logging.getLogger().handlers[0].setFormatter(logging.Formatter(log_format)) michael@0: else: michael@0: logging.basicConfig(format=log_format) michael@0: # Handle SystemExit here since python has a bug to exit current process michael@0: try: michael@0: return test.Run() michael@0: except SystemExit: michael@0: return TestResults() michael@0: michael@0: def SetTestsContainer(tests_container): michael@0: """Sets tests container. michael@0: michael@0: multiprocessing.Queue can't be pickled across processes, so we need to set michael@0: this as a 'global', per process, via multiprocessing.Pool. michael@0: """ michael@0: BaseTestSharder.tests_container = tests_container michael@0: michael@0: michael@0: class BaseTestSharder(object): michael@0: """Base class for sharding tests across multiple devices. michael@0: michael@0: Args: michael@0: attached_devices: A list of attached devices. michael@0: """ michael@0: # See more in SetTestsContainer. michael@0: tests_container = None michael@0: michael@0: def __init__(self, attached_devices): michael@0: self.attached_devices = attached_devices michael@0: self.retries = 1 michael@0: self.tests = [] michael@0: michael@0: def CreateShardedTestRunner(self, device, index): michael@0: """Factory function to create a suite-specific test runner. michael@0: michael@0: Args: michael@0: device: Device serial where this shard will run michael@0: index: Index of this device in the pool. michael@0: michael@0: Returns: michael@0: An object of BaseTestRunner type (that can provide a "Run()" method). michael@0: """ michael@0: pass michael@0: michael@0: def SetupSharding(self, tests): michael@0: """Called before starting the shards.""" michael@0: pass michael@0: michael@0: def OnTestsCompleted(self, test_runners, test_results): michael@0: """Notifies that we completed the tests.""" michael@0: pass michael@0: michael@0: def RunShardedTests(self): michael@0: """Runs the tests in all connected devices. michael@0: michael@0: Returns: michael@0: A TestResults object. michael@0: """ michael@0: logging.warning('*' * 80) michael@0: logging.warning('Sharding in ' + str(len(self.attached_devices)) + michael@0: ' devices.') michael@0: logging.warning('Note that the output is not synchronized.') michael@0: logging.warning('Look for the "Final result" banner in the end.') michael@0: logging.warning('*' * 80) michael@0: final_results = TestResults() michael@0: for retry in xrange(self.retries): michael@0: logging.warning('Try %d of %d', retry + 1, self.retries) michael@0: self.SetupSharding(self.tests) michael@0: test_runners = [] michael@0: for index, device in enumerate(self.attached_devices): michael@0: logging.warning('*' * 80) michael@0: logging.warning('Creating shard %d for %s', index, device) michael@0: logging.warning('*' * 80) michael@0: test_runner = self.CreateShardedTestRunner(device, index) michael@0: test_runners += [test_runner] michael@0: logging.warning('Starting...') michael@0: pool = multiprocessing.Pool(len(self.attached_devices), michael@0: SetTestsContainer, michael@0: [BaseTestSharder.tests_container]) michael@0: # map can't handle KeyboardInterrupt exception. It's a python bug. michael@0: # So use map_async instead. michael@0: async_results = pool.map_async(_ShardedTestRunnable, test_runners) michael@0: results_lists = async_results.get(999999) michael@0: test_results = TestResults.FromTestResults(results_lists) michael@0: if retry == self.retries - 1: michael@0: all_passed = final_results.ok + test_results.ok michael@0: final_results = test_results michael@0: final_results.ok = all_passed michael@0: break michael@0: else: michael@0: final_results.ok += test_results.ok michael@0: self.tests = [] michael@0: for t in test_results.GetAllBroken(): michael@0: self.tests += [t.name] michael@0: if not self.tests: michael@0: break michael@0: self.OnTestsCompleted(test_runners, final_results) michael@0: return final_results