#ifndef GOOSE_SEMA_TC_CONTEXT_H
#define GOOSE_SEMA_TC_CONTEXT_H
namespace goose::sema
{
class TypeCheckingContext
{
public:
static constexpr uint32_t InvalidIndex = numeric_limits< uint32_t >::max();
TypeCheckingContext( const Context& c );
TypeCheckingContext( Context&& c );
const auto& context() const { return m_context; }
const auto& env() const { return m_context.env(); }
const auto& rules() const { return env()->typeCheckingRuleSet(); }
uint32_t getLHSHoleIndex( const StringId& name ) const;
uint32_t getRHSHoleIndex( const StringId& name ) const;
uint32_t createValue();
void setLHSHoleIndex( const StringId& name, uint32_t index );
void setRHSHoleIndex( const StringId& name, uint32_t index );
void eraseLHSName( const StringId& name );
void eraseRHSName( const StringId& name );
uint32_t LHSNamespaceIndex() const { return m_currentLHSNamespaceIndex; }
uint32_t RHSNamespaceIndex() const { return m_currentRHSNamespaceIndex; }
void setLHSNamespaceIndex( uint32_t index ) { m_currentLHSNamespaceIndex = index; }
void setRHSNamespaceIndex( uint32_t index ) { m_currentRHSNamespaceIndex = index; }
uint32_t newNamespaceIndex() { return m_nextNamespaceIndex++; }
// By default, any encountered hole will be considered as required, ie
// they will count towards numUnknownValues() if we can't solve them.
// This function allows to temporarily disable this, so that any hole
// encountered from that point on will not count towards unresolved holes,
// unless they also appear in a section where holes are required.
void setValueResolutionRequired( bool required )
{
m_valuesAreRequired = required;
}
bool isValueResolutionRequired() const
{
return m_valuesAreRequired;
}
const optional< Term >& getValue( uint32_t index ) const
{
assert( m_pCow->values.size() > index );
return m_pCow->values[index].m_term;
}
template< typename T >
void setValue( uint32_t index, T&& val )
{
assert( m_pCow->values.size() > index );
if( m_pCow->values[index].m_required && !m_pCow->values[index].m_term )
--m_numUnknownValues;
if( m_pCow->values[index].m_term )
m_complexity -= GetComplexity( *m_pCow->values[index].m_term );
CoW( m_pCow )->values[index] = { forward< T >( val ), true };
m_complexity += GetComplexity( *m_pCow->values[index].m_term );
}
TypeCheckingContext& flip()
{
swap( m_currentLHSNamespaceIndex, m_currentRHSNamespaceIndex );
return *this;
}
uint32_t numUnknownValues() const { return m_numUnknownValues; }
uint32_t complexity() const { return m_complexity; }
void addComplexity( uint32_t c ) { m_complexity +=c; }
void subComplexity( uint32_t c ) { m_complexity -=c; }
void setComplexity( uint32_t complexity ) { m_complexity = complexity; }
void addAnonymousHole() { ++m_numAnonymousHoles; }
auto score() const { return TypeCheckingScore( m_complexity, m_pCow->holeDict.size() + m_numAnonymousHoles ); }
// Used to detect and reject recursive hole nesting.
bool isHoleLocked( uint32_t index ) const;
void lockHole( uint32_t index );
void unlockHole( uint32_t index );
private:
void setValueRequired( uint32_t index );
Context m_context;
struct StoredValue
{
optional< Term > m_term;
bool m_required = false;
};
uint32_t m_currentLHSNamespaceIndex = 1;
uint32_t m_currentRHSNamespaceIndex = 2;
uint32_t m_nextNamespaceIndex = 3;
uint32_t m_numUnknownValues = 0;
uint32_t m_complexity = 0;
uint32_t m_numAnonymousHoles = 0;
using HoleName = pair< StringId, uint32_t >;
struct Cow
{
vector< StoredValue > values;
unordered_map< HoleName, uint32_t > holeDict;
unordered_set< uint32_t > lockedHoles;
};
ptr< Cow > m_pCow = make_shared< Cow >();
bool m_valuesAreRequired = true;
};
}
#endif