Skip to content
Snippets Groups Projects
Commit 21fb888a authored by Frederik Hennig's avatar Frederik Hennig
Browse files

move integer_functions and fast_approximations into sympyextensions module

parent 6bc218df
Branches
Tags
No related merge requests found
...@@ -53,9 +53,8 @@ class FreezeExpressions: ...@@ -53,9 +53,8 @@ class FreezeExpressions:
- Augmented Assignments - Augmented Assignments
- AddressOf - AddressOf
- Conditionals (+ frontend class)
- Relations (sp.Relational) - Relations (sp.Relational)
- pystencils.integer_functions - pystencils.sympyextensions.integer_functions
- pystencils.sympyextensions.bit_masks - pystencils.sympyextensions.bit_masks
- GPU fast approximations (pystencils.fast_approximation) - GPU fast approximations (pystencils.fast_approximation)
- ConditionalFieldAccess - ConditionalFieldAccess
......
...@@ -4,7 +4,7 @@ from pystencils.boundaries.boundaryhandling import DEFAULT_FLAG_TYPE ...@@ -4,7 +4,7 @@ from pystencils.boundaries.boundaryhandling import DEFAULT_FLAG_TYPE
from pystencils.sympyextensions import TypedSymbol from pystencils.sympyextensions import TypedSymbol
from pystencils.types import create_type from pystencils.types import create_type
from pystencils.field import Field from pystencils.field import Field
from pystencils.integer_functions import bitwise_and from pystencils.sympyextensions.integer_functions import bitwise_and
def add_neumann_boundary(eqs, fields, flag_field, boundary_flag="neumann_flag", inverse_flag=False): def add_neumann_boundary(eqs, fields, flag_field, boundary_flag="neumann_flag", inverse_flag=False):
......
...@@ -4,7 +4,7 @@ import uuid ...@@ -4,7 +4,7 @@ import uuid
from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Set, Union from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Set, Union
import sympy as sp import sympy as sp
from sympy.codegen.ast import Assignment, AugmentedAssignment, AddAugmentedAssignment from sympy.codegen.ast import Assignment, AugmentedAssignment
from sympy.printing.latex import LatexPrinter from sympy.printing.latex import LatexPrinter
import numpy as np import numpy as np
......
# TODO #47 move to a module functions
import numpy as np
import sympy as sp import sympy as sp
from pystencils.sympyextensions import CastFunc
from pystencils.types import create_type
from pystencils.sympyextensions import is_integer_sequence from pystencils.sympyextensions import is_integer_sequence
...@@ -11,22 +6,7 @@ class IntegerFunctionTwoArgsMixIn(sp.Function): ...@@ -11,22 +6,7 @@ class IntegerFunctionTwoArgsMixIn(sp.Function):
is_integer = True is_integer = True
def __new__(cls, arg1, arg2): def __new__(cls, arg1, arg2):
args = [] args = [arg1, arg2]
for a in (arg1, arg2):
if isinstance(a, sp.Number) or isinstance(a, int):
args.append(CastFunc(a, create_type("int")))
elif isinstance(a, np.generic):
args.append(CastFunc(a, a.dtype))
else:
args.append(a)
for a in args:
try:
dtype = get_type_of_expression(a)
if not dtype.is_int():
raise ValueError("Argument to integer function is not an int but " + str(dtype))
except NotImplementedError:
raise ValueError("Integer functions can only be constructed with typed expressions")
return super().__new__(cls, *args) return super().__new__(cls, *args)
def _eval_evalf(self, *pargs, **kwargs): def _eval_evalf(self, *pargs, **kwargs):
...@@ -100,11 +80,12 @@ class modulo_floor(sp.Function): ...@@ -100,11 +80,12 @@ class modulo_floor(sp.Function):
else: else:
return super().__new__(cls, integer, divisor) return super().__new__(cls, integer, divisor)
def to_c(self, print_func): # TODO: Implement this in FreezeExpressions
dtype = collate_types((get_type_of_expression(self.args[0]), get_type_of_expression(self.args[1]))) # def to_c(self, print_func):
assert dtype.is_int() # dtype = collate_types((get_type_of_expression(self.args[0]), get_type_of_expression(self.args[1])))
return "({dtype})(({0}) / ({1})) * ({1})".format(print_func(self.args[0]), # assert dtype.is_int()
print_func(self.args[1]), dtype=dtype) # return "({dtype})(({0}) / ({1})) * ({1})".format(print_func(self.args[0]),
# print_func(self.args[1]), dtype=dtype)
# noinspection PyPep8Naming # noinspection PyPep8Naming
...@@ -132,11 +113,12 @@ class modulo_ceil(sp.Function): ...@@ -132,11 +113,12 @@ class modulo_ceil(sp.Function):
else: else:
return super().__new__(cls, integer, divisor) return super().__new__(cls, integer, divisor)
def to_c(self, print_func): # TODO: Implement this in FreezeExpressions
dtype = collate_types((get_type_of_expression(self.args[0]), get_type_of_expression(self.args[1]))) # def to_c(self, print_func):
assert dtype.is_int() # dtype = collate_types((get_type_of_expression(self.args[0]), get_type_of_expression(self.args[1])))
code = "(({0}) % ({1}) == 0 ? {0} : (({dtype})(({0}) / ({1}))+1) * ({1}))" # assert dtype.is_int()
return code.format(print_func(self.args[0]), print_func(self.args[1]), dtype=dtype) # code = "(({0}) % ({1}) == 0 ? {0} : (({dtype})(({0}) / ({1}))+1) * ({1}))"
# return code.format(print_func(self.args[0]), print_func(self.args[1]), dtype=dtype)
# noinspection PyPep8Naming # noinspection PyPep8Naming
...@@ -162,11 +144,12 @@ class div_ceil(sp.Function): ...@@ -162,11 +144,12 @@ class div_ceil(sp.Function):
else: else:
return super().__new__(cls, integer, divisor) return super().__new__(cls, integer, divisor)
def to_c(self, print_func): # TODO: Implement this in FreezeExpressions
dtype = collate_types((get_type_of_expression(self.args[0]), get_type_of_expression(self.args[1]))) # def to_c(self, print_func):
assert dtype.is_int() # dtype = collate_types((get_type_of_expression(self.args[0]), get_type_of_expression(self.args[1])))
code = "( ({0}) % ({1}) == 0 ? ({dtype})({0}) / ({dtype})({1}) : ( ({dtype})({0}) / ({dtype})({1}) ) +1 )" # assert dtype.is_int()
return code.format(print_func(self.args[0]), print_func(self.args[1]), dtype=dtype) # code = "( ({0}) % ({1}) == 0 ? ({dtype})({0}) / ({dtype})({1}) : ( ({dtype})({0}) / ({dtype})({1}) ) +1 )"
# return code.format(print_func(self.args[0]), print_func(self.args[1]), dtype=dtype)
# noinspection PyPep8Naming # noinspection PyPep8Naming
...@@ -192,8 +175,9 @@ class div_floor(sp.Function): ...@@ -192,8 +175,9 @@ class div_floor(sp.Function):
else: else:
return super().__new__(cls, integer, divisor) return super().__new__(cls, integer, divisor)
def to_c(self, print_func): # TODO: Implement this in FreezeExpressions
dtype = collate_types((get_type_of_expression(self.args[0]), get_type_of_expression(self.args[1]))) # def to_c(self, print_func):
assert dtype.is_int() # dtype = collate_types((get_type_of_expression(self.args[0]), get_type_of_expression(self.args[1])))
code = "(({dtype})({0}) / ({dtype})({1}))" # assert dtype.is_int()
return code.format(print_func(self.args[0]), print_func(self.args[1]), dtype=dtype) # code = "(({dtype})({0}) / ({dtype})({1}))"
# return code.format(print_func(self.args[0]), print_func(self.args[1]), dtype=dtype)
...@@ -549,7 +549,7 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr], List[Assignment]], ...@@ -549,7 +549,7 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr], List[Assignment]],
Returns: Returns:
dict with 'adds', 'muls' and 'divs' keys dict with 'adds', 'muls' and 'divs' keys
""" """
from pystencils.fast_approximation import fast_sqrt, fast_inv_sqrt, fast_division from pystencils.sympyextensions.fast_approximation import fast_sqrt, fast_inv_sqrt, fast_division
result = {'adds': 0, 'muls': 0, 'divs': 0, 'sqrts': 0, result = {'adds': 0, 'muls': 0, 'divs': 0, 'sqrts': 0,
'fast_sqrts': 0, 'fast_inv_sqrts': 0, 'fast_div': 0} 'fast_sqrts': 0, 'fast_inv_sqrts': 0, 'fast_div': 0}
......
import time import time
from pystencils.integer_functions import modulo_ceil from pystencils.sympyextensions.integer_functions import modulo_ceil
class TimeLoop: class TimeLoop:
......
...@@ -2,7 +2,7 @@ import pytest ...@@ -2,7 +2,7 @@ import pytest
import sympy as sp import sympy as sp
import pystencils as ps import pystencils as ps
from pystencils.fast_approximation import ( from pystencils.sympyextensions.fast_approximation import (
fast_division, fast_inv_sqrt, fast_sqrt, insert_fast_divisions, insert_fast_sqrts) fast_division, fast_inv_sqrt, fast_sqrt, insert_fast_divisions, insert_fast_sqrts)
......
...@@ -2,7 +2,7 @@ import pytest ...@@ -2,7 +2,7 @@ import pytest
import sympy as sp import sympy as sp
import numpy as np import numpy as np
import pystencils as ps import pystencils as ps
from pystencils.fast_approximation import fast_division from pystencils.sympyextensions.fast_approximation import fast_division
@pytest.mark.parametrize('dtype', ["float64", "float32"]) @pytest.mark.parametrize('dtype', ["float64", "float32"])
......
...@@ -15,7 +15,7 @@ from pystencils.sympyextensions import scalar_product ...@@ -15,7 +15,7 @@ from pystencils.sympyextensions import scalar_product
from pystencils.sympyextensions import kronecker_delta from pystencils.sympyextensions import kronecker_delta
from pystencils import Assignment from pystencils import Assignment
from pystencils.fast_approximation import (fast_division, fast_inv_sqrt, fast_sqrt, from pystencils.sympyextensions.fast_approximation import (fast_division, fast_inv_sqrt, fast_sqrt,
insert_fast_divisions, insert_fast_sqrts) insert_fast_divisions, insert_fast_sqrts)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment