Teach Call ICs how to call Wasm
[WebKit-https.git] / Source / JavaScriptCore / wasm / js / WebAssemblyFunction.cpp
index 42468ef..02ef41e 100644 (file)
 #include "JSWebAssemblyMemory.h"
 #include "JSWebAssemblyRuntimeError.h"
 #include "LLIntThunks.h"
+#include "LinkBuffer.h"
 #include "ProtoCallFrame.h"
 #include "VM.h"
 #include "WasmCallee.h"
+#include "WasmCallingConvention.h"
 #include "WasmContextInlines.h"
 #include "WasmFormat.h"
 #include "WasmMemory.h"
+#include "WasmMemoryInformation.h"
+#include "WasmModuleInformation.h"
 #include "WasmSignatureInlines.h"
 #include <wtf/FastTLS.h>
 #include <wtf/StackPointer.h>
@@ -148,6 +152,335 @@ static EncodedJSValue JSC_HOST_CALL callWebAssemblyFunction(ExecState* exec)
     return rawResult;
 }
 
+bool WebAssemblyFunction::useTagRegisters() const
+{
+    const auto& signature = Wasm::SignatureInformation::get(signatureIndex());
+    return signature.argumentCount() || signature.returnType() != Wasm::Void;
+}
+
+RegisterSet WebAssemblyFunction::calleeSaves() const
+{
+    RegisterSet toSave = Wasm::PinnedRegisterInfo::get().toSave(instance()->memoryMode());
+    if (useTagRegisters()) {
+        RegisterSet tagRegisters = RegisterSet::runtimeTagRegisters();
+        // We rely on these being disjoint sets.
+#if !ASSERT_DISABLED
+        for (Reg reg : tagRegisters)
+            ASSERT(!toSave.contains(reg));
+#endif
+        toSave.merge(tagRegisters);
+    }
+    return toSave;
+}
+
+RegisterAtOffsetList WebAssemblyFunction::usedCalleeSaveRegisters() const
+{
+    return RegisterAtOffsetList { calleeSaves(), RegisterAtOffsetList::OffsetBaseType::FramePointerBased };
+}
+
+ptrdiff_t WebAssemblyFunction::previousInstanceOffset() const
+{
+    ptrdiff_t result = calleeSaves().numberOfSetRegisters() * sizeof(CPURegister);
+    result = -result - sizeof(CPURegister);
+#if !ASSERT_DISABLED
+    ptrdiff_t minOffset = 1;
+    for (const RegisterAtOffset& regAtOffset : usedCalleeSaveRegisters()) {
+        ptrdiff_t offset = regAtOffset.offset();
+        ASSERT(offset < 0);
+        minOffset = std::min(offset, minOffset);
+    }
+    ASSERT(minOffset - static_cast<ptrdiff_t>(sizeof(CPURegister)) == result);
+#endif
+    return result;
+}
+
+Wasm::Instance* WebAssemblyFunction::previousInstance(CallFrame* callFrame)
+{
+    ASSERT(callFrame->callee().rawPtr() == m_jsToWasmICCallee.get());
+    auto* result = *bitwise_cast<Wasm::Instance**>(bitwise_cast<char*>(callFrame) + previousInstanceOffset());
+    return result;
+}
+
+MacroAssemblerCodePtr<JSEntryPtrTag> WebAssemblyFunction::jsCallEntrypointSlow()
+{
+    VM& vm = *this->vm();
+    CCallHelpers jit;
+
+    const auto& signature = Wasm::SignatureInformation::get(signatureIndex());
+    const auto& pinnedRegs = Wasm::PinnedRegisterInfo::get();
+    RegisterAtOffsetList registersToSpill = usedCalleeSaveRegisters();
+
+    auto& moduleInformation = instance()->instance().module().moduleInformation();
+
+    unsigned totalFrameSize = registersToSpill.size() * sizeof(CPURegister);
+    totalFrameSize += sizeof(CPURegister); // Slot for the VM's previous wasm instance.
+    totalFrameSize += Wasm::WasmCallingConvention::headerSizeInBytes();
+    totalFrameSize -= sizeof(CallerFrameAndPC);
+
+    unsigned numGPRs = 0;
+    unsigned numFPRs = 0;
+    bool argumentsIncludeI64 = false;
+    for (unsigned i = 0; i < signature.argumentCount(); i++) {
+        switch (signature.argument(i)) {
+        case Wasm::I64:
+            argumentsIncludeI64 = true;
+            break;
+        case Wasm::I32:
+            if (numGPRs >= Wasm::wasmCallingConvention().m_gprArgs.size())
+                totalFrameSize += sizeof(CPURegister);
+            ++numGPRs;
+            break;
+        case Wasm::F32:
+        case Wasm::F64:
+            if (numFPRs >= Wasm::wasmCallingConvention().m_fprArgs.size())
+                totalFrameSize += sizeof(CPURegister);
+            ++numFPRs;
+            break;
+        default:
+            RELEASE_ASSERT_NOT_REACHED();
+        }
+    }
+
+    if (argumentsIncludeI64)
+        return nullptr;
+
+    totalFrameSize = WTF::roundUpToMultipleOf(stackAlignmentBytes(), totalFrameSize);
+
+    jit.emitFunctionPrologue();
+    jit.subPtr(MacroAssembler::TrustedImm32(totalFrameSize), MacroAssembler::stackPointerRegister);
+    jit.store64(CCallHelpers::TrustedImm64(0), CCallHelpers::addressFor(CallFrameSlot::codeBlock));
+
+    for (const RegisterAtOffset& regAtOffset : registersToSpill) {
+        GPRReg reg = regAtOffset.reg().gpr();
+        ptrdiff_t offset = regAtOffset.offset();
+        jit.storePtr(reg, CCallHelpers::Address(GPRInfo::callFrameRegister, offset));
+    }
+
+    GPRReg scratchGPR = Wasm::wasmCallingConventionAir().prologueScratch(1);
+    GPRReg scratch2GPR = Wasm::wasmCallingConventionAir().prologueScratch(0);
+    jit.loadPtr(vm.addressOfSoftStackLimit(), scratch2GPR);
+
+    CCallHelpers::JumpList slowPath;
+    slowPath.append(jit.branchPtr(CCallHelpers::Above, MacroAssembler::stackPointerRegister, GPRInfo::callFrameRegister));
+    slowPath.append(jit.branchPtr(CCallHelpers::Below, MacroAssembler::stackPointerRegister, scratch2GPR));
+
+    // Ensure:
+    // argCountPlusThis - 1 >= signature.argumentCount()
+    // argCountPlusThis >= signature.argumentCount() + 1
+    // FIXME: We should handle mismatched arity
+    // https://bugs.webkit.org/show_bug.cgi?id=196564
+    slowPath.append(jit.branch32(CCallHelpers::Below,
+        CCallHelpers::payloadFor(CallFrameSlot::argumentCount), CCallHelpers::TrustedImm32(signature.argumentCount() + 1)));
+
+    if (useTagRegisters())
+        jit.emitMaterializeTagCheckRegisters();
+
+    // First we do stack slots for FPRs so we can use FPR argument registers as scratch.
+    // After that, we handle FPR argument registers.
+    // We also handle all GPR types here as we have GPR scratch registers.
+    {
+        CCallHelpers::Address calleeFrame = CCallHelpers::Address(MacroAssembler::stackPointerRegister, -static_cast<ptrdiff_t>(sizeof(CallerFrameAndPC)));
+        numGPRs = 0;
+        numFPRs = 0;
+        FPRReg scratchFPR = Wasm::wasmCallingConvention().m_fprArgs[0].fpr();
+
+        ptrdiff_t jsOffset = CallFrameSlot::firstArgument * sizeof(EncodedJSValue);
+
+        ptrdiff_t wasmOffset = CallFrame::headerSizeInRegisters * sizeof(CPURegister);
+        for (unsigned i = 0; i < signature.argumentCount(); i++) {
+            switch (signature.argument(i)) {
+            case Wasm::I32:
+                jit.load64(CCallHelpers::Address(GPRInfo::callFrameRegister, jsOffset), scratchGPR);
+                slowPath.append(jit.branchIfNotInt32(scratchGPR));
+                if (numGPRs >= Wasm::wasmCallingConvention().m_gprArgs.size()) {
+                    jit.store32(scratchGPR, calleeFrame.withOffset(wasmOffset));
+                    wasmOffset += sizeof(CPURegister);
+                } else {
+                    jit.zeroExtend32ToPtr(scratchGPR, Wasm::wasmCallingConvention().m_gprArgs[numGPRs].gpr());
+                    ++numGPRs;
+                }
+                break;
+            case Wasm::F32:
+            case Wasm::F64:
+                if (numFPRs >= Wasm::wasmCallingConvention().m_fprArgs.size()) {
+                    jit.load64(CCallHelpers::Address(GPRInfo::callFrameRegister, jsOffset), scratchGPR);
+                    slowPath.append(jit.branchIfNotNumber(scratchGPR));
+                    auto isInt32 = jit.branchIfInt32(scratchGPR);
+                    if (signature.argument(i) == Wasm::F32) {
+                        jit.unboxDouble(scratchGPR, scratchGPR, scratchFPR);
+                        jit.convertDoubleToFloat(scratchFPR, scratchFPR);
+                        jit.storeFloat(scratchFPR, calleeFrame.withOffset(wasmOffset));
+                    } else {
+                        jit.add64(GPRInfo::tagTypeNumberRegister, scratchGPR, scratchGPR);
+                        jit.store64(scratchGPR, calleeFrame.withOffset(wasmOffset));
+                    }
+                    auto done = jit.jump();
+
+                    isInt32.link(&jit);
+                    if (signature.argument(i) == Wasm::F32) {
+                        jit.convertInt32ToFloat(scratchGPR, scratchFPR);
+                        jit.storeFloat(scratchFPR, calleeFrame.withOffset(wasmOffset));
+                    } else {
+                        jit.convertInt32ToDouble(scratchGPR, scratchFPR);
+                        jit.storeDouble(scratchFPR, calleeFrame.withOffset(wasmOffset));
+                    }
+                    done.link(&jit);
+
+                    wasmOffset += sizeof(CPURegister);
+                } else
+                    ++numFPRs;
+                break;
+            default:
+                RELEASE_ASSERT_NOT_REACHED();
+            }
+
+            jsOffset += sizeof(EncodedJSValue);
+        }
+    }
+
+    // Now handle FPR arguments in registers.
+    {
+        numFPRs = 0;
+        ptrdiff_t jsOffset = CallFrameSlot::firstArgument * sizeof(EncodedJSValue);
+        for (unsigned i = 0; i < signature.argumentCount(); i++) {
+            switch (signature.argument(i)) {
+            case Wasm::F32:
+            case Wasm::F64:
+                if (numFPRs < Wasm::wasmCallingConvention().m_fprArgs.size()) {
+                    FPRReg argFPR = Wasm::wasmCallingConvention().m_fprArgs[numFPRs].fpr();
+                    jit.load64(CCallHelpers::Address(GPRInfo::callFrameRegister, jsOffset), scratchGPR);
+                    slowPath.append(jit.branchIfNotNumber(scratchGPR));
+                    auto isInt32 = jit.branchIfInt32(scratchGPR);
+                    jit.unboxDouble(scratchGPR, scratchGPR, argFPR);
+                    if (signature.argument(i) == Wasm::F32)
+                        jit.convertDoubleToFloat(argFPR, argFPR);
+                    auto done = jit.jump();
+
+                    isInt32.link(&jit);
+                    if (signature.argument(i) == Wasm::F32)
+                        jit.convertInt32ToFloat(scratchGPR, argFPR);
+                    else
+                        jit.convertInt32ToDouble(scratchGPR, argFPR);
+
+                    done.link(&jit);
+                    ++numFPRs;
+                }
+                break;
+            default:
+                break;
+            }
+
+            jsOffset += sizeof(EncodedJSValue);
+        }
+    }
+
+    // At this point, we're committed to doing a fast call.
+
+    if (Wasm::Context::useFastTLS()) 
+        jit.loadWasmContextInstance(scratchGPR);
+    else
+        jit.loadPtr(vm.wasmContext.pointerToInstance(), scratchGPR);
+    ptrdiff_t previousInstanceOffset = this->previousInstanceOffset();
+    jit.storePtr(scratchGPR, CCallHelpers::Address(GPRInfo::callFrameRegister, previousInstanceOffset));
+
+    jit.move(CCallHelpers::TrustedImmPtr(&instance()->instance()), scratchGPR);
+    if (Wasm::Context::useFastTLS()) 
+        jit.storeWasmContextInstance(scratchGPR);
+    else {
+        jit.move(scratchGPR, pinnedRegs.wasmContextInstancePointer);
+        jit.storePtr(scratchGPR, vm.wasmContext.pointerToInstance());
+    }
+    // This contains the cached stack limit still.
+    jit.storePtr(scratch2GPR, CCallHelpers::Address(scratchGPR, Wasm::Instance::offsetOfCachedStackLimit()));
+
+    if (!!moduleInformation.memory) {
+        GPRReg baseMemory = pinnedRegs.baseMemoryPointer;
+
+        if (instance()->memoryMode() != Wasm::MemoryMode::Signaling) {
+            ASSERT(pinnedRegs.sizeRegister != scratchGPR);
+            jit.loadPtr(CCallHelpers::Address(scratchGPR, Wasm::Instance::offsetOfCachedMemorySize()), pinnedRegs.sizeRegister);
+        }
+
+        jit.loadPtr(CCallHelpers::Address(scratchGPR, Wasm::Instance::offsetOfCachedMemory()), baseMemory);
+    }
+
+    // We use this callee to indicate how to unwind past these types of frames:
+    // 1. We need to know where to get callee saves.
+    // 2. We need to know to restore the previous wasm context.
+    if (!m_jsToWasmICCallee)
+        m_jsToWasmICCallee.set(vm, this, JSToWasmICCallee::create(vm, globalObject(), this));
+    jit.storePtr(CCallHelpers::TrustedImmPtr(m_jsToWasmICCallee.get()), CCallHelpers::addressFor(CallFrameSlot::callee));
+
+    {
+        // FIXME: Currently we just do an indirect jump. But we should teach the Module
+        // how to repatch us:
+        // https://bugs.webkit.org/show_bug.cgi?id=196570
+        jit.loadPtr(entrypointLoadLocation(), scratchGPR);
+        jit.call(scratchGPR, WasmEntryPtrTag);
+    }
+
+    ASSERT(!RegisterSet::runtimeTagRegisters().contains(GPRInfo::nonPreservedNonReturnGPR));
+    jit.loadPtr(CCallHelpers::Address(GPRInfo::callFrameRegister, previousInstanceOffset), GPRInfo::nonPreservedNonReturnGPR);
+    if (Wasm::Context::useFastTLS())
+        jit.storeWasmContextInstance(GPRInfo::nonPreservedNonReturnGPR);
+    else
+        jit.storePtr(GPRInfo::nonPreservedNonReturnGPR, vm.wasmContext.pointerToInstance());
+
+    switch (signature.returnType()) {
+    case Wasm::Void:
+        jit.moveTrustedValue(jsUndefined(), JSValueRegs { GPRInfo::returnValueGPR });
+        break;
+    case Wasm::I32:
+        jit.zeroExtend32ToPtr(GPRInfo::returnValueGPR, GPRInfo::returnValueGPR);
+        jit.boxInt32(GPRInfo::returnValueGPR, JSValueRegs { GPRInfo::returnValueGPR });
+        break;
+    case Wasm::F32:
+        jit.convertFloatToDouble(FPRInfo::returnValueFPR, FPRInfo::returnValueFPR);
+        FALLTHROUGH;
+    case Wasm::F64: {
+        jit.moveTrustedValue(jsNumber(pureNaN()), JSValueRegs { GPRInfo::returnValueGPR });
+        auto isNaN = jit.branchIfNaN(FPRInfo::returnValueFPR);
+        jit.boxDouble(FPRInfo::returnValueFPR, JSValueRegs { GPRInfo::returnValueGPR });
+        isNaN.link(&jit);
+        break;
+    }
+    case Wasm::I64:
+    case Wasm::Func:
+    case Wasm::Anyfunc:
+        return nullptr;
+    default:
+        break;
+    }
+
+    auto emitRestoreCalleeSaves = [&] {
+        for (const RegisterAtOffset& regAtOffset : registersToSpill) {
+            GPRReg reg = regAtOffset.reg().gpr();
+            ASSERT(reg != GPRInfo::returnValueGPR);
+            ptrdiff_t offset = regAtOffset.offset();
+            jit.loadPtr(CCallHelpers::Address(GPRInfo::callFrameRegister, offset), reg);
+        }
+    };
+
+    emitRestoreCalleeSaves();
+
+    jit.emitFunctionEpilogue();
+    jit.ret();
+
+    slowPath.link(&jit);
+    emitRestoreCalleeSaves();
+    jit.move(CCallHelpers::TrustedImmPtr(this), GPRInfo::regT0);
+    jit.emitFunctionEpilogue();
+    auto jumpToHostCallThunk = jit.jump();
+
+    LinkBuffer linkBuffer(jit, nullptr, JITCompilationCanFail);
+    if (UNLIKELY(linkBuffer.didFailToAllocate()))
+        return nullptr;
+
+    linkBuffer.link(jumpToHostCallThunk, CodeLocationLabel<JSEntryPtrTag>(executable()->entrypointFor(CodeForCall, MustCheckArity).executableAddress()));
+    m_jsCallEntrypoint = FINALIZE_CODE(linkBuffer, WasmEntryPtrTag, "JS->Wasm IC");
+    return m_jsCallEntrypoint.code();
+}
+
 WebAssemblyFunction* WebAssemblyFunction::create(VM& vm, JSGlobalObject* globalObject, Structure* structure, unsigned length, const String& name, JSWebAssemblyInstance* instance, Wasm::Callee& jsEntrypoint, Wasm::WasmToWasmImportableFunction::LoadLocation wasmToWasmEntrypointLoadLocation, Wasm::SignatureIndex signatureIndex)
 {
     NativeExecutable* executable = vm.getHostFunction(callWebAssemblyFunction, NoIntrinsic, callHostFunctionAsConstructor, nullptr, name);
@@ -169,6 +502,20 @@ WebAssemblyFunction::WebAssemblyFunction(VM& vm, JSGlobalObject* globalObject, S
     , m_importableFunction { signatureIndex, wasmToWasmEntrypointLoadLocation }
 { }
 
+void WebAssemblyFunction::visitChildren(JSCell* cell, SlotVisitor& visitor)
+{
+    WebAssemblyFunction* thisObject = jsCast<WebAssemblyFunction*>(cell);
+    ASSERT_GC_OBJECT_INHERITS(thisObject, info());
+
+    Base::visitChildren(thisObject, visitor);
+    visitor.append(thisObject->m_jsToWasmICCallee);
+}
+
+void WebAssemblyFunction::destroy(JSCell* cell)
+{
+    static_cast<WebAssemblyFunction*>(cell)->WebAssemblyFunction::~WebAssemblyFunction();
+}
+
 } // namespace JSC
 
 #endif // ENABLE(WEBASSEMBLY)