#include "verify.h"
#include "builtins/builtins.h"
using namespace goose::verify;
uint32_t Remapper::remap( RemappingMap& rm, BlockAndUnrollingId blUnrollId )
{
auto it = rm.find( blUnrollId );
if( it != rm.end() )
return it->second;
uint32_t newRemappedIndex = m_nextUniqueId++;
rm.emplace( blUnrollId, newRemappedIndex );
m_originalBBIndices.emplace( newRemappedIndex, blUnrollId.first );
return newRemappedIndex;
}
uint32_t Remapper::remapBBId( const llr::BasicBlock& bb )
{
if( m_loopUnrollingStack.empty() )
return bb.index();
auto& lus = m_loopUnrollingStack.back();
auto blUnrollId = make_pair( bb.index(), lus.currentUnrollIndex );
return remap( lus.loopRemapping, blUnrollId );
}
uint32_t Remapper::remapOutgoingEdge( const llr::BasicBlock& currentBB, const llr::BasicBlock& succBB )
{
if( succBB.isLoopHeader() )
{
// Exiting into a loop header. If it's an active loop, we are exiting into its next iteration.
// Otherwise, we are exiting into its first iteration.
auto it = find_if( m_loopUnrollingStack.rbegin(), m_loopUnrollingStack.rend(), [&]( auto&& lus )
{
return lus.loopId == succBB.index();
} );
if( it == m_loopUnrollingStack.rend() )
{
// First iteration of a new loop.
auto blUnrollId = make_pair( succBB.index(), 1 );
return remap( m_nextLoopRemapping, blUnrollId );
}
// Next iteration of that loop.
auto& lus = *it;
auto blUnrollId = make_pair( succBB.index(), lus.currentUnrollIndex + 1 );
return remap( lus.loopRemapping, blUnrollId );
}
else
{
// Exiting towards the function's non loop code: no remapping necessary.
if( !succBB.loopId() )
return succBB.index();
// Exiting into a non loop header: if its part of an active loop, remap to the current
// iteration of that loop.
auto it = find_if( m_loopUnrollingStack.rbegin(), m_loopUnrollingStack.rend(), [&]( auto&& lus )
{
return lus.loopId == succBB.loopId();
} );
assert( it != m_loopUnrollingStack.rend() );
auto& lus = *it;
auto blUnrollId = make_pair( succBB.index(), lus.currentUnrollIndex );
return remap( lus.loopRemapping, blUnrollId );
}
}
uint32_t Remapper::getCurrentLoopId() const
{
if( m_loopUnrollingStack.empty() )
return 0;
auto& lus = m_loopUnrollingStack.back();
return lus.loopId;
}
void Remapper::beginLoopUnrolling( uint32_t loopId )
{
LoopUnrollingState lus = { move( m_nextLoopRemapping ), loopId, 0 };
m_loopUnrollingStack.emplace_back( move( lus ) );
m_nextLoopRemapping.clear();
}
void Remapper::endLoopUnrolling()
{
m_loopUnrollingStack.resize( m_loopUnrollingStack.size() - 1 );
}
void Remapper::nextLoopIteration()
{
++m_loopUnrollingStack.back().currentUnrollIndex;
}
uint32_t Remapper::getOriginalBBIndex( uint32_t remappedId ) const
{
auto it = m_originalBBIndices.find( remappedId );
if( it == m_originalBBIndices.end() )
return remappedId;
return it->second;
}
bool Remapper::areAllPredecessorsProcessed( const llr::BasicBlock& bb )
{
auto bbid = remapBBId( bb );
auto begin = m_edges.lower_bound( { bbid, 0U } );
auto end = m_edges.upper_bound( { bbid, ~0U } );
uint32_t expected = bb.backEdges().size();
// If this is the first unroll of a loop header,
// we expect only incoming links from blocks preceding the loop.
// Otherwise, we only expect incoming links from the previous
// iteration.
if( bb.isLoopHeader() && !m_loopUnrollingStack.empty() )
{
const auto& lus = m_loopUnrollingStack.back();
if( lus.loopId != bb.index() || lus.currentUnrollIndex != 1 )
expected = bb.loopEdges().size();
}
if( distance( begin, end ) < expected )
return false;
// We may have more incoming edges invented by unrolling (exits from the unrolled copies)
// than were originally present in the cfg. So we have to check that all incoming
// egdes in the original cfg are accounted for, just counting them is not enough.
llvm::SmallVector< uint32_t, 8 > realProcessedIds;
for( auto itEdge = begin; itEdge != end; ++itEdge )
{
uint32_t origPredBBId = getOriginalBBIndex( itEdge->first.second );
auto it = find( realProcessedIds.begin(), realProcessedIds.end(), origPredBBId );
if( it != realProcessedIds.end() )
continue;
realProcessedIds.push_back( origPredBBId );
}
return realProcessedIds.size() == expected;
}