Update std::expected to match libc++ coding style
[WebKit-https.git] / Source / WebKit / NetworkProcess / webrtc / NetworkRTCProvider.cpp
1 /*
2  * Copyright (C) 2017 Apple Inc. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without
5  * modification, are permitted provided that the following conditions
6  * are met:
7  * 1. Redistributions of source code must retain the above copyright
8  *    notice, this list of conditions and the following disclaimer.
9  * 2. Redistributions in binary form must reproduce the above copyright
10  *    notice, this list of conditions and the following disclaimer in the
11  *    documentation and/or other materials provided with the distribution.
12  *
13  * THIS SOFTWARE IS PROVIDED BY APPLE INC. AND ITS CONTRIBUTORS ``AS IS''
14  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
15  * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
16  * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL APPLE INC. OR ITS CONTRIBUTORS
17  * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18  * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20  * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21  * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23  * THE POSSIBILITY OF SUCH DAMAGE.
24  */
25
26 #include "config.h"
27 #include "NetworkRTCProvider.h"
28
29 #if USE(LIBWEBRTC)
30
31 #include "Logging.h"
32 #include "NetworkConnectionToWebProcess.h"
33 #include "NetworkProcess.h"
34 #include "NetworkRTCResolver.h"
35 #include "NetworkRTCSocket.h"
36 #include "WebRTCResolverMessages.h"
37 #include "WebRTCSocketMessages.h"
38 #include <WebCore/LibWebRTCMacros.h>
39 #include <webrtc/base/asyncpacketsocket.h>
40 #include <wtf/MainThread.h>
41 #include <wtf/text/WTFString.h>
42
43 namespace WebKit {
44
45 static inline std::unique_ptr<rtc::Thread> createThread()
46 {
47     auto thread = rtc::Thread::CreateWithSocketServer();
48     auto result = thread->Start();
49     ASSERT_UNUSED(result, result);
50     // FIXME: Set thread name.
51     return thread;
52 }
53
54 NetworkRTCProvider::NetworkRTCProvider(NetworkConnectionToWebProcess& connection)
55     : m_connection(&connection)
56     , m_rtcMonitor(*this)
57     , m_rtcNetworkThread(createThread())
58     , m_packetSocketFactory(makeUniqueRef<rtc::BasicPacketSocketFactory>(m_rtcNetworkThread.get()))
59 {
60 #if defined(NDEBUG)
61     rtc::LogMessage::LogToDebug(rtc::LS_NONE);
62 #else
63     if (WebKit2LogWebRTC.state != WTFLogChannelOn)
64         rtc::LogMessage::LogToDebug(rtc::LS_WARNING);
65 #endif
66 }
67
68 NetworkRTCProvider::~NetworkRTCProvider()
69 {
70     ASSERT(!m_connection);
71     ASSERT(!m_sockets.size());
72     ASSERT(!m_rtcMonitor.isStarted());
73 }
74
75 void NetworkRTCProvider::close()
76 {
77     // Cancel all pending DNS resolutions.
78     while (!m_resolvers.isEmpty())
79         stopResolver(*m_resolvers.keys().begin());
80
81     m_connection = nullptr;
82     m_rtcMonitor.stopUpdating();
83
84     callOnRTCNetworkThread([this]() {
85         m_sockets.clear();
86         callOnMainThread([provider = makeRef(*this)]() {
87             if (provider->m_rtcNetworkThread)
88                 provider->m_rtcNetworkThread->Stop();
89         });
90     });
91 }
92
93 void NetworkRTCProvider::createSocket(uint64_t identifier, std::unique_ptr<rtc::AsyncPacketSocket>&& socket, LibWebRTCSocketClient::Type type)
94 {
95     if (!socket) {
96         sendFromMainThread([identifier](IPC::Connection& connection) {
97             connection.send(Messages::WebRTCSocket::SignalClose(1), identifier);
98         });
99         return;
100     }
101     addSocket(identifier, std::make_unique<LibWebRTCSocketClient>(identifier, *this, WTFMove(socket), type));
102 }
103
104 void NetworkRTCProvider::createUDPSocket(uint64_t identifier, const RTCNetwork::SocketAddress& address, uint16_t minPort, uint16_t maxPort)
105 {
106     callOnRTCNetworkThread([this, identifier, address = RTCNetwork::isolatedCopy(address.value), minPort, maxPort]() {
107         std::unique_ptr<rtc::AsyncPacketSocket> socket(m_packetSocketFactory->CreateUdpSocket(address, minPort, maxPort));
108         createSocket(identifier, WTFMove(socket), LibWebRTCSocketClient::Type::UDP);
109     });
110 }
111
112 void NetworkRTCProvider::createServerTCPSocket(uint64_t identifier, const RTCNetwork::SocketAddress& address, uint16_t minPort, uint16_t maxPort, int options)
113 {
114     if (!m_isListeningSocketAuthorized) {
115         if (m_connection)
116             m_connection->connection().send(Messages::WebRTCSocket::SignalClose(1), identifier);
117         return;
118     }
119
120     callOnRTCNetworkThread([this, identifier, address = RTCNetwork::isolatedCopy(address.value), minPort, maxPort, options]() {
121         std::unique_ptr<rtc::AsyncPacketSocket> socket(m_packetSocketFactory->CreateServerTcpSocket(address, minPort, maxPort, options));
122         createSocket(identifier, WTFMove(socket), LibWebRTCSocketClient::Type::ServerTCP);
123     });
124 }
125
126 void NetworkRTCProvider::createClientTCPSocket(uint64_t identifier, const RTCNetwork::SocketAddress& localAddress, const RTCNetwork::SocketAddress& remoteAddress, int options)
127 {
128     callOnRTCNetworkThread([this, identifier, localAddress = RTCNetwork::isolatedCopy(localAddress.value), remoteAddress = RTCNetwork::isolatedCopy(remoteAddress.value), options]() {
129         std::unique_ptr<rtc::AsyncPacketSocket> socket(m_packetSocketFactory->CreateClientTcpSocket(localAddress, remoteAddress, { }, { }, options));
130         createSocket(identifier, WTFMove(socket), LibWebRTCSocketClient::Type::ClientTCP);
131     });
132 }
133
134 void NetworkRTCProvider::wrapNewTCPConnection(uint64_t identifier, uint64_t newConnectionSocketIdentifier)
135 {
136     callOnRTCNetworkThread([this, identifier, newConnectionSocketIdentifier]() {
137         std::unique_ptr<rtc::AsyncPacketSocket> socket = m_pendingIncomingSockets.take(newConnectionSocketIdentifier);
138         addSocket(identifier, std::make_unique<LibWebRTCSocketClient>(identifier, *this, WTFMove(socket), LibWebRTCSocketClient::Type::ServerConnectionTCP));
139     });
140 }
141
142 void NetworkRTCProvider::addSocket(uint64_t identifier, std::unique_ptr<LibWebRTCSocketClient>&& socket)
143 {
144     m_sockets.add(identifier, WTFMove(socket));
145 }
146
147 std::unique_ptr<LibWebRTCSocketClient> NetworkRTCProvider::takeSocket(uint64_t identifier)
148 {
149     return m_sockets.take(identifier);
150 }
151
152 void NetworkRTCProvider::newConnection(LibWebRTCSocketClient& serverSocket, std::unique_ptr<rtc::AsyncPacketSocket>&& newSocket)
153 {
154     sendFromMainThread([identifier = serverSocket.identifier(), incomingSocketIdentifier = ++m_incomingSocketIdentifier, remoteAddress = RTCNetwork::isolatedCopy(newSocket->GetRemoteAddress())](IPC::Connection& connection) {
155         connection.send(Messages::WebRTCSocket::SignalNewConnection(incomingSocketIdentifier, RTCNetwork::SocketAddress(remoteAddress)), identifier);
156     });
157     m_pendingIncomingSockets.add(m_incomingSocketIdentifier, WTFMove(newSocket));
158 }
159
160 void NetworkRTCProvider::didReceiveNetworkRTCSocketMessage(IPC::Connection& connection, IPC::Decoder& decoder)
161 {
162     NetworkRTCSocket(decoder.destinationID(), *this).didReceiveMessage(connection, decoder);
163 }
164
165 void NetworkRTCProvider::createResolver(uint64_t identifier, const String& address)
166 {
167     auto resolver = std::make_unique<NetworkRTCResolver>([this, identifier](NetworkRTCResolver::AddressesOrError&& result) mutable {
168         if (!result.has_value()) {
169             if (result.error() != NetworkRTCResolver::Error::Cancelled)
170                 m_connection->connection().send(Messages::WebRTCResolver::ResolvedAddressError(1), identifier);
171             return;
172         }
173         m_connection->connection().send(Messages::WebRTCResolver::SetResolvedAddress(result.value()), identifier);
174     });
175     resolver->start(address);
176     m_resolvers.add(identifier, WTFMove(resolver));
177 }
178
179 void NetworkRTCProvider::stopResolver(uint64_t identifier)
180 {
181     if (auto resolver = m_resolvers.take(identifier))
182         resolver->stop();
183 }
184
185 void NetworkRTCProvider::closeListeningSockets(Function<void()>&& completionHandler)
186 {
187     if (!m_isListeningSocketAuthorized) {
188         completionHandler();
189         return;
190     }
191
192     m_isListeningSocketAuthorized = false;
193     callOnRTCNetworkThread([this, completionHandler = WTFMove(completionHandler)]() mutable {
194         Vector<uint64_t> listeningSocketIdentifiers;
195         for (auto& keyValue : m_sockets) {
196             if (keyValue.value->type() == LibWebRTCSocketClient::Type::ServerTCP)
197                 listeningSocketIdentifiers.append(keyValue.key);
198         }
199         for (auto id : listeningSocketIdentifiers)
200             m_sockets.get(id)->close();
201
202         callOnMainThread([provider = makeRef(*this), listeningSocketIdentifiers = WTFMove(listeningSocketIdentifiers), completionHandler = WTFMove(completionHandler)] {
203             if (provider->m_connection) {
204                 for (auto identifier : listeningSocketIdentifiers)
205                     provider->m_connection->connection().send(Messages::WebRTCSocket::SignalClose(ECONNABORTED), identifier);
206             }
207             completionHandler();
208         });
209     });
210 }
211
212 struct NetworkMessageData : public rtc::MessageData {
213     NetworkMessageData(Ref<NetworkRTCProvider>&& rtcProvider, Function<void()>&& callback)
214         : rtcProvider(WTFMove(rtcProvider))
215         , callback(WTFMove(callback))
216     { }
217     Ref<NetworkRTCProvider> rtcProvider;
218     Function<void()> callback;
219 };
220
221 void NetworkRTCProvider::OnMessage(rtc::Message* message)
222 {
223     ASSERT(message->message_id == 1);
224     auto* data = static_cast<NetworkMessageData*>(message->pdata);
225     data->callback();
226     delete data;
227 }
228
229 void NetworkRTCProvider::callOnRTCNetworkThread(Function<void()>&& callback)
230 {
231     m_rtcNetworkThread->Post(RTC_FROM_HERE, this, 1, new NetworkMessageData(*this, WTFMove(callback)));
232 }
233
234 void NetworkRTCProvider::callSocket(uint64_t identifier, Function<void(LibWebRTCSocketClient&)>&& callback)
235 {
236     callOnRTCNetworkThread([this, identifier, callback = WTFMove(callback)]() {
237         if (auto* socket = m_sockets.get(identifier))
238             callback(*socket);
239     });
240 }
241
242 void NetworkRTCProvider::sendFromMainThread(Function<void(IPC::Connection&)>&& callback)
243 {
244     callOnMainThread([provider = makeRef(*this), callback = WTFMove(callback)]() {
245         if (provider->m_connection)
246             callback(provider->m_connection->connection());
247     });
248 }
249
250 } // namespace WebKit
251
252 #endif // USE(LIBWEBRTC)