Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
pystencils
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Terms and privacy
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
pycodegen
pystencils
Commits
d2dd3dfa
Commit
d2dd3dfa
authored
4 months ago
by
Frederik Hennig
Browse files
Options
Downloads
Patches
Plain Diff
also use ThreadIdxMapping for sparse kernels
parent
3d81f031
No related branches found
No related tags found
1 merge request
!449
GPU Indexing Schemes and Launch Configurations
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
src/pystencils/backend/platforms/cuda.py
+59
-29
59 additions, 29 deletions
src/pystencils/backend/platforms/cuda.py
with
59 additions
and
29 deletions
src/pystencils/backend/platforms/cuda.py
+
59
−
29
View file @
d2dd3dfa
...
@@ -49,7 +49,7 @@ GRID_DIM = [
...
@@ -49,7 +49,7 @@ GRID_DIM = [
class
ThreadToIndexMapping
(
ABC
):
class
ThreadToIndexMapping
(
ABC
):
@abstractmethod
@abstractmethod
def
__call__
(
self
,
ispace
:
Full
IterationSpace
)
->
dict
[
PsSymbol
,
PsExpression
]:
def
__call__
(
self
,
ispace
:
IterationSpace
)
->
dict
[
PsSymbol
,
PsExpression
]:
"""
Map the current thread index onto a point in the given iteration space.
"""
Map the current thread index onto a point in the given iteration space.
Implementations of this method must return a declaration for each dimension counter
Implementations of this method must return a declaration for each dimension counter
...
@@ -61,7 +61,18 @@ class Linear3DMapping(ThreadToIndexMapping):
...
@@ -61,7 +61,18 @@ class Linear3DMapping(ThreadToIndexMapping):
"""
3D globally linearized mapping, where each thread is assigned a work item according to
"""
3D globally linearized mapping, where each thread is assigned a work item according to
its location in the global launch grid.
"""
its location in the global launch grid.
"""
def
__call__
(
self
,
ispace
:
FullIterationSpace
)
->
dict
[
PsSymbol
,
PsExpression
]:
def
__call__
(
self
,
ispace
:
IterationSpace
)
->
dict
[
PsSymbol
,
PsExpression
]:
match
ispace
:
case
FullIterationSpace
():
return
self
.
_dense_mapping
(
ispace
)
case
SparseIterationSpace
():
return
self
.
_sparse_mapping
(
ispace
)
case
_
:
assert
False
,
"
unexpected iteration space
"
def
_dense_mapping
(
self
,
ispace
:
FullIterationSpace
)
->
dict
[
PsSymbol
,
PsExpression
]:
if
ispace
.
rank
>
3
:
if
ispace
.
rank
>
3
:
raise
MaterializationError
(
raise
MaterializationError
(
f
"
Cannot handle
{
ispace
.
rank
}
-dimensional iteration space
"
f
"
Cannot handle
{
ispace
.
rank
}
-dimensional iteration space
"
...
@@ -79,6 +90,18 @@ class Linear3DMapping(ThreadToIndexMapping):
...
@@ -79,6 +90,18 @@ class Linear3DMapping(ThreadToIndexMapping):
return
idx_map
return
idx_map
def
_sparse_mapping
(
self
,
ispace
:
SparseIterationSpace
)
->
dict
[
PsSymbol
,
PsExpression
]:
sparse_ctr
=
PsExpression
.
make
(
ispace
.
sparse_counter
)
thread_idx
=
self
.
_linear_thread_idx
(
0
)
idx_map
:
dict
[
PsSymbol
,
PsExpression
]
=
{
ispace
.
sparse_counter
:
PsCast
(
deconstify
(
sparse_ctr
.
get_dtype
()),
thread_idx
)
}
return
idx_map
def
_linear_thread_idx
(
self
,
coord
:
int
):
def
_linear_thread_idx
(
self
,
coord
:
int
):
block_size
=
BLOCK_DIM
[
coord
]
block_size
=
BLOCK_DIM
[
coord
]
block_idx
=
BLOCK_IDX
[
coord
]
block_idx
=
BLOCK_IDX
[
coord
]
...
@@ -97,7 +120,18 @@ class Blockwise4DMapping(ThreadToIndexMapping):
...
@@ -97,7 +120,18 @@ class Blockwise4DMapping(ThreadToIndexMapping):
THREAD_IDX
[
0
],
THREAD_IDX
[
0
],
]
]
def
__call__
(
self
,
ispace
:
FullIterationSpace
)
->
dict
[
PsSymbol
,
PsExpression
]:
def
__call__
(
self
,
ispace
:
IterationSpace
)
->
dict
[
PsSymbol
,
PsExpression
]:
match
ispace
:
case
FullIterationSpace
():
return
self
.
_dense_mapping
(
ispace
)
case
SparseIterationSpace
():
return
self
.
_sparse_mapping
(
ispace
)
case
_
:
assert
False
,
"
unexpected iteration space
"
def
_dense_mapping
(
self
,
ispace
:
FullIterationSpace
)
->
dict
[
PsSymbol
,
PsExpression
]:
if
ispace
.
rank
>
4
:
if
ispace
.
rank
>
4
:
raise
MaterializationError
(
raise
MaterializationError
(
f
"
Cannot handle
{
ispace
.
rank
}
-dimensional iteration space
"
f
"
Cannot handle
{
ispace
.
rank
}
-dimensional iteration space
"
...
@@ -114,6 +148,18 @@ class Blockwise4DMapping(ThreadToIndexMapping):
...
@@ -114,6 +148,18 @@ class Blockwise4DMapping(ThreadToIndexMapping):
return
idx_map
return
idx_map
def
_sparse_mapping
(
self
,
ispace
:
SparseIterationSpace
)
->
dict
[
PsSymbol
,
PsExpression
]:
sparse_ctr
=
PsExpression
.
make
(
ispace
.
sparse_counter
)
thread_idx
=
self
.
_indices_in_loop_order
[
-
1
]
idx_map
:
dict
[
PsSymbol
,
PsExpression
]
=
{
ispace
.
sparse_counter
:
PsCast
(
deconstify
(
sparse_ctr
.
get_dtype
()),
thread_idx
)
}
return
idx_map
class
CudaPlatform
(
GenericGpu
):
class
CudaPlatform
(
GenericGpu
):
"""
Platform for CUDA-based GPUs.
"""
"""
Platform for CUDA-based GPUs.
"""
...
@@ -127,7 +173,9 @@ class CudaPlatform(GenericGpu):
...
@@ -127,7 +173,9 @@ class CudaPlatform(GenericGpu):
super
().
__init__
(
ctx
)
super
().
__init__
(
ctx
)
self
.
_omit_range_check
=
omit_range_check
self
.
_omit_range_check
=
omit_range_check
self
.
_thread_mapping
=
thread_mapping
self
.
_thread_mapping
=
(
thread_mapping
if
thread_mapping
is
not
None
else
Linear3DMapping
()
)
self
.
_typify
=
Typifier
(
ctx
)
self
.
_typify
=
Typifier
(
ctx
)
...
@@ -212,26 +260,7 @@ class CudaPlatform(GenericGpu):
...
@@ -212,26 +260,7 @@ class CudaPlatform(GenericGpu):
)
->
PsBlock
:
)
->
PsBlock
:
dimensions
=
ispace
.
dimensions_in_loop_order
()
dimensions
=
ispace
.
dimensions_in_loop_order
()
# TODO move to codegen
ctr_mapping
=
self
.
_thread_mapping
(
ispace
)
# if not self._manual_launch_grid:
# try:
# threads_range = self.threads_from_ispace(ispace)
# except MaterializationError as e:
# warn(
# str(e.args[0])
# + "\nIf this is intended, set `manual_launch_grid=True` in the code generator configuration.",
# UserWarning,
# )
# threads_range = None
# else:
# threads_range = None
idx_mapper
=
(
self
.
_thread_mapping
if
self
.
_thread_mapping
is
not
None
else
Linear3DMapping
()
)
ctr_mapping
=
idx_mapper
(
ispace
)
indexing_decls
=
[]
indexing_decls
=
[]
conds
=
[]
conds
=
[]
...
@@ -264,10 +293,11 @@ class CudaPlatform(GenericGpu):
...
@@ -264,10 +293,11 @@ class CudaPlatform(GenericGpu):
factory
=
AstFactory
(
self
.
_ctx
)
factory
=
AstFactory
(
self
.
_ctx
)
ispace
.
sparse_counter
.
dtype
=
constify
(
ispace
.
sparse_counter
.
get_dtype
())
ispace
.
sparse_counter
.
dtype
=
constify
(
ispace
.
sparse_counter
.
get_dtype
())
sparse_ctr
=
PsExpression
.
make
(
ispace
.
sparse_counter
)
sparse_ctr_expr
=
PsExpression
.
make
(
ispace
.
sparse_counter
)
thread_idx
=
BLOCK_IDX
[
0
]
*
BLOCK_DIM
[
0
]
+
THREAD_IDX
[
0
]
ctr_mapping
=
self
.
_thread_mapping
(
ispace
)
sparse_idx_decl
=
self
.
_typify
(
sparse_idx_decl
=
self
.
_typify
(
PsDeclaration
(
sparse_ctr
,
PsCast
(
sparse_ctr
.
get_dtype
(),
thread_idx
)
)
PsDeclaration
(
sparse_ctr
_expr
,
ctr_mapping
[
ispace
.
sparse_counter
]
)
)
)
mappings
=
[
mappings
=
[
...
@@ -276,7 +306,7 @@ class CudaPlatform(GenericGpu):
...
@@ -276,7 +306,7 @@ class CudaPlatform(GenericGpu):
PsLookup
(
PsLookup
(
PsBufferAcc
(
PsBufferAcc
(
ispace
.
index_list
.
base_pointer
,
ispace
.
index_list
.
base_pointer
,
(
sparse_ctr
,
factory
.
parse_index
(
0
)),
(
sparse_ctr
_expr
.
clone
()
,
factory
.
parse_index
(
0
)),
),
),
coord
.
name
,
coord
.
name
,
),
),
...
@@ -287,7 +317,7 @@ class CudaPlatform(GenericGpu):
...
@@ -287,7 +317,7 @@ class CudaPlatform(GenericGpu):
if
not
self
.
_omit_range_check
:
if
not
self
.
_omit_range_check
:
stop
=
PsExpression
.
make
(
ispace
.
index_list
.
shape
[
0
])
stop
=
PsExpression
.
make
(
ispace
.
index_list
.
shape
[
0
])
condition
=
PsLt
(
sparse_ctr
,
stop
)
condition
=
PsLt
(
sparse_ctr
_expr
.
clone
()
,
stop
)
ast
=
PsBlock
([
sparse_idx_decl
,
PsConditional
(
condition
,
body
)])
ast
=
PsBlock
([
sparse_idx_decl
,
PsConditional
(
condition
,
body
)])
else
:
else
:
body
.
statements
=
[
sparse_idx_decl
]
+
body
.
statements
body
.
statements
=
[
sparse_idx_decl
]
+
body
.
statements
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment