#include "g0api/g0api.h"
#include "eir/eir.h"
#include "parse/parse.h"
#include "builtins/helpers.h"
using namespace goose;
using namespace goose::parse;
using namespace goose::g0api;
namespace
{
template< typename T >
void RegisterMkTermOverload( Env& e, const ptr< OverloadSet >& pOvlSet )
{
if constexpr( IsTypeWrapper< T >::value )
{
RegisterBuiltinFunc< TermWrapper ( T ) >( e, pOvlSet,
[]( const T& v ) -> TermWrapper
{
return TERM( v.get() );
} );
}
else
{
RegisterBuiltinFunc< TermWrapper ( T ) >( e, pOvlSet,
[]( const T& v ) -> TermWrapper
{
return TERM( v );
} );
}
}
template< typename T >
void RegisterGetTermValueOverload( Env& e, const ptr< OverloadSet >& pOvlSet )
{
if constexpr( IsTypeWrapper< T >::value )
{
RegisterBuiltinFunc< bool ( TermWrapper, TermRef< T > ) >( e, pOvlSet,
[]( const TermWrapper& t, TermRef< T >& tref )
{
const auto* pVal = get_if< typename T::type >( &t.get() );
if( !pVal )
return ToValue( false );
tref = *pVal;
return ToValue( true );
} );
}
else
{
RegisterBuiltinFunc< bool ( TermWrapper, TermRef< T > ) >( e, pOvlSet,
[]( const TermWrapper& t, TermRef< T >& tref )
{
const auto* pVal = get_if< T >( &t.get() );
if( !pVal )
return ToValue( false );
tref = *pVal;
return ToValue( true );
} );
}
}
}
namespace goose::g0api
{
void SetupEIRExtensibilityFuncs( Env& e )
{
// Constants.
DefineConstant( e, "DelimiterOpenParen"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( Delimiter::OpenParen ) ) ) );
DefineConstant( e, "DelimiterOpenBrace"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( Delimiter::OpenBrace ) ) ) );
DefineConstant( e, "DelimiterOpenBracket"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( Delimiter::OpenBracket ) ) ) );
DefineConstant( e, "DelimiterCloseParen"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( Delimiter::CloseParen ) ) ) );
DefineConstant( e, "DelimiterCloseBrace"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( Delimiter::CloseBrace ) ) ) );
DefineConstant( e, "DelimiterCloseBracket"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( Delimiter::CloseBracket ) ) ) );
// This enum must match the order of the types in the Term variant.
enum class TermType
{
UInt32,
LocationId,
String,
StringId,
Delimiter,
Hole,
AnyTerm,
VecOfLength,
Vec,
BigInt,
FixedInt,
Internal
};
DefineConstant( e, "TermTypeUInt32"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( TermType::UInt32 ) ) ) );
DefineConstant( e, "TermTypeLocationId"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( TermType::LocationId ) ) ) );
DefineConstant( e, "TermTypeString"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( TermType::String ) ) ) );
DefineConstant( e, "TermTypeStringId"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( TermType::StringId ) ) ) );
DefineConstant( e, "TermTypeDelimiter"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( TermType::Delimiter ) ) ) );
DefineConstant( e, "TermTypeHole"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( TermType::Hole ) ) ) );
DefineConstant( e, "TermTypeAnyTerm"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( TermType::AnyTerm ) ) ) );
DefineConstant( e, "TermTypeVecOfLength"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( TermType::VecOfLength ) ) ) );
DefineConstant( e, "TermTypeVec"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( TermType::Vec ) ) ) );
DefineConstant( e, "TermTypeBigInt"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( TermType::BigInt ) ) ) );
DefineConstant( e, "TermTypeFixedInt"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( TermType::FixedInt ) ) ) );
DefineConstant( e, "TermTypeInternal"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( TermType::Internal ) ) ) );
// Functions
RegisterBuiltinFunc< BigInt ( TermWrapper ) >( e, "GetTermType"_sid,
[]( const TermWrapper& t )
{
return BigInt::FromU32( min< uint8_t >( t.get().index(), static_cast< uint8_t >( TermType::Internal ) ) );
} );
////////////////////////////
// MkTerm overloads
////////////////////////////
auto MkTerm = CreateOverloadSet( e, "MkTerm"_sid );
RegisterMkTermOverload< uint32_t >( e, MkTerm );
RegisterMkTermOverload< TypeWrapper< LocationId > >( e, MkTerm );
RegisterMkTermOverload< string >( e, MkTerm );
RegisterMkTermOverload< TypeWrapper< StringId > >( e, MkTerm );
RegisterBuiltinFunc< TermWrapper ( uint8_t ) >( e, "MkDelimiterTerm"_sid,
[]( uint8_t d ) -> TermWrapper
{
return TERM( static_cast< Delimiter >( d ) );
} );
RegisterMkTermOverload< TypeWrapper< Hole > >( e, MkTerm );
RegisterMkTermOverload< TypeWrapper< AnyTerm > >( e, MkTerm );
RegisterMkTermOverload< TypeWrapper< VecOfLength > >( e, MkTerm );
RegisterMkTermOverload< TypeWrapper< pvec > >( e, MkTerm );
RegisterMkTermOverload< BigInt >( e, MkTerm );
RegisterMkTermOverload< TypeWrapper< APSInt > >( e, MkTerm );
////////////////////////////
// GetTermValue overloads
////////////////////////////
auto GetTermValue = CreateOverloadSet( e, "GetTermValue"_sid );
RegisterGetTermValueOverload< uint32_t >( e, GetTermValue );
RegisterGetTermValueOverload< TypeWrapper< LocationId > >( e, GetTermValue );
RegisterGetTermValueOverload< string >( e, GetTermValue );
RegisterGetTermValueOverload< TypeWrapper< StringId > >( e, GetTermValue );
RegisterBuiltinFunc< bool ( TermWrapper, TermRef< uint8_t > ) >( e, "GetDelimiterTermValue"_sid,
[]( const TermWrapper& t, TermRef< uint8_t >& tref )
{
const auto* pVal = get_if< Delimiter >( &t.get() );
if( !pVal )
return ToValue( false );
tref = static_cast< uint8_t >( *pVal );
return ToValue( true );
} );
RegisterGetTermValueOverload< TypeWrapper< Hole > >( e, GetTermValue );
RegisterGetTermValueOverload< TypeWrapper< AnyTerm > >( e, GetTermValue );
RegisterGetTermValueOverload< TypeWrapper< VecOfLength > >( e, GetTermValue );
RegisterGetTermValueOverload< TypeWrapper< pvec > >( e, GetTermValue );
RegisterGetTermValueOverload< BigInt >( e, GetTermValue );
RegisterGetTermValueOverload< TypeWrapper< APSInt > >( e, GetTermValue );
////////////////////////////
// LocationId
////////////////////////////
RegisterBuiltinFunc< TypeWrapper< LocationId > ( TypeWrapper< LocationId >, TypeWrapper< LocationId > ) >( e,
"MkSpanningLocation"_sid,
[]( const TypeWrapper< LocationId >& loc1, const TypeWrapper< LocationId >& loc2 ) -> TypeWrapper< LocationId >
{
return static_cast< LocationId >( Location::CreateSpanningLocation(
static_cast< uint32_t >( loc1.get() ),
static_cast< uint32_t >( loc2.get() ) ) );
} );
////////////////////////////
// StringId
////////////////////////////
RegisterBuiltinFunc< TypeWrapper< StringId > ( string ) >( e, "MkStringId"_sid,
[]( string s ) -> TypeWrapper< StringId >
{
return s;
} );
RegisterBuiltinFunc< string ( TypeWrapper< StringId > ) >( e, "ToString"_sid,
[]( const TypeWrapper< StringId >& sid )
{
return sid.get().str();
} );
////////////////////////////
// Hole
////////////////////////////
RegisterBuiltinFunc< TypeWrapper< Hole > ( TypeWrapper< StringId > ) >( e, "MkHole"_sid,
[]( const TypeWrapper< StringId >& name ) -> TypeWrapper< Hole >
{
return Hole( name );
} );
RegisterBuiltinFunc< TypeWrapper< Hole > ( TypeWrapper< StringId >, TermWrapper ) >( e, "MkHole"_sid,
[]( const TypeWrapper< StringId >& name, const TermWrapper& kind ) -> TypeWrapper< Hole >
{
return Hole( name, kind.get() );
} );
RegisterBuiltinFunc< TypeWrapper< StringId > ( TypeWrapper< Hole > ) >( e, "GetHoleName"_sid,
[]( const TypeWrapper< Hole >& h ) -> TypeWrapper< StringId >
{
return h.get().name();
} );
RegisterBuiltinFunc< TermWrapper ( TypeWrapper< Hole > ) >( e, "GetHoleKind"_sid,
[]( const TypeWrapper< Hole >& h ) -> TermWrapper
{
return h.get().kind();
} );
////////////////////////////
// AnyTerm
////////////////////////////
RegisterBuiltinFunc< TypeWrapper< StringId > ( TypeWrapper< AnyTerm > ) >( e, "GetAnyTermVarName"_sid,
[]( const TypeWrapper< AnyTerm >& at ) -> TypeWrapper< StringId >
{
return at.get().varName();
} );
RegisterBuiltinFunc< TypeWrapper< AnyTerm > ( TypeWrapper< StringId > ) >( e, "MkAnyTerm"_sid,
[]( const TypeWrapper< StringId >& name ) -> TypeWrapper< AnyTerm >
{
return AnyTerm( name );
} );
////////////////////////////
// VecOfLength
////////////////////////////
RegisterBuiltinFunc< TypeWrapper< StringId > ( TypeWrapper< VecOfLength > ) >( e, "GetVecOfLengthVarName"_sid,
[]( const TypeWrapper< VecOfLength >& at ) -> TypeWrapper< StringId >
{
return at.get().varName();
} );
RegisterBuiltinFunc< TypeWrapper< VecOfLength > ( TypeWrapper< StringId > ) >( e, "MkVecOfLength"_sid,
[]( const TypeWrapper< StringId >& name ) -> TypeWrapper< VecOfLength >
{
return VecOfLength( name );
} );
////////////////////////////
// Vector
////////////////////////////
RegisterBuiltinFunc< TypeWrapper< pvec > () >( e, "MkVec"_sid,
[]() -> TypeWrapper< pvec >
{
return make_shared< Vector >();
} );
RegisterBuiltinFunc< tuple< uint32_t, bool > ( TypeWrapper< pvec > ) >( e, "GetVecLength"_sid,
[]( const TypeWrapper< pvec >& vec )
{
auto vl = vec.get()->length();
return make_tuple( static_cast< uint32_t >( vl.minLength() ), vl.isVariable() );
} );
RegisterBuiltinFunc< bool ( TypeWrapper< pvec >, uint32_t, TermRef< TermWrapper > ) >( e, "GetVecTerm"_sid,
[]( const TypeWrapper< pvec >& vec, uint32_t index, TermRef< TermWrapper >& out )
{
const auto& v = *vec.get();
if( v.length().minLength() <= index )
return false;
out = v.terms()[index];
return true;
} );
RegisterBuiltinFunc< int32_t ( TypeWrapper< pvec > ) >( e, "GetVecWeight"_sid,
[]( const TypeWrapper< pvec >& vec )
{
return vec.get()->weight();
} );
RegisterBuiltinFunc< int32_t ( TypeWrapper< pvec > ) >( e, "GetVecWeightOverride"_sid,
[]( const TypeWrapper< pvec >& vec )
{
return vec.get()->weightOverride();
} );
RegisterBuiltinFunc< int32_t ( TypeWrapper< pvec >, TermRef< TermWrapper > ) >( e, "GetVecRepetitionTerm"_sid,
[]( const TypeWrapper< pvec >& vec, TermRef< TermWrapper >& out )
{
const auto& rt = vec.get()->repetitionTerm();
if( !rt )
return false;
out = *rt;
return true;
} );
RegisterBuiltinFunc< bool ( TypeWrapper< pvec > ) >( e, "IsVecEmpty"_sid,
[]( const TypeWrapper< pvec >& vec )
{
return vec.get()->empty();
} );
}
}