#ifndef GOOSE_SEMA_TC_CONTEXT_H #define GOOSE_SEMA_TC_CONTEXT_H namespace goose::sema { #ifndef NDEBUG struct TCRuleInfo; #endif class TypeCheckingContext { public: static constexpr uint32_t InvalidIndex = numeric_limits< uint32_t >::max(); TypeCheckingContext( const Context& c ); TypeCheckingContext( Context&& c ); const auto& context() const { return m_context; } const auto& env() const { return m_context.env(); } const auto& rules() const { return env()->typeCheckingRuleSet(); } uint32_t getLHSHoleIndex( StringId name, uint32_t repetitionIndex ) const; uint32_t getRHSHoleIndex( StringId name, uint32_t repetitionIndex ) const; uint32_t createValue( bool required = false ); void setLHSHoleIndex( StringId name, uint32_t repetitionIndex, uint32_t index ); void setRHSHoleIndex( StringId name, uint32_t repetitionIndex, uint32_t index ); void eraseLHSName( StringId name ); void eraseRHSName( StringId name ); class SubContext { public: SubContext( uint32_t nsIndex ) : namespaceIndex( nsIndex ) {} uint32_t repetitionIndex( uint32_t depth ) const; void setRepetitionIndex( uint32_t depth, uint32_t index ); uint32_t namespaceIndex = 0; uint32_t currentRepetitionDepth = 0; private: using RepIndicesVec = llvm::SmallVector< uint32_t, 4 >; shared_ptr< RepIndicesVec > m_repetitionIndices; }; const SubContext& LHSSubContext() const { return m_lhsSubContext; } const SubContext& RHSSubContext() const { return m_rhsSubContext; } SubContext& LHSSubContext() { return m_lhsSubContext; } SubContext& RHSSubContext() { return m_rhsSubContext; } uint32_t newNamespaceIndex() { return m_nextNamespaceIndex++; } // By default, any encountered hole will be considered as required, ie // they will count towards numUnknownValues() if we can't solve them. // This function allows to temporarily disable this, so that any hole // encountered from that point on will not count towards unresolved holes, // unless they also appear in a section where holes are required. void setValueResolutionRequired( bool required ) { m_valuesAreRequired = required; } bool isValueResolutionRequired() const { return m_valuesAreRequired; } const optional< Term >& getValue( uint32_t index ) const { assert( m_pCow->values.size() > index ); return m_pCow->values[index].m_term; } template< typename T > void setValue( uint32_t index, T&& val ) { assert( m_pCow->values.size() > index ); if( m_pCow->values[index].m_required && !m_pCow->values[index].m_term ) --m_numUnknownValues; CoW( m_pCow )->values[index] = { forward< T >( val ), true }; } TypeCheckingContext& flip() { swap( m_lhsSubContext, m_rhsSubContext ); return *this; } uint32_t numUnknownValues() const { return m_numUnknownValues; } int32_t cost() const { return m_cost; } void addCost( int32_t c ) { m_cost +=c; } void setCost( int32_t cost ) { m_cost = cost; } void addAnonymousHole() { ++m_numAnonymousHoles; } auto score() const { return TypeCheckingScore( m_cost, m_pCow->holeDict.size() + m_numAnonymousHoles ); } // Used to detect and reject recursive hole nesting. bool isHoleLocked( uint32_t index ) const; void lockHole( uint32_t index ); void unlockHole( uint32_t index ); #ifndef NDEBUG void TCRuleTrace( const TCRuleInfo* pRule ) const; void PushRuleTraceForParam() const; void DumpParamsTraces( ostream& out ) const; #endif private: void setValueRequired( uint32_t index ); Context m_context; struct StoredValue { optional< Term > m_term; bool m_required = false; }; SubContext m_lhsSubContext{ 1 }; SubContext m_rhsSubContext{ 2 }; uint32_t m_nextNamespaceIndex = 3; uint32_t m_numUnknownValues = 0; int32_t m_cost = 0; uint32_t m_numAnonymousHoles = 0; using HoleKey = tuple< StringId, uint32_t, uint32_t >; struct Cow { vector< StoredValue > values; map< HoleKey, uint32_t > holeDict; unordered_set< uint32_t > lockedHoles; #ifndef NDEBUG // In debug, keep track of which "path" was taken through the various // rules to end up with the result. using TCTrace = vector< const TCRuleInfo* >; mutable TCTrace currentTypeCheckingTrace; mutable vector< TCTrace > paramsTypeCheckingTrace; #endif }; mutable ptr< Cow > m_pCow = make_shared< Cow >(); bool m_valuesAreRequired = true; }; } #endif