Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
pystencils
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Terms and privacy
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
hyteg
pystencils
Commits
8327805d
Commit
8327805d
authored
1 year ago
by
Daniel Bauer
Browse files
Options
Downloads
Patches
Plain Diff
reviewer feedback
parent
091b2497
Branches
Branches containing commit
No related tags found
No related merge requests found
Pipeline
#66727
passed
1 year ago
Stage: Code Quality
Stage: Unit Tests
Stage: legacy_test
Stage: docs
Changes
1
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
src/pystencils/backend/transformations/eliminate_branches.py
+88
-89
88 additions, 89 deletions
src/pystencils/backend/transformations/eliminate_branches.py
with
88 additions
and
89 deletions
src/pystencils/backend/transformations/eliminate_branches.py
+
88
−
89
View file @
8327805d
...
...
@@ -52,12 +52,12 @@ class EliminateBranches:
of enclosing loops and enclosing conditionals into its analysis.
Args:
no
_isl (bool, optional):
dis
able islpy based analysis
use
_isl (bool, optional):
en
able islpy based analysis
(default: True)
"""
def
__init__
(
self
,
ctx
:
KernelCreationContext
,
no
_isl
:
bool
=
Fals
e
)
->
None
:
def
__init__
(
self
,
ctx
:
KernelCreationContext
,
use
_isl
:
bool
=
Tru
e
)
->
None
:
self
.
_ctx
=
ctx
self
.
_
no
_isl
=
no
_isl
self
.
_
use
_isl
=
use
_isl
self
.
_elim_constants
=
EliminateConstants
(
ctx
,
extract_constant_exprs
=
False
)
def
__call__
(
self
,
node
:
PsAstNode
)
->
PsAstNode
:
...
...
@@ -104,8 +104,8 @@ class EliminateBranches:
self
,
conditional
:
PsConditional
,
ec
:
BranchElimContext
)
->
PsConditional
|
PsBlock
|
None
:
condition_simplified
=
self
.
_elim_constants
(
conditional
.
condition
)
if
not
self
.
_
no
_isl
:
condition_simplified
=
self
.
_isl_s
y
mplify_condition
(
if
self
.
_
use
_isl
:
condition_simplified
=
self
.
_isl_s
i
mplify_condition
(
condition_simplified
,
ec
)
...
...
@@ -117,7 +117,7 @@ class EliminateBranches:
return
conditional
def
_isl_s
y
mplify_condition
(
def
_isl_s
i
mplify_condition
(
self
,
condition
:
PsExpression
,
ec
:
BranchElimContext
)
->
PsExpression
:
"""
If installed, use ISL to simplify the passed condition to true or
...
...
@@ -127,102 +127,101 @@ class EliminateBranches:
try
:
import
islpy
as
isl
except
ImportError
:
return
condition
def
printer
(
expr
:
PsExpression
):
match
expr
:
case
PsSymbolExpr
(
symbol
):
return
symbol
.
name
case
PsConstantExpr
(
constant
):
dtype
=
constant
.
get_dtype
()
if
not
isinstance
(
dtype
,
(
PsIntegerType
,
PsBoolType
)):
raise
IslAnalysisError
(
"
Only scalar integer and bool constant may appear in isl expressions.
"
)
return
str
(
constant
.
value
)
case
PsAdd
(
op1
,
op2
):
return
f
"
(
{
printer
(
op1
)
}
+
{
printer
(
op2
)
}
)
"
case
PsSub
(
op1
,
op2
):
return
f
"
(
{
printer
(
op1
)
}
-
{
printer
(
op2
)
}
)
"
case
PsMul
(
op1
,
op2
):
return
f
"
(
{
printer
(
op1
)
}
*
{
printer
(
op2
)
}
)
"
case
PsDiv
(
op1
,
op2
)
|
PsIntDiv
(
op1
,
op2
):
return
f
"
(
{
printer
(
op1
)
}
/
{
printer
(
op2
)
}
)
"
case
PsAnd
(
op1
,
op2
):
return
f
"
(
{
printer
(
op1
)
}
and
{
printer
(
op2
)
}
)
"
case
PsOr
(
op1
,
op2
):
return
f
"
(
{
printer
(
op1
)
}
or
{
printer
(
op2
)
}
)
"
case
PsEq
(
op1
,
op2
):
return
f
"
(
{
printer
(
op1
)
}
=
{
printer
(
op2
)
}
)
"
case
PsNe
(
op1
,
op2
):
return
f
"
(
{
printer
(
op1
)
}
!=
{
printer
(
op2
)
}
)
"
case
PsGt
(
op1
,
op2
):
return
f
"
(
{
printer
(
op1
)
}
>
{
printer
(
op2
)
}
)
"
case
PsGe
(
op1
,
op2
):
return
f
"
(
{
printer
(
op1
)
}
>=
{
printer
(
op2
)
}
)
"
case
PsLt
(
op1
,
op2
):
return
f
"
(
{
printer
(
op1
)
}
<
{
printer
(
op2
)
}
)
"
case
PsLe
(
op1
,
op2
):
return
f
"
(
{
printer
(
op1
)
}
<=
{
printer
(
op2
)
}
)
"
case
PsNeg
(
operand
):
return
f
"
(-
{
printer
(
operand
)
}
)
"
case
PsNot
(
operand
):
return
f
"
(not
{
printer
(
operand
)
}
)
"
case
PsCast
(
_
,
operand
):
return
printer
(
operand
)
def
printer
(
expr
:
PsExpression
):
match
expr
:
case
PsSymbolExpr
(
symbol
):
return
symbol
.
name
case
_
:
case
PsConstantExpr
(
constant
):
dtype
=
constant
.
get_dtype
()
if
not
isinstance
(
dtype
,
(
PsIntegerType
,
PsBoolType
)):
raise
IslAnalysisError
(
f
"
Not supported by isl or don
'
t know how to print
{
expr
}
"
"
Only scalar integer and bool constant may appear in isl expressions.
"
)
dofs
=
collect_undefined_symbols
(
condition
)
outer_conditions
=
[]
for
loop
in
ec
.
enclosing_loops
:
if
not
(
isinstance
(
loop
.
step
,
PsConstantExpr
)
and
loop
.
step
.
constant
.
value
==
1
):
return
str
(
constant
.
value
)
case
PsAdd
(
op1
,
op2
):
return
f
"
(
{
printer
(
op1
)
}
+
{
printer
(
op2
)
}
)
"
case
PsSub
(
op1
,
op2
):
return
f
"
(
{
printer
(
op1
)
}
-
{
printer
(
op2
)
}
)
"
case
PsMul
(
op1
,
op2
):
return
f
"
(
{
printer
(
op1
)
}
*
{
printer
(
op2
)
}
)
"
case
PsDiv
(
op1
,
op2
)
|
PsIntDiv
(
op1
,
op2
):
return
f
"
(
{
printer
(
op1
)
}
/
{
printer
(
op2
)
}
)
"
case
PsAnd
(
op1
,
op2
):
return
f
"
(
{
printer
(
op1
)
}
and
{
printer
(
op2
)
}
)
"
case
PsOr
(
op1
,
op2
):
return
f
"
(
{
printer
(
op1
)
}
or
{
printer
(
op2
)
}
)
"
case
PsEq
(
op1
,
op2
):
return
f
"
(
{
printer
(
op1
)
}
=
{
printer
(
op2
)
}
)
"
case
PsNe
(
op1
,
op2
):
return
f
"
(
{
printer
(
op1
)
}
!=
{
printer
(
op2
)
}
)
"
case
PsGt
(
op1
,
op2
):
return
f
"
(
{
printer
(
op1
)
}
>
{
printer
(
op2
)
}
)
"
case
PsGe
(
op1
,
op2
):
return
f
"
(
{
printer
(
op1
)
}
>=
{
printer
(
op2
)
}
)
"
case
PsLt
(
op1
,
op2
):
return
f
"
(
{
printer
(
op1
)
}
<
{
printer
(
op2
)
}
)
"
case
PsLe
(
op1
,
op2
):
return
f
"
(
{
printer
(
op1
)
}
<=
{
printer
(
op2
)
}
)
"
case
PsNeg
(
operand
):
return
f
"
(-
{
printer
(
operand
)
}
)
"
case
PsNot
(
operand
):
return
f
"
(not
{
printer
(
operand
)
}
)
"
case
PsCast
(
_
,
operand
):
return
printer
(
operand
)
case
_
:
raise
IslAnalysisError
(
"
Loops with strides != 1 are not yet supported.
"
f
"
Not supported by isl or don
'
t know how to print
{
expr
}
"
)
dofs
.
add
(
loop
.
counter
.
symbol
)
dofs
.
update
(
collect_undefined_symbols
(
loop
.
start
))
dofs
.
update
(
collect_undefined_symbols
(
loop
.
stop
))
dofs
=
collect_undefined_symbols
(
condition
)
outer_conditions
=
[]
loop_start_str
=
printer
(
loop
.
start
)
loop_stop_str
=
printer
(
loop
.
stop
)
ctr_name
=
loop
.
counter
.
symbol
.
name
outer_conditions
.
append
(
f
"
{
ctr_name
}
>=
{
loop_start_str
}
and
{
ctr_name
}
<
{
loop_stop_str
}
"
for
loop
in
ec
.
enclosing_loops
:
if
not
(
isinstance
(
loop
.
step
,
PsConstantExpr
)
and
loop
.
step
.
constant
.
value
==
1
):
raise
IslAnalysisError
(
"
Loops with strides != 1 are not yet supported.
"
)
for
cond
in
ec
.
enclosing_conditions
:
dofs
.
update
(
collect_undefined_symbols
(
cond
))
outer_conditions
.
append
(
printer
(
cond
))
dofs
.
add
(
loop
.
counter
.
symbol
)
dofs
.
update
(
collect_undefined_symbols
(
loop
.
start
))
dofs
.
update
(
collect_undefined_symbols
(
loop
.
stop
))
loop_start_str
=
printer
(
loop
.
start
)
loop_stop_str
=
printer
(
loop
.
stop
)
ctr_name
=
loop
.
counter
.
symbol
.
name
outer_conditions
.
append
(
f
"
{
ctr_name
}
>=
{
loop_start_str
}
and
{
ctr_name
}
<
{
loop_stop_str
}
"
)
dofs_str
=
"
,
"
.
join
(
dof
.
name
for
dof
in
dofs
)
outer_conditions_str
=
"
and
"
.
join
(
outer_conditions
)
condition
_str
=
printer
(
cond
ition
)
for
cond
in
ec
.
enclosing_conditions
:
dofs
.
update
(
collect_undefined_symbols
(
cond
)
)
outer_
condition
s
.
append
(
printer
(
cond
)
)
outer_set
=
isl
.
BasicSet
(
f
"
{{ [
{
dofs_str
}
] :
{
outer_conditions_str
}
}}
"
)
inner_set
=
isl
.
BasicSet
(
f
"
{{ [
{
dofs_str
}
] :
{
condition_str
}
}}
"
)
dofs_str
=
"
,
"
.
join
(
dof
.
name
for
dof
in
dofs
)
outer_conditions_str
=
"
and
"
.
join
(
outer_conditions
)
condition_str
=
printer
(
condition
)
if
inner_set
.
is_empty
():
return
PsExpression
.
make
(
PsConstant
(
False
)
)
outer_set
=
isl
.
BasicSet
(
f
"
{{ [
{
dofs_str
}
] :
{
outer_conditions_str
}
}}
"
)
inner_set
=
isl
.
BasicSet
(
f
"
{{ [
{
dofs_str
}
] :
{
condition_str
}
}}
"
)
intersection
=
outer_set
.
intersect
(
inner_set
)
if
intersection
.
is_empty
():
return
PsExpression
.
make
(
PsConstant
(
False
))
elif
intersection
==
outer_set
:
return
PsExpression
.
make
(
PsConstant
(
True
))
else
:
return
condition
if
inner_set
.
is_empty
():
return
PsExpression
.
make
(
PsConstant
(
False
))
except
ImportError
:
intersection
=
outer_set
.
intersect
(
inner_set
)
if
intersection
.
is_empty
():
return
PsExpression
.
make
(
PsConstant
(
False
))
elif
intersection
==
outer_set
:
return
PsExpression
.
make
(
PsConstant
(
True
))
else
:
return
condition
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment