Unreviewed. Update W3C WebDriver imported tests.
[WebKit-https.git] / WebDriverTests / imported / w3c / tools / wptrunner / wptrunner / environment.py
1 import json
2 import os
3 import multiprocessing
4 import signal
5 import socket
6 import sys
7 import time
8
9 from mozlog import get_default_logger, handlers, proxy
10
11 from wptlogging import LogLevelRewriter
12 from wptserve.handlers import StringHandler
13
14 here = os.path.split(__file__)[0]
15 repo_root = os.path.abspath(os.path.join(here, os.pardir, os.pardir, os.pardir))
16
17 serve = None
18 sslutils = None
19
20
21 hostnames = ["web-platform.test",
22              "www.web-platform.test",
23              "www1.web-platform.test",
24              "www2.web-platform.test",
25              "xn--n8j6ds53lwwkrqhv28a.web-platform.test",
26              "xn--lve-6lad.web-platform.test"]
27
28
29 def do_delayed_imports(logger, test_paths):
30     global serve, sslutils
31
32     serve_root = serve_path(test_paths)
33     sys.path.insert(0, serve_root)
34
35     failed = []
36
37     try:
38         from tools.serve import serve
39     except ImportError:
40         from wpt_tools.serve import serve
41     except ImportError:
42         failed.append("serve")
43
44     try:
45         import sslutils
46     except ImportError:
47         failed.append("sslutils")
48
49     if failed:
50         logger.critical(
51             "Failed to import %s. Ensure that tests path %s contains web-platform-tests" %
52             (", ".join(failed), serve_root))
53         sys.exit(1)
54
55
56 def serve_path(test_paths):
57     return test_paths["/"]["tests_path"]
58
59
60 def get_ssl_kwargs(**kwargs):
61     if kwargs["ssl_type"] == "openssl":
62         args = {"openssl_binary": kwargs["openssl_binary"]}
63     elif kwargs["ssl_type"] == "pregenerated":
64         args = {"host_key_path": kwargs["host_key_path"],
65                 "host_cert_path": kwargs["host_cert_path"],
66                  "ca_cert_path": kwargs["ca_cert_path"]}
67     else:
68         args = {}
69     return args
70
71
72 def ssl_env(logger, **kwargs):
73     ssl_env_cls = sslutils.environments[kwargs["ssl_type"]]
74     return ssl_env_cls(logger, **get_ssl_kwargs(**kwargs))
75
76
77 class TestEnvironmentError(Exception):
78     pass
79
80
81 class TestEnvironment(object):
82     def __init__(self, test_paths, ssl_env, pause_after_test, debug_info, options, env_extras):
83         """Context manager that owns the test environment i.e. the http and
84         websockets servers"""
85         self.test_paths = test_paths
86         self.ssl_env = ssl_env
87         self.server = None
88         self.config = None
89         self.external_config = None
90         self.pause_after_test = pause_after_test
91         self.test_server_port = options.pop("test_server_port", True)
92         self.debug_info = debug_info
93         self.options = options if options is not None else {}
94
95         self.cache_manager = multiprocessing.Manager()
96         self.stash = serve.stash.StashServer()
97         self.env_extras = env_extras
98
99
100     def __enter__(self):
101         self.stash.__enter__()
102         self.ssl_env.__enter__()
103         self.cache_manager.__enter__()
104         for cm in self.env_extras:
105             cm.__enter__(self.options)
106         self.setup_server_logging()
107         self.config = self.load_config()
108         serve.set_computed_defaults(self.config)
109         self.external_config, self.servers = serve.start(self.config, self.ssl_env,
110                                                          self.get_routes())
111         if self.options.get("supports_debugger") and self.debug_info and self.debug_info.interactive:
112             self.ignore_interrupts()
113         return self
114
115     def __exit__(self, exc_type, exc_val, exc_tb):
116         self.process_interrupts()
117
118         for scheme, servers in self.servers.iteritems():
119             for port, server in servers:
120                 server.kill()
121         for cm in self.env_extras:
122             cm.__exit__(exc_type, exc_val, exc_tb)
123         self.cache_manager.__exit__(exc_type, exc_val, exc_tb)
124         self.ssl_env.__exit__(exc_type, exc_val, exc_tb)
125         self.stash.__exit__()
126
127     def ignore_interrupts(self):
128         signal.signal(signal.SIGINT, signal.SIG_IGN)
129
130     def process_interrupts(self):
131         signal.signal(signal.SIGINT, signal.SIG_DFL)
132
133     def load_config(self):
134         default_config_path = os.path.join(serve_path(self.test_paths), "config.default.json")
135         local_config_path = os.path.join(here, "config.json")
136
137         with open(default_config_path) as f:
138             default_config = json.load(f)
139
140         with open(local_config_path) as f:
141             data = f.read()
142             local_config = json.loads(data % self.options)
143
144         #TODO: allow non-default configuration for ssl
145
146         local_config["external_host"] = self.options.get("external_host", None)
147         local_config["ssl"]["encrypt_after_connect"] = self.options.get("encrypt_after_connect", False)
148
149         config = serve.merge_json(default_config, local_config)
150         config["doc_root"] = serve_path(self.test_paths)
151
152         if not self.ssl_env.ssl_enabled:
153             config["ports"]["https"] = [None]
154
155         host = self.options.get("certificate_domain", config["host"])
156         hosts = [host]
157         hosts.extend("%s.%s" % (item[0], host) for item in serve.get_subdomains(host).values())
158         key_file, certificate = self.ssl_env.host_cert_path(hosts)
159
160         config["key_file"] = key_file
161         config["certificate"] = certificate
162
163         return config
164
165     def setup_server_logging(self):
166         server_logger = get_default_logger(component="wptserve")
167         assert server_logger is not None
168         log_filter = handlers.LogLevelFilter(lambda x:x, "info")
169         # Downgrade errors to warnings for the server
170         log_filter = LogLevelRewriter(log_filter, ["error"], "warning")
171         server_logger.component_filter = log_filter
172
173         server_logger = proxy.QueuedProxyLogger(server_logger)
174
175         try:
176             #Set as the default logger for wptserve
177             serve.set_logger(server_logger)
178             serve.logger = server_logger
179         except Exception:
180             # This happens if logging has already been set up for wptserve
181             pass
182
183     def get_routes(self):
184         route_builder = serve.RoutesBuilder()
185
186         for path, format_args, content_type, route in [
187                 ("testharness_runner.html", {}, "text/html", "/testharness_runner.html"),
188                 (self.options.get("testharnessreport", "testharnessreport.js"),
189                  {"output": self.pause_after_test}, "text/javascript",
190                  "/resources/testharnessreport.js")]:
191             path = os.path.normpath(os.path.join(here, path))
192             route_builder.add_static(path, format_args, content_type, route)
193
194         data = b""
195         with open(os.path.join(repo_root, "resources", "testdriver.js"), "rb") as fp:
196             data += fp.read()
197         with open(os.path.join(here, "testdriver-extra.js"), "rb") as fp:
198             data += fp.read()
199         route_builder.add_handler(b"GET", b"/resources/testdriver.js",
200                                   StringHandler(data, "text/javascript"))
201
202         for url_base, paths in self.test_paths.iteritems():
203             if url_base == "/":
204                 continue
205             route_builder.add_mount_point(url_base, paths["tests_path"])
206
207         if "/" not in self.test_paths:
208             del route_builder.mountpoint_routes["/"]
209
210         return route_builder.get_routes()
211
212     def ensure_started(self):
213         # Pause for a while to ensure that the server has a chance to start
214         for _ in xrange(20):
215             failed = self.test_servers()
216             if not failed:
217                 return
218             time.sleep(0.5)
219         raise EnvironmentError("Servers failed to start (scheme:port): %s" % ("%s:%s" for item in failed))
220
221     def test_servers(self):
222         failed = []
223         for scheme, servers in self.servers.iteritems():
224             for port, server in servers:
225                 if self.test_server_port:
226                     s = socket.socket()
227                     try:
228                         s.connect((self.config["host"], port))
229                     except socket.error:
230                         failed.append((scheme, port))
231                     finally:
232                         s.close()
233
234                 if not server.is_alive():
235                     failed.append((scheme, port))