#include "builtins/builtins.h"
//#define OVL_TC_DEBUG_CANDIDATES
//#define OVL_TC_DEBUG
using namespace goose::sema;
namespace goose::builtins
{
class OverloadSetInvocationRule : public InvocationRule
{
public:
Value resolveInvocation( Context& c, LocationId loc, const Value& callee, const Term& args ) const final
{
ProfileZoneScoped;
auto pOvlSet = *FromValue< ptr< OverloadSet > >( callee );
#if TRACY_ENABLE
stringstream sstr;
sstr << pOvlSet->identity();
ProfileZoneName( sstr.str().c_str(), sstr.str().size() );
#endif
if( auto ovl = pOvlSet->getResolutionFromCache( args ) )
return ovl->pInvRule->resolveInvocation( c, loc, *ovl->callee, args );
else
return resolve( c, loc, pOvlSet, args );
}
private:
Value resolve( Context& c, LocationId loc, const ptr< OverloadSet >& pOvlSet, const Term& args ) const
{
const OverloadSet::Overload* bestOvl = nullptr;
optional< TypeCheckingContext > bestTCC;
optional< Term > bestSol;
{
ProfileZoneScopedN( "Overload resolution" );
bool ambiguous = false;
#if defined( OVL_TC_DEBUG ) && !defined( NDEBUG )
cout << "#### Invoking " << pOvlSet->identity() << endl;
#endif
auto callPat = PrependToVectorTerm( args, HOLE( "_"_sid, "fwd"_sid ) );
TypeCheckingContext tcc( c );
for( auto&& [s,ovl,tcc] : pOvlSet->typeCheck( callPat, tcc ) )
{
if( tcc.numUnknownValues() )
continue;
auto subs = Substitute( s, tcc );
// Typechecking rules often end up stripping part of the original type,
// and we want to invoke the overload where these removals are minimized.
//
// Obvious example: if there is an overload that accepts a reference
// and one that accepts a value of the same type and we started with a
// reference, then we want to call the overload where the typechecking
// solution didn't strip the reference.
//
// So we add the weight of the original arguments to the cost,
// and remove the cost of the typechecking solution to account for that.
int32_t cost = tcc.cost();
cost += GetWeight( callPat );
cost -= GetWeight( subs );
tcc.setCost( cost );
#ifdef OVL_TC_DEBUG_CANDIDATES
cout << " ## CANDIDATE: " << tcc.score() << " " << subs << endl;
#endif
auto score = tcc.score();
if( bestTCC && score < bestTCC->score() )
continue;
auto pps = Postprocess( subs, tcc );
if( !pps )
continue;
if( bestTCC && score == bestTCC->score() )
{
ambiguous = true;
continue;
}
bestTCC = tcc;
bestSol = move( *pps );
bestOvl = &ovl;
ambiguous = false;
}
if( ambiguous )
{
// TODO display details
DiagnosticsManager::GetInstance().emitErrorMessage( loc,
"ambiguous function call." );
return PoisonValue();
}
if( !bestSol )
{
// TODO display details
DiagnosticsManager::GetInstance().emitErrorMessage( loc,
"function arguments mismatch." );
return PoisonValue();
}
#if defined( OVL_TC_DEBUG ) && !defined( NDEBUG )
bestTCC->DumpParamsTraces( cout );
cout << endl;
#endif
}
pOvlSet->addResolutionToCache( args, *bestOvl );
return bestOvl->pInvRule->invoke( c, loc, *bestOvl->callee, args, *bestSol, *bestTCC );
}
};
ptr< InvocationRule >& GetOverloadSetInvocationRule()
{
static ptr< InvocationRule > pRule = make_shared< OverloadSetInvocationRule >();
return pRule;
}
void SetupOverloadSetInvocationRule( Env& e )
{
e.invocationRuleSet()->addRule(
ValueToEIR( Value(
GetValueType< ptr< OverloadSet > >(),
ANYTERM( _ ) ) ),
GetOverloadSetInvocationRule() );
}
}