Goose  uni-tdecl.cpp at [0c3c9488af]

File bs/builtins/types/template/uni-tdecl.cpp artifact 487a999216 part of check-in 0c3c9488af


#include "builtins/builtins.h"

namespace empathy::builtins
{
    Term BuildArgPatternFromTDecl( const TDecl& td )
    {
        return ValueToIRExpr( ValuePattern( MkHole( "_"_sid ), td.type(), MkHole( "_"_sid ) ) );
    }

    UniGen UnifyTDecl( const Term& lhs, const Term& rhs, UnificationContext& uc )
    {
        auto tdecl = FromValue< TDecl >( *ValueFromIRExpr( lhs ) );
        assert( tdecl );

        auto tdeclHole = MkHole( tdecl->name() );

        auto pat = ValueToIRExpr( Value( tdecl->type(), MkHole( "_"_sid ) ) );

        // We are replacing lhs with a different terms and re-unifying,
        // so update the complexity accordingly. The structure of the tdecl
        // shouldn't count, only its pattern.
        uc.subComplexity( GetComplexity( lhs ) );
        uc.addComplexity( GetComplexity( pat ) );

        for( auto&& [s,uc] : Unify( pat, rhs, uc ) )
        {
            // We need to unify the result with a hole named after the decl. However, since both sides of
            // this unification orignally appeared on the LHS, we need to setup RHS to alias the LHS namespace for this.
            auto savedRHSNamespaceIndex = uc.RHSNamespaceIndex();
            uc.setRHSNamespaceIndex( uc.LHSNamespaceIndex() );

            uc.subComplexity( GetComplexity( pat ) + GetComplexity( rhs ) );
            uc.addComplexity( GetComplexity( s ) );

            for( auto&& [s,c] : Unify( s, tdeclHole, uc ) )
            {
                auto localC = c;
                localC.setRHSNamespaceIndex( savedRHSNamespaceIndex );
                co_yield { s, localC };
            }

            uc.setRHSNamespaceIndex( savedRHSNamespaceIndex );
        }
    }

    void SetupTDeclUnification( Env& e )
    {
        auto tDeclPat = ValueToIRExpr( Value( GetValueType< TDecl >(), VEC( ANYTERM( _ ), ANYTERM( _ ) ) ) );

        e.unificationRuleSet()->addHalfUnificationRule( tDeclPat,
            []( const Term& lhs, UnificationContext& c ) -> UniGen
            {
                auto tdecl = FromValue< TDecl >( *ValueFromIRExpr( lhs ) );
                assert( tdecl );
                return HalfUnify( tdecl->type(), c );
            } );

        e.unificationRuleSet()->addSymRule( tDeclPat, ANYTERM( _ ), UnifyTDecl );
        e.unificationRuleSet()->addSymRule( tDeclPat, UnifyTDecl );

        // tfunc tdecl param / tfunc arg
        auto tFuncTypePat = ValueToIRExpr( Value( TypeType(), VEC( TSID( texpr ), TSID( tfunc ),
            ANYTERM( _ ), ANYTERM( _ ), ANYTERM( _ ), ANYTERM( _ ) ) ) );

        auto tDeclTFuncPat = ParamPat( GetValueType< TDecl >(), VEC( tFuncTypePat, ANYTERM( _ ) ) );

        e.unificationRuleSet()->addSymRule(

            tDeclTFuncPat,

            ValueToIRExpr( ValuePattern(
                TSID( constant ),
                move( tFuncTypePat ),
                ANYTERM( _ ) ) ),

        []( const Term& lhs, const Term& rhs, UnificationContext& uc ) -> UniGen
        {
            auto tdecl = FromValue< TDecl >( *ValueFromIRExpr( lhs ) );
            assert( tdecl );

            auto tfuncType = FromValue< TFuncType >( *ValueFromIRExpr( tdecl->type() ) );
            assert( tfuncType );

            auto callPat = BuildArgPatternFromTDecl( *tdecl );
            auto tdeclHole = MkHole( tdecl->name() );

            auto rhsVal = *ValueFromIRExpr( rhs );

            auto constraintPat = BuildTFuncSignature( uc.context(), *tfuncType );
            assert( constraintPat );

            ConstrainedFunc cfunc( *constraintPat, GetTFuncInvocationRule(), rhsVal );
            auto cFuncTerm = ValueToIRExpr( ToValue( move( cfunc ) ) );

            // Create a new named hole namespace to isolate holes from the passed function from those in
            // the called function.
            auto savedRHSNamespaceIndex = uc.RHSNamespaceIndex();
            uc.setRHSNamespaceIndex( uc.newNamespaceIndex() );

            auto oldValueRequired = uc.isValueResolutionRequired();
            uc.setValueResolutionRequired( false );

            // We are replacing lhs with a different terms and re-unifying,
            // so update the complexity accordingly. The structure of the tdecl
            // shouldn't count, only its pattern.
            uc.subComplexity( GetComplexity( lhs ) );
            uc.addComplexity( GetComplexity( callPat ) );

            for( auto&& [s, uc] : Unify( callPat, rhs, uc ) )
            {
                // Restore the namespace
                auto localC = uc;
                localC.setRHSNamespaceIndex( savedRHSNamespaceIndex );
                localC.setValueResolutionRequired( oldValueRequired );

                // We need to unify the result with a hole named after the decl. However, since both sides of
                // this unification orignally appeared on the LHS, we need to setup RHS to alias the LHS namespace for this.
                localC.setRHSNamespaceIndex( localC.LHSNamespaceIndex() );

                for( auto&& [s,c] : Unify( cFuncTerm, tdeclHole, localC ) )
                {
                    auto localC = c;
                    localC.setRHSNamespaceIndex( savedRHSNamespaceIndex );
                    co_yield { s, localC };
                }
            }

            uc.setRHSNamespaceIndex( savedRHSNamespaceIndex );
            uc.setValueResolutionRequired( oldValueRequired );
        } );

        // tfunc tdecl param / overloadset arg
        e.unificationRuleSet()->addSymRule(

            move( tDeclTFuncPat ),

            ValueToIRExpr( ValuePattern(
                TSID( constant ),
                GetValueType< ptr< builtins::OverloadSet > >(),
                ANYTERM( _ ) ) ),

        []( const Term& lhs, const Term& rhs, UnificationContext& uc ) -> UniGen
        {
            auto tdecl = FromValue< TDecl >( *ValueFromIRExpr( lhs ) );
            assert( tdecl );

            auto tfuncType = FromValue< TFuncType >( *ValueFromIRExpr( tdecl->type() ) );
            assert( tfuncType );

            auto callPat = BuildArgPatternFromTDecl( *tdecl );
            auto tdeclHole = MkHole( tdecl->name() );

            auto rhsVal = *ValueFromIRExpr( rhs );

            auto constraintPat = BuildTFuncSignature( uc.context(), *tfuncType );
            assert( constraintPat );

            ConstrainedFunc cfunc( *constraintPat, GetOverloadSetInvocationRule(), rhsVal );
            auto cFuncTerm = ValueToIRExpr( ToValue( move( cfunc ) ) );

            // Create a new named hole namespace to isolate holes from the passed function from those in
            // the called function.
            auto savedRHSNamespaceIndex = uc.RHSNamespaceIndex();
            uc.setRHSNamespaceIndex( uc.newNamespaceIndex() );

            auto oldValueRequired = uc.isValueResolutionRequired();
            uc.setValueResolutionRequired( false );

            // We are replacing lhs with a different terms and re-unifying,
            // so update the complexity accordingly. The structure of the tdecl
            // shouldn't count, only its pattern.
            uc.subComplexity( GetComplexity( lhs ) );
            uc.addComplexity( GetComplexity( callPat ) );

            for( auto&& [s, uc] : Unify( callPat, rhs, uc ) )
            {
                // Restore the namespace
                auto localC = uc;
                localC.setRHSNamespaceIndex( savedRHSNamespaceIndex );
                localC.setValueResolutionRequired( oldValueRequired );

                // We need to unify the result with a hole named after the decl. However, since both sides of
                // this unification orignally appeared on the LHS, we need to setup RHS to alias the LHS namespace for this.
                localC.setRHSNamespaceIndex( localC.LHSNamespaceIndex() );

                for( auto&& [s,c] : Unify( cFuncTerm, tdeclHole, localC ) )
                {
                    auto localC = c;
                    localC.setRHSNamespaceIndex( savedRHSNamespaceIndex );
                    co_yield { s, localC };
                }
            }

            uc.setRHSNamespaceIndex( savedRHSNamespaceIndex );
            uc.setValueResolutionRequired( oldValueRequired );
        } );
    }
}