Goose  Artifact [80b40af42f]

Artifact 80b40af42f042ba03e56d83c12c828c6d7299214319f485918c79b7d6fe0f900:

  • File bs/g0api/extensibility/eir.cpp — part of check-in [1fe5a1ac2b] at 2021-10-31 18:52:16 on branch trunk —
    • Reverted most of the horribly complicated changes done since [b3ff9af3c2fe4925], other than the ability to hash all terms, values, instructions and predicates
    • Solved the problem for which tests were added in [b3ff9af3c2fe4925] in a much simpler way
    (user: zlodo size: 16237)

#include "g0api/g0api.h"
#include "eir/eir.h"
#include "parse/parse.h"
#include "builtins/helpers.h"

using namespace goose;
using namespace goose::parse;
using namespace goose::g0api;

namespace
{
    template< typename T >
    void RegisterMkTermOverload( Env& e, const ptr< OverloadSet >& pOvlSet )
    {
        if constexpr( IsTypeWrapper< T >::value )
        {
            RegisterBuiltinFunc< TypeWrapper< Term > ( T ) >( e, pOvlSet,
                []( const T& v ) -> TypeWrapper< Term >
                {
                    return TERM( v.get() );
                } );
        }
        else
        {
            RegisterBuiltinFunc< TypeWrapper< Term > ( T ) >( e, pOvlSet,
                []( const T& v ) -> TypeWrapper< Term >
                {
                    return TERM( v );
                } );
        }
    }

    template< typename T >
    void RegisterGetTermValueOverload( Env& e, const ptr< OverloadSet >& pOvlSet )
    {
        if constexpr( IsTypeWrapper< T >::value )
        {
            RegisterBuiltinFunc< bool ( TypeWrapper< Term >, TermRef< T > ) >( e, pOvlSet,
                []( const TypeWrapper< Term >& t, TermRef< T >& tref )
                {
                    const auto* pVal = get_if< typename T::type >( &t.get() );
                    if( !pVal )
                        return ToValue( false );

                    tref = *pVal;
                    return ToValue( true );
                } );
        }
        else
        {
            RegisterBuiltinFunc< bool ( TypeWrapper< Term >, TermRef< T > ) >( e, pOvlSet,
                []( const TypeWrapper< Term >& t, TermRef< T >& tref )
                {
                    const auto* pVal = get_if< T >( &t.get() );
                    if( !pVal )
                        return ToValue( false );

                    tref = *pVal;
                    return ToValue( true );
                } );
        }
    }
}

namespace goose::g0api
{
    void SetupEIRExtensibilityFuncs( Env& e )
    {
        // Constants.
        DefineConstant( e, "DelimiterOpenParen"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( Delimiter::OpenParen ) ) ) );
        DefineConstant( e, "DelimiterOpenBrace"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( Delimiter::OpenBrace ) ) ) );
        DefineConstant( e, "DelimiterOpenBracket"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( Delimiter::OpenBracket ) ) ) );
        DefineConstant( e, "DelimiterCloseParen"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( Delimiter::CloseParen ) ) ) );
        DefineConstant( e, "DelimiterCloseBrace"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( Delimiter::CloseBrace ) ) ) );
        DefineConstant( e, "DelimiterCloseBracket"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( Delimiter::CloseBracket ) ) ) );

        // This enum must match the order of the types in the Term variant.
        enum class TermType
        {
            UInt64,
            LocationId,
            String,
            StringId,
            Delimiter,
            Hole,
            AnyTerm,
            VecOfLength,
            Vec,
            BigInt,
            FixedInt,
            Internal
        };

        DefineConstant( e, "TermTypeUInt64"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( TermType::UInt64 ) ) ) );
        DefineConstant( e, "TermTypeLocationId"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( TermType::LocationId ) ) ) );
        DefineConstant( e, "TermTypeString"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( TermType::String ) ) ) );
        DefineConstant( e, "TermTypeStringId"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( TermType::StringId ) ) ) );
        DefineConstant( e, "TermTypeDelimiter"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( TermType::Delimiter ) ) ) );
        DefineConstant( e, "TermTypeHole"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( TermType::Hole ) ) ) );
        DefineConstant( e, "TermTypeAnyTerm"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( TermType::AnyTerm ) ) ) );
        DefineConstant( e, "TermTypeVecOfLength"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( TermType::VecOfLength ) ) ) );
        DefineConstant( e, "TermTypeVec"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( TermType::Vec ) ) ) );
        DefineConstant( e, "TermTypeBigInt"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( TermType::BigInt ) ) ) );
        DefineConstant( e, "TermTypeFixedInt"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( TermType::FixedInt ) ) ) );
        DefineConstant( e, "TermTypeInternal"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( TermType::Internal ) ) ) );

        // Functions
        RegisterBuiltinFunc< BigInt ( TypeWrapper< Term > ) >( e, "GetTermType"_sid,
            []( const TypeWrapper< Term >& t )
            {
                return BigInt::FromU32( min< uint8_t >( t.get().index(), static_cast< uint8_t >( TermType::Internal ) ) );
            } );

        ////////////////////////////
        // MkTerm overloads
        ////////////////////////////
        auto MkTerm = CreateOverloadSet( e, "MkTerm"_sid );

        RegisterMkTermOverload< uint64_t >( e, MkTerm );
        RegisterMkTermOverload< TypeWrapper< LocationId > >( e, MkTerm );
        RegisterMkTermOverload< string >( e, MkTerm );
        RegisterMkTermOverload< TypeWrapper< StringId > >( e, MkTerm );

        RegisterBuiltinFunc< TypeWrapper< Term > ( uint8_t ) >( e, "MkDelimiterTerm"_sid,
            []( uint8_t d ) -> TypeWrapper< Term >
            {
                return TERM( static_cast< Delimiter >( d ) );
            } );

        RegisterMkTermOverload< TypeWrapper< Hole > >( e, MkTerm );
        RegisterMkTermOverload< TypeWrapper< AnyTerm > >( e, MkTerm );
        RegisterMkTermOverload< TypeWrapper< VecOfLength > >( e, MkTerm );
        RegisterMkTermOverload< TypeWrapper< pvec > >( e, MkTerm );
        RegisterMkTermOverload< BigInt >( e, MkTerm );
        RegisterMkTermOverload< TypeWrapper< APSInt > >( e, MkTerm );

        ////////////////////////////
        // GetTermValue overloads
        ////////////////////////////
        auto GetTermValue = CreateOverloadSet( e, "GetTermValue"_sid );
        RegisterGetTermValueOverload< uint64_t >( e, GetTermValue );
        RegisterGetTermValueOverload< TypeWrapper< LocationId > >( e, GetTermValue );
        RegisterGetTermValueOverload< string >( e, GetTermValue );
        RegisterGetTermValueOverload< TypeWrapper< StringId > >( e, GetTermValue );

        RegisterBuiltinFunc< bool ( TypeWrapper< Term >, TermRef< uint8_t > ) >( e, "GetDelimiterTermValue"_sid,
            []( const TypeWrapper< Term >& t, TermRef< uint8_t >& tref )
            {
                const auto* pVal = get_if< Delimiter >( &t.get() );
                if( !pVal )
                    return ToValue( false );

                tref = static_cast< uint8_t >( *pVal );
                return ToValue( true );
            } );

        RegisterGetTermValueOverload< TypeWrapper< Hole > >( e, GetTermValue );
        RegisterGetTermValueOverload< TypeWrapper< AnyTerm > >( e, GetTermValue );
        RegisterGetTermValueOverload< TypeWrapper< VecOfLength > >( e, GetTermValue );
        RegisterGetTermValueOverload< TypeWrapper< pvec > >( e, GetTermValue );
        RegisterGetTermValueOverload< BigInt >( e, GetTermValue );
        RegisterGetTermValueOverload< TypeWrapper< APSInt > >( e, GetTermValue );

        ////////////////////////////
        // LocationId
        ////////////////////////////
        RegisterBuiltinFunc< TypeWrapper< LocationId > ( TypeWrapper< LocationId >, TypeWrapper< LocationId > ) >( e,
            "MkSpanningLocation"_sid,
            []( const TypeWrapper< LocationId >& loc1, const TypeWrapper< LocationId >& loc2 ) -> TypeWrapper< LocationId >
            {
                return static_cast< LocationId >( Location::CreateSpanningLocation( loc1.get(), loc2.get() ) );
            } );

        ////////////////////////////
        // StringId
        ////////////////////////////
        RegisterBuiltinFunc< TypeWrapper< StringId > ( string ) >( e, "MkStringId"_sid,
            []( string s ) -> TypeWrapper< StringId >
            {
                return s;
            } );

        ////////////////////////////
        // Hole
        ////////////////////////////
        RegisterBuiltinFunc< TypeWrapper< Hole > ( TypeWrapper< StringId > ) >( e, "MkHole"_sid,
            []( const TypeWrapper< StringId >& name ) -> TypeWrapper< Hole >
            {
                return Hole( name );
            } );

        RegisterBuiltinFunc< TypeWrapper< Hole > ( TypeWrapper< StringId >, TypeWrapper< Term > ) >( e, "MkHole"_sid,
            []( const TypeWrapper< StringId >& name, const TypeWrapper< Term >& kind ) -> TypeWrapper< Hole >
            {
                return Hole( name, kind.get() );
            } );

        RegisterBuiltinFunc< TypeWrapper< StringId > ( TypeWrapper< Hole > ) >( e, "GetHoleName"_sid,
            []( const TypeWrapper< Hole >& h ) -> TypeWrapper< StringId >
            {
                return h.get().name();
            } );

        RegisterBuiltinFunc< TypeWrapper< Term > ( TypeWrapper< Hole > ) >( e, "GetHoleKind"_sid,
            []( const TypeWrapper< Hole >& h ) -> TypeWrapper< Term >
            {
                return h.get().kind();
            } );

        ////////////////////////////
        // AnyTerm
        ////////////////////////////
        RegisterBuiltinFunc< TypeWrapper< StringId > ( TypeWrapper< AnyTerm > ) >( e, "GetAnyTermVarName"_sid,
            []( const TypeWrapper< AnyTerm >& at ) -> TypeWrapper< StringId >
            {
                return at.get().varName();
            } );

        RegisterBuiltinFunc< TypeWrapper< AnyTerm > ( TypeWrapper< StringId > ) >( e, "MkAnyTerm"_sid,
            []( const TypeWrapper< StringId >& name ) -> TypeWrapper< AnyTerm >
            {
                return AnyTerm( name );
            } );

        ////////////////////////////
        // VecOfLength
        ////////////////////////////
        RegisterBuiltinFunc< TypeWrapper< StringId > ( TypeWrapper< VecOfLength > ) >( e, "GetVecOfLengthVarName"_sid,
            []( const TypeWrapper< VecOfLength >& at ) -> TypeWrapper< StringId >
            {
                return at.get().varName();
            } );

        RegisterBuiltinFunc< TypeWrapper< VecOfLength > ( TypeWrapper< StringId > ) >( e, "MkVecOfLength"_sid,
            []( const TypeWrapper< StringId >& name ) -> TypeWrapper< VecOfLength >
            {
                return VecOfLength( name );
            } );

        ////////////////////////////
        // Vector
        ////////////////////////////
        RegisterBuiltinFunc< TypeWrapper< pvec > () >( e, "MkVec"_sid,
            []() -> TypeWrapper< pvec >
            {
                return make_shared< Vector >();
            } );

        RegisterBuiltinFunc< void ( TypeWrapper< pvec >, TypeWrapper< pvec > ) >( e, "MkVecConcat"_sid,
            []( const TypeWrapper< pvec >& vec1, const TypeWrapper< pvec >& vec2 )
            {
                return make_shared< Vector >( Vector::MakeConcat( *vec1.get(), *vec2.get() ) );
            } );

        RegisterBuiltinFunc< void ( TypeWrapper< pvec >, uint32_t, TypeWrapper< Term > ) >( e, "SetTerm"_sid,
            []( const TypeWrapper< pvec >& vec, uint32_t index, const TypeWrapper< Term >& t )
            {
                vec->terms()[index] = t.get();
            } );

        RegisterBuiltinFunc< void ( TypeWrapper< pvec >, int32_t ) >( e, "SetVecWeight"_sid,
            []( const TypeWrapper< pvec >& vec, int32_t w )
            {
                vec->setWeight( w );
            } );

        RegisterBuiltinFunc< void ( TypeWrapper< pvec >, int32_t ) >( e, "SetVecWeightOverride"_sid,
            []( const TypeWrapper< pvec >& vec, int32_t w )
            {
                vec->setWeightOverride( w );
            } );

        RegisterBuiltinFunc< void ( TypeWrapper< pvec >, TypeWrapper< Term > ) >( e, "VecAppend"_sid,
            []( const TypeWrapper< pvec >& vec, const TypeWrapper< Term >& t )
            {
                vec->append( t.get() );
            } );

        RegisterBuiltinFunc< void ( TypeWrapper< pvec >, TypeWrapper< Term > ) >( e, "VecSetRepetition"_sid,
            []( const TypeWrapper< pvec >& vec, const TypeWrapper< Term >& t )
            {
                vec->setRepetitionTerm( t.get() );
            } );

        RegisterBuiltinFunc< tuple< uint32_t, bool > ( TypeWrapper< pvec > ) >( e, "GetVecLength"_sid,
            []( const TypeWrapper< pvec >& vec )
            {
                auto vl = vec->length();
                return make_tuple( static_cast< uint32_t >( vl.minLength() ), vl.isVariable() );
            } );

        RegisterBuiltinFunc< bool ( TypeWrapper< pvec >, uint32_t, TermRef< TypeWrapper< Term > > ) >( e, "GetVecTerm"_sid,
            []( const TypeWrapper< pvec >& vec, uint32_t index, TermRef< TypeWrapper< Term > >& out )
            {
                const auto& v = *vec.get();

                if( v.length().minLength() <= index )
                    return false;

                out = v.terms()[index];
                return true;
            } );

        RegisterBuiltinFunc< int32_t ( TypeWrapper< pvec > ) >( e, "GetVecWeight"_sid,
            []( const TypeWrapper< pvec >& vec )
            {
                return vec->weight();
            } );

        RegisterBuiltinFunc< int32_t ( TypeWrapper< pvec > ) >( e, "GetVecWeightOverride"_sid,
            []( const TypeWrapper< pvec >& vec )
            {
                return vec->weightOverride();
            } );

        RegisterBuiltinFunc< int32_t ( TypeWrapper< pvec >, TermRef< TypeWrapper< Term > > ) >( e, "GetVecRepetitionTerm"_sid,
            []( const TypeWrapper< pvec >& vec, TermRef< TypeWrapper< Term > >& out )
            {
                const auto& rt = vec->repetitionTerm();

                if( !rt )
                    return false;

                out = *rt;
                return true;
            } );

        RegisterBuiltinFunc< bool ( TypeWrapper< pvec > ) >( e, "IsVecEmpty"_sid,
            []( const TypeWrapper< pvec >& vec )
            {
                return vec->empty();
            } );

        ////////////////////////////
        // Helpers
        ////////////////////////////
        RegisterBuiltinFunc< TypeWrapper< Term > ( TypeWrapper< Term >, TypeWrapper< Term > ) >( e, "AppendToVectorTerm"_sid,
            []( const TypeWrapper< Term >& vec, TypeWrapper< Term > t ) -> TypeWrapper< Term >
            {
                return AppendToVectorTerm( vec, t );
            } );

        ////////////////////////////
        // Predicates
        ////////////////////////////
        RegisterBuiltinFunc< bool ( TypeWrapper< ptr< Context > >, TypeWrapper< Value >, TermRef< TypeWrapper< ptr< TypePredicates > > > ) >( e, "GetTypePredicates"_sid,
            []( const TypeWrapper< ptr< Context > >& c, const TypeWrapper< Value >& type, TermRef< TypeWrapper< ptr< TypePredicates > > >& out )
            {
                if( !ParseTypePredicates( *c.get(), type.get() ) )
                    return false;

                auto ppPreds = GetTypePredicates( type );

                if( !ppPreds || !( *ppPreds ) )
                    return false;

                out = *ppPreds;
                return true;
            } );

        RegisterBuiltinFunc< uint32_t ( TypeWrapper< ptr< TypePredicates > > ) >( e, "GetPredicatesCount"_sid,
            []( const TypeWrapper< ptr< TypePredicates > >& preds )
            {
                return preds->m_predicates.size();
            } );

        RegisterBuiltinFunc< bool ( TypeWrapper< ptr< TypePredicates > >, uint32_t, TermRef< TypeWrapper< Value > > ) >( e, "GetPredicate"_sid,
            []( const TypeWrapper< ptr< TypePredicates > >& preds, uint32_t index, TermRef< TypeWrapper< Value > >& out )
            {
                if( preds->m_predicates.size() <= index )
                    return false;

                out = preds->m_predicates[index];
                return true;
            } );
    }
}