Implement table-based switches in B3/Air
[WebKit-https.git] / Source / JavaScriptCore / b3 / B3LowerMacros.cpp
1 /*
2  * Copyright (C) 2015-2016 Apple Inc. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without
5  * modification, are permitted provided that the following conditions
6  * are met:
7  * 1. Redistributions of source code must retain the above copyright
8  *    notice, this list of conditions and the following disclaimer.
9  * 2. Redistributions in binary form must reproduce the above copyright
10  *    notice, this list of conditions and the following disclaimer in the
11  *    documentation and/or other materials provided with the distribution.
12  *
13  * THIS SOFTWARE IS PROVIDED BY APPLE INC. ``AS IS'' AND ANY
14  * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
16  * PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL APPLE INC. OR
17  * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
18  * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
19  * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
20  * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
21  * OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
22  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
23  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 
24  */
25
26 #include "config.h"
27 #include "B3LowerMacros.h"
28
29 #if ENABLE(B3_JIT)
30
31 #include "AllowMacroScratchRegisterUsage.h"
32 #include "B3BasicBlockInlines.h"
33 #include "B3BlockInsertionSet.h"
34 #include "B3CCallValue.h"
35 #include "B3CaseCollectionInlines.h"
36 #include "B3ConstPtrValue.h"
37 #include "B3InsertionSetInlines.h"
38 #include "B3MemoryValue.h"
39 #include "B3PatchpointValue.h"
40 #include "B3PhaseScope.h"
41 #include "B3ProcedureInlines.h"
42 #include "B3StackmapGenerationParams.h"
43 #include "B3SwitchValue.h"
44 #include "B3UpsilonValue.h"
45 #include "B3ValueInlines.h"
46 #include "CCallHelpers.h"
47 #include "LinkBuffer.h"
48 #include <cmath>
49
50 namespace JSC { namespace B3 {
51
52 namespace {
53
54 class LowerMacros {
55 public:
56     LowerMacros(Procedure& proc)
57         : m_proc(proc)
58         , m_blockInsertionSet(proc)
59         , m_insertionSet(proc)
60     {
61     }
62
63     bool run()
64     {
65         for (BasicBlock* block : m_proc) {
66             m_block = block;
67             processCurrentBlock();
68         }
69         m_changed |= m_blockInsertionSet.execute();
70         if (m_changed) {
71             m_proc.resetReachability();
72             m_proc.invalidateCFG();
73         }
74         return m_changed;
75     }
76     
77 private:
78     void processCurrentBlock()
79     {
80         for (m_index = 0; m_index < m_block->size(); ++m_index) {
81             m_value = m_block->at(m_index);
82             m_origin = m_value->origin();
83             switch (m_value->opcode()) {
84             case Mod: {
85                 double (*fmodDouble)(double, double) = fmod;
86                 if (m_value->type() == Double) {
87                     Value* functionAddress = m_insertionSet.insert<ConstPtrValue>(m_index, m_origin, fmodDouble);
88                     Value* result = m_insertionSet.insert<CCallValue>(m_index, Double, m_origin,
89                         Effects::none(),
90                         functionAddress,
91                         m_value->child(0),
92                         m_value->child(1));
93                     m_value->replaceWithIdentity(result);
94                     m_changed = true;
95                 } else if (m_value->type() == Float) {
96                     Value* numeratorAsDouble = m_insertionSet.insert<Value>(m_index, FloatToDouble, m_origin, m_value->child(0));
97                     Value* denominatorAsDouble = m_insertionSet.insert<Value>(m_index, FloatToDouble, m_origin, m_value->child(1));
98                     Value* functionAddress = m_insertionSet.insert<ConstPtrValue>(m_index, m_origin, fmodDouble);
99                     Value* doubleMod = m_insertionSet.insert<CCallValue>(m_index, Double, m_origin,
100                         Effects::none(),
101                         functionAddress,
102                         numeratorAsDouble,
103                         denominatorAsDouble);
104                     Value* result = m_insertionSet.insert<Value>(m_index, DoubleToFloat, m_origin, doubleMod);
105                     m_value->replaceWithIdentity(result);
106                     m_changed = true;
107                 } else if (isARM64()) {
108                     Value* divResult = m_insertionSet.insert<Value>(m_index, ChillDiv, m_origin, m_value->child(0), m_value->child(1));
109                     Value* multipliedBack = m_insertionSet.insert<Value>(m_index, Mul, m_origin, divResult, m_value->child(1));
110                     Value* result = m_insertionSet.insert<Value>(m_index, Sub, m_origin, m_value->child(0), multipliedBack);
111                     m_value->replaceWithIdentity(result);
112                     m_changed = true;
113                 }
114                 break;
115             }
116             case ChillDiv: {
117                 makeDivisionChill(Div);
118                 break;
119             }
120
121             case ChillMod: {
122                 if (isARM64()) {
123                     BasicBlock* before = m_blockInsertionSet.splitForward(m_block, m_index, &m_insertionSet);
124                     BasicBlock* zeroDenCase = m_blockInsertionSet.insertBefore(m_block);
125                     BasicBlock* normalModCase = m_blockInsertionSet.insertBefore(m_block);
126
127                     before->replaceLastWithNew<Value>(m_proc, Branch, m_origin, m_value->child(1));
128                     before->setSuccessors(
129                         FrequentedBlock(normalModCase, FrequencyClass::Normal),
130                         FrequentedBlock(zeroDenCase, FrequencyClass::Rare));
131
132                     Value* divResult = normalModCase->appendNew<Value>(m_proc, ChillDiv, m_origin, m_value->child(0), m_value->child(1));
133                     Value* multipliedBack = normalModCase->appendNew<Value>(m_proc, Mul, m_origin, divResult, m_value->child(1));
134                     Value* result = normalModCase->appendNew<Value>(m_proc, Sub, m_origin, m_value->child(0), multipliedBack);
135                     UpsilonValue* normalResult = normalModCase->appendNew<UpsilonValue>(m_proc, m_origin, result);
136                     normalModCase->appendNew<Value>(m_proc, Jump, m_origin);
137                     normalModCase->setSuccessors(FrequentedBlock(m_block));
138
139                     UpsilonValue* zeroResult = zeroDenCase->appendNew<UpsilonValue>(
140                         m_proc, m_origin,
141                         zeroDenCase->appendIntConstant(m_proc, m_value, 0));
142                     zeroDenCase->appendNew<Value>(m_proc, Jump, m_origin);
143                     zeroDenCase->setSuccessors(FrequentedBlock(m_block));
144
145                     Value* phi = m_insertionSet.insert<Value>(m_index, Phi, m_value->type(), m_origin);
146                     normalResult->setPhi(phi);
147                     zeroResult->setPhi(phi);
148                     m_value->replaceWithIdentity(phi);
149                     m_changed = true;
150                 } else
151                     makeDivisionChill(Mod);
152                 break;
153             }
154
155             case Switch: {
156                 SwitchValue* switchValue = m_value->as<SwitchValue>();
157                 Vector<SwitchCase> cases;
158                 for (const SwitchCase& switchCase : switchValue->cases(m_block))
159                     cases.append(switchCase);
160                 std::sort(
161                     cases.begin(), cases.end(),
162                     [] (const SwitchCase& left, const SwitchCase& right) {
163                         return left.caseValue() < right.caseValue();
164                     });
165                 FrequentedBlock fallThrough = m_block->fallThrough();
166                 m_block->values().removeLast();
167                 recursivelyBuildSwitch(cases, fallThrough, 0, false, cases.size(), m_block);
168                 m_proc.deleteValue(switchValue);
169                 m_block->updatePredecessorsAfter();
170                 m_changed = true;
171                 break;
172             }
173
174             default:
175                 break;
176             }
177         }
178         m_insertionSet.execute(m_block);
179     }
180
181     void makeDivisionChill(Opcode nonChillOpcode)
182     {
183         ASSERT(nonChillOpcode == Div || nonChillOpcode == Mod);
184
185         // ARM supports this instruction natively.
186         if (isARM64())
187             return;
188
189         // We implement "res = ChillDiv/ChillMod(num, den)" as follows:
190         //
191         //     if (den + 1 <=_unsigned 1) {
192         //         if (!den) {
193         //             res = 0;
194         //             goto done;
195         //         }
196         //         if (num == -2147483648) {
197         //             res = isDiv ? num : 0;
198         //             goto done;
199         //         }
200         //     }
201         //     res = num (/ or %) dev;
202         // done:
203         m_changed = true;
204
205         Value* num = m_value->child(0);
206         Value* den = m_value->child(1);
207
208         Value* one = m_insertionSet.insertIntConstant(m_index, m_value, 1);
209         Value* isDenOK = m_insertionSet.insert<Value>(
210             m_index, Above, m_origin,
211             m_insertionSet.insert<Value>(m_index, Add, m_origin, den, one),
212             one);
213
214         BasicBlock* before = m_blockInsertionSet.splitForward(m_block, m_index, &m_insertionSet);
215
216         BasicBlock* normalDivCase = m_blockInsertionSet.insertBefore(m_block);
217         BasicBlock* shadyDenCase = m_blockInsertionSet.insertBefore(m_block);
218         BasicBlock* zeroDenCase = m_blockInsertionSet.insertBefore(m_block);
219         BasicBlock* neg1DenCase = m_blockInsertionSet.insertBefore(m_block);
220         BasicBlock* intMinCase = m_blockInsertionSet.insertBefore(m_block);
221
222         before->replaceLastWithNew<Value>(m_proc, Branch, m_origin, isDenOK);
223         before->setSuccessors(
224             FrequentedBlock(normalDivCase, FrequencyClass::Normal),
225             FrequentedBlock(shadyDenCase, FrequencyClass::Rare));
226
227         UpsilonValue* normalResult = normalDivCase->appendNew<UpsilonValue>(
228             m_proc, m_origin,
229             normalDivCase->appendNew<Value>(m_proc, nonChillOpcode, m_origin, num, den));
230         normalDivCase->appendNew<Value>(m_proc, Jump, m_origin);
231         normalDivCase->setSuccessors(FrequentedBlock(m_block));
232
233         shadyDenCase->appendNew<Value>(m_proc, Branch, m_origin, den);
234         shadyDenCase->setSuccessors(
235             FrequentedBlock(neg1DenCase, FrequencyClass::Normal),
236             FrequentedBlock(zeroDenCase, FrequencyClass::Rare));
237
238         UpsilonValue* zeroResult = zeroDenCase->appendNew<UpsilonValue>(
239             m_proc, m_origin,
240             zeroDenCase->appendIntConstant(m_proc, m_value, 0));
241         zeroDenCase->appendNew<Value>(m_proc, Jump, m_origin);
242         zeroDenCase->setSuccessors(FrequentedBlock(m_block));
243
244         int64_t badNumeratorConst = 0;
245         switch (m_value->type()) {
246         case Int32:
247             badNumeratorConst = std::numeric_limits<int32_t>::min();
248             break;
249         case Int64:
250             badNumeratorConst = std::numeric_limits<int64_t>::min();
251             break;
252         default:
253             ASSERT_NOT_REACHED();
254             badNumeratorConst = 0;
255         }
256
257         Value* badNumerator =
258             neg1DenCase->appendIntConstant(m_proc, m_value, badNumeratorConst);
259
260         neg1DenCase->appendNew<Value>(
261             m_proc, Branch, m_origin,
262             neg1DenCase->appendNew<Value>(
263                 m_proc, Equal, m_origin, num, badNumerator));
264         neg1DenCase->setSuccessors(
265             FrequentedBlock(intMinCase, FrequencyClass::Rare),
266             FrequentedBlock(normalDivCase, FrequencyClass::Normal));
267
268         Value* intMinResult = nonChillOpcode == Div ? badNumerator : intMinCase->appendIntConstant(m_proc, m_value, 0);
269         UpsilonValue* intMinResultUpsilon = intMinCase->appendNew<UpsilonValue>(
270             m_proc, m_origin, intMinResult);
271         intMinCase->appendNew<Value>(m_proc, Jump, m_origin);
272         intMinCase->setSuccessors(FrequentedBlock(m_block));
273
274         Value* phi = m_insertionSet.insert<Value>(
275             m_index, Phi, m_value->type(), m_origin);
276         normalResult->setPhi(phi);
277         zeroResult->setPhi(phi);
278         intMinResultUpsilon->setPhi(phi);
279
280         m_value->replaceWithIdentity(phi);
281         before->updatePredecessorsAfter();
282     }
283
284     void recursivelyBuildSwitch(
285         const Vector<SwitchCase>& cases, FrequentedBlock fallThrough, unsigned start, bool hardStart,
286         unsigned end, BasicBlock* before)
287     {
288         Value* child = m_value->child(0);
289         Type type = child->type();
290         
291         // It's a good idea to use a table-based switch in some cases: the number of cases has to be
292         // large enough and they have to be dense enough. This could probably be improved a lot. For
293         // example, we could still use a jump table in cases where the inputs are sparse so long as we
294         // shift off the uninteresting bits. On the other hand, it's not clear that this would
295         // actually be any better than what we have done here and it's not clear that it would be
296         // better than a binary switch.
297         const unsigned minCasesForTable = 7;
298         const unsigned densityLimit = 4;
299         if (end - start >= minCasesForTable) {
300             int64_t firstValue = cases[start].caseValue();
301             int64_t lastValue = cases[end - 1].caseValue();
302             if ((lastValue - firstValue + 1) / (end - start) < densityLimit) {
303                 BasicBlock* switchBlock = m_blockInsertionSet.insertAfter(m_block);
304                 Value* index = before->appendNew<Value>(
305                     m_proc, Sub, m_origin, child,
306                     before->appendIntConstant(m_proc, m_origin, type, firstValue));
307                 before->appendNew<Value>(
308                     m_proc, Branch, m_origin,
309                     before->appendNew<Value>(
310                         m_proc, Above, m_origin, index,
311                         before->appendIntConstant(m_proc, m_origin, type, lastValue - firstValue)));
312                 before->setSuccessors(fallThrough, FrequentedBlock(switchBlock));
313                 
314                 size_t tableSize = lastValue - firstValue + 1;
315                 
316                 if (index->type() != pointerType() && index->type() == Int32)
317                     index = switchBlock->appendNew<Value>(m_proc, ZExt32, m_origin, index);
318                 
319                 PatchpointValue* patchpoint =
320                     switchBlock->appendNew<PatchpointValue>(m_proc, Void, m_origin);
321
322                 // Even though this loads from the jump table, the jump table is immutable. For the
323                 // purpose of alias analysis, reading something immutable is like reading nothing.
324                 patchpoint->effects = Effects();
325                 patchpoint->effects.terminal = true;
326                 
327                 patchpoint->appendSomeRegister(index);
328                 patchpoint->numGPScratchRegisters++;
329                 // Technically, we don't have to clobber macro registers on X86_64. This is probably
330                 // OK though.
331                 patchpoint->clobber(RegisterSet::macroScratchRegisters());
332                 
333                 BitVector handledIndices;
334                 for (unsigned i = start; i < end; ++i) {
335                     FrequentedBlock block = cases[i].target();
336                     int64_t value = cases[i].caseValue();
337                     switchBlock->appendSuccessor(block);
338                     size_t index = value - firstValue;
339                     ASSERT(!handledIndices.get(index));
340                     handledIndices.set(index);
341                 }
342                 
343                 bool hasUnhandledIndex = false;
344                 for (unsigned i = 0; i < tableSize; ++i) {
345                     if (!handledIndices.get(i)) {
346                         hasUnhandledIndex = true;
347                         break;
348                     }
349                 }
350                 
351                 if (hasUnhandledIndex)
352                     switchBlock->appendSuccessor(fallThrough);
353
354                 patchpoint->setGenerator(
355                     [=] (CCallHelpers& jit, const StackmapGenerationParams& params) {
356                         AllowMacroScratchRegisterUsage allowScratch(jit);
357                         
358                         MacroAssemblerCodePtr* jumpTable = static_cast<MacroAssemblerCodePtr*>(
359                             params.proc().addDataSection(sizeof(MacroAssemblerCodePtr) * tableSize));
360                         
361                         GPRReg index = params[0].gpr();
362                         GPRReg scratch = params.gpScratch(0);
363                         
364                         jit.move(CCallHelpers::TrustedImmPtr(jumpTable), scratch);
365                         jit.jump(CCallHelpers::BaseIndex(scratch, index, CCallHelpers::timesPtr()));
366                         
367                         // These labels are guaranteed to be populated before either late paths or
368                         // link tasks run.
369                         Vector<Box<CCallHelpers::Label>> labels = params.successorLabels();
370                         
371                         jit.addLinkTask(
372                             [=] (LinkBuffer& linkBuffer) {
373                                 if (hasUnhandledIndex) {
374                                     MacroAssemblerCodePtr fallThrough =
375                                         linkBuffer.locationOf(*labels.last());
376                                     for (unsigned i = tableSize; i--;)
377                                         jumpTable[i] = fallThrough;
378                                 }
379                                 
380                                 unsigned labelIndex = 0;
381                                 for (unsigned tableIndex : handledIndices) {
382                                     jumpTable[tableIndex] =
383                                         linkBuffer.locationOf(*labels[labelIndex++]);
384                                 }
385                             });
386                     });
387                 return;
388             }
389         }
390         
391         // See comments in jit/BinarySwitch.cpp for a justification of this algorithm. The only
392         // thing we do differently is that we don't use randomness.
393
394         const unsigned leafThreshold = 3;
395
396         unsigned size = end - start;
397
398         if (size <= leafThreshold) {
399             bool allConsecutive = false;
400
401             if ((hardStart || (start && cases[start - 1].caseValue() == cases[start].caseValue() - 1))
402                 && end < cases.size()
403                 && cases[end - 1].caseValue() == cases[end].caseValue() - 1) {
404                 allConsecutive = true;
405                 for (unsigned i = 0; i < size - 1; ++i) {
406                     if (cases[start + i].caseValue() + 1 != cases[start + i + 1].caseValue()) {
407                         allConsecutive = false;
408                         break;
409                     }
410                 }
411             }
412
413             unsigned limit = allConsecutive ? size - 1 : size;
414             
415             for (unsigned i = 0; i < limit; ++i) {
416                 BasicBlock* nextCheck = m_blockInsertionSet.insertAfter(m_block);
417                 before->appendNew<Value>(
418                     m_proc, Branch, m_origin,
419                     before->appendNew<Value>(
420                         m_proc, Equal, m_origin, child,
421                         before->appendIntConstant(
422                             m_proc, m_origin, type,
423                             cases[start + i].caseValue())));
424                 before->setSuccessors(cases[start + i].target(), FrequentedBlock(nextCheck));
425
426                 before = nextCheck;
427             }
428
429             before->appendNew<Value>(m_proc, Jump, m_origin);
430             if (allConsecutive)
431                 before->setSuccessors(cases[end - 1].target());
432             else
433                 before->setSuccessors(fallThrough);
434             return;
435         }
436
437         unsigned medianIndex = (start + end) / 2;
438
439         BasicBlock* left = m_blockInsertionSet.insertAfter(m_block);
440         BasicBlock* right = m_blockInsertionSet.insertAfter(m_block);
441
442         before->appendNew<Value>(
443             m_proc, Branch, m_origin,
444             before->appendNew<Value>(
445                 m_proc, LessThan, m_origin, child,
446                 before->appendIntConstant(
447                     m_proc, m_origin, type,
448                     cases[medianIndex].caseValue())));
449         before->setSuccessors(FrequentedBlock(left), FrequentedBlock(right));
450
451         recursivelyBuildSwitch(cases, fallThrough, start, hardStart, medianIndex, left);
452         recursivelyBuildSwitch(cases, fallThrough, medianIndex, true, end, right);
453     }
454     
455     Procedure& m_proc;
456     BlockInsertionSet m_blockInsertionSet;
457     InsertionSet m_insertionSet;
458     BasicBlock* m_block;
459     unsigned m_index;
460     Value* m_value;
461     Origin m_origin;
462     bool m_changed { false };
463 };
464
465 bool lowerMacrosImpl(Procedure& proc)
466 {
467     LowerMacros lowerMacros(proc);
468     return lowerMacros.run();
469 }
470
471 } // anonymous namespace
472
473 bool lowerMacros(Procedure& proc)
474 {
475     PhaseScope phaseScope(proc, "lowerMacros");
476     bool result = lowerMacrosImpl(proc);
477     if (shouldValidateIR())
478         RELEASE_ASSERT(!lowerMacrosImpl(proc));
479     return result;
480 }
481
482 } } // namespace JSC::B3
483
484 #endif // ENABLE(B3_JIT)
485