Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found
Select Git revision

Target

Select target project
  • anirudh.jonnalagadda/pystencils
  • hyteg/pystencils
  • jbadwaik/pystencils
  • jngrad/pystencils
  • itischler/pystencils
  • ob28imeq/pystencils
  • hoenig/pystencils
  • Bindgen/pystencils
  • hammer/pystencils
  • da15siwa/pystencils
  • holzer/pystencils
  • alexander.reinauer/pystencils
  • ec93ujoh/pystencils
  • Harke/pystencils
  • seitz/pystencils
  • pycodegen/pystencils
16 results
Select Git revision
Show changes
Showing
with 793 additions and 113 deletions
......@@ -8,13 +8,13 @@ import six
from blitzdb.backends.file.backend import serializer_classes
from blitzdb.backends.file.utils import JsonEncoder
from pystencils.cpu.cpujit import get_compiler_config
from pystencils import CreateKernelConfig, Target, Backend, Field
from pystencils.jit.legacy_cpu import get_compiler_config
from pystencils import CreateKernelConfig, Target, Field
import json
import sympy as sp
from pystencils.typing import BasicType
from pystencils.types import PsType
class PystencilsJsonEncoder(JsonEncoder):
......@@ -26,9 +26,9 @@ class PystencilsJsonEncoder(JsonEncoder):
return float(obj)
if isinstance(obj, sp.Integer):
return int(obj)
if isinstance(obj, (BasicType, MappingProxyType)):
if isinstance(obj, (PsType, MappingProxyType)):
return str(obj)
if isinstance(obj, (Target, Backend, sp.Symbol)):
if isinstance(obj, (Target, sp.Symbol)):
return obj.name
if isinstance(obj, Field):
return f"pystencils.Field(name = {obj.name}, field_type = {obj.field_type.name}, " \
......
from .assignment_collection import AssignmentCollection
from .simplifications import (
add_subexpressions_for_constants,
add_subexpressions_for_divisions, add_subexpressions_for_field_reads,
add_subexpressions_for_sums, apply_on_all_subexpressions, apply_to_all_assignments,
add_subexpressions_for_divisions,
add_subexpressions_for_field_reads,
add_subexpressions_for_sums,
apply_on_all_subexpressions,
apply_to_all_assignments,
subexpression_substitution_in_existing_subexpressions,
subexpression_substitution_in_main_assignments, sympy_cse, sympy_cse_on_assignment_list)
subexpression_substitution_in_main_assignments,
sympy_cse,
sympy_cse_on_assignment_list,
)
from .subexpression_insertion import (
insert_aliases, insert_zeros, insert_constants,
insert_constant_additions, insert_constant_multiples,
insert_squares, insert_symbol_times_minus_one)
insert_aliases,
insert_zeros,
insert_constants,
insert_constant_additions,
insert_constant_multiples,
insert_squares,
insert_symbol_times_minus_one,
)
from .simplificationstrategy import SimplificationStrategy
__all__ = ['AssignmentCollection', 'SimplificationStrategy',
'sympy_cse', 'sympy_cse_on_assignment_list', 'apply_to_all_assignments',
'apply_on_all_subexpressions', 'subexpression_substitution_in_existing_subexpressions',
'subexpression_substitution_in_main_assignments', 'add_subexpressions_for_constants',
'add_subexpressions_for_divisions', 'add_subexpressions_for_sums', 'add_subexpressions_for_field_reads',
'insert_aliases', 'insert_zeros', 'insert_constants',
'insert_constant_additions', 'insert_constant_multiples',
'insert_squares', 'insert_symbol_times_minus_one']
__all__ = [
"AssignmentCollection",
"SimplificationStrategy",
"sympy_cse",
"sympy_cse_on_assignment_list",
"apply_to_all_assignments",
"apply_on_all_subexpressions",
"subexpression_substitution_in_existing_subexpressions",
"subexpression_substitution_in_main_assignments",
"add_subexpressions_for_constants",
"add_subexpressions_for_divisions",
"add_subexpressions_for_sums",
"add_subexpressions_for_field_reads",
"insert_aliases",
"insert_zeros",
"insert_constants",
"insert_constant_additions",
"insert_constant_multiples",
"insert_squares",
"insert_symbol_times_minus_one",
]
......@@ -5,9 +5,9 @@ from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Set,
import sympy as sp
import pystencils
from pystencils.assignment import Assignment
from pystencils.simp.simplifications import (sort_assignments_topologically, transform_lhs_and_rhs, transform_rhs)
from pystencils.sympyextensions import count_operations, fast_subs
from ..assignment import Assignment
from .simplifications import (sort_assignments_topologically, transform_lhs_and_rhs, transform_rhs)
from ..sympyextensions import count_operations, fast_subs
class AssignmentCollection:
......@@ -31,6 +31,8 @@ class AssignmentCollection:
"""
__match_args__ = ("main_assignments", "subexpressions")
# ------------------------------- Creation & Inplace Manipulation --------------------------------------------------
def __init__(self, main_assignments: Union[List[Assignment], Dict[sp.Expr, sp.Expr]],
......@@ -116,8 +118,8 @@ class AssignmentCollection:
for eq in self.all_assignments:
if isinstance(eq, Assignment):
rhs_symbols.update(eq.rhs.atoms(sp.Symbol))
elif isinstance(eq, pystencils.astnodes.Node):
rhs_symbols.update(eq.undefined_symbols)
# elif isinstance(eq, pystencils.astnodes.Node): # TODO remove or replace
# rhs_symbols.update(eq.undefined_symbols)
return rhs_symbols
......@@ -136,10 +138,10 @@ class AssignmentCollection:
assert len(bound_symbols_set) == len(list(a for a in self.all_assignments if isinstance(a, Assignment))), \
"Not in SSA form - same symbol assigned multiple times"
bound_symbols_set = bound_symbols_set.union(*[
assignment.symbols_defined for assignment in self.all_assignments
if isinstance(assignment, pystencils.astnodes.Node)
])
# bound_symbols_set = bound_symbols_set.union(*[
# assignment.symbols_defined for assignment in self.all_assignments
# if isinstance(assignment, pystencils.astnodes.Node)
# ]) TODO: replace?
return bound_symbols_set
......@@ -162,8 +164,9 @@ class AssignmentCollection:
def defined_symbols(self) -> Set[sp.Symbol]:
"""All symbols which occur as left-hand-sides of one of the main equations"""
lhs_set = set([assignment.lhs for assignment in self.main_assignments if isinstance(assignment, Assignment)])
return (lhs_set.union(*[assignment.symbols_defined for assignment in self.main_assignments
if isinstance(assignment, pystencils.astnodes.Node)]))
return lhs_set
# return (lhs_set.union(*[assignment.symbols_defined for assignment in self.main_assignments
# if isinstance(assignment, pystencils.astnodes.Node)])) TODO
@property
def operation_count(self):
......@@ -286,12 +289,13 @@ class AssignmentCollection:
processed_other_subexpression_equations = []
for other_subexpression_eq in other.subexpressions:
if other_subexpression_eq.lhs in own_subexpression_symbols:
if other_subexpression_eq.rhs == own_subexpression_symbols[other_subexpression_eq.lhs]:
new_rhs = fast_subs(other_subexpression_eq.rhs, substitution_dict)
if new_rhs == own_subexpression_symbols[other_subexpression_eq.lhs]:
continue # exact the same subexpression equation exists already
else:
# different definition - a new name has to be introduced
new_lhs = next(self.subexpression_symbol_generator)
new_eq = Assignment(new_lhs, fast_subs(other_subexpression_eq.rhs, substitution_dict))
new_eq = Assignment(new_lhs, new_rhs)
processed_other_subexpression_equations.append(new_eq)
substitution_dict[other_subexpression_eq.lhs] = new_lhs
else:
......
......@@ -4,21 +4,21 @@ from collections import defaultdict
import sympy as sp
from pystencils.assignment import Assignment
from pystencils.astnodes import Node
from pystencils.field import Field
from pystencils.sympyextensions import subs_additive, is_constant, recursive_collect
from pystencils.typing import TypedSymbol
from ..assignment import Assignment
from ..sympyextensions import subs_additive, is_constant, recursive_collect
from ..sympyextensions.typed_sympy import TypedSymbol
def sort_assignments_topologically(assignments: Sequence[Union[Assignment, Node]]) -> List[Union[Assignment, Node]]:
# TODO rewrite with SymPy AST
# def sort_assignments_topologically(assignments: Sequence[Union[Assignment, Node]]) -> List[Union[Assignment, Node]]:
def sort_assignments_topologically(assignments: Sequence[Union[Assignment]]) -> List[Union[Assignment]]:
"""Sorts assignments in topological order, such that symbols used on rhs occur first on a lhs"""
edges = []
for c1, e1 in enumerate(assignments):
if hasattr(e1, 'lhs') and hasattr(e1, 'rhs'):
symbols = [e1.lhs]
elif isinstance(e1, Node):
symbols = e1.symbols_defined
# elif isinstance(e1, Node):
# symbols = e1.symbols_defined
else:
raise NotImplementedError(f"Cannot sort topologically. Object of type {type(e1)} cannot be handled.")
......@@ -26,8 +26,8 @@ def sort_assignments_topologically(assignments: Sequence[Union[Assignment, Node]
for c2, e2 in enumerate(assignments):
if isinstance(e2, Assignment) and lhs in e2.rhs.free_symbols:
edges.append((c1, c2))
elif isinstance(e2, Node) and lhs in e2.undefined_symbols:
edges.append((c1, c2))
# elif isinstance(e2, Node) and lhs in e2.undefined_symbols:
# edges.append((c1, c2))
return [assignments[i] for i in sp.topological_sort((range(len(assignments)), edges))]
......@@ -55,7 +55,7 @@ def sympy_cse(ac, **kwargs):
def sympy_cse_on_assignment_list(assignments: List[Assignment]) -> List[Assignment]:
"""Extracts common subexpressions from a list of assignments."""
from pystencils.simp.assignment_collection import AssignmentCollection
from pystencils.simp import AssignmentCollection
ec = AssignmentCollection([], assignments)
return sympy_cse(ec).all_assignments
......@@ -163,6 +163,7 @@ def add_subexpressions_for_sums(ac):
for eq in ac.all_assignments:
search_addends(eq.rhs)
from pystencils.field import Field
addends = [a for a in addends if not isinstance(a, sp.Symbol) or isinstance(a, Field.Access)]
new_symbol_gen = ac.subexpression_symbol_generator
substitutions = {addend: new_symbol for new_symbol, addend in zip(new_symbol_gen, addends)}
......@@ -185,6 +186,7 @@ def add_subexpressions_for_field_reads(ac, subexpressions=True, main_assignments
if main_assignments:
to_iterate = chain(to_iterate, ac.main_assignments)
from pystencils.field import Field
for assignment in to_iterate:
if hasattr(assignment, 'lhs') and hasattr(assignment, 'rhs'):
field_reads.update(assignment.rhs.atoms(Field.Access))
......@@ -236,31 +238,3 @@ def apply_on_all_subexpressions(operation: Callable[[sp.Expr], sp.Expr]):
f.__name__ = operation.__name__
return f
# TODO Markus
# make this really work for Assignmentcollections
# this function should ONLY evaluate
# do the optims_c99 elsewhere optionally
# def apply_sympy_optimisations(ac: AssignmentCollection):
# """ Evaluates constant expressions (e.g. :math:`\\sqrt{3}` will be replaced by its floating point representation)
# and applies the default sympy optimisations. See sympy.codegen.rewriting
# """
#
# # Evaluates all constant terms
#
# assignments = ac.all_assignments
#
# evaluate_constant_terms = ReplaceOptim(lambda e: hasattr(e, 'is_constant') and e.is_constant and not e.is_integer,
# lambda p: p.evalf())
#
# sympy_optimisations = [evaluate_constant_terms] + list(optims_c99)
#
# assignments = [Assignment(a.lhs, optimize(a.rhs, sympy_optimisations))
# if hasattr(a, 'lhs')
# else a for a in assignments]
# assignments_nodes = [a.atoms(SympyAssignment) for a in assignments]
# for a in chain.from_iterable(assignments_nodes):
# a.optimize(sympy_optimisations)
#
# return AssignmentCollection(assignments)
......@@ -3,7 +3,7 @@ from typing import Any, Callable, Optional, Sequence
import sympy as sp
from pystencils.simp.assignment_collection import AssignmentCollection
from ..simp import AssignmentCollection
class SimplificationStrategy:
......@@ -57,7 +57,7 @@ class SimplificationStrategy:
def __str__(self):
try:
import tabulate
from tabulate import tabulate
return tabulate(self.elements, headers=['Name', 'Runtime', 'Adds', 'Muls', 'Divs', 'Total'])
except ImportError:
result = "Name, Adds, Muls, Divs, Runtime\n"
......
import sympy as sp
from pystencils.sympyextensions import is_constant
from ..sympyextensions import is_constant
# Subexpression Insertion
......
from pystencils.simp import (SimplificationStrategy, insert_constants, insert_symbol_times_minus_one,
insert_constant_multiples, insert_constant_additions, insert_squares, insert_zeros)
from pystencils.simp import (
SimplificationStrategy,
insert_constants,
insert_symbol_times_minus_one,
insert_constant_multiples,
insert_constant_additions,
insert_squares,
insert_zeros,
)
def create_simplification_strategy():
......
......@@ -26,7 +26,7 @@ class SlicedGetterDataHandling:
def __getitem__(self, slice_obj):
if slice_obj is None:
slice_obj = make_slice[:, :] if self.data_handling.dim == 2 else make_slice[:, :, 0.5]
slice_obj = make_slice[:, :] if self.dh.dim == 2 else make_slice[:, :, 0.5]
return self.dh.gather_array(self.name, slice_obj).squeeze()
......
import sympy
from .defaults import DEFAULTS
import pystencils
import pystencils.astnodes
x_, y_, z_ = tuple(pystencils.astnodes.LoopOverCoordinate.get_loop_counter_symbol(i) for i in range(3))
x_, y_, z_ = DEFAULTS.spatial_counters
x_staggered, y_staggered, z_staggered = x_ + 0.5, y_ + 0.5, z_ + 0.5
def x_vector(ndim):
return sympy.Matrix(tuple(pystencils.astnodes.LoopOverCoordinate.get_loop_counter_symbol(i) for i in range(ndim)))
return sympy.Matrix(DEFAULTS.spatial_counters[:ndim])
def x_staggered_vector(ndim):
return sympy.Matrix(tuple(
pystencils.astnodes.LoopOverCoordinate.get_loop_counter_symbol(i) + 0.5 for i in range(ndim)
))
return sympy.Matrix(tuple(DEFAULTS.spatial_counters[i] + 0.5 for i in range(ndim)))
"""pystencils extensions to the SymPy symbolic language."""
from .sympyextensions.integer_functions import (
bitwise_and,
bitwise_or,
bitwise_xor,
bit_shift_left,
bit_shift_right,
int_div,
int_rem,
int_power_of_2,
round_to_multiple_towards_zero,
ceil_to_multiple,
div_ceil,
)
__all__ = [
"bitwise_and",
"bitwise_or",
"bitwise_xor",
"bit_shift_left",
"bit_shift_right",
"int_div",
"int_rem",
"int_power_of_2",
"round_to_multiple_towards_zero",
"ceil_to_multiple",
"div_ceil",
]
from .astnodes import ConditionalFieldAccess
from .typed_sympy import TypedSymbol, CastFunc, tcast, DynamicType
from .pointers import mem_acc
from .math import (
prod,
remove_small_floats,
is_integer_sequence,
scalar_product,
kronecker_delta,
tanh_step_function_approximation,
multidimensional_sum,
normalize_product,
symmetric_product,
fast_subs,
is_constant,
subs_additive,
replace_second_order_products,
remove_higher_order_terms,
complete_the_square,
complete_the_squares_in_exp,
extract_most_common_factor,
recursive_collect,
summands,
simplify_by_equality,
count_operations,
count_operations_in_ast,
common_denominator,
get_symmetric_part,
SymbolCreator
)
__all__ = [
"ConditionalFieldAccess",
"TypedSymbol",
"CastFunc",
"tcast",
"mem_acc",
"remove_higher_order_terms",
"prod",
"remove_small_floats",
"is_integer_sequence",
"scalar_product",
"kronecker_delta",
"tanh_step_function_approximation",
"multidimensional_sum",
"normalize_product",
"symmetric_product",
"fast_subs",
"is_constant",
"subs_additive",
"replace_second_order_products",
"remove_higher_order_terms",
"complete_the_square",
"complete_the_squares_in_exp",
"extract_most_common_factor",
"recursive_collect",
"summands",
"simplify_by_equality",
"count_operations",
"count_operations_in_ast",
"common_denominator",
"get_symmetric_part",
"SymbolCreator",
"DynamicType"
]
import sympy as sp
class ConditionalFieldAccess(sp.Function):
"""
:class:`pystencils.Field.Access` that is only executed if a certain condition is met.
Can be used, for instance, for out-of-bound checks.
"""
def __new__(cls, field_access, outofbounds_condition, outofbounds_value=0):
return sp.Function.__new__(cls, field_access, outofbounds_condition, sp.S(outofbounds_value))
@property
def access(self):
return self.args[0]
@property
def outofbounds_condition(self):
return self.args[1]
@property
def outofbounds_value(self):
return self.args[2]
def __getnewargs__(self):
return self.access, self.outofbounds_condition, self.outofbounds_value
def __getnewargs_ex__(self):
return (self.access, self.outofbounds_condition, self.outofbounds_value), {}
def generic_visit(term, visitor):
from pystencils import AssignmentCollection, Assignment
if isinstance(term, AssignmentCollection):
new_main_assignments = generic_visit(term.main_assignments, visitor)
new_subexpressions = generic_visit(term.subexpressions, visitor)
return term.copy(new_main_assignments, new_subexpressions)
elif isinstance(term, list):
return [generic_visit(e, visitor) for e in term]
elif isinstance(term, Assignment):
return Assignment(term.lhs, generic_visit(term.rhs, visitor))
elif isinstance(term, sp.Matrix):
return term.applyfunc(lambda e: generic_visit(e, visitor))
else:
return visitor(term)
import sympy as sp
# noinspection PyPep8Naming
class flag_cond(sp.Function):
"""Evaluates a flag condition on a bit mask, and returns the value of one of two expressions,
depending on whether the flag is set.
Three argument version:
```
flag_cond(flag_bit, mask, expr) = expr if (flag_bit is set in mask) else 0
```
Four argument version:
```
flag_cond(flag_bit, mask, expr_then, expr_else) = expr_then if (flag_bit is set in mask) else expr_else
```
"""
nargs = (3, 4)
def __new__(cls, flag_bit, mask_expression, *expressions):
# TODO Jan reintroduce checking
# flag_dtype = get_type_of_expression(flag_bit)
# if not flag_dtype.is_int():
# raise ValueError('Argument flag_bit must be of integer type.')
#
# mask_dtype = get_type_of_expression(mask_expression)
# if not mask_dtype.is_int():
# raise ValueError('Argument mask_expression must be of integer type.')
return super().__new__(cls, flag_bit, mask_expression, *expressions)
def to_c(self, print_func):
flag_bit = self.args[0]
mask = self.args[1]
then_expression = self.args[2]
flag_bit_code = print_func(flag_bit)
mask_code = print_func(mask)
then_code = print_func(then_expression)
code = f"(({mask_code}) >> ({flag_bit_code}) & 1) * ({then_code})"
if len(self.args) > 3:
else_expression = self.args[3]
else_code = print_func(else_expression)
code += f" + (({mask_code}) >> ({flag_bit_code}) ^ 1) * ({else_code})"
return code
......@@ -2,9 +2,8 @@ from typing import List, Union
import sympy as sp
from pystencils.astnodes import Node
from pystencils.simp import AssignmentCollection
from pystencils.assignment import Assignment
from ..assignment import Assignment
from ..simp import AssignmentCollection
# noinspection PyPep8Naming
......@@ -44,8 +43,6 @@ def _run(term, visitor):
def insert_fast_sqrts(term: Union[sp.Expr, List[sp.Expr], AssignmentCollection, Assignment]):
def visit(expr):
if isinstance(expr, Node):
return expr
if expr.func == sp.Pow and isinstance(expr.exp, sp.Rational) and expr.exp.q == 2:
power = expr.exp.p
if power < 0:
......@@ -61,8 +58,6 @@ def insert_fast_sqrts(term: Union[sp.Expr, List[sp.Expr], AssignmentCollection,
def insert_fast_divisions(term: Union[sp.Expr, List[sp.Expr], AssignmentCollection, Assignment]):
def visit(expr):
if isinstance(expr, Node):
return expr
if expr.func == sp.Mul:
div_args = []
other_args = []
......
import sympy as sp
import warnings
from pystencils.sympyextensions import is_integer_sequence
class IntegerFunctionTwoArgsMixIn(sp.Function):
is_integer = True
def __new__(cls, arg1, arg2):
args = [arg1, arg2]
return super().__new__(cls, *args)
def _eval_evalf(self, *pargs, **kwargs):
arg1 = self.args[0].evalf(*pargs, **kwargs) if hasattr(self.args[0], 'evalf') else self.args[0]
arg2 = self.args[1].evalf(*pargs, **kwargs) if hasattr(self.args[1], 'evalf') else self.args[1]
return self._eval_op(arg1, arg2)
def _eval_op(self, arg1, arg2):
return self
# noinspection PyPep8Naming
class bitwise_xor(IntegerFunctionTwoArgsMixIn):
pass
# noinspection PyPep8Naming
class bit_shift_right(IntegerFunctionTwoArgsMixIn):
pass
# noinspection PyPep8Naming
class bit_shift_left(IntegerFunctionTwoArgsMixIn):
pass
# noinspection PyPep8Naming
class bitwise_and(IntegerFunctionTwoArgsMixIn):
pass
# noinspection PyPep8Naming
class bitwise_or(IntegerFunctionTwoArgsMixIn):
pass
# noinspection PyPep8Naming
class int_div(IntegerFunctionTwoArgsMixIn):
"""C-style round-to-zero integer division"""
def _eval_op(self, arg1, arg2):
from ..utils import c_intdiv
return c_intdiv(arg1, arg2)
class int_rem(IntegerFunctionTwoArgsMixIn):
"""C-style round-to-zero integer remainder"""
def _eval_op(self, arg1, arg2):
from ..utils import c_rem
return c_rem(arg1, arg2)
# noinspection PyPep8Naming
# TODO: What do the *two* arguments mean?
# Apparently, the second is required but ignored?
class int_power_of_2(IntegerFunctionTwoArgsMixIn):
pass
# noinspection PyPep8Naming
class round_to_multiple_towards_zero(IntegerFunctionTwoArgsMixIn):
"""Returns the next smaller/equal in magnitude integer divisible by given
divisor.
Examples:
>>> round_to_multiple_towards_zero(9, 4)
8
>>> round_to_multiple_towards_zero(11, -4)
8
>>> round_to_multiple_towards_zero(12, 4)
12
>>> round_to_multiple_towards_zero(-9, 4)
-8
>>> round_to_multiple_towards_zero(-9, -4)
-8
"""
@classmethod
def eval(cls, arg1, arg2):
from ..utils import c_intdiv
if is_integer_sequence((arg1, arg2)):
return c_intdiv(arg1, arg2) * arg2
def _eval_op(self, arg1, arg2):
return self.eval(arg1, arg2)
# noinspection PyPep8Naming
class ceil_to_multiple(IntegerFunctionTwoArgsMixIn):
"""For positive input, returns the next greater/equal integer divisible
by given divisor. The return value is unspecified if either argument is
negative.
Examples:
>>> ceil_to_multiple(9, 4)
12
>>> ceil_to_multiple(11, 4)
12
>>> ceil_to_multiple(12, 4)
12
"""
@classmethod
def eval(cls, arg1, arg2):
from ..utils import c_intdiv
if is_integer_sequence((arg1, arg2)):
return c_intdiv(arg1 + arg2 - 1, arg2) * arg2
def _eval_op(self, arg1, arg2):
return self.eval(arg1, arg2)
# noinspection PyPep8Naming
class div_ceil(IntegerFunctionTwoArgsMixIn):
"""For positive input, integer division that is always rounded up, i.e.
``div_ceil(a, b) = ceil(div(a, b))``. The return value is unspecified if
either argument is negative.
Examples:
>>> div_ceil(9, 4)
3
>>> div_ceil(8, 4)
2
"""
@classmethod
def eval(cls, arg1, arg2):
from ..utils import c_intdiv
if is_integer_sequence((arg1, arg2)):
return c_intdiv(arg1 + arg2 - 1, arg2)
def _eval_op(self, arg1, arg2):
return self.eval(arg1, arg2)
# Deprecated functions.
# noinspection PyPep8Naming
class modulo_floor:
def __new__(cls, integer, divisor):
warnings.warn(
"`modulo_floor` is deprecated. Use `round_to_multiple_towards_zero` instead.",
DeprecationWarning,
)
return round_to_multiple_towards_zero(integer, divisor)
# noinspection PyPep8Naming
class modulo_ceil(sp.Function):
def __new__(cls, integer, divisor):
warnings.warn(
"`modulo_ceil` is deprecated. Use `ceil_to_multiple` instead.",
DeprecationWarning,
)
return ceil_to_multiple(integer, divisor)
# noinspection PyPep8Naming
class div_floor(sp.Function):
def __new__(cls, integer, divisor):
warnings.warn(
"`div_floor` is deprecated. Use `int_div` instead.",
DeprecationWarning,
)
return int_div(integer, divisor)
......@@ -10,10 +10,9 @@ from sympy import PolynomialError
from sympy.functions import Abs
from sympy.core.numbers import Zero
from pystencils.assignment import Assignment
from pystencils.functions import DivFunc
from pystencils.typing import CastFunc, get_type_of_expression, PointerType, VectorType
from pystencils.typing.typed_sympy import FieldPointerSymbol
from ..assignment import Assignment
from .typed_sympy import TypeCast
from ..types import PsPointerType, PsVectorType
T = TypeVar('T')
......@@ -169,7 +168,7 @@ def fast_subs(expression: T, substitutions: Dict,
return substitutions[expr]
elif not hasattr(expr, 'args'):
return expr
elif isinstance(expr, (sp.UnevaluatedExpr, DivFunc)):
elif isinstance(expr, sp.UnevaluatedExpr):
args = [visit(a, False) for a in expr.args]
return expr.func(*args)
else:
......@@ -550,7 +549,7 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr], List[Assignment]],
Returns:
dict with 'adds', 'muls' and 'divs' keys
"""
from pystencils.fast_approximation import fast_sqrt, fast_inv_sqrt, fast_division
from pystencils.sympyextensions.fast_approximation import fast_sqrt, fast_inv_sqrt, fast_division
result = {'adds': 0, 'muls': 0, 'divs': 0, 'sqrts': 0,
'fast_sqrts': 0, 'fast_inv_sqrts': 0, 'fast_div': 0}
......@@ -566,16 +565,15 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr], List[Assignment]],
def check_type(e):
if only_type is None:
return True
if isinstance(e, FieldPointerSymbol) and only_type == "real":
return only_type == "int"
try:
base_type = get_type_of_expression(e)
# base_type = get_type_of_expression(e)
base_type = None # TODO nbackend: Fix count_operations without relying on data types
except ValueError:
return False
if isinstance(base_type, VectorType):
if isinstance(base_type, PsVectorType):
return False
if isinstance(base_type, PointerType):
if isinstance(base_type, PsPointerType):
return only_type == 'int'
if only_type == 'int' and (base_type.is_int() or base_type.is_uint()):
return True
......@@ -605,7 +603,7 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr], List[Assignment]],
visit_children = False
elif t.is_integer:
pass
elif isinstance(t, CastFunc):
elif isinstance(t, TypeCast):
visit_children = False
visit(t.args[0])
elif t.func is fast_sqrt:
......@@ -641,8 +639,6 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr], List[Assignment]],
visit_children = False
elif isinstance(t, (sp.Rel, sp.UnevaluatedExpr)):
pass
elif isinstance(t, DivFunc):
result["divs"] += 1
else:
warnings.warn(f"Unknown sympy node of type {str(t.func)} counting will be inaccurate")
......@@ -656,7 +652,7 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr], List[Assignment]],
def count_operations_in_ast(ast) -> Dict[str, int]:
"""Counts number of operations in an abstract syntax tree, see also :func:`count_operations`"""
from pystencils.astnodes import SympyAssignment
from pystencils.sympyextensions.astnodes import SympyAssignment
result = defaultdict(int)
def visit(node):
......@@ -678,15 +674,14 @@ def common_denominator(expr: sp.Expr) -> sp.Expr:
def get_symmetric_part(expr: sp.Expr, symbols: Iterable[sp.Symbol]) -> sp.Expr:
"""
Returns the symmetric part of a sympy expressions.
"""Returns the symmetric part of a sympy expressions.
This function returns the symmetric part of the given expression w.r.t. the
given degrees of freedom, computed as :math:`\\frac{1}{2} [ f(x_0, x_1, ..) + f(-x_0, -x_1) ]`.
Args:
expr: sympy expression, labeled here as :math:`f`
symbols: sequence of symbols which are considered as degrees of freedom, labeled here as :math:`x_0, x_1,...`
Returns:
:math:`\frac{1}{2} [ f(x_0, x_1, ..) + f(-x_0, -x_1) ]`
"""
substitution_dict = {e: -e for e in symbols}
return sp.Rational(1, 2) * (expr + expr.subs(substitution_dict))
......
import sympy as sp
from pystencils.typing import PointerType
class DivFunc(sp.Function):
"""
DivFunc represents a division operation, since sympy represents divisions with ^-1
"""
is_Atom = True
is_real = True
def __new__(cls, *args, **kwargs):
if len(args) != 2:
raise ValueError(f'{cls} takes only 2 arguments, instead {len(args)} received!')
divisor, dividend, *other_args = args
return sp.Function.__new__(cls, divisor, dividend, *other_args, **kwargs)
def _eval_evalf(self, *args, **kwargs):
return self.divisor.evalf() / self.dividend.evalf()
@property
def divisor(self):
return self.args[0]
@property
def dividend(self):
return self.args[1]
from ..types import PsPointerType, PsType
class AddressOf(sp.Function):
......@@ -51,7 +25,20 @@ class AddressOf(sp.Function):
@property
def dtype(self):
if hasattr(self.args[0], 'dtype'):
return PointerType(self.args[0].dtype, restrict=True)
arg_type = getattr(self.args[0], 'dtype', None)
if arg_type is not None:
assert isinstance(arg_type, PsType)
return PsPointerType(arg_type, restrict=True, const=True)
else:
raise ValueError(f'pystencils supports only non void pointers. Current address_of type: {self.args[0]}')
class mem_acc(sp.Function):
"""Memory access through a raw pointer with an offset.
This function should be used to model offset memory accesses through raw pointers.
"""
@classmethod
def eval(cls, ptr, offset):
return None
from __future__ import annotations
from typing import cast
import sympy as sp
from enum import Enum, auto
from ..types import (
PsType,
PsNumericType,
create_type,
UserTypeSpec
)
from sympy.logic.boolalg import Boolean
from warnings import warn
def is_loop_counter_symbol(symbol):
from ..defaults import DEFAULTS
try:
return DEFAULTS.spatial_counters.index(symbol)
except ValueError:
return None
class DynamicType(Enum):
"""Dynamic data type that will be resolved during kernel creation"""
NUMERIC_TYPE = auto()
"""Use the default numeric type set for the kernel"""
INDEX_TYPE = auto()
"""Use the default index type set for the kernel.
This is guaranteed to be an interger type.
"""
class TypeAtom(sp.Atom):
"""Wrapper around a type to disguise it as a SymPy atom."""
_dtype: PsType | DynamicType
def __new__(cls, dtype: PsType | DynamicType):
obj = super().__new__(cls)
obj._dtype = dtype
return obj
def _sympystr(self, *args, **kwargs):
return str(self._dtype)
def get(self) -> PsType | DynamicType:
return self._dtype
def _hashable_content(self):
return (self._dtype,)
def __getnewargs__(self):
return (self._dtype,)
def assumptions_from_dtype(dtype: PsType | DynamicType):
"""Derives SymPy assumptions from :class:`PsAbstractType`
Args:
dtype (PsAbstractType): a pystencils data type
Returns:
A dict of SymPy assumptions
"""
assumptions = dict()
match dtype:
case DynamicType.INDEX_TYPE:
assumptions.update({"integer": True, "real": True})
case DynamicType.NUMERIC_TYPE:
assumptions.update({"real": True})
case PsNumericType():
if dtype.is_int():
assumptions.update({"integer": True})
if dtype.is_uint():
assumptions.update({"negative": False})
if dtype.is_int() or dtype.is_float():
assumptions.update({"real": True})
return assumptions
class TypedSymbol(sp.Symbol):
_dtype: PsType | DynamicType
def __new__(cls, *args, **kwds):
obj = TypedSymbol.__xnew_cached_(cls, *args, **kwds)
return obj
def __new_stage2__(
cls, name: str, dtype: UserTypeSpec | DynamicType, **kwargs
): # TODO does not match signature of sp.Symbol???
# TODO: also Symbol should be allowed ---> see sympy Variable
if not isinstance(dtype, DynamicType):
dtype = create_type(dtype)
assumptions = assumptions_from_dtype(dtype)
assumptions.update(kwargs)
obj = super(TypedSymbol, cls).__xnew__(cls, name, **assumptions)
obj._dtype = dtype
return obj
__xnew__ = staticmethod(__new_stage2__)
__xnew_cached_ = staticmethod(sp.core.cacheit(__new_stage2__))
@property
def dtype(self) -> PsType | DynamicType:
# mypy: ignore
return self._dtype
def _hashable_content(self):
# mypy: ignore
return super()._hashable_content(), hash(self._dtype)
def __getnewargs__(self):
return self.name, self.dtype
def __getnewargs_ex__(self):
return (self.name, self.dtype), self.assumptions0
@property
def canonical(self):
return self
@property
def reversed(self):
return self
@property
def headers(self) -> set[str]:
return self.dtype.required_headers if isinstance(self.dtype, PsType) else set()
class TypeCast(sp.Function):
"""Explicitly cast an expression to a data type."""
@staticmethod
def as_numeric(expr):
return TypeCast(expr, DynamicType.NUMERIC_TYPE)
@staticmethod
def as_index(expr):
return TypeCast(expr, DynamicType.INDEX_TYPE)
@property
def expr(self) -> sp.Basic:
return self.args[0]
@property
def dtype(self) -> PsType | DynamicType:
return cast(TypeAtom, self._args[1]).get()
def __new__(cls, expr: sp.Basic, dtype: UserTypeSpec | DynamicType | TypeAtom):
tatom: TypeAtom
match dtype:
case TypeAtom():
tatom = dtype
case DynamicType():
tatom = TypeAtom(dtype)
case _:
tatom = TypeAtom(create_type(dtype))
return super().__new__(cls, expr, tatom)
@classmethod
def eval(cls, expr: sp.Basic, tatom: TypeAtom) -> TypeCast | None:
dtype = tatom.get()
if cls is not BoolCast and isinstance(dtype, PsNumericType) and dtype.is_bool():
return BoolCast(expr, tatom)
return None
def _eval_is_integer(self):
if self.dtype == DynamicType.INDEX_TYPE:
return True
if isinstance(self.dtype, PsNumericType) and self.dtype.is_int():
return True
def _eval_is_real(self):
if isinstance(self.dtype, DynamicType):
return True
if isinstance(self.dtype, PsNumericType) and (self.dtype.is_float() or self.dtype.is_int()):
return True
def _eval_is_nonnegative(self):
if isinstance(self.dtype, PsNumericType) and self.dtype.is_uint():
return True
class BoolCast(TypeCast, Boolean):
pass
tcast = TypeCast
class CastFunc(TypeCast):
def __new__(cls, *args, **kwargs):
warn(
"CastFunc is deprecated and will be removed in pystencils 2.1. "
"Use `pystencils.tcast` instead.",
FutureWarning
)
return TypeCast.__new__(cls, *args, **kwargs)
import time
from pystencils.integer_functions import modulo_ceil
from pystencils.sympyextensions.integer_functions import modulo_ceil
class TimeLoop:
......
"""
The `pystencils.types` module contains the set of classes used by pystencils
to model data types. Data types are used extensively within the code generator,
but can safely be ignored by most users unless you wish to force certain types on
symbols, generate mixed-precision kernels, et cetera.
"""
from .meta import PsType, constify, deconstify
from .types import (
PsCustomType,
PsStructType,
PsNumericType,
PsScalarType,
PsVectorType,
PsDereferencableType,
PsPointerType,
PsArrayType,
PsBoolType,
PsIntegerType,
PsUnsignedIntegerType,
PsSignedIntegerType,
PsIeeeFloatType,
)
from .parsing import UserTypeSpec, create_type, create_numeric_type
from .exception import PsTypeError
__all__ = [
"PsType",
"PsCustomType",
"PsStructType",
"PsDereferencableType",
"PsPointerType",
"PsArrayType",
"PsNumericType",
"PsScalarType",
"PsVectorType",
"PsIntegerType",
"PsBoolType",
"PsUnsignedIntegerType",
"PsSignedIntegerType",
"PsIeeeFloatType",
"constify",
"deconstify",
"UserTypeSpec",
"create_type",
"create_numeric_type",
"PsTypeError",
]