Skip to content
Snippets Groups Projects

Various extensions to the vectorizer

Merged Daniel Bauer requested to merge hyteg/pystencils:bauerd/vec-extensions into v2.0-dev
Files
3
@@ -156,6 +156,10 @@ class VectorizationContext:
)
return PsVectorType(scalar_type, self._lanes)
def axis_ctr_dependees(self, symbols: set[PsSymbol]) -> set[PsSymbol]:
"""Returns all symbols in `symbols` that depend on the axis counter."""
return symbols & (self.vectorized_symbols.keys() | {self.axis.counter})
@dataclass
class Affine:
@@ -303,16 +307,13 @@ class AstVectorizer:
return PsAssignment(lhs_vec, rhs_vec)
case PsLoop(counter, start, stop, step, body):
# Check that loop bounds are lane-invariant
free_symbols = (
self._collect_symbols(start)
| self._collect_symbols(stop)
| self._collect_symbols(step)
)
# Check that loop bounds are lane-invariant
vec_dependencies = free_symbols & (
vc.vectorized_symbols.keys() | {vc.axis.counter}
)
vec_dependencies = vc.axis_ctr_dependees(free_symbols)
if vec_dependencies:
raise VectorizationError(
"Unable to vectorize loop depending on vectorized symbols:\n"
@@ -469,14 +470,11 @@ class AstVectorizer:
vec_expr = PsVecBroadcast(vc.lanes, expr.clone())
case PsSubscript(array, index):
# Check that array expression and indices are lane-invariant
free_symbols = self._collect_symbols(array).union(
*[self._collect_symbols(i) for i in index]
)
# Check that array expression and indices are lane-invariant
vec_dependencies = free_symbols & (
vc.vectorized_symbols.keys() | {vc.axis.counter}
)
vec_dependencies = vc.axis_ctr_dependees(free_symbols)
if vec_dependencies:
raise VectorizationError(
"Unable to vectorize array subscript depending on vectorized symbols:\n"