#ifndef GOOSE_SEMA_TC_CONTEXT_H
#define GOOSE_SEMA_TC_CONTEXT_H
namespace goose::sema
{
#ifndef NDEBUG
struct TCRuleInfo;
#endif
class TypeCheckingContext
{
public:
static constexpr uint32_t InvalidIndex = numeric_limits< uint32_t >::max();
TypeCheckingContext( const Context& c );
TypeCheckingContext( Context&& c );
const auto& context() const { return m_context; }
const auto& env() const { return m_context.env(); }
const auto& rules() const { return env()->typeCheckingRuleSet(); }
uint32_t getLHSHoleIndex( StringId name, uint32_t repetitionIndex ) const;
uint32_t getRHSHoleIndex( StringId name, uint32_t repetitionIndex ) const;
uint32_t createValue( bool required = false );
void setLHSHoleIndex( StringId name, uint32_t repetitionIndex, uint32_t index );
void setRHSHoleIndex( StringId name, uint32_t repetitionIndex, uint32_t index );
void eraseLHSName( StringId name );
void eraseRHSName( StringId name );
class SubContext
{
public:
SubContext( uint32_t nsIndex ) :
namespaceIndex( nsIndex )
{}
uint32_t repetitionIndex( uint32_t depth ) const;
void setRepetitionIndex( uint32_t depth, uint32_t index );
uint32_t namespaceIndex = 0;
uint32_t currentRepetitionDepth = 0;
private:
using RepIndicesVec = llvm::SmallVector< uint32_t, 4 >;
shared_ptr< RepIndicesVec > m_repetitionIndices;
};
const SubContext& LHSSubContext() const { return m_lhsSubContext; }
const SubContext& RHSSubContext() const { return m_rhsSubContext; }
SubContext& LHSSubContext() { return m_lhsSubContext; }
SubContext& RHSSubContext() { return m_rhsSubContext; }
uint32_t newNamespaceIndex() { return m_nextNamespaceIndex++; }
// By default, any encountered hole will be considered as required, ie
// they will count towards numUnknownValues() if we can't solve them.
// This function allows to temporarily disable this, so that any hole
// encountered from that point on will not count towards unresolved holes,
// unless they also appear in a section where holes are required.
void setValueResolutionRequired( bool required )
{
m_valuesAreRequired = required;
}
bool isValueResolutionRequired() const
{
return m_valuesAreRequired;
}
const optional< Term >& getValue( uint32_t index ) const
{
assert( m_pCow->values.size() > index );
return m_pCow->values[index].m_term;
}
template< typename T >
void setValue( uint32_t index, T&& val )
{
assert( m_pCow->values.size() > index );
if( m_pCow->values[index].m_required && !m_pCow->values[index].m_term )
--m_numUnknownValues;
CoW( m_pCow )->values[index] = { forward< T >( val ), true };
}
TypeCheckingContext& flip()
{
swap( m_lhsSubContext, m_rhsSubContext );
return *this;
}
uint32_t numUnknownValues() const { return m_numUnknownValues; }
int32_t cost() const { return m_cost; }
void addCost( int32_t c ) { m_cost +=c; }
void setCost( int32_t cost ) { m_cost = cost; }
void addAnonymousHole() { ++m_numAnonymousHoles; }
auto score() const { return TypeCheckingScore( m_cost, m_pCow->holeDict.size() + m_numAnonymousHoles ); }
// Used to detect and reject recursive hole nesting.
bool isHoleLocked( uint32_t index ) const;
void lockHole( uint32_t index );
void unlockHole( uint32_t index );
#ifndef NDEBUG
void TCRuleTrace( const TCRuleInfo* pRule ) const;
void PushRuleTraceForParam() const;
void DumpParamsTraces( ostream& out ) const;
#endif
private:
void setValueRequired( uint32_t index );
Context m_context;
struct StoredValue
{
optional< Term > m_term;
bool m_required = false;
};
SubContext m_lhsSubContext{ 1 };
SubContext m_rhsSubContext{ 2 };
uint32_t m_nextNamespaceIndex = 3;
uint32_t m_numUnknownValues = 0;
int32_t m_cost = 0;
uint32_t m_numAnonymousHoles = 0;
using HoleKey = tuple< StringId, uint32_t, uint32_t >;
struct Cow
{
vector< StoredValue > values;
map< HoleKey, uint32_t > holeDict;
unordered_set< uint32_t > lockedHoles;
#ifndef NDEBUG
// In debug, keep track of which "path" was taken through the various
// rules to end up with the result.
using TCTrace = vector< const TCRuleInfo* >;
mutable TCTrace currentTypeCheckingTrace;
mutable vector< TCTrace > paramsTypeCheckingTrace;
#endif
};
mutable ptr< Cow > m_pCow = make_shared< Cow >();
bool m_valuesAreRequired = true;
};
}
#endif