Goose  Artifact [24a6526b29]

Artifact 24a6526b2960adfb09930dbce65be31022f29fc047052156894e117d64dbbac3:

  • File bs/builtins/types/overloadset/invoke.cpp — part of check-in [0db147f117] at 2024-09-15 20:24:31 on branch cir-ssa-refactor — Add clang format settings, reformat everything (user: achavasse size: 5503)

#include "builtins/builtins.h"

// #define OVL_TC_DEBUG

using namespace goose::sema;

namespace goose::builtins
{
    class OverloadSetInvocationRule : public InvocationRule
    {
      public:
        Value resolveInvocation(
            Context& c, LocationId loc, const Value& callee, const Term& args ) const final
        {
            ProfileZoneScoped;

            auto pOvlSet = *FromValue< ptr< OverloadSet > >( callee );

#if TRACY_ENABLE
            stringstream sstr;
            sstr << pOvlSet->identity();
            ProfileZoneName( sstr.str().c_str(), sstr.str().size() );
#endif

            if( auto ovl = pOvlSet->getResolutionFromCache( args ) )
                return ovl->pInvRule->resolveInvocation( c, loc, *ovl->callee, args );
            else
                return resolve( c, loc, pOvlSet, args );
        }

      private:
        Value resolve(
            Context& c, LocationId loc, const ptr< OverloadSet >& pOvlSet, const Term& args ) const
        {
            const OverloadSet::Overload* bestOvl = nullptr;
            optional< TypeCheckingContext > bestTCC;
            optional< Term > bestSol;

            {
                ProfileZoneScopedN( "Overload resolution" );

                bool ambiguous = false;

                if( pOvlSet->verboseResolution() )
                {
                    DiagnosticsManager::GetInstance().emitTraceMessage(
                        loc, std::format( "invoking overload set: {}", pOvlSet->identity() ) );
                }

                auto callPat = PrependToVectorTerm( args, HOLE( "_"_sid, "fwd"_sid ) );
                TypeCheckingContext tcc( c );
                for( auto&& [s, ovl, tcc] : pOvlSet->typeCheck( callPat, tcc ) )
                {
                    if( tcc.numUnknownValues() )
                    {
                        if( pOvlSet->verboseResolution() )
                        {
                            DiagnosticsManager::GetInstance().emitTraceMessage(
                                ovl.callee ? ovl.callee->locationId() : 0,
                                format( "rejected candidate", tcc.score() ) );
                        }
                        continue;
                    }

                    auto subs = Substitute( s, tcc );

                    // Typechecking rules often end up stripping part of the original type,
                    // and we want to invoke the overload where these removals are minimized.
                    //
                    // Obvious example: if there is an overload that accepts a reference
                    // and one that accepts a value of the same type and we started with a
                    // reference, then we want to call the overload where the typechecking
                    // solution didn't strip the reference.
                    //
                    // So we add the weight of the original arguments to the cost,
                    // and remove the cost of the typechecking solution to account for that.
                    int32_t cost = tcc.cost();
                    cost += GetWeight( callPat );
                    cost -= GetWeight( subs );
                    tcc.setCost( cost );

                    if( pOvlSet->verboseResolution() )
                    {
                        DiagnosticsManager::GetInstance().emitTraceMessage(
                            ovl.callee ? ovl.callee->locationId() : 0,
                            format( "candidate score: {}", tcc.score() ) );
                    }

                    auto score = tcc.score();
                    if( bestTCC && score < bestTCC->score() )
                        continue;

                    auto pps = Postprocess( subs, tcc );
                    if( !pps )
                        continue;

                    if( bestTCC && score == bestTCC->score() )
                    {
                        ambiguous = true;
                        continue;
                    }

                    bestTCC = tcc;
                    bestSol = move( *pps );
                    bestOvl = &ovl;
                    ambiguous = false;
                }

                if( ambiguous )
                {
                    // TODO display details
                    DiagnosticsManager::GetInstance().emitErrorMessage(
                        loc, "ambiguous function call." );
                    return PoisonValue();
                }

                if( !bestSol )
                {
                    // TODO display details
                    DiagnosticsManager::GetInstance().emitErrorMessage(
                        loc, "function arguments mismatch." );
                    return PoisonValue();
                }

#if defined( OVL_TC_DEBUG ) && !defined( NDEBUG )
                bestTCC->DumpParamsTraces( cout );
                cout << endl;
#endif
            }

            pOvlSet->addResolutionToCache( args, *bestOvl );

            return bestOvl->pInvRule->invoke( c, loc, *bestOvl->callee, args, *bestSol, *bestTCC );
        }
    };

    ptr< InvocationRule >& GetOverloadSetInvocationRule()
    {
        static ptr< InvocationRule > pRule = make_shared< OverloadSetInvocationRule >();
        return pRule;
    }

    void SetupOverloadSetInvocationRule( Env& e )
    {
        e.invocationRuleSet()->addRule(
            ValueToEIR( Value( GetValueType< ptr< OverloadSet > >(), ANYTERM( _ ) ) ),
            GetOverloadSetInvocationRule() );
    }
} // namespace goose::builtins