Goose  invoke.cpp at [9f242fb2ff]

File bs/builtins/types/overloadset/invoke.cpp artifact b217794e06 part of check-in 9f242fb2ff


#include "builtins/builtins.h"

//#define OVL_TC_DEBUG_CANDIDATES
//#define OVL_TC_DEBUG

using namespace goose::sema;

namespace goose::builtins
{
    class OverloadSetInvocationRule : public InvocationRule
    {
        public:
            Value resolveInvocation( Context& c, LocationId loc, const Value& callee, const Term& args ) const final
            {
                ProfileZoneScoped;

                auto pOvlSet = *FromValue< ptr< OverloadSet > >( callee );

                #if TRACY_ENABLE
                    stringstream sstr;
                    sstr << pOvlSet->identity();
                    ProfileZoneName( sstr.str().c_str(), sstr.str().size() );
                #endif

                if( auto ovl = pOvlSet->getResolutionFromCache( args ) )
                    return ovl->pInvRule->resolveInvocation( c, loc, *ovl->callee, args );
                else
                    return resolve( c, loc, pOvlSet, args );
            }

        private:
            Value resolve( Context& c, LocationId loc, const ptr< OverloadSet >& pOvlSet, const Term& args ) const
            {
                const OverloadSet::Overload* bestOvl = nullptr;
                optional< TypeCheckingContext > bestTCC;
                optional< Term > bestSol;

                {
                    ProfileZoneScopedN( "Overload resolution" );

                    bool ambiguous = false;

                #if defined( OVL_TC_DEBUG ) && !defined( NDEBUG )
                    cout << "#### Invoking " << pOvlSet->identity() << endl;
                #endif

                    auto callPat = PrependToVectorTerm( args, HOLE( "_"_sid, "fwd"_sid ) );
                    TypeCheckingContext tcc( c );
                    for( auto&& [s,ovl,tcc] : pOvlSet->typeCheck( callPat, tcc ) )
                    {
                        if( tcc.numUnknownValues() )
                            continue;

                        auto subs = Substitute( s, tcc );

                        // Typechecking rules often end up stripping part of the original type,
                        // and we want to invoke the overload where these removals are minimized.
                        //
                        // Obvious example: if there is an overload that accepts a reference
                        // and one that accepts a value of the same type and we started with a
                        // reference, then we want to call the overload where the typechecking
                        // solution didn't strip the reference.
                        //
                        // So we add the weight of the original arguments to the cost,
                        // and remove the cost of the typechecking solution to account for that.
                        int32_t cost = tcc.cost();
                        cost += GetWeight( callPat );
                        cost -= GetWeight( subs );
                        tcc.setCost( cost );

                    #ifdef OVL_TC_DEBUG_CANDIDATES
                        cout << "  ## CANDIDATE: " << tcc.score() << "  " << subs << endl;
                    #endif

                        auto score = tcc.score();
                        if( bestTCC && score < bestTCC->score() )
                            continue;

                        auto pps = Postprocess( subs, tcc );
                        if( !pps )
                            continue;

                        if( bestTCC && score == bestTCC->score() )
                        {
                            ambiguous = true;
                            continue;
                        }

                        bestTCC = tcc;
                        bestSol = move( *pps );
                        bestOvl = &ovl;
                        ambiguous = false;
                    }

                    if( ambiguous )
                    {
                        // TODO display details
                        DiagnosticsManager::GetInstance().emitErrorMessage( loc,
                            "ambiguous function call." );
                        return PoisonValue();
                    }

                    if( !bestSol )
                    {
                        // TODO display details
                        DiagnosticsManager::GetInstance().emitErrorMessage( loc,
                            "function arguments mismatch." );
                        return PoisonValue();
                    }

                #if defined( OVL_TC_DEBUG ) && !defined( NDEBUG )
                    bestTCC->DumpParamsTraces( cout );
                    cout << endl;
                #endif
                }

                pOvlSet->addResolutionToCache( args, *bestOvl );

                return bestOvl->pInvRule->invoke( c, loc, *bestOvl->callee, args, *bestSol, *bestTCC );
            }
    };

    ptr< InvocationRule >& GetOverloadSetInvocationRule()
    {
        static ptr< InvocationRule > pRule = make_shared< OverloadSetInvocationRule >();
        return pRule;
    }

    void SetupOverloadSetInvocationRule( Env& e )
    {
        e.invocationRuleSet()->addRule(
            ValueToEIR( Value(
                GetValueType< ptr< OverloadSet > >(),
                ANYTERM( _ ) ) ),
            GetOverloadSetInvocationRule() );
    }
}