Goose  match.cpp at [176ee856a6]

File bs/eir/match.cpp artifact 7464fbb2da part of check-in 176ee856a6


#include "eir.h"

namespace goose::eir
{
    size_t MatchSolution::numVars() const
    {
        if( !m_pVars )
            return 0;

        return m_pVars->size();
    }

    void MatchSolution::setupVars()
    {
        if( !m_pVars )
            m_pVars = make_shared< unordered_map< StringId, any > >();
        else if( m_pVars.use_count() > 1 )
            m_pVars = make_shared< unordered_map< StringId, any > >( *m_pVars );
    }

    Generator< MatchSolution > Match( const Term& expression, const Trie<>& patterns )
    {
        for( auto&& [s,p] : Match<>( expression, patterns ) )
            co_yield move( s );
    }

    optional< MatchSolution > Match( MatchSolution&& currentSolution, const Term& expression, const Term& pattern );

    //--------------------------------------------------------------------------------------------
    // AnyTerm
    //--------------------------------------------------------------------------------------------
    optional< MatchSolution > MatchAnyTerm( MatchSolution&& currentSolution, const Term& expression, const AnyTerm& pattern )
    {
        if( currentSolution.setVar( pattern.varName(), expression ) )
            return move( currentSolution );

        return nullopt;
    }

    //--------------------------------------------------------------------------------------------
    // VecOfLength
    //--------------------------------------------------------------------------------------------
    optional< MatchSolution > MatchVecOfLength( MatchSolution&& currentSolution, const Term& expression, const VecOfLength& pattern )
    {
        if( !holds_alternative< pvec >( expression ) )
            return nullopt;

        const auto& vec = *get< pvec >( expression );
        const auto exprLength = vec.length();

        // Ignore variable length vectors in expressions. They can't be matched.
        // We don't need it and it's much simpler this way.
        if( exprLength.isVariable() )
            return nullopt;

        if( currentSolution.setVar( pattern.varName(), exprLength.minLength() ) )
        {
            currentSolution.addComplexity( 1 );
            return move( currentSolution );
        }

        return nullopt;
    }

    //--------------------------------------------------------------------------------------------
    // Vec
    //--------------------------------------------------------------------------------------------
    optional< MatchSolution > MatchVec( MatchSolution&& currentSolution, const Term& expression, const pvec& pattern )
    {
        if( !holds_alternative< pvec >( expression ) )
            return nullopt;

        const auto& vec = *get< pvec >( expression );
        const auto exprLength = vec.length();

        // Ignore variable length vectors in expressions. They can't be matched.
        // We don't need it and it's much simpler this way.
        if( exprLength.isVariable() )
            return nullopt;

        const auto patLength = pattern->length();

        if( exprLength.minLength() < patLength.minLength() )
            return nullopt;

        if( !patLength.isVariable() && patLength.minLength() < exprLength.minLength() )
            return nullopt;

        auto gen = vec.forEachTerm();
        auto it = gen.begin();

        auto curSolution = move( currentSolution );

        for( auto&& pat : pattern->terms() )
        {
            assert( it != gen.end() );

            auto sol = Match( move( curSolution ), *it, pat );
            if( !sol )
                return nullopt;

            curSolution = move( *sol );
            ++it;
        }

        assert( it == gen.end() || pattern->repetitionTerm() );

        while( it != gen.end() )
        {
            auto sol = Match( move( curSolution ), *it, *pattern->repetitionTerm() );
            if( !sol )
                return nullopt;

            curSolution = move( *sol );
            ++it;
        }

        curSolution.addComplexity( 1 );
        return move( curSolution );
    }

    //--------------------------------------------------------------------------------------------
    // Value
    //--------------------------------------------------------------------------------------------
    template< typename T >
    optional< MatchSolution > MatchValue( MatchSolution&& currentSolution, const Term& expression, const T& pattern )
    {
        using TT = remove_cvref_t< decltype( pattern ) >;

        const auto* pVal = get_if< TT >( &expression );
        if( pVal && *pVal == pattern )
        {
            currentSolution.addComplexity( 2 );
            return move( currentSolution );
        }

        return nullopt;
    }

    //--------------------------------------------------------------------------------------------
    // Hole
    //--------------------------------------------------------------------------------------------
    optional< MatchSolution > MatchHole( MatchSolution&& currentSolution, const Term& expression, const Hole& pattern )
    {
        const auto* pHole = get_if< Hole >( &expression );
        if( !pHole )
            return nullopt;

        auto s = Match( move( currentSolution ), pHole->flavor(), pattern.flavor() );
        if( !s )
            return nullopt;

        // A hole with an empty name means "match any hole".
        if( pattern.name() == ""_sid || pHole->name() == ""_sid )
        {
            s->addComplexity( 1 );
            return move( *s );
        }

        if( pHole->name() == pattern.name() )
        {
            s->addComplexity( 2 );
            return move( *s );
        }

        return nullopt;
    }

    //--------------------------------------------------------------------------------------------
    // LocationId
    //--------------------------------------------------------------------------------------------
    template< typename T >
    optional< MatchSolution > MatchValue( MatchSolution&& currentSolution, const Term& expression, const LocationId& pattern )
    {
        return move( currentSolution );
    }

    //--------------------------------------------------------------------------------------------
    // ptr< void >
    //--------------------------------------------------------------------------------------------
    template< typename T >
    optional< MatchSolution > MatchValue( MatchSolution&& currentSolution, const Term& expression, const ptr< void >& pattern )
    {
        return move( currentSolution );
    }

    //--------------------------------------------------------------------------------------------
    // void*
    //--------------------------------------------------------------------------------------------
    template< typename T >
    optional< MatchSolution > MatchValue( MatchSolution&& currentSolution, const Term& expression, const void* pattern )
    {
        return move( currentSolution );
    }

    //--------------------------------------------------------------------------------------------
    // Terms
    //--------------------------------------------------------------------------------------------
    optional< MatchSolution > Match( MatchSolution&& currentSolution, const Term& expression, const Term& pattern )
    {
        return visit( [&]< typename T >( const T& t )
        {
            if constexpr( is_same_v< T, AnyTerm > )
                return MatchAnyTerm( move( currentSolution ), expression, t );
            else if constexpr( is_same_v< T, VecOfLength > )
                return MatchVecOfLength( move( currentSolution ), expression, t );
            else if constexpr( is_same_v< T, Hole > )
                return MatchHole( move( currentSolution ), expression, t );
            else if constexpr( is_same_v< T, pvec > )
                return MatchVec( move( currentSolution ), expression, t );
            else
                return MatchValue( move( currentSolution ), expression, t );
        }, pattern );
    }

    optional< MatchSolution > Match( const Term& expression, const Term& pattern )
    {
        return Match( MatchSolution(), expression, pattern );
    }
}