#include "g0api/g0api.h"
#include "eir/eir.h"
#include "parse/parse.h"
using namespace goose;
using namespace goose::parse;
using namespace goose::g0api;
namespace
{
template< typename T >
void RegisterMkTermOverload( Env& e, const ptr< OverloadSet >& pOvlSet )
{
if constexpr( IsTypeWrapper< T >::value )
{
RegisterBuiltinFunc< TypeWrapper< Term > ( T ) >( e, pOvlSet,
[]( const T& v ) -> TypeWrapper< Term >
{
return TERM( v.get() );
} );
}
else
{
RegisterBuiltinFunc< TypeWrapper< Term > ( T ) >( e, pOvlSet,
[]( const T& v ) -> TypeWrapper< Term >
{
return TERM( v );
} );
}
}
template< typename T >
void RegisterGetTermValueOverload( Env& e, const ptr< OverloadSet >& pOvlSet )
{
if constexpr( IsTypeWrapper< T >::value )
{
RegisterBuiltinFunc< bool ( TypeWrapper< Term >, TermRef< T > ) >( e, pOvlSet,
[]( const TypeWrapper< Term >& t, TermRef< T >& tref )
{
const auto* pVal = get_if< typename T::type >( &t.get() );
if( !pVal )
return false;
tref = *pVal;
return true;
} );
}
else
{
RegisterBuiltinFunc< bool ( TypeWrapper< Term >, TermRef< T > ) >( e, pOvlSet,
[]( const TypeWrapper< Term >& t, TermRef< T >& tref )
{
const auto* pVal = get_if< T >( &t.get() );
if( !pVal )
return false;
tref = *pVal;
return true;
} );
}
}
}
namespace goose::g0api
{
void SetupEIRExtensibilityFuncs( Env& e )
{
// Constants.
DefineConstant( e, "DelimiterOpenParen"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( Delimiter::OpenParen ) ) ) );
DefineConstant( e, "DelimiterOpenBrace"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( Delimiter::OpenBrace ) ) ) );
DefineConstant( e, "DelimiterOpenBracket"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( Delimiter::OpenBracket ) ) ) );
DefineConstant( e, "DelimiterCloseParen"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( Delimiter::CloseParen ) ) ) );
DefineConstant( e, "DelimiterCloseBrace"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( Delimiter::CloseBrace ) ) ) );
DefineConstant( e, "DelimiterCloseBracket"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( Delimiter::CloseBracket ) ) ) );
// This enum must match the order of the types in the Term variant.
enum class TermType
{
UInt64,
LocationId,
String,
StringId,
Delimiter,
Hole,
AnyTerm,
VecOfLength,
Vec,
BigInt,
FixedInt,
Internal
};
DefineConstant( e, "TermTypeUInt64"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( TermType::UInt64 ) ) ) );
DefineConstant( e, "TermTypeLocationId"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( TermType::LocationId ) ) ) );
DefineConstant( e, "TermTypeString"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( TermType::String ) ) ) );
DefineConstant( e, "TermTypeStringId"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( TermType::StringId ) ) ) );
DefineConstant( e, "TermTypeDelimiter"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( TermType::Delimiter ) ) ) );
DefineConstant( e, "TermTypeHole"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( TermType::Hole ) ) ) );
DefineConstant( e, "TermTypeAnyTerm"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( TermType::AnyTerm ) ) ) );
DefineConstant( e, "TermTypeVecOfLength"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( TermType::VecOfLength ) ) ) );
DefineConstant( e, "TermTypeVec"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( TermType::Vec ) ) ) );
DefineConstant( e, "TermTypeBigInt"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( TermType::BigInt ) ) ) );
DefineConstant( e, "TermTypeFixedInt"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( TermType::FixedInt ) ) ) );
DefineConstant( e, "TermTypeInternal"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( TermType::Internal ) ) ) );
DefineConstant( e, "HoleBhvStandard"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( Hole::Behavior::Standard ) ) ) );
DefineConstant( e, "HoleBhvPack"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( Hole::Behavior::Pack ) ) ) );
DefineConstant( e, "HoleBhvAny"_sid, ValueToEIR( ToValue( static_cast< uint8_t >( Hole::Behavior::Any ) ) ) );
// Functions
RegisterBuiltinFunc< BigInt ( TypeWrapper< Term > ) >( e, "GetTermType"_sid,
[]( const auto& t )
{
return BigInt::FromU32( min< uint8_t >( t.get().index(), static_cast< uint8_t >( TermType::Internal ) ) );
} );
////////////////////////////
// MkTerm overloads
////////////////////////////
auto MkTerm = CreateOverloadSet( e, "MkTerm"_sid );
RegisterMkTermOverload< uint64_t >( e, MkTerm );
RegisterMkTermOverload< TypeWrapper< LocationId > >( e, MkTerm );
RegisterMkTermOverload< string >( e, MkTerm );
RegisterMkTermOverload< TypeWrapper< StringId > >( e, MkTerm );
RegisterBuiltinFunc< TypeWrapper< Term > ( uint8_t ) >( e, "MkDelimiterTerm"_sid,
[]( uint8_t d ) -> TypeWrapper< Term >
{
return TERM( static_cast< Delimiter >( d ) );
} );
RegisterMkTermOverload< TypeWrapper< Hole > >( e, MkTerm );
RegisterMkTermOverload< TypeWrapper< AnyTerm > >( e, MkTerm );
RegisterMkTermOverload< TypeWrapper< VecOfLength > >( e, MkTerm );
RegisterMkTermOverload< TypeWrapper< pvec > >( e, MkTerm );
RegisterMkTermOverload< BigInt >( e, MkTerm );
RegisterMkTermOverload< TypeWrapper< APSInt > >( e, MkTerm );
////////////////////////////
// GetTermValue overloads
////////////////////////////
auto GetTermValue = CreateOverloadSet( e, "GetTermValue"_sid );
RegisterGetTermValueOverload< uint64_t >( e, GetTermValue );
RegisterGetTermValueOverload< TypeWrapper< LocationId > >( e, GetTermValue );
RegisterGetTermValueOverload< string >( e, GetTermValue );
RegisterGetTermValueOverload< TypeWrapper< StringId > >( e, GetTermValue );
RegisterBuiltinFunc< bool ( TypeWrapper< Term >, TermRef< uint8_t > ) >( e, "GetDelimiterTermValue"_sid,
[]( const auto& t, auto& tref )
{
const auto* pVal = get_if< Delimiter >( &t.get() );
if( !pVal )
return false;
tref = static_cast< uint8_t >( *pVal );
return true;
} );
RegisterGetTermValueOverload< TypeWrapper< Hole > >( e, GetTermValue );
RegisterGetTermValueOverload< TypeWrapper< AnyTerm > >( e, GetTermValue );
RegisterGetTermValueOverload< TypeWrapper< VecOfLength > >( e, GetTermValue );
RegisterGetTermValueOverload< TypeWrapper< pvec > >( e, GetTermValue );
RegisterGetTermValueOverload< BigInt >( e, GetTermValue );
RegisterGetTermValueOverload< TypeWrapper< APSInt > >( e, GetTermValue );
////////////////////////////
// LocationId
////////////////////////////
RegisterBuiltinFunc< TypeWrapper< LocationId > ( TypeWrapper< LocationId >, TypeWrapper< LocationId > ) >( e,
"MkSpanningLocation"_sid,
[]( const auto& loc1, const auto& loc2 ) -> TypeWrapper< LocationId >
{
return static_cast< LocationId >( Location::CreateSpanningLocation( loc1.get(), loc2.get() ) );
} );
////////////////////////////
// StringId
////////////////////////////
RegisterBuiltinFunc< TypeWrapper< StringId > ( string ) >( e, "MkStringId"_sid,
[]( string s ) -> TypeWrapper< StringId >
{
return s;
} );
////////////////////////////
// Hole
////////////////////////////
RegisterBuiltinFunc< Intrinsic< TypeWrapper< Hole > ( TypeWrapper< StringId >, TypeWrapper< StringId >, uint8_t ) > >( e, "MkHole"_sid,
[]( const auto& c, const Value& nameVal, const Value& kindVal, const Value& bhvVal ) -> Value
{
auto name = *FromValue< TypeWrapper< StringId > >( nameVal );
auto kind = *FromValue< TypeWrapper< StringId > >( kindVal );
auto bhv = *FromValue< uint8_t >( bhvVal );
if( bhv > static_cast< uint8_t >( Hole::Behavior::Any ) )
{
DiagnosticsManager::GetInstance().emitErrorMessage( bhvVal.locationId(),
"invalid hole behavior." );
return PoisonValue();
}
return ToValue( TypeWrapper< Hole >( Hole( name, kind, static_cast< Hole::Behavior >( bhv ) ) ) );
} );
RegisterBuiltinFunc< TypeWrapper< Hole > ( TypeWrapper< StringId >, TypeWrapper< StringId > ) >( e, "MkHole"_sid,
[]( const auto& name, const auto& kind ) -> TypeWrapper< Hole >
{
return Hole( name, kind );
} );
RegisterBuiltinFunc< TypeWrapper< StringId > ( TypeWrapper< Hole > ) >( e, "GetHoleName"_sid,
[]( const auto& h ) -> TypeWrapper< StringId >
{
return h.get().name();
} );
RegisterBuiltinFunc< TypeWrapper< Term > ( TypeWrapper< Hole > ) >( e, "GetHoleKind"_sid,
[]( const auto& h ) -> TypeWrapper< Term >
{
return h.get().kind();
} );
RegisterBuiltinFunc< uint8_t ( TypeWrapper< Hole > ) >( e, "GetHoleBehavior"_sid,
[]( const auto& h )
{
return static_cast< uint8_t>( h.get().behavior() );
} );
////////////////////////////
// AnyTerm
////////////////////////////
RegisterBuiltinFunc< TypeWrapper< StringId > ( TypeWrapper< AnyTerm > ) >( e, "GetAnyTermVarName"_sid,
[]( const auto& at ) -> TypeWrapper< StringId >
{
return at.get().varName();
} );
RegisterBuiltinFunc< TypeWrapper< AnyTerm > ( TypeWrapper< StringId > ) >( e, "MkAnyTerm"_sid,
[]( const auto& name ) -> TypeWrapper< AnyTerm >
{
return AnyTerm( name );
} );
////////////////////////////
// VecOfLength
////////////////////////////
RegisterBuiltinFunc< TypeWrapper< StringId > ( TypeWrapper< VecOfLength > ) >( e, "GetVecOfLengthVarName"_sid,
[]( const auto& at ) -> TypeWrapper< StringId >
{
return at.get().varName();
} );
RegisterBuiltinFunc< TypeWrapper< VecOfLength > ( TypeWrapper< StringId > ) >( e, "MkVecOfLength"_sid,
[]( const auto& name ) -> TypeWrapper< VecOfLength >
{
return VecOfLength( name );
} );
////////////////////////////
// Vector
////////////////////////////
RegisterBuiltinFunc< TypeWrapper< pvec > () >( e, "MkVec"_sid,
[]() -> TypeWrapper< pvec >
{
return make_shared< Vector >();
} );
RegisterBuiltinFunc< void ( TypeWrapper< pvec >, TypeWrapper< pvec > ) >( e, "MkVecConcat"_sid,
[]( const auto& vec1, const auto& vec2 )
{
return make_shared< Vector >( Vector::MakeConcat( *vec1.get(), *vec2.get() ) );
} );
RegisterBuiltinFunc< void ( TypeWrapper< pvec >, uint32_t len ) >( e, "VecReserve"_sid,
[]( const auto& vec, uint32_t len )
{
vec->reserve( len );
} );
RegisterBuiltinFunc< void ( TypeWrapper< pvec >, uint32_t, TypeWrapper< Term > ) >( e, "SetTerm"_sid,
[]( const auto& vec, uint32_t index, const auto& t )
{
vec->terms()[index] = t.get();
} );
RegisterBuiltinFunc< void ( TypeWrapper< pvec >, int32_t ) >( e, "SetVecWeight"_sid,
[]( const auto& vec, int32_t w )
{
vec->setWeight( w );
} );
RegisterBuiltinFunc< void ( TypeWrapper< pvec >, int32_t ) >( e, "SetVecWeightOverride"_sid,
[]( const auto& vec, int32_t w )
{
vec->setWeightOverride( w );
} );
RegisterBuiltinFunc< void ( TypeWrapper< pvec >, TypeWrapper< Term > ) >( e, "VecAppend"_sid,
[]( const auto& vec, const auto& t )
{
vec->append( t.get() );
} );
RegisterBuiltinFunc< void ( TypeWrapper< pvec >, TypeWrapper< Term > ) >( e, "VecSetRepetition"_sid,
[]( const auto& vec, const auto& t )
{
vec->setRepetitionTerm( t.get() );
} );
RegisterBuiltinFunc< tuple< uint32_t, bool > ( TypeWrapper< pvec > ) >( e, "GetVecLength"_sid,
[]( const auto& vec )
{
auto vl = vec->length();
return make_tuple( static_cast< uint32_t >( vl.minLength() ), vl.isVariable() );
} );
RegisterBuiltinFunc< bool ( TypeWrapper< pvec >, uint32_t, TermRef< TypeWrapper< Term > > ) >( e, "GetVecTerm"_sid,
[]( const auto& vec, uint32_t index, auto& out )
{
const auto& v = *vec.get();
if( v.length().minLength() <= index )
return false;
out = v.terms()[index];
return true;
} );
RegisterBuiltinFunc< int32_t ( TypeWrapper< pvec > ) >( e, "GetVecWeight"_sid,
[]( const auto& vec )
{
return vec->weight();
} );
RegisterBuiltinFunc< int32_t ( TypeWrapper< pvec > ) >( e, "GetVecWeightOverride"_sid,
[]( const auto& vec )
{
return vec->weightOverride();
} );
RegisterBuiltinFunc< int32_t ( TypeWrapper< pvec >, TermRef< TypeWrapper< Term > > ) >( e, "GetVecRepetitionTerm"_sid,
[]( const auto& vec, auto& out )
{
const auto& rt = vec->repetitionTerm();
if( !rt )
return false;
out = *rt;
return true;
} );
RegisterBuiltinFunc< bool ( TypeWrapper< pvec > ) >( e, "IsVecEmpty"_sid,
[]( const auto& vec )
{
return vec->empty();
} );
////////////////////////////
// Helpers
////////////////////////////
RegisterBuiltinFunc< TypeWrapper< Term > ( TypeWrapper< Term >, TypeWrapper< Term > ) >( e, "AppendToVectorTerm"_sid,
[]( const auto& vec, const auto& t ) -> TypeWrapper< Term >
{
return AppendToVectorTerm( vec, t );
} );
////////////////////////////
// Integers
////////////////////////////
RegisterBuiltinFunc< Intrinsic< TypeWrapper< APSInt > ( BigInt, bool ) > >( e, "ToFixedInt"_sid,
[]( auto&& c, const Value& biVal, const Value& sVal )
{
auto bi = *FromValue< BigInt >( biVal );
auto s = *FromValue< bool >( sVal );
if( !s && bi.isNegative() )
{
DiagnosticsManager::GetInstance().emitErrorMessage( biVal.locationId(),
"this is negative and can't be converted to an unsigned int." );
return PoisonValue();
}
auto ai = bi.getAPSInt();
if( !s )
ai.setIsSigned( false );
return ToValue( TypeWrapper< APSInt >( s ) );
} );
auto ToBigInt = CreateOverloadSet( e, "ToBigInt"_sid );
RegisterBuiltinFunc< Eager< BigInt > ( TypeWrapper< APSInt > ) >( e, ToBigInt,
[]( const auto& fInt )
{
return fInt;
} );
RegisterBuiltinFunc< Eager< BigInt > ( char32_t ) >( e, ToBigInt,
[]( char32_t c )
{
return BigInt::FromU32( c );
} );
////////////////////////////
// Propositions
////////////////////////////
RegisterBuiltinFunc< bool ( TypeWrapper< ptr< Context > >, TypeWrapper< Value >, TermRef< TypeWrapper< ptr< Propositions > > > ) >( e, "GetTypePredicates"_sid,
[]( const auto& c, const auto& type, auto& out ) -> bool
{
if( !ParseTypePredicates( *c.get(), type.get() ) )
return false;
auto ppPreds = GetTypePredicates( type );
if( !ppPreds || !( *ppPreds ) )
return false;
out = *ppPreds;
return true;
} );
RegisterBuiltinFunc< uint32_t ( TypeWrapper< ptr< Propositions > > ) >( e, "GetPropositionsCount"_sid,
[]( const auto& preds )
{
return preds->props().size();
} );
RegisterBuiltinFunc< bool ( TypeWrapper< ptr< Propositions > >, uint32_t, TermRef< TypeWrapper< Value > > ) >( e, "GetProposition"_sid,
[]( const auto& preds, uint32_t index, auto& out )
{
if( preds->props().size() <= index )
return false;
out = preds->props()[index];
return true;
} );
}
}