Goose  Artifact [661dc45a61]

Artifact 661dc45a619e504f72cb0916c816adc1dc98ce25f15dd60c50b3c9ed21680255:

  • File bs/g0api/extensibility/eir.cpp — part of check-in [967d3ba3d7] at 2021-09-16 19:00:02 on branch trunk — More work on the g0 EIR api (user: achavasse size: 12934)

#include "g0api/g0api.h"
#include "eir/eir.h"
#include "parse/parse.h"
#include "builtins/helpers.h"

using namespace goose;
using namespace goose::parse;
using namespace goose::g0api;

namespace
{
    template< typename T >
    void RegisterMkTermOverload( Env& e, const ptr< OverloadSet >& pOvlSet )
    {
        if constexpr( IsTypeWrapper< T >::value )
        {
            RegisterBuiltinFunc< TermWrapper ( T ) >( e, pOvlSet,
                []( const T& v ) -> TermWrapper
                {
                    return TERM( v.get() );
                } );
        }
        else
        {
            RegisterBuiltinFunc< TermWrapper ( T ) >( e, pOvlSet,
                []( const T& v ) -> TermWrapper
                {
                    return TERM( v );
                } );
        }
    }

    template< typename T >
    void RegisterGetTermValueOverload( Env& e, const ptr< OverloadSet >& pOvlSet )
    {
        if constexpr( IsTypeWrapper< T >::value )
        {
            RegisterBuiltinFunc< bool ( TermWrapper, TermRef< T > ) >( e, pOvlSet,
                []( const TermWrapper& t, TermRef< T >& tref )
                {
                    const auto* pVal = get_if< typename T::type >( &t.get() );
                    if( !pVal )
                        return ToValue( false );

                    tref = *pVal;
                    return ToValue( true );
                } );
        }
        else
        {
            RegisterBuiltinFunc< bool ( TermWrapper, TermRef< T > ) >( e, pOvlSet,
                []( const TermWrapper& t, TermRef< T >& tref )
                {
                    const auto* pVal = get_if< T >( &t.get() );
                    if( !pVal )
                        return ToValue( false );

                    tref = *pVal;
                    return ToValue( true );
                } );
        }
    }
}

namespace goose::g0api
{
    void SetupEIRExtensibilityFuncs( Env& e )
    {
        // Constants.
        DefineConstant( e, "DelimiterOpenParen"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( Delimiter::OpenParen ) ) ) );
        DefineConstant( e, "DelimiterOpenBrace"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( Delimiter::OpenBrace ) ) ) );
        DefineConstant( e, "DelimiterOpenBracket"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( Delimiter::OpenBracket ) ) ) );
        DefineConstant( e, "DelimiterCloseParen"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( Delimiter::CloseParen ) ) ) );
        DefineConstant( e, "DelimiterCloseBrace"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( Delimiter::CloseBrace ) ) ) );
        DefineConstant( e, "DelimiterCloseBracket"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( Delimiter::CloseBracket ) ) ) );

        // This enum must match the order of the types in the Term variant.
        enum class TermType
        {
            UInt32,
            LocationId,
            String,
            StringId,
            Delimiter,
            Hole,
            AnyTerm,
            VecOfLength,
            Vec,
            BigInt,
            FixedInt,
            Internal
        };

        DefineConstant( e, "TermTypeUInt32"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( TermType::UInt32 ) ) ) );
        DefineConstant( e, "TermTypeLocationId"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( TermType::LocationId ) ) ) );
        DefineConstant( e, "TermTypeString"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( TermType::String ) ) ) );
        DefineConstant( e, "TermTypeStringId"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( TermType::StringId ) ) ) );
        DefineConstant( e, "TermTypeDelimiter"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( TermType::Delimiter ) ) ) );
        DefineConstant( e, "TermTypeHole"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( TermType::Hole ) ) ) );
        DefineConstant( e, "TermTypeAnyTerm"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( TermType::AnyTerm ) ) ) );
        DefineConstant( e, "TermTypeVecOfLength"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( TermType::VecOfLength ) ) ) );
        DefineConstant( e, "TermTypeVec"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( TermType::Vec ) ) ) );
        DefineConstant( e, "TermTypeBigInt"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( TermType::BigInt ) ) ) );
        DefineConstant( e, "TermTypeFixedInt"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( TermType::FixedInt ) ) ) );
        DefineConstant( e, "TermTypeInternal"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( TermType::Internal ) ) ) );

        // Functions
        RegisterBuiltinFunc< BigInt ( TermWrapper ) >( e, "GetTermType"_sid,
            []( const TermWrapper& t )
            {
                return BigInt::FromU32( min< uint8_t >( t.get().index(), static_cast< uint8_t >( TermType::Internal ) ) );
            } );

        ////////////////////////////
        // MkTerm overloads
        ////////////////////////////
        auto MkTerm = CreateOverloadSet( e, "MkTerm"_sid );

        RegisterMkTermOverload< uint32_t >( e, MkTerm );
        RegisterMkTermOverload< TypeWrapper< LocationId > >( e, MkTerm );
        RegisterMkTermOverload< string >( e, MkTerm );
        RegisterMkTermOverload< TypeWrapper< StringId > >( e, MkTerm );

        RegisterBuiltinFunc< TermWrapper ( uint8_t ) >( e, "MkDelimiterTerm"_sid,
            []( uint8_t d ) -> TermWrapper
            {
                return TERM( static_cast< Delimiter >( d ) );
            } );

        RegisterMkTermOverload< TypeWrapper< Hole > >( e, MkTerm );
        RegisterMkTermOverload< TypeWrapper< AnyTerm > >( e, MkTerm );
        RegisterMkTermOverload< TypeWrapper< VecOfLength > >( e, MkTerm );
        RegisterMkTermOverload< TypeWrapper< pvec > >( e, MkTerm );
        RegisterMkTermOverload< BigInt >( e, MkTerm );
        RegisterMkTermOverload< TypeWrapper< APSInt > >( e, MkTerm );

        ////////////////////////////
        // GetTermValue overloads
        ////////////////////////////
        auto GetTermValue = CreateOverloadSet( e, "GetTermValue"_sid );
        RegisterGetTermValueOverload< uint32_t >( e, GetTermValue );
        RegisterGetTermValueOverload< TypeWrapper< LocationId > >( e, GetTermValue );
        RegisterGetTermValueOverload< string >( e, GetTermValue );
        RegisterGetTermValueOverload< TypeWrapper< StringId > >( e, GetTermValue );

        RegisterBuiltinFunc< bool ( TermWrapper, TermRef< uint8_t > ) >( e, "GetDelimiterTermValue"_sid,
            []( const TermWrapper& t, TermRef< uint8_t >& tref )
            {
                const auto* pVal = get_if< Delimiter >( &t.get() );
                if( !pVal )
                    return ToValue( false );

                tref = static_cast< uint8_t >( *pVal );
                return ToValue( true );
            } );

        RegisterGetTermValueOverload< TypeWrapper< Hole > >( e, GetTermValue );
        RegisterGetTermValueOverload< TypeWrapper< AnyTerm > >( e, GetTermValue );
        RegisterGetTermValueOverload< TypeWrapper< VecOfLength > >( e, GetTermValue );
        RegisterGetTermValueOverload< TypeWrapper< pvec > >( e, GetTermValue );
        RegisterGetTermValueOverload< BigInt >( e, GetTermValue );
        RegisterGetTermValueOverload< TypeWrapper< APSInt > >( e, GetTermValue );

        ////////////////////////////
        // LocationId
        ////////////////////////////
        RegisterBuiltinFunc< TypeWrapper< LocationId > ( TypeWrapper< LocationId >, TypeWrapper< LocationId > ) >( e,
            "MkSpanningLocation"_sid,
            []( const TypeWrapper< LocationId >& loc1, const TypeWrapper< LocationId >& loc2 ) -> TypeWrapper< LocationId >
            {
                return static_cast< LocationId >( Location::CreateSpanningLocation(
                    static_cast< uint32_t >( loc1.get() ),
                    static_cast< uint32_t >( loc2.get() ) ) );
            } );

        ////////////////////////////
        // StringId
        ////////////////////////////
        RegisterBuiltinFunc< TypeWrapper< StringId > ( string ) >( e, "MkStringId"_sid,
            []( string s ) -> TypeWrapper< StringId >
            {
                return s;
            } );

        RegisterBuiltinFunc< string ( TypeWrapper< StringId > ) >( e, "ToString"_sid,
            []( const TypeWrapper< StringId >& sid )
            {
                return sid.get().str();
            } );

        ////////////////////////////
        // Hole
        ////////////////////////////
        RegisterBuiltinFunc< TypeWrapper< Hole > ( TypeWrapper< StringId > ) >( e, "MkHole"_sid,
            []( const TypeWrapper< StringId >& name ) -> TypeWrapper< Hole >
            {
                return Hole( name );
            } );

        RegisterBuiltinFunc< TypeWrapper< Hole > ( TypeWrapper< StringId >, TermWrapper ) >( e, "MkHole"_sid,
            []( const TypeWrapper< StringId >& name, const TermWrapper& kind ) -> TypeWrapper< Hole >
            {
                return Hole( name, kind.get() );
            } );

        RegisterBuiltinFunc< TypeWrapper< StringId > ( TypeWrapper< Hole > ) >( e, "GetHoleName"_sid,
            []( const TypeWrapper< Hole >& h ) -> TypeWrapper< StringId >
            {
                return h.get().name();
            } );

        RegisterBuiltinFunc< TermWrapper ( TypeWrapper< Hole > ) >( e, "GetHoleKind"_sid,
            []( const TypeWrapper< Hole >& h ) -> TermWrapper
            {
                return h.get().kind();
            } );

        ////////////////////////////
        // AnyTerm
        ////////////////////////////
        RegisterBuiltinFunc< TypeWrapper< StringId > ( TypeWrapper< AnyTerm > ) >( e, "GetAnyTermVarName"_sid,
            []( const TypeWrapper< AnyTerm >& at ) -> TypeWrapper< StringId >
            {
                return at.get().varName();
            } );

        RegisterBuiltinFunc< TypeWrapper< AnyTerm > ( TypeWrapper< StringId > ) >( e, "MkAnyTerm"_sid,
            []( const TypeWrapper< StringId >& name ) -> TypeWrapper< AnyTerm >
            {
                return AnyTerm( name );
            } );

        ////////////////////////////
        // VecOfLength
        ////////////////////////////
        RegisterBuiltinFunc< TypeWrapper< StringId > ( TypeWrapper< VecOfLength > ) >( e, "GetVecOfLengthVarName"_sid,
            []( const TypeWrapper< VecOfLength >& at ) -> TypeWrapper< StringId >
            {
                return at.get().varName();
            } );

        RegisterBuiltinFunc< TypeWrapper< VecOfLength > ( TypeWrapper< StringId > ) >( e, "MkVecOfLength"_sid,
            []( const TypeWrapper< StringId >& name ) -> TypeWrapper< VecOfLength >
            {
                return VecOfLength( name );
            } );

        ////////////////////////////
        // Vector
        ////////////////////////////
        RegisterBuiltinFunc< TypeWrapper< pvec > () >( e, "MkVec"_sid,
            []() -> TypeWrapper< pvec >
            {
                return make_shared< Vector >();
            } );

        RegisterBuiltinFunc< tuple< uint32_t, bool > ( TypeWrapper< pvec > ) >( e, "GetVecLength"_sid,
            []( const TypeWrapper< pvec >& vec )
            {
                auto vl = vec.get()->length();
                return make_tuple( static_cast< uint32_t >( vl.minLength() ), vl.isVariable() );
            } );

        RegisterBuiltinFunc< bool ( TypeWrapper< pvec >, uint32_t, TermRef< TermWrapper > ) >( e, "GetVecTerm"_sid,
            []( const TypeWrapper< pvec >& vec, uint32_t index, TermRef< TermWrapper >& out )
            {
                const auto& v = *vec.get();

                if( v.length().minLength() <= index )
                    return false;

                out = v.terms()[index];
                return true;
            } );

        RegisterBuiltinFunc< int32_t ( TypeWrapper< pvec > ) >( e, "GetVecWeight"_sid,
            []( const TypeWrapper< pvec >& vec )
            {
                return vec.get()->weight();
            } );

        RegisterBuiltinFunc< int32_t ( TypeWrapper< pvec > ) >( e, "GetVecWeightOverride"_sid,
            []( const TypeWrapper< pvec >& vec )
            {
                return vec.get()->weightOverride();
            } );

        RegisterBuiltinFunc< int32_t ( TypeWrapper< pvec >, TermRef< TermWrapper > ) >( e, "GetVecRepetitionTerm"_sid,
            []( const TypeWrapper< pvec >& vec, TermRef< TermWrapper >& out )
            {
                const auto& rt = vec.get()->repetitionTerm();

                if( !rt )
                    return false;

                out = *rt;
                return true;
            } );

        RegisterBuiltinFunc< bool ( TypeWrapper< pvec > ) >( e, "IsVecEmpty"_sid,
            []( const TypeWrapper< pvec >& vec )
            {
                return vec.get()->empty();
            } );
    }
}