//===-- wrapper_function_utils.h - Utilities for wrapper funcs --*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file is a part of the ORC runtime support library. // //===----------------------------------------------------------------------===// #ifndef ORC_RT_WRAPPER_FUNCTION_UTILS_H #define ORC_RT_WRAPPER_FUNCTION_UTILS_H #include "orc_rt/c_api.h" #include "common.h" #include "error.h" #include "executor_address.h" #include "simple_packed_serialization.h" #include namespace __orc_rt { /// C++ wrapper function result: Same as CWrapperFunctionResult but /// auto-releases memory. class WrapperFunctionResult { public: /// Create a default WrapperFunctionResult. WrapperFunctionResult() { orc_rt_CWrapperFunctionResultInit(&R); } /// Create a WrapperFunctionResult from a CWrapperFunctionResult. This /// instance takes ownership of the result object and will automatically /// call dispose on the result upon destruction. WrapperFunctionResult(orc_rt_CWrapperFunctionResult R) : R(R) {} WrapperFunctionResult(const WrapperFunctionResult &) = delete; WrapperFunctionResult &operator=(const WrapperFunctionResult &) = delete; WrapperFunctionResult(WrapperFunctionResult &&Other) { orc_rt_CWrapperFunctionResultInit(&R); std::swap(R, Other.R); } WrapperFunctionResult &operator=(WrapperFunctionResult &&Other) { orc_rt_CWrapperFunctionResult Tmp; orc_rt_CWrapperFunctionResultInit(&Tmp); std::swap(Tmp, Other.R); std::swap(R, Tmp); return *this; } ~WrapperFunctionResult() { orc_rt_DisposeCWrapperFunctionResult(&R); } /// Relinquish ownership of and return the /// orc_rt_CWrapperFunctionResult. orc_rt_CWrapperFunctionResult release() { orc_rt_CWrapperFunctionResult Tmp; orc_rt_CWrapperFunctionResultInit(&Tmp); std::swap(R, Tmp); return Tmp; } /// Get a pointer to the data contained in this instance. char *data() { return orc_rt_CWrapperFunctionResultData(&R); } /// Returns the size of the data contained in this instance. size_t size() const { return orc_rt_CWrapperFunctionResultSize(&R); } /// Returns true if this value is equivalent to a default-constructed /// WrapperFunctionResult. bool empty() const { return orc_rt_CWrapperFunctionResultEmpty(&R); } /// Create a WrapperFunctionResult with the given size and return a pointer /// to the underlying memory. static WrapperFunctionResult allocate(size_t Size) { WrapperFunctionResult R; R.R = orc_rt_CWrapperFunctionResultAllocate(Size); return R; } /// Copy from the given char range. static WrapperFunctionResult copyFrom(const char *Source, size_t Size) { return orc_rt_CreateCWrapperFunctionResultFromRange(Source, Size); } /// Copy from the given null-terminated string (includes the null-terminator). static WrapperFunctionResult copyFrom(const char *Source) { return orc_rt_CreateCWrapperFunctionResultFromString(Source); } /// Copy from the given std::string (includes the null terminator). static WrapperFunctionResult copyFrom(const std::string &Source) { return copyFrom(Source.c_str()); } /// Create an out-of-band error by copying the given string. static WrapperFunctionResult createOutOfBandError(const char *Msg) { return orc_rt_CreateCWrapperFunctionResultFromOutOfBandError(Msg); } /// Create an out-of-band error by copying the given string. static WrapperFunctionResult createOutOfBandError(const std::string &Msg) { return createOutOfBandError(Msg.c_str()); } template static WrapperFunctionResult fromSPSArgs(const ArgTs &...Args) { auto Result = allocate(SPSArgListT::size(Args...)); SPSOutputBuffer OB(Result.data(), Result.size()); if (!SPSArgListT::serialize(OB, Args...)) return createOutOfBandError( "Error serializing arguments to blob in call"); return Result; } /// If this value is an out-of-band error then this returns the error message, /// otherwise returns nullptr. const char *getOutOfBandError() const { return orc_rt_CWrapperFunctionResultGetOutOfBandError(&R); } private: orc_rt_CWrapperFunctionResult R; }; namespace detail { template class WrapperFunctionHandlerCaller { public: template static decltype(auto) call(HandlerT &&H, ArgTupleT &Args, std::index_sequence) { return std::forward(H)(std::get(Args)...); } }; template <> class WrapperFunctionHandlerCaller { public: template static SPSEmpty call(HandlerT &&H, ArgTupleT &Args, std::index_sequence) { std::forward(H)(std::get(Args)...); return SPSEmpty(); } }; template class ResultSerializer, typename... SPSTagTs> class WrapperFunctionHandlerHelper : public WrapperFunctionHandlerHelper< decltype(&std::remove_reference_t::operator()), ResultSerializer, SPSTagTs...> {}; template class ResultSerializer, typename... SPSTagTs> class WrapperFunctionHandlerHelper { public: using ArgTuple = std::tuple...>; using ArgIndices = std::make_index_sequence::value>; template static WrapperFunctionResult apply(HandlerT &&H, const char *ArgData, size_t ArgSize) { ArgTuple Args; if (!deserialize(ArgData, ArgSize, Args, ArgIndices{})) return WrapperFunctionResult::createOutOfBandError( "Could not deserialize arguments for wrapper function call"); auto HandlerResult = WrapperFunctionHandlerCaller::call( std::forward(H), Args, ArgIndices{}); return ResultSerializer::serialize( std::move(HandlerResult)); } private: template static bool deserialize(const char *ArgData, size_t ArgSize, ArgTuple &Args, std::index_sequence) { SPSInputBuffer IB(ArgData, ArgSize); return SPSArgList::deserialize(IB, std::get(Args)...); } }; // Map function pointers to function types. template class ResultSerializer, typename... SPSTagTs> class WrapperFunctionHandlerHelper : public WrapperFunctionHandlerHelper {}; // Map non-const member function types to function types. template class ResultSerializer, typename... SPSTagTs> class WrapperFunctionHandlerHelper : public WrapperFunctionHandlerHelper {}; // Map const member function types to function types. template class ResultSerializer, typename... SPSTagTs> class WrapperFunctionHandlerHelper : public WrapperFunctionHandlerHelper {}; template class ResultSerializer { public: static WrapperFunctionResult serialize(RetT Result) { return WrapperFunctionResult::fromSPSArgs>(Result); } }; template class ResultSerializer { public: static WrapperFunctionResult serialize(Error Err) { return WrapperFunctionResult::fromSPSArgs>( toSPSSerializable(std::move(Err))); } }; template class ResultSerializer> { public: static WrapperFunctionResult serialize(Expected E) { return WrapperFunctionResult::fromSPSArgs>( toSPSSerializable(std::move(E))); } }; template class ResultDeserializer { public: static void makeSafe(RetT &Result) {} static Error deserialize(RetT &Result, const char *ArgData, size_t ArgSize) { SPSInputBuffer IB(ArgData, ArgSize); if (!SPSArgList::deserialize(IB, Result)) return make_error( "Error deserializing return value from blob in call"); return Error::success(); } }; template <> class ResultDeserializer { public: static void makeSafe(Error &Err) { cantFail(std::move(Err)); } static Error deserialize(Error &Err, const char *ArgData, size_t ArgSize) { SPSInputBuffer IB(ArgData, ArgSize); SPSSerializableError BSE; if (!SPSArgList::deserialize(IB, BSE)) return make_error( "Error deserializing return value from blob in call"); Err = fromSPSSerializable(std::move(BSE)); return Error::success(); } }; template class ResultDeserializer, Expected> { public: static void makeSafe(Expected &E) { cantFail(E.takeError()); } static Error deserialize(Expected &E, const char *ArgData, size_t ArgSize) { SPSInputBuffer IB(ArgData, ArgSize); SPSSerializableExpected BSE; if (!SPSArgList>::deserialize(IB, BSE)) return make_error( "Error deserializing return value from blob in call"); E = fromSPSSerializable(std::move(BSE)); return Error::success(); } }; } // end namespace detail template class WrapperFunction; template class WrapperFunction { private: template using ResultSerializer = detail::ResultSerializer; public: template static Error call(const void *FnTag, RetT &Result, const ArgTs &...Args) { // RetT might be an Error or Expected value. Set the checked flag now: // we don't want the user to have to check the unused result if this // operation fails. detail::ResultDeserializer::makeSafe(Result); // Since the functions cannot be zero/unresolved on Windows, the following // reference taking would always be non-zero, thus generating a compiler // warning otherwise. #if !defined(_WIN32) if (ORC_RT_UNLIKELY(!&__orc_rt_jit_dispatch_ctx)) return make_error("__orc_rt_jit_dispatch_ctx not set"); if (ORC_RT_UNLIKELY(!&__orc_rt_jit_dispatch)) return make_error("__orc_rt_jit_dispatch not set"); #endif auto ArgBuffer = WrapperFunctionResult::fromSPSArgs>(Args...); if (const char *ErrMsg = ArgBuffer.getOutOfBandError()) return make_error(ErrMsg); WrapperFunctionResult ResultBuffer = __orc_rt_jit_dispatch( &__orc_rt_jit_dispatch_ctx, FnTag, ArgBuffer.data(), ArgBuffer.size()); if (auto ErrMsg = ResultBuffer.getOutOfBandError()) return make_error(ErrMsg); return detail::ResultDeserializer::deserialize( Result, ResultBuffer.data(), ResultBuffer.size()); } template static WrapperFunctionResult handle(const char *ArgData, size_t ArgSize, HandlerT &&Handler) { using WFHH = detail::WrapperFunctionHandlerHelper, ResultSerializer, SPSTagTs...>; return WFHH::apply(std::forward(Handler), ArgData, ArgSize); } private: template static const T &makeSerializable(const T &Value) { return Value; } static detail::SPSSerializableError makeSerializable(Error Err) { return detail::toSPSSerializable(std::move(Err)); } template static detail::SPSSerializableExpected makeSerializable(Expected E) { return detail::toSPSSerializable(std::move(E)); } }; template class WrapperFunction : private WrapperFunction { public: template static Error call(const void *FnTag, const ArgTs &...Args) { SPSEmpty BE; return WrapperFunction::call(FnTag, BE, Args...); } using WrapperFunction::handle; }; /// A function object that takes an ExecutorAddr as its first argument, /// casts that address to a ClassT*, then calls the given method on that /// pointer passing in the remaining function arguments. This utility /// removes some of the boilerplate from writing wrappers for method calls. /// /// @code{.cpp} /// class MyClass { /// public: /// void myMethod(uint32_t, bool) { ... } /// }; /// /// // SPS Method signature -- note MyClass object address as first argument. /// using SPSMyMethodWrapperSignature = /// SPSTuple; /// /// WrapperFunctionResult /// myMethodCallWrapper(const char *ArgData, size_t ArgSize) { /// return WrapperFunction::handle( /// ArgData, ArgSize, makeMethodWrapperHandler(&MyClass::myMethod)); /// } /// @endcode /// template class MethodWrapperHandler { public: using MethodT = RetT (ClassT::*)(ArgTs...); MethodWrapperHandler(MethodT M) : M(M) {} RetT operator()(ExecutorAddr ObjAddr, ArgTs &...Args) { return (ObjAddr.toPtr()->*M)(std::forward(Args)...); } private: MethodT M; }; /// Create a MethodWrapperHandler object from the given method pointer. template MethodWrapperHandler makeMethodWrapperHandler(RetT (ClassT::*Method)(ArgTs...)) { return MethodWrapperHandler(Method); } /// Represents a call to a wrapper function. class WrapperFunctionCall { public: // FIXME: Switch to a SmallVector once ORC runtime has a // smallvector. using ArgDataBufferType = std::vector; /// Create a WrapperFunctionCall using the given SPS serializer to serialize /// the arguments. template static Expected Create(ExecutorAddr FnAddr, const ArgTs &...Args) { ArgDataBufferType ArgData; ArgData.resize(SPSSerializer::size(Args...)); SPSOutputBuffer OB(ArgData.empty() ? nullptr : ArgData.data(), ArgData.size()); if (SPSSerializer::serialize(OB, Args...)) return WrapperFunctionCall(FnAddr, std::move(ArgData)); return make_error("Cannot serialize arguments for " "AllocActionCall"); } WrapperFunctionCall() = default; /// Create a WrapperFunctionCall from a target function and arg buffer. WrapperFunctionCall(ExecutorAddr FnAddr, ArgDataBufferType ArgData) : FnAddr(FnAddr), ArgData(std::move(ArgData)) {} /// Returns the address to be called. const ExecutorAddr &getCallee() const { return FnAddr; } /// Returns the argument data. const ArgDataBufferType &getArgData() const { return ArgData; } /// WrapperFunctionCalls convert to true if the callee is non-null. explicit operator bool() const { return !!FnAddr; } /// Run call returning raw WrapperFunctionResult. WrapperFunctionResult run() const { using FnTy = orc_rt_CWrapperFunctionResult(const char *ArgData, size_t ArgSize); return WrapperFunctionResult( FnAddr.toPtr()(ArgData.data(), ArgData.size())); } /// Run call and deserialize result using SPS. template std::enable_if_t::value, Error> runWithSPSRet(RetT &RetVal) const { auto WFR = run(); if (const char *ErrMsg = WFR.getOutOfBandError()) return make_error(ErrMsg); SPSInputBuffer IB(WFR.data(), WFR.size()); if (!SPSSerializationTraits::deserialize(IB, RetVal)) return make_error("Could not deserialize result from " "serialized wrapper function call"); return Error::success(); } /// Overload for SPS functions returning void. template std::enable_if_t::value, Error> runWithSPSRet() const { SPSEmpty E; return runWithSPSRet(E); } /// Run call and deserialize an SPSError result. SPSError returns and /// deserialization failures are merged into the returned error. Error runWithSPSRetErrorMerged() const { detail::SPSSerializableError RetErr; if (auto Err = runWithSPSRet(RetErr)) return Err; return detail::fromSPSSerializable(std::move(RetErr)); } private: ExecutorAddr FnAddr; std::vector ArgData; }; using SPSWrapperFunctionCall = SPSTuple>; template <> class SPSSerializationTraits { public: static size_t size(const WrapperFunctionCall &WFC) { return SPSArgList>::size( WFC.getCallee(), WFC.getArgData()); } static bool serialize(SPSOutputBuffer &OB, const WrapperFunctionCall &WFC) { return SPSArgList>::serialize( OB, WFC.getCallee(), WFC.getArgData()); } static bool deserialize(SPSInputBuffer &IB, WrapperFunctionCall &WFC) { ExecutorAddr FnAddr; WrapperFunctionCall::ArgDataBufferType ArgData; if (!SPSWrapperFunctionCall::AsArgList::deserialize(IB, FnAddr, ArgData)) return false; WFC = WrapperFunctionCall(FnAddr, std::move(ArgData)); return true; } }; } // end namespace __orc_rt #endif // ORC_RT_WRAPPER_FUNCTION_UTILS_H