Goose  Artifact [3d7435a178]

Artifact 3d7435a178ed12e53888ca91ed65ed2f7b77e1b2ecb969fd407f9573987d6e64:

  • File bs/sema/uni-holes.cpp — part of check-in [af650a9e95] at 2019-09-22 14:37:55 on branch trunk — Project renaming. (user: achavasse size: 9489)

#include "sema.h"

namespace goose::sema
{
    void SetupHoleUnificationRules( UnificationRuleSet& ruleSet )
    {
        // Anonymous hole half-unification: add 1 to the anon holes count,
        // yield the hold as is.
        ruleSet.addHalfUnificationRule(
            VEC( TSID( hole ), TSID( _ ) ),
        []( const Term& lhs, UnificationContext& c ) -> UniGen
        {
            c.addAnonymousHole();
            co_yield { lhs, c };
        } );

        // Anonymous hole versus anything: add 1 to the anon holes count,
        // yield the half unification of the rhs
        ruleSet.addSymRule(
            VEC( TSID( hole ), TSID( _ ) ),
            ANYTERM( _ ),
        []( const Term& lhs, const Term& rhs, UnificationContext& c ) -> UniGen
        {
            c.addAnonymousHole();

            for( auto&& [s,uc] : HalfUnify( rhs, c.flip() ) )
                co_yield { s, uc.flip() };

            c.flip();
        } );

        // Hole half-unification: Convert it to a numbered hole,
        // If the name wasn't already known, add 1 to the score's unique holes count.
        ruleSet.addHalfUnificationRule(
            VEC( TSID( hole ), ANYTERM( _ ) ),
        []( const Term& lhs, UnificationContext& c ) -> UniGen
        {
            auto lh = *HoleFromIRExpr( lhs );

            if( holds_alternative< StringId >( lh ) )
            {
                // This is a named hole: look up its name.
                const auto& name = get< StringId >( lh );

                auto holeIndex = c.getLHSHoleIndex( name );
                if( holeIndex != UnificationContext::InvalidIndex )
                {
                    if( !c.isHoleLocked( holeIndex ) )
                        co_yield { MkHole( holeIndex ), c };
                }
                else
                {
                    // This is a new name: create a new value,
                    // and increment the number of unique holes in the current score.
                    auto index = c.createValue();
                    c.setLHSHoleIndex( name, index );
                    co_yield { MkHole( index ), c };
                }
            }
            else
            {
                // This is already an indexed hole: yield it as is.
                if( !c.isHoleLocked( get< uint32_t >( lh ) ) )
                    co_yield { lhs, c };
            }
        } );

        // Hole vs anything
        ruleSet.addSymRule(
            VEC( TSID( hole ), ANYTERM( _ ) ),
            ANYTERM( _ ),
        []( const Term& lhs, const Term& rhs, UnificationContext& c ) -> UniGen
        {
            auto h = *HoleFromIRExpr( lhs );
            uint32_t index = 0;

            // Remember the previous complexity count so we know how much complexity
            // is added by this particular sub-term. This is because we need
            // to be able to subtract it when updating the hole's value with a new solution.
            uint32_t oldComplexity = c.complexity();

            if( holds_alternative< uint32_t >( h ) )
                index = get< uint32_t >( h );
            else
            {
                // This is a named hole: look up its name.
                const auto& name = get< StringId >( h );

                index = c.getLHSHoleIndex( name );
                if( index == UnificationContext::InvalidIndex )
                {
                    // This is a new name: create a new value.
                    index = c.createValue();
                    c.setLHSHoleIndex( name, index );
                    auto holeExpr = MkHole( index );

                    for( auto&& [e,c] : HalfUnify( rhs, c.flip() ) )
                    {
                        c.setValue( index, SetComplexity( move( e ), c.complexity() - oldComplexity ) );
                        co_yield { move( holeExpr ), c.flip() };
                    }

                    c.flip();
                    co_return;
                }
            }

            // Reject recursive hole nesting.
            if( c.isHoleLocked( index ) )
                co_return;
            c.lockHole( index );

            auto holeExpr = MkHole( index );

            auto& maybeVal = c.getValue( index );
            if( maybeVal )
            {
                for( auto&& [e,c] : Unify( *maybeVal, rhs, c ) )
                {
                    c.unlockHole( index );
                    c.setValue( index, SetComplexity( move( e ), c.complexity() - oldComplexity ) );
                    co_yield { move( holeExpr ), c };
                }
            }
            else
            {
                for( auto&& [e,c] : HalfUnify( rhs, c.flip() ) )
                {
                    c.unlockHole( index );
                    c.setValue( index, SetComplexity( move( e ), c.complexity() - oldComplexity ) );
                    co_yield { move( holeExpr ), c.flip() };
                }

                c.flip();
            }
        } );

        // Hole vs hole
        ruleSet.addAsymRule(
            VEC( TSID( hole ), ANYTERM( _ ) ),
            VEC( TSID( hole ), ANYTERM( _ ) ),
        []( const Term& lhs, const Term& rhs, UnificationContext& c ) -> UniGen
        {
            auto lh = *HoleFromIRExpr( lhs );
            auto rh = *HoleFromIRExpr( rhs );

            StringId lname;
            StringId rname;

            uint32_t lindex = 0;
            uint32_t rindex = 0;

            if( holds_alternative< StringId >( lh ) )
            {
                // L is a named hole: look up its name.
                lname = get< StringId >( lh );
                lindex = c.getLHSHoleIndex( lname );
            }
            else
                lindex = get< uint32_t >( lh );

            if( holds_alternative< StringId >( rh ) )
            {
                // R is a named hole: look up its name.
                rname = get< StringId >( rh );
                rindex = c.getRHSHoleIndex( rname );
            }
            else
                rindex = get< uint32_t >( rh );

            // If neither hole currently have a value, create a new one.
            if( lindex == UnificationContext::InvalidIndex && rindex == UnificationContext::InvalidIndex )
            {
                auto index = c.createValue();
                c.setLHSHoleIndex( lname, index );
                c.setRHSHoleIndex( rname, index );

                co_yield { MkHole( index ), c };
                co_return;
            }

            // If both holes actually point to the same value, just yield it as the solution.
            if( lindex == rindex )
            {
                co_yield { MkHole( lindex ), c };
                co_return;
            }

            // If either hole doesn't have a value yet, assign it the other one's value.
            if( lindex == UnificationContext::InvalidIndex )
            {
                c.setLHSHoleIndex( lname, rindex );
                co_yield { MkHole( rindex ), c };
                co_return;
            }

            if( rindex == UnificationContext::InvalidIndex )
            {
                c.setRHSHoleIndex( rname, lindex );
                co_yield { MkHole( lindex ), c };
                co_return;
            }

            // Reject recursive hole nesting.
            if( c.isHoleLocked( lindex ) )
                co_return;
            if( c.isHoleLocked( rindex ) )
                co_return;

            c.lockHole( lindex );
            c.lockHole( rindex );

            // If either hole have an empty value, set it to a hole expression with the id of the value
            // stored in the other one. We can't just copy the value over as we would lose the dependency
            // relationship between the two holes.
            const auto& lval = c.getValue( lindex );
            const auto& rval = c.getValue( rindex );

            if( !rval )
            {
                for( auto&& [e,c] : HalfUnify( *lval, c ) )
                {
                    c.unlockHole( lindex );
                    c.unlockHole( rindex );

                    c.setValue( rindex, MkHole( lindex ) );
                    co_yield { MkHole( lindex ), c };
                }
                co_return;
            }

            if( !lval )
            {
                for( auto&& [e,c] : HalfUnify( *rval, c.flip() ) )
                {
                    c.unlockHole( lindex );
                    c.unlockHole( rindex );

                    c.setValue( lindex, MkHole( rindex ) );
                    co_yield { MkHole( rindex ), c.flip() };
                }
                c.flip();
                co_return;
            }

            // Both L and R have a value: unify them, store the result in lhs,
            // replace rhs with a hole expression pointing to lhs's value.

            // Remember the previous complexity count so we know how much complexity
            // is added by this particular sub-term. This is because we need
            // to be able to subtract it when updating the hole's value with a new solution.
            uint32_t oldComplexity = c.complexity();

            for( auto&& [e,c] : Unify( *lval, *rval, c ) )
            {
                c.unlockHole( lindex );
                c.unlockHole( rindex );

                c.setValue( lindex, SetComplexity( move( e ), c.complexity() - oldComplexity ) );
                c.setValue( rindex, MkHole( lindex ) );
                co_yield { MkHole( lindex ), c };
            }
        } );
    }
}