#include "g0api/g0api.h"
#include "eir/eir.h"
#include "parse/parse.h"
#include "builtins/helpers.h"
using namespace goose;
using namespace goose::parse;
using namespace goose::g0api;
namespace
{
template< typename T >
void RegisterMkTermOverload( Env& e, const ptr< OverloadSet >& pOvlSet )
{
if constexpr( IsTypeWrapper< T >::value )
{
RegisterBuiltinFunc< TypeWrapper< Term > ( T ) >( e, pOvlSet,
[]( const T& v ) -> TypeWrapper< Term >
{
return TERM( v.get() );
} );
}
else
{
RegisterBuiltinFunc< TypeWrapper< Term > ( T ) >( e, pOvlSet,
[]( const T& v ) -> TypeWrapper< Term >
{
return TERM( v );
} );
}
}
template< typename T >
void RegisterGetTermValueOverload( Env& e, const ptr< OverloadSet >& pOvlSet )
{
if constexpr( IsTypeWrapper< T >::value )
{
RegisterBuiltinFunc< bool ( TypeWrapper< Term >, TermRef< T > ) >( e, pOvlSet,
[]( const TypeWrapper< Term >& t, TermRef< T >& tref )
{
const auto* pVal = get_if< typename T::type >( &t.get() );
if( !pVal )
return ToValue( false );
tref = *pVal;
return ToValue( true );
} );
}
else
{
RegisterBuiltinFunc< bool ( TypeWrapper< Term >, TermRef< T > ) >( e, pOvlSet,
[]( const TypeWrapper< Term >& t, TermRef< T >& tref )
{
const auto* pVal = get_if< T >( &t.get() );
if( !pVal )
return ToValue( false );
tref = *pVal;
return ToValue( true );
} );
}
}
}
namespace goose::g0api
{
void SetupEIRExtensibilityFuncs( Env& e )
{
// Constants.
DefineConstant( e, "DelimiterOpenParen"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( Delimiter::OpenParen ) ) ) );
DefineConstant( e, "DelimiterOpenBrace"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( Delimiter::OpenBrace ) ) ) );
DefineConstant( e, "DelimiterOpenBracket"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( Delimiter::OpenBracket ) ) ) );
DefineConstant( e, "DelimiterCloseParen"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( Delimiter::CloseParen ) ) ) );
DefineConstant( e, "DelimiterCloseBrace"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( Delimiter::CloseBrace ) ) ) );
DefineConstant( e, "DelimiterCloseBracket"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( Delimiter::CloseBracket ) ) ) );
// This enum must match the order of the types in the Term variant.
enum class TermType
{
UInt64,
LocationId,
String,
StringId,
Delimiter,
Hole,
AnyTerm,
VecOfLength,
Vec,
BigInt,
FixedInt,
Internal
};
DefineConstant( e, "TermTypeUInt64"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( TermType::UInt64 ) ) ) );
DefineConstant( e, "TermTypeLocationId"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( TermType::LocationId ) ) ) );
DefineConstant( e, "TermTypeString"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( TermType::String ) ) ) );
DefineConstant( e, "TermTypeStringId"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( TermType::StringId ) ) ) );
DefineConstant( e, "TermTypeDelimiter"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( TermType::Delimiter ) ) ) );
DefineConstant( e, "TermTypeHole"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( TermType::Hole ) ) ) );
DefineConstant( e, "TermTypeAnyTerm"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( TermType::AnyTerm ) ) ) );
DefineConstant( e, "TermTypeVecOfLength"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( TermType::VecOfLength ) ) ) );
DefineConstant( e, "TermTypeVec"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( TermType::Vec ) ) ) );
DefineConstant( e, "TermTypeBigInt"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( TermType::BigInt ) ) ) );
DefineConstant( e, "TermTypeFixedInt"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( TermType::FixedInt ) ) ) );
DefineConstant( e, "TermTypeInternal"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( TermType::Internal ) ) ) );
// Functions
RegisterBuiltinFunc< BigInt ( TypeWrapper< Term > ) >( e, "GetTermType"_sid,
[]( const TypeWrapper< Term >& t )
{
return BigInt::FromU32( min< uint8_t >( t.get().index(), static_cast< uint8_t >( TermType::Internal ) ) );
} );
////////////////////////////
// MkTerm overloads
////////////////////////////
auto MkTerm = CreateOverloadSet( e, "MkTerm"_sid );
RegisterMkTermOverload< uint64_t >( e, MkTerm );
RegisterMkTermOverload< TypeWrapper< LocationId > >( e, MkTerm );
RegisterMkTermOverload< string >( e, MkTerm );
RegisterMkTermOverload< TypeWrapper< StringId > >( e, MkTerm );
RegisterBuiltinFunc< TypeWrapper< Term > ( uint8_t ) >( e, "MkDelimiterTerm"_sid,
[]( uint8_t d ) -> TypeWrapper< Term >
{
return TERM( static_cast< Delimiter >( d ) );
} );
RegisterMkTermOverload< TypeWrapper< Hole > >( e, MkTerm );
RegisterMkTermOverload< TypeWrapper< AnyTerm > >( e, MkTerm );
RegisterMkTermOverload< TypeWrapper< VecOfLength > >( e, MkTerm );
RegisterMkTermOverload< TypeWrapper< pvec > >( e, MkTerm );
RegisterMkTermOverload< BigInt >( e, MkTerm );
RegisterMkTermOverload< TypeWrapper< APSInt > >( e, MkTerm );
////////////////////////////
// GetTermValue overloads
////////////////////////////
auto GetTermValue = CreateOverloadSet( e, "GetTermValue"_sid );
RegisterGetTermValueOverload< uint64_t >( e, GetTermValue );
RegisterGetTermValueOverload< TypeWrapper< LocationId > >( e, GetTermValue );
RegisterGetTermValueOverload< string >( e, GetTermValue );
RegisterGetTermValueOverload< TypeWrapper< StringId > >( e, GetTermValue );
RegisterBuiltinFunc< bool ( TypeWrapper< Term >, TermRef< uint8_t > ) >( e, "GetDelimiterTermValue"_sid,
[]( const TypeWrapper< Term >& t, TermRef< uint8_t >& tref )
{
const auto* pVal = get_if< Delimiter >( &t.get() );
if( !pVal )
return ToValue( false );
tref = static_cast< uint8_t >( *pVal );
return ToValue( true );
} );
RegisterGetTermValueOverload< TypeWrapper< Hole > >( e, GetTermValue );
RegisterGetTermValueOverload< TypeWrapper< AnyTerm > >( e, GetTermValue );
RegisterGetTermValueOverload< TypeWrapper< VecOfLength > >( e, GetTermValue );
RegisterGetTermValueOverload< TypeWrapper< pvec > >( e, GetTermValue );
RegisterGetTermValueOverload< BigInt >( e, GetTermValue );
RegisterGetTermValueOverload< TypeWrapper< APSInt > >( e, GetTermValue );
////////////////////////////
// LocationId
////////////////////////////
RegisterBuiltinFunc< TypeWrapper< LocationId > ( TypeWrapper< LocationId >, TypeWrapper< LocationId > ) >( e,
"MkSpanningLocation"_sid,
[]( const TypeWrapper< LocationId >& loc1, const TypeWrapper< LocationId >& loc2 ) -> TypeWrapper< LocationId >
{
return static_cast< LocationId >( Location::CreateSpanningLocation( loc1.get(), loc2.get() ) );
} );
////////////////////////////
// StringId
////////////////////////////
RegisterBuiltinFunc< TypeWrapper< StringId > ( string ) >( e, "MkStringId"_sid,
[]( string s ) -> TypeWrapper< StringId >
{
return s;
} );
////////////////////////////
// Hole
////////////////////////////
RegisterBuiltinFunc< TypeWrapper< Hole > ( TypeWrapper< StringId > ) >( e, "MkHole"_sid,
[]( const TypeWrapper< StringId >& name ) -> TypeWrapper< Hole >
{
return Hole( name );
} );
RegisterBuiltinFunc< TypeWrapper< Hole > ( TypeWrapper< StringId >, TypeWrapper< Term > ) >( e, "MkHole"_sid,
[]( const TypeWrapper< StringId >& name, const TypeWrapper< Term >& kind ) -> TypeWrapper< Hole >
{
return Hole( name, kind.get() );
} );
RegisterBuiltinFunc< TypeWrapper< StringId > ( TypeWrapper< Hole > ) >( e, "GetHoleName"_sid,
[]( const TypeWrapper< Hole >& h ) -> TypeWrapper< StringId >
{
return h.get().name();
} );
RegisterBuiltinFunc< TypeWrapper< Term > ( TypeWrapper< Hole > ) >( e, "GetHoleKind"_sid,
[]( const TypeWrapper< Hole >& h ) -> TypeWrapper< Term >
{
return h.get().kind();
} );
////////////////////////////
// AnyTerm
////////////////////////////
RegisterBuiltinFunc< TypeWrapper< StringId > ( TypeWrapper< AnyTerm > ) >( e, "GetAnyTermVarName"_sid,
[]( const TypeWrapper< AnyTerm >& at ) -> TypeWrapper< StringId >
{
return at.get().varName();
} );
RegisterBuiltinFunc< TypeWrapper< AnyTerm > ( TypeWrapper< StringId > ) >( e, "MkAnyTerm"_sid,
[]( const TypeWrapper< StringId >& name ) -> TypeWrapper< AnyTerm >
{
return AnyTerm( name );
} );
////////////////////////////
// VecOfLength
////////////////////////////
RegisterBuiltinFunc< TypeWrapper< StringId > ( TypeWrapper< VecOfLength > ) >( e, "GetVecOfLengthVarName"_sid,
[]( const TypeWrapper< VecOfLength >& at ) -> TypeWrapper< StringId >
{
return at.get().varName();
} );
RegisterBuiltinFunc< TypeWrapper< VecOfLength > ( TypeWrapper< StringId > ) >( e, "MkVecOfLength"_sid,
[]( const TypeWrapper< StringId >& name ) -> TypeWrapper< VecOfLength >
{
return VecOfLength( name );
} );
////////////////////////////
// Vector
////////////////////////////
RegisterBuiltinFunc< TypeWrapper< pvec > () >( e, "MkVec"_sid,
[]() -> TypeWrapper< pvec >
{
return make_shared< Vector >();
} );
RegisterBuiltinFunc< void ( TypeWrapper< pvec >, TypeWrapper< pvec > ) >( e, "MkVecConcat"_sid,
[]( const TypeWrapper< pvec >& vec1, const TypeWrapper< pvec >& vec2 )
{
return make_shared< Vector >( Vector::MakeConcat( *vec1.get(), *vec2.get() ) );
} );
RegisterBuiltinFunc< void ( TypeWrapper< pvec >, uint32_t, TypeWrapper< Term > ) >( e, "SetTerm"_sid,
[]( const TypeWrapper< pvec >& vec, uint32_t index, const TypeWrapper< Term >& t )
{
vec->terms()[index] = t.get();
} );
RegisterBuiltinFunc< void ( TypeWrapper< pvec >, int32_t ) >( e, "SetVecWeight"_sid,
[]( const TypeWrapper< pvec >& vec, int32_t w )
{
vec->setWeight( w );
} );
RegisterBuiltinFunc< void ( TypeWrapper< pvec >, int32_t ) >( e, "SetVecWeightOverride"_sid,
[]( const TypeWrapper< pvec >& vec, int32_t w )
{
vec->setWeightOverride( w );
} );
RegisterBuiltinFunc< void ( TypeWrapper< pvec >, TypeWrapper< Term > ) >( e, "VecAppend"_sid,
[]( const TypeWrapper< pvec >& vec, const TypeWrapper< Term >& t )
{
vec->append( t.get() );
} );
RegisterBuiltinFunc< void ( TypeWrapper< pvec >, TypeWrapper< Term > ) >( e, "VecSetRepetition"_sid,
[]( const TypeWrapper< pvec >& vec, const TypeWrapper< Term >& t )
{
vec->setRepetitionTerm( t.get() );
} );
RegisterBuiltinFunc< tuple< uint32_t, bool > ( TypeWrapper< pvec > ) >( e, "GetVecLength"_sid,
[]( const TypeWrapper< pvec >& vec )
{
auto vl = vec->length();
return make_tuple( static_cast< uint32_t >( vl.minLength() ), vl.isVariable() );
} );
RegisterBuiltinFunc< bool ( TypeWrapper< pvec >, uint32_t, TermRef< TypeWrapper< Term > > ) >( e, "GetVecTerm"_sid,
[]( const TypeWrapper< pvec >& vec, uint32_t index, TermRef< TypeWrapper< Term > >& out )
{
const auto& v = *vec.get();
if( v.length().minLength() <= index )
return false;
out = v.terms()[index];
return true;
} );
RegisterBuiltinFunc< int32_t ( TypeWrapper< pvec > ) >( e, "GetVecWeight"_sid,
[]( const TypeWrapper< pvec >& vec )
{
return vec->weight();
} );
RegisterBuiltinFunc< int32_t ( TypeWrapper< pvec > ) >( e, "GetVecWeightOverride"_sid,
[]( const TypeWrapper< pvec >& vec )
{
return vec->weightOverride();
} );
RegisterBuiltinFunc< int32_t ( TypeWrapper< pvec >, TermRef< TypeWrapper< Term > > ) >( e, "GetVecRepetitionTerm"_sid,
[]( const TypeWrapper< pvec >& vec, TermRef< TypeWrapper< Term > >& out )
{
const auto& rt = vec->repetitionTerm();
if( !rt )
return false;
out = *rt;
return true;
} );
RegisterBuiltinFunc< bool ( TypeWrapper< pvec > ) >( e, "IsVecEmpty"_sid,
[]( const TypeWrapper< pvec >& vec )
{
return vec->empty();
} );
////////////////////////////
// Helpers
////////////////////////////
RegisterBuiltinFunc< TypeWrapper< Term > ( TypeWrapper< Term >, TypeWrapper< Term > ) >( e, "AppendToVectorTerm"_sid,
[]( const TypeWrapper< Term >& vec, TypeWrapper< Term > t ) -> TypeWrapper< Term >
{
return AppendToVectorTerm( vec, t );
} );
////////////////////////////
// Predicates
////////////////////////////
RegisterBuiltinFunc< bool ( TypeWrapper< ptr< Context > >, TypeWrapper< Value >, TermRef< TypeWrapper< ptr< TypePredicates > > > ) >( e, "GetTypePredicates"_sid,
[]( const TypeWrapper< ptr< Context > >& c, const TypeWrapper< Value >& type, TermRef< TypeWrapper< ptr< TypePredicates > > >& out )
{
if( !ParseTypePredicates( *c.get(), type.get() ) )
return false;
auto ppPreds = GetTypePredicates( type );
if( !ppPreds || !( *ppPreds ) )
return false;
out = *ppPreds;
return true;
} );
RegisterBuiltinFunc< uint32_t ( TypeWrapper< ptr< TypePredicates > > ) >( e, "GetPredicatesCount"_sid,
[]( const TypeWrapper< ptr< TypePredicates > >& preds )
{
return preds->m_predicates.size();
} );
RegisterBuiltinFunc< bool ( TypeWrapper< ptr< TypePredicates > >, uint32_t, TermRef< TypeWrapper< Value > > ) >( e, "GetPredicate"_sid,
[]( const TypeWrapper< ptr< TypePredicates > >& preds, uint32_t index, TermRef< TypeWrapper< Value > >& out )
{
if( preds->m_predicates.size() <= index )
return false;
out = preds->m_predicates[index];
return true;
} );
}
}