Skip to content
Snippets Groups Projects
Commit 0da250c0 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Regard case that pystencils.astnodes.Conditional is in assignments or...

Regard case that pystencils.astnodes.Conditional is in assignments or sympy.codegen.rewriting does not exist
parent 605bf8bd
Branches
No related tags found
No related merge requests found
......@@ -509,6 +509,13 @@ class SympyAssignment(Node):
self.lhs = fast_subs(self.lhs, subs_dict)
self.rhs = fast_subs(self.rhs, subs_dict)
def optimize(self, optimizations):
try:
from sympy.codegen.rewriting import optimize
self.rhs = optimize(self.rhs, optimizations)
except Exception:
pass
@property
def args(self):
return [self._lhs_symbol, self.rhs]
......
......@@ -8,7 +8,7 @@ from pystencils.astnodes import Block, KernelFunction, LoopOverCoordinate, Sympy
from pystencils.cpu.cpujit import make_python_function
from pystencils.data_types import BasicType, StructType, TypedSymbol, create_type
from pystencils.field import Field, FieldType
from pystencils.optimizations import optims_pystencils_cpu
from pystencils.optimizations import optims_pystencils_cpu, optimize_assignments
from pystencils.transformations import (
add_types, filtered_tree_iteration, get_base_buffer_index, get_optimal_loop_ordering,
make_loop_over_domain, move_constants_before_loop, parse_base_pointer_info,
......@@ -57,7 +57,7 @@ def create_kernel(assignments: AssignmentOrAstNodeList, function_name: str = "ke
if optimizations is None:
optimizations = optims_pystencils_cpu
assignments = [Assignment(a.lhs, sp.codegen.rewriting.optimize(a.rhs, optimizations)) for a in assignments]
assignments = optimize_assignments(assignments, optimizations)
fields_read, fields_written, assignments = add_types(assignments, type_info, not skip_independence_check)
all_fields = fields_read.union(fields_written)
......@@ -119,7 +119,7 @@ def create_indexed_kernel(assignments: AssignmentOrAstNodeList,
if optimizations is None:
optimizations = optims_pystencils_cpu
assignments = [Assignment(a.lhs, sp.codegen.rewriting.optimize(a.rhs, optimizations)) for a in assignments]
assignments = optimize_assignments(assignments, optimizations)
fields_read, fields_written, assignments = add_types(assignments, type_info, check_independence_condition=False)
all_fields = fields_read.union(fields_written)
......
import sympy as sp
from pystencils import Assignment
from pystencils.astnodes import Block, KernelFunction, LoopOverCoordinate, SympyAssignment
from pystencils.data_types import BasicType, StructType, TypedSymbol
from pystencils.field import Field, FieldType
from pystencils.gpucuda.cudajit import make_python_function
from pystencils.gpucuda.indexing import BlockIndexing
from pystencils.optimizations import optims_pystencils_gpu
from pystencils.optimizations import optims_pystencils_gpu, optimize_assignments
from pystencils.transformations import (
add_types, get_base_buffer_index, get_common_shape, parse_base_pointer_info,
resolve_buffer_accesses, resolve_field_accesses, unify_shape_symbols)
......@@ -17,7 +14,7 @@ def create_cuda_kernel(assignments, function_name="kernel", type_info=None, inde
if optimizations is None:
optimizations = optims_pystencils_gpu
assignments = [Assignment(a.lhs, sp.codegen.rewriting.optimize(a.rhs, optimizations)) for a in assignments]
assignments = optimize_assignments(assignments, optimizations)
fields_read, fields_written, assignments = add_types(assignments, type_info, not skip_independence_check)
all_fields = fields_read.union(fields_written)
......@@ -99,7 +96,7 @@ def created_indexed_cuda_kernel(assignments, index_fields, function_name="kernel
coordinate_names=('x', 'y', 'z'), indexing_creator=BlockIndexing, optimizations=None):
if optimizations is None:
optimizations = optims_pystencils_gpu
assignments = [Assignment(a.lhs, sp.codegen.rewriting.optimize(a.rhs, optimizations)) for a in assignments]
assignments = optimize_assignments(assignments, optimizations)
fields_read, fields_written, assignments = add_types(assignments, type_info, check_independence_condition=False)
all_fields = fields_read.union(fields_written)
......
......@@ -5,6 +5,11 @@ See :func:`sympy.codegen.rewriting.optimize`.
"""
import itertools
from pystencils import Assignment
from pystencils.astnodes import SympyAssignment
try:
from sympy.codegen.rewriting import optims_c99, optimize
from sympy.codegen.rewriting import ReplaceOptim
......@@ -20,14 +25,18 @@ try:
optims_pystencils_gpu = [evaluate_constant_terms] + list(optims_c99)
except ImportError:
optims_c99 = []
optims_pystencils_cpu = []
optims_pystencils_gpu = []
HAS_REWRITING = False
# Evaluates all constant terms
evaluate_constant_terms = ReplaceOptim(
lambda e: e.is_constant,
lambda p: p.evalf()
)
def optimize_assignments(assignments, optimizations):
if HAS_REWRITING:
assignments = [Assignment(a.lhs, optimize(a.rhs, optimizations))
if hasattr(a, 'lhs') else a for a in assignments]
assignments_nodes = [a.atoms(SympyAssignment) for a in assignments]
for a in itertools.chain.from_iterable(assignments_nodes):
a.optimize(optimizations)
optims_pystencils_cpu = [evaluate_constant_terms] + list(optims_c99)
optims_pystencils_gpu = [evaluate_constant_terms] + list(optims_c99)
return assignments
import pytest
import sympy as sp
import pystencils
from pystencils.optimizations import HAS_REWRITING
@pytest.mark.skipif(not HAS_REWRITING, reason="need sympy.codegen.rewriting")
def test_sympy_optimizations():
for target in ('cpu', 'gpu'):
x, y, z = pystencils.fields('x, y, z: float32[2d]')
......@@ -17,6 +20,7 @@ def test_sympy_optimizations():
assert 'expm1(' in code
@pytest.mark.skipif(not HAS_REWRITING, reason="need sympy.codegen.rewriting")
def test_evaluate_constant_terms():
for target in ('cpu', 'gpu'):
x, y, z = pystencils.fields('x, y, z: float32[2d]')
......@@ -32,6 +36,7 @@ def test_evaluate_constant_terms():
print(code)
@pytest.mark.skipif(not HAS_REWRITING, reason="need sympy.codegen.rewriting")
def test_do_not_evaluate_constant_terms():
optimizations = pystencils.optimizations.optims_pystencils_cpu
optimizations.remove(pystencils.optimizations.evaluate_constant_terms)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment