#include "sema.h"
namespace goose::sema
{
void SetupHoleUnificationRules( UnificationRuleSet& ruleSet )
{
// Anonymous hole half-unification: add 1 to the anon holes count,
// yield the hold as is.
ruleSet.addHalfUnificationRule(
VEC( TSID( hole ), TSID( _ ) ),
[]( const Term& lhs, UnificationContext& c ) -> UniGen
{
c.addAnonymousHole();
co_yield { lhs, c };
} );
// Anonymous hole versus anything: add 1 to the anon holes count,
// yield the half unification of the rhs
ruleSet.addSymRule(
VEC( TSID( hole ), TSID( _ ) ),
ANYTERM( _ ),
[]( const Term& lhs, const Term& rhs, UnificationContext& c ) -> UniGen
{
c.addAnonymousHole();
for( auto&& [s,uc] : HalfUnify( rhs, c.flip() ) )
co_yield { s, uc.flip() };
c.flip();
} );
// Hole half-unification: Convert it to a numbered hole,
// If the name wasn't already known, add 1 to the score's unique holes count.
ruleSet.addHalfUnificationRule(
VEC( TSID( hole ), ANYTERM( _ ) ),
[]( const Term& lhs, UnificationContext& c ) -> UniGen
{
auto lh = *HoleFromIRExpr( lhs );
if( holds_alternative< StringId >( lh ) )
{
// This is a named hole: look up its name.
const auto& name = get< StringId >( lh );
auto holeIndex = c.getLHSHoleIndex( name );
if( holeIndex != UnificationContext::InvalidIndex )
{
if( !c.isHoleLocked( holeIndex ) )
co_yield { MkHole( holeIndex ), c };
}
else
{
// This is a new name: create a new value,
// and increment the number of unique holes in the current score.
auto index = c.createValue();
c.setLHSHoleIndex( name, index );
co_yield { MkHole( index ), c };
}
}
else
{
// This is already an indexed hole: yield it as is.
if( !c.isHoleLocked( get< uint32_t >( lh ) ) )
co_yield { lhs, c };
}
} );
// Hole vs anything
ruleSet.addSymRule(
VEC( TSID( hole ), ANYTERM( _ ) ),
ANYTERM( _ ),
[]( const Term& lhs, const Term& rhs, UnificationContext& c ) -> UniGen
{
auto h = *HoleFromIRExpr( lhs );
uint32_t index = 0;
// Remember the previous complexity count so we know how much complexity
// is added by this particular sub-term. This is because we need
// to be able to subtract it when updating the hole's value with a new solution.
uint32_t oldComplexity = c.complexity();
if( holds_alternative< uint32_t >( h ) )
index = get< uint32_t >( h );
else
{
// This is a named hole: look up its name.
const auto& name = get< StringId >( h );
index = c.getLHSHoleIndex( name );
if( index == UnificationContext::InvalidIndex )
{
// This is a new name: create a new value.
index = c.createValue();
c.setLHSHoleIndex( name, index );
auto holeExpr = MkHole( index );
for( auto&& [e,c] : HalfUnify( rhs, c.flip() ) )
{
c.setValue( index, SetComplexity( move( e ), c.complexity() - oldComplexity ) );
co_yield { move( holeExpr ), c.flip() };
}
c.flip();
co_return;
}
}
// Reject recursive hole nesting.
if( c.isHoleLocked( index ) )
co_return;
c.lockHole( index );
auto holeExpr = MkHole( index );
auto& maybeVal = c.getValue( index );
if( maybeVal )
{
for( auto&& [e,c] : Unify( *maybeVal, rhs, c ) )
{
c.unlockHole( index );
c.setValue( index, SetComplexity( move( e ), c.complexity() - oldComplexity ) );
co_yield { move( holeExpr ), c };
}
}
else
{
for( auto&& [e,c] : HalfUnify( rhs, c.flip() ) )
{
c.unlockHole( index );
c.setValue( index, SetComplexity( move( e ), c.complexity() - oldComplexity ) );
co_yield { move( holeExpr ), c.flip() };
}
c.flip();
}
} );
// Hole vs hole
ruleSet.addAsymRule(
VEC( TSID( hole ), ANYTERM( _ ) ),
VEC( TSID( hole ), ANYTERM( _ ) ),
[]( const Term& lhs, const Term& rhs, UnificationContext& c ) -> UniGen
{
auto lh = *HoleFromIRExpr( lhs );
auto rh = *HoleFromIRExpr( rhs );
StringId lname;
StringId rname;
uint32_t lindex = 0;
uint32_t rindex = 0;
if( holds_alternative< StringId >( lh ) )
{
// L is a named hole: look up its name.
lname = get< StringId >( lh );
lindex = c.getLHSHoleIndex( lname );
}
else
lindex = get< uint32_t >( lh );
if( holds_alternative< StringId >( rh ) )
{
// R is a named hole: look up its name.
rname = get< StringId >( rh );
rindex = c.getRHSHoleIndex( rname );
}
else
rindex = get< uint32_t >( rh );
// If neither hole currently have a value, create a new one.
if( lindex == UnificationContext::InvalidIndex && rindex == UnificationContext::InvalidIndex )
{
auto index = c.createValue();
c.setLHSHoleIndex( lname, index );
c.setRHSHoleIndex( rname, index );
co_yield { MkHole( index ), c };
co_return;
}
// If both holes actually point to the same value, just yield it as the solution.
if( lindex == rindex )
{
co_yield { MkHole( lindex ), c };
co_return;
}
// If either hole doesn't have a value yet, assign it the other one's value.
if( lindex == UnificationContext::InvalidIndex )
{
c.setLHSHoleIndex( lname, rindex );
co_yield { MkHole( rindex ), c };
co_return;
}
if( rindex == UnificationContext::InvalidIndex )
{
c.setRHSHoleIndex( rname, lindex );
co_yield { MkHole( lindex ), c };
co_return;
}
// Reject recursive hole nesting.
if( c.isHoleLocked( lindex ) )
co_return;
if( c.isHoleLocked( rindex ) )
co_return;
c.lockHole( lindex );
c.lockHole( rindex );
// If either hole have an empty value, set it to a hole expression with the id of the value
// stored in the other one. We can't just copy the value over as we would lose the dependency
// relationship between the two holes.
const auto& lval = c.getValue( lindex );
const auto& rval = c.getValue( rindex );
if( !rval )
{
for( auto&& [e,c] : HalfUnify( *lval, c ) )
{
c.unlockHole( lindex );
c.unlockHole( rindex );
c.setValue( rindex, MkHole( lindex ) );
co_yield { MkHole( lindex ), c };
}
co_return;
}
if( !lval )
{
for( auto&& [e,c] : HalfUnify( *rval, c.flip() ) )
{
c.unlockHole( lindex );
c.unlockHole( rindex );
c.setValue( lindex, MkHole( rindex ) );
co_yield { MkHole( rindex ), c.flip() };
}
c.flip();
co_return;
}
// Both L and R have a value: unify them, store the result in lhs,
// replace rhs with a hole expression pointing to lhs's value.
// Remember the previous complexity count so we know how much complexity
// is added by this particular sub-term. This is because we need
// to be able to subtract it when updating the hole's value with a new solution.
uint32_t oldComplexity = c.complexity();
for( auto&& [e,c] : Unify( *lval, *rval, c ) )
{
c.unlockHole( lindex );
c.unlockHole( rindex );
c.setValue( lindex, SetComplexity( move( e ), c.complexity() - oldComplexity ) );
c.setValue( rindex, MkHole( lindex ) );
co_yield { MkHole( lindex ), c };
}
} );
}
}