webkitpy: Notify parent process when a worker is spawned
[WebKit-https.git] / Tools / Scripts / webkitpy / common / message_pool.py
1 # Copyright (C) 2011 Google Inc. All rights reserved.
2 #
3 # Redistribution and use in source and binary forms, with or without
4 # modification, are permitted provided that the following conditions are
5 # met:
6 #
7 #     * Redistributions of source code must retain the above copyright
8 # notice, this list of conditions and the following disclaimer.
9 #     * Redistributions in binary form must reproduce the above
10 # copyright notice, this list of conditions and the following disclaimer
11 # in the documentation and/or other materials provided with the
12 # distribution.
13 #     * Neither the name of Google Inc. nor the names of its
14 # contributors may be used to endorse or promote products derived from
15 # this software without specific prior written permission.
16 #
17 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
18 # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
19 # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
20 # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
21 # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
22 # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
23 # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
24 # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
25 # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
26 # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28
29 """Module for handling messages and concurrency for run-webkit-tests
30 and test-webkitpy. This module follows the design for multiprocessing.Pool
31 and concurrency.futures.ProcessPoolExecutor, with the following differences:
32
33 * Tasks are executed in stateful subprocesses via objects that implement the
34   Worker interface - this allows the workers to share state across tasks.
35 * The pool provides an asynchronous event-handling interface so the caller
36   may receive events as tasks are processed.
37
38 If you don't need these features, use multiprocessing.Pool or concurrency.futures
39 intead.
40
41 """
42
43 import cPickle
44 import logging
45 import multiprocessing
46 import Queue
47 import sys
48 import time
49 import traceback
50
51
52 from webkitpy.common.host import Host
53 from webkitpy.common.system import stack_utils
54
55
56 _log = logging.getLogger(__name__)
57
58
59 def get(caller, worker_factory, num_workers, worker_startup_delay_secs=0.0, host=None):
60     """Returns an object that exposes a run() method that takes a list of test shards and runs them in parallel."""
61     return _MessagePool(caller, worker_factory, num_workers, worker_startup_delay_secs, host)
62
63
64 class _MessagePool(object):
65     def __init__(self, caller, worker_factory, num_workers, worker_startup_delay_secs=0.0, host=None):
66         self._caller = caller
67         self._worker_factory = worker_factory
68         self._num_workers = num_workers
69         self._worker_startup_delay_secs = worker_startup_delay_secs
70         self._workers = []
71         self._workers_stopped = set()
72         self._host = host
73         self._name = 'manager'
74         self._running_inline = (self._num_workers == 1)
75         if self._running_inline:
76             self._messages_to_worker = Queue.Queue()
77             self._messages_to_manager = Queue.Queue()
78         else:
79             self._messages_to_worker = multiprocessing.Queue()
80             self._messages_to_manager = multiprocessing.Queue()
81
82     def __enter__(self):
83         return self
84
85     def __exit__(self, exc_type, exc_value, exc_traceback):
86         self._close()
87         return False
88
89     def run(self, shards):
90         """Posts a list of messages to the pool and waits for them to complete."""
91         for message in shards:
92             self._messages_to_worker.put(_Message(self._name, message[0], message[1:], from_user=True, logs=()))
93
94         for _ in xrange(self._num_workers):
95             self._messages_to_worker.put(_Message(self._name, 'stop', message_args=(), from_user=False, logs=()))
96
97         self.wait()
98
99     def _start_workers(self):
100         assert not self._workers
101         self._workers_stopped = set()
102         host = None
103         if self._running_inline or self._can_pickle(self._host):
104             host = self._host
105
106         for worker_number in xrange(self._num_workers):
107             worker = _Worker(host, self._messages_to_manager, self._messages_to_worker, self._worker_factory, worker_number, self._running_inline, self if self._running_inline else None, self._worker_log_level())
108             self._workers.append(worker)
109             worker.start()
110             if not self._running_inline:
111                 self._caller.handle('did_spawn_worker', worker_number)
112             if self._worker_startup_delay_secs:
113                 time.sleep(self._worker_startup_delay_secs)
114
115     def _worker_log_level(self):
116         log_level = logging.NOTSET
117         for handler in logging.root.handlers:
118             if handler.level != logging.NOTSET:
119                 if log_level == logging.NOTSET:
120                     log_level = handler.level
121                 else:
122                     log_level = min(log_level, handler.level)
123         return log_level
124
125     def wait(self):
126         try:
127             self._start_workers()
128             if self._running_inline:
129                 self._workers[0].run()
130                 self._loop(block=False)
131             else:
132                 self._loop(block=True)
133         finally:
134             self._close()
135
136     def _close(self):
137         for worker in self._workers:
138             if worker.is_alive():
139                 worker.terminate()
140                 worker.join()
141         self._workers = []
142         if not self._running_inline:
143             # FIXME: This is a hack to get multiprocessing to not log tracebacks during shutdown :(.
144             multiprocessing.util._exiting = True
145             if self._messages_to_worker:
146                 self._messages_to_worker.close()
147                 self._messages_to_worker = None
148             if self._messages_to_manager:
149                 self._messages_to_manager.close()
150                 self._messages_to_manager = None
151
152     def _log_messages(self, messages):
153         for message in messages:
154             logging.root.handle(message)
155
156     def _handle_done(self, source):
157         self._workers_stopped.add(source)
158
159     @staticmethod
160     def _handle_worker_exception(source, exception_type, exception_value, _):
161         if exception_type == KeyboardInterrupt:
162             raise exception_type(exception_value)
163         raise WorkerException(str(exception_value))
164
165     def _can_pickle(self, host):
166         try:
167             cPickle.dumps(host)
168             return True
169         except TypeError:
170             return False
171
172     def _loop(self, block):
173         try:
174             while True:
175                 if len(self._workers_stopped) == len(self._workers):
176                     block = False
177                 message = self._messages_to_manager.get(block)
178                 self._log_messages(message.logs)
179                 if message.from_user:
180                     self._caller.handle(message.name, message.src, *message.args)
181                     continue
182                 method = getattr(self, '_handle_' + message.name)
183                 assert method, 'bad message %s' % repr(message)
184                 method(message.src, *message.args)
185         except Queue.Empty:
186             pass
187
188
189 class WorkerException(BaseException):
190     """Raised when we receive an unexpected/unknown exception from a worker."""
191     pass
192
193
194 class _Message(object):
195     def __init__(self, src, message_name, message_args, from_user, logs):
196         self.src = src
197         self.name = message_name
198         self.args = message_args
199         self.from_user = from_user
200         self.logs = logs
201
202     def __repr__(self):
203         return '_Message(src=%s, name=%s, args=%s, from_user=%s, logs=%s)' % (self.src, self.name, self.args, self.from_user, self.logs)
204
205
206 class _Worker(multiprocessing.Process):
207     def __init__(self, host, messages_to_manager, messages_to_worker, worker_factory, worker_number, running_inline, manager, log_level):
208         super(_Worker, self).__init__()
209         self.host = host
210         self.worker_number = worker_number
211         self.name = 'worker/%d' % worker_number
212         self.log_messages = []
213         self.log_level = log_level
214         self._running_inline = running_inline
215         self._manager = manager
216
217         self._messages_to_manager = messages_to_manager
218         self._messages_to_worker = messages_to_worker
219         self._worker = worker_factory(self)
220         self._logger = None
221         self._log_handler = None
222
223     def terminate(self):
224         if self._worker:
225             if hasattr(self._worker, 'stop'):
226                 self._worker.stop()
227             self._worker = None
228         if self.is_alive():
229             super(_Worker, self).terminate()
230
231     def _close(self):
232         if self._log_handler and self._logger:
233             self._logger.removeHandler(self._log_handler)
234         self._log_handler = None
235         self._logger = None
236
237     def start(self):
238         if not self._running_inline:
239             super(_Worker, self).start()
240
241     def run(self):
242         if not self.host:
243             self.host = Host()
244         if not self._running_inline:
245             self._set_up_logging()
246
247         worker = self._worker
248         exception_msg = ""
249         _log.debug("%s starting" % self.name)
250
251         try:
252             if hasattr(worker, 'start'):
253                 worker.start()
254             while True:
255                 message = self._messages_to_worker.get()
256                 if message.from_user:
257                     worker.handle(message.name, message.src, *message.args)
258                     self._yield_to_manager()
259                 else:
260                     assert message.name == 'stop', 'bad message %s' % repr(message)
261                     break
262
263             _log.debug("%s exiting" % self.name)
264         except Queue.Empty:
265             assert False, '%s: ran out of messages in worker queue.' % self.name
266         except KeyboardInterrupt, e:
267             self._raise(sys.exc_info())
268         except Exception, e:
269             self._raise(sys.exc_info())
270         finally:
271             try:
272                 if hasattr(worker, 'stop'):
273                     worker.stop()
274             finally:
275                 self._post(name='done', args=(), from_user=False)
276             self._close()
277
278     def post(self, name, *args):
279         self._post(name, args, from_user=True)
280         self._yield_to_manager()
281
282     def _yield_to_manager(self):
283         if self._running_inline:
284             self._manager._loop(block=False)
285
286     def _post(self, name, args, from_user):
287         log_messages = self.log_messages
288         self.log_messages = []
289         self._messages_to_manager.put(_Message(self.name, name, args, from_user, log_messages))
290
291     def _raise(self, exc_info):
292         exception_type, exception_value, exception_traceback = exc_info
293         if self._running_inline:
294             raise exception_type, exception_value, exception_traceback
295
296         if exception_type == KeyboardInterrupt:
297             _log.debug("%s: interrupted, exiting" % self.name)
298             stack_utils.log_traceback(_log.debug, exception_traceback)
299         else:
300             _log.error("%s: %s('%s') raised:" % (self.name, exception_value.__class__.__name__, str(exception_value)))
301             stack_utils.log_traceback(_log.error, exception_traceback)
302         # Since tracebacks aren't picklable, send the extracted stack instead.
303         stack = traceback.extract_tb(exception_traceback)
304         self._post(name='worker_exception', args=(exception_type, exception_value, stack), from_user=False)
305
306     def _set_up_logging(self):
307         self._logger = logging.getLogger()
308
309         # The unix multiprocessing implementation clones any log handlers into the child process,
310         # so we remove them to avoid duplicate logging.
311         for h in self._logger.handlers:
312             self._logger.removeHandler(h)
313
314         self._log_handler = _WorkerLogHandler(self)
315         self._logger.addHandler(self._log_handler)
316         self._logger.setLevel(self.log_level)
317
318
319 class _WorkerLogHandler(logging.Handler):
320     def __init__(self, worker):
321         logging.Handler.__init__(self)
322         self._worker = worker
323         self.setLevel(worker.log_level)
324
325     def emit(self, record):
326         self._worker.log_messages.append(record)