Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
pystencils
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Terms and privacy
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
pycodegen
pystencils
Commits
4a8659e8
Commit
4a8659e8
authored
8 years ago
by
Jan Hoenig
Browse files
Options
Downloads
Patches
Plain Diff
it actually somehow comiles
parent
b444ae25
No related branches found
No related tags found
No related merge requests found
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
astnodes.py
+6
-0
6 additions, 0 deletions
astnodes.py
backends/llvm.py
+54
-19
54 additions, 19 deletions
backends/llvm.py
llvm/__init__.py
+1
-0
1 addition, 0 deletions
llvm/__init__.py
llvm/jit.py
+6
-6
6 additions, 6 deletions
llvm/jit.py
types.py
+3
-0
3 additions, 0 deletions
types.py
with
70 additions
and
25 deletions
astnodes.py
+
6
−
0
View file @
4a8659e8
...
@@ -511,4 +511,10 @@ class Number(Node, sp.AtomicExpr):
...
@@ -511,4 +511,10 @@ class Number(Node, sp.AtomicExpr):
def
__repr__
(
self
):
def
__repr__
(
self
):
return
repr
(
self
.
value
)
return
repr
(
self
.
value
)
def
__float__
(
self
):
return
float
(
self
.
value
)
def
__int__
(
self
):
return
int
(
self
.
value
)
This diff is collapsed.
Click to expand it.
backends/llvm.py
+
54
−
19
View file @
4a8659e8
import
llvmlite.ir
as
ir
import
llvmlite.ir
as
ir
import
functools
from
sympy.printing.printer
import
Printer
from
sympy.printing.printer
import
Printer
from
sympy
import
S
from
sympy
import
S
# S is numbers?
# S is numbers?
from
pystencils.llvm.control_flow
import
Loop
from
pystencils.llvm.control_flow
import
Loop
from
..types
import
DataType
from
..astnodes
import
Indexed
def
generateLLVM
(
ast_node
):
def
generateLLVM
(
ast_node
):
...
@@ -25,6 +28,7 @@ class LLVMPrinter(Printer):
...
@@ -25,6 +28,7 @@ class LLVMPrinter(Printer):
self
.
fp_type
=
ir
.
DoubleType
()
self
.
fp_type
=
ir
.
DoubleType
()
self
.
fp_pointer
=
self
.
fp_type
.
as_pointer
()
self
.
fp_pointer
=
self
.
fp_type
.
as_pointer
()
self
.
integer
=
ir
.
IntType
(
64
)
self
.
integer
=
ir
.
IntType
(
64
)
self
.
integer_pointer
=
self
.
integer
.
as_pointer
()
self
.
void
=
ir
.
VoidType
()
self
.
void
=
ir
.
VoidType
()
self
.
module
=
module
self
.
module
=
module
self
.
builder
=
builder
self
.
builder
=
builder
...
@@ -35,8 +39,13 @@ class LLVMPrinter(Printer):
...
@@ -35,8 +39,13 @@ class LLVMPrinter(Printer):
def
_add_tmp_var
(
self
,
name
,
value
):
def
_add_tmp_var
(
self
,
name
,
value
):
self
.
tmp_var
[
name
]
=
value
self
.
tmp_var
[
name
]
=
value
def
_print_Number
(
self
,
n
,
**
kwargs
):
def
_print_Number
(
self
,
n
):
return
ir
.
Constant
(
self
.
fp_type
,
n
)
if
n
.
dtype
==
DataType
(
"
int
"
):
return
ir
.
Constant
(
self
.
integer
,
int
(
n
))
elif
n
.
dtype
==
DataType
(
"
double
"
):
return
ir
.
Constant
(
self
.
fp_type
,
float
(
n
))
else
:
raise
NotImplementedError
(
"
Numbers can only have int and double
"
,
n
)
def
_print_Float
(
self
,
expr
):
def
_print_Float
(
self
,
expr
):
return
ir
.
Constant
(
self
.
fp_type
,
expr
.
p
)
return
ir
.
Constant
(
self
.
fp_type
,
expr
.
p
)
...
@@ -81,16 +90,23 @@ class LLVMPrinter(Printer):
...
@@ -81,16 +90,23 @@ class LLVMPrinter(Printer):
def
_print_Mul
(
self
,
expr
):
def
_print_Mul
(
self
,
expr
):
nodes
=
[
self
.
_print
(
a
)
for
a
in
expr
.
args
]
nodes
=
[
self
.
_print
(
a
)
for
a
in
expr
.
args
]
e
=
nodes
[
0
]
e
=
nodes
[
0
]
if
expr
.
dtype
==
DataType
(
'
double
'
):
mul
=
self
.
builder
.
fmul
else
:
# int TODO others?
mul
=
self
.
builder
.
mul
for
node
in
nodes
[
1
:]:
for
node
in
nodes
[
1
:]:
e
=
self
.
builder
.
f
mul
(
e
,
node
)
e
=
mul
(
e
,
node
)
return
e
return
e
def
_print_Add
(
self
,
expr
):
def
_print_Add
(
self
,
expr
):
nodes
=
[
self
.
_print
(
a
)
for
a
in
expr
.
args
]
nodes
=
[
self
.
_print
(
a
)
for
a
in
expr
.
args
]
e
=
nodes
[
0
]
e
=
nodes
[
0
]
if
expr
.
dtype
==
DataType
(
'
double
'
):
add
=
self
.
builder
.
fadd
else
:
# int TODO others?
add
=
self
.
builder
.
add
for
node
in
nodes
[
1
:]:
for
node
in
nodes
[
1
:]:
print
(
e
,
node
)
e
=
add
(
e
,
node
)
e
=
self
.
builder
.
fadd
(
e
,
node
)
return
e
return
e
def
_print_KernelFunction
(
self
,
function
):
def
_print_KernelFunction
(
self
,
function
):
...
@@ -118,6 +134,7 @@ class LLVMPrinter(Printer):
...
@@ -118,6 +134,7 @@ class LLVMPrinter(Printer):
block
=
fn
.
append_basic_block
(
name
=
"
entry
"
)
block
=
fn
.
append_basic_block
(
name
=
"
entry
"
)
self
.
builder
=
ir
.
IRBuilder
(
block
)
self
.
builder
=
ir
.
IRBuilder
(
block
)
self
.
_print
(
function
.
body
)
self
.
_print
(
function
.
body
)
self
.
builder
.
ret_void
()
self
.
fn
=
fn
self
.
fn
=
fn
return
fn
return
fn
...
@@ -129,29 +146,47 @@ class LLVMPrinter(Printer):
...
@@ -129,29 +146,47 @@ class LLVMPrinter(Printer):
with
Loop
(
self
.
builder
,
self
.
_print
(
loop
.
start
),
self
.
_print
(
loop
.
stop
),
self
.
_print
(
loop
.
step
),
with
Loop
(
self
.
builder
,
self
.
_print
(
loop
.
start
),
self
.
_print
(
loop
.
stop
),
self
.
_print
(
loop
.
step
),
loop
.
loopCounterName
,
loop
.
loopCounterSymbol
.
name
)
as
i
:
loop
.
loopCounterName
,
loop
.
loopCounterSymbol
.
name
)
as
i
:
self
.
_add_tmp_var
(
loop
.
loopCounterSymbol
,
i
)
self
.
_add_tmp_var
(
loop
.
loopCounterSymbol
,
i
)
# TODO remove tmp var
self
.
_print
(
loop
.
body
)
self
.
_print
(
loop
.
body
)
def
_print_SympyAssignment
(
self
,
assignment
):
def
_print_SympyAssignment
(
self
,
assignment
):
expr
=
self
.
_print
(
assignment
.
rhs
)
expr
=
self
.
_print
(
assignment
.
rhs
)
lhs
=
assignment
.
lhs
if
isinstance
(
lhs
,
Indexed
):
ptr
=
self
.
_print
(
lhs
.
base
.
label
)
index
=
self
.
_print
(
lhs
.
args
[
1
])
gep
=
self
.
builder
.
gep
(
ptr
,
[
index
])
return
self
.
builder
.
store
(
expr
,
gep
)
self
.
func_arg_map
[
assignment
.
lhs
.
name
]
=
expr
return
expr
def
_print_Conversion
(
self
,
conversion
):
def
_print_Conversion
(
self
,
conversion
):
node
=
self
.
_print
(
conversion
.
args
[
0
])
to_dtype
=
conversion
.
dtype
to_dtype
=
conversion
.
dtype
from_dtype
=
conversion
.
args
[
0
].
dtype
from_dtype
=
conversion
.
args
[
0
].
dtype
print
(
to_dtype
,
from_dtype
)
# (From, to)
# fp -> int: fptosi
decision
=
{
# int -> fp: sitofp
(
DataType
(
"
int
"
),
DataType
(
"
double
"
)):
functools
.
partial
(
self
.
builder
.
sitofp
,
node
,
self
.
fp_type
),
# ptr -> int: ptrtoint
(
DataType
(
"
double
"
),
DataType
(
"
int
"
)):
functools
.
partial
(
self
.
builder
.
fptosi
,
node
,
self
.
integer
),
# int -> ptr: inttoptr
(
DataType
(
"
double *
"
),
DataType
(
"
int
"
)):
functools
.
partial
(
self
.
builder
.
ptrtoint
,
node
,
self
.
integer
),
# ?bitcast, ?addrspacecast
(
DataType
(
"
int
"
),
DataType
(
"
double *
"
)):
functools
.
partial
(
self
.
builder
.
inttoptr
,
node
,
self
.
fp_pointer
),
(
DataType
(
"
double * __restrict__
"
),
DataType
(
"
int
"
)):
functools
.
partial
(
self
.
builder
.
ptrtoint
,
node
,
self
.
integer
),
(
DataType
(
"
int
"
),
DataType
(
"
double * __restrict__
"
)):
functools
.
partial
(
self
.
builder
.
inttoptr
,
node
,
self
.
fp_pointer
),
(
DataType
(
"
const double * __restrict__
"
),
DataType
(
"
int
"
)):
functools
.
partial
(
self
.
builder
.
ptrtoint
,
node
,
self
.
integer
),
(
DataType
(
"
int
"
),
DataType
(
"
const double * __restrict__
"
)):
functools
.
partial
(
self
.
builder
.
inttoptr
,
node
,
self
.
fp_pointer
),
}
# TODO float, const, restrict
# TODO bitcast, addrspacecast
return
decision
[(
from_dtype
,
to_dtype
)]()
def
_print_Indexed
(
self
,
indexed
):
def
_print_Indexed
(
self
,
indexed
):
pass
ptr
=
self
.
_print
(
indexed
.
base
.
label
)
index
=
self
.
_print
(
indexed
.
args
[
1
])
gep
=
self
.
builder
.
gep
(
ptr
,
[
index
])
return
self
.
builder
.
load
(
gep
,
name
=
indexed
.
base
.
label
.
name
)
# Should have a list of math library functions to validate this.
# TODO function calls
# Should have a list of math library functions to validate this.
# TODO delete this -> NO this should be a function call
def
_print_Function
(
self
,
expr
):
def
_print_Function
(
self
,
expr
):
name
=
expr
.
func
.
__name__
name
=
expr
.
func
.
__name__
e0
=
self
.
_print
(
expr
.
args
[
0
])
e0
=
self
.
_print
(
expr
.
args
[
0
])
...
@@ -163,5 +198,5 @@ class LLVMPrinter(Printer):
...
@@ -163,5 +198,5 @@ class LLVMPrinter(Printer):
return
self
.
builder
.
call
(
fn
,
[
e0
],
name
)
return
self
.
builder
.
call
(
fn
,
[
e0
],
name
)
def
emptyPrinter
(
self
,
expr
):
def
emptyPrinter
(
self
,
expr
):
raise
TypeError
(
"
Unsupported type for LLVM JIT conversion: %s
"
raise
TypeError
(
"
Unsupported type for LLVM JIT conversion:
%s
%s
"
%
type
(
expr
))
%
type
(
expr
)
,
expr
)
This diff is collapsed.
Click to expand it.
llvm/__init__.py
+
1
−
0
View file @
4a8659e8
from
.kernelcreation
import
createKernel
from
.kernelcreation
import
createKernel
from
.jit
import
compileLLVM
\ No newline at end of file
This diff is collapsed.
Click to expand it.
llvm/jit.py
+
6
−
6
View file @
4a8659e8
import
llvmlite.binding
as
llvm
import
llvmlite.binding
as
llvm
import
logging.config
import
logging.config
logger
=
logging
.
getLogger
(
__name__
)
def
compileLLVM
(
module
):
return
Eval
().
compile
(
module
)
class
Eval
(
object
):
class
Eval
(
object
):
def
__init__
(
self
):
def
__init__
(
self
):
...
@@ -63,9 +69,3 @@ class Eval(object):
...
@@ -63,9 +69,3 @@ class Eval(object):
# result = fptr(2, 3)
# result = fptr(2, 3)
# print(result)
# print(result)
return
0
return
0
if
__name__
==
"
__main__
"
:
logger
=
logging
.
getLogger
(
__name__
)
else
:
logger
=
logging
.
getLogger
(
__name__
)
This diff is collapsed.
Click to expand it.
types.py
+
3
−
0
View file @
4a8659e8
...
@@ -70,6 +70,9 @@ class DataType(object):
...
@@ -70,6 +70,9 @@ class DataType(object):
if
self
.
dtype
>
other
.
dtype
:
if
self
.
dtype
>
other
.
dtype
:
return
True
return
True
def
__hash__
(
self
):
return
hash
(
repr
(
self
))
def
get_type_from_sympy
(
node
):
def
get_type_from_sympy
(
node
):
# Rational, NumberSymbol?
# Rational, NumberSymbol?
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment