Skip to content
Snippets Groups Projects
Commit 92231944 authored by Martin Bauer's avatar Martin Bauer
Browse files

Merge branch 'bugfix-fields_accessed-for-InterpolatorAccess' into 'master'

Bugfix fields accessed for interpolator access

See merge request pycodegen/pystencils!62
parents 89f158d0 32d636c4
Branches
Tags
1 merge request!62Bugfix fields accessed for interpolator access
Pipeline #18570 passed with warnings
...@@ -211,17 +211,18 @@ class KernelFunction(Node): ...@@ -211,17 +211,18 @@ class KernelFunction(Node):
return self._body, return self._body,
@property @property
def fields_accessed(self) -> Set['ResolvedFieldAccess']: def fields_accessed(self) -> Set[Field]:
"""Set of Field instances: fields which are accessed inside this kernel function""" """Set of Field instances: fields which are accessed inside this kernel function"""
return set(o.field for o in self.atoms(ResolvedFieldAccess)) from pystencils.interpolation_astnodes import InterpolatorAccess
return set(o.field for o in itertools.chain(self.atoms(ResolvedFieldAccess), self.atoms(InterpolatorAccess)))
@property @property
def fields_written(self) -> Set['ResolvedFieldAccess']: def fields_written(self) -> Set[Field]:
assignments = self.atoms(SympyAssignment) assignments = self.atoms(SympyAssignment)
return {a.lhs.field for a in assignments if isinstance(a.lhs, ResolvedFieldAccess)} return {a.lhs.field for a in assignments if isinstance(a.lhs, ResolvedFieldAccess)}
@property @property
def fields_read(self) -> Set['ResolvedFieldAccess']: def fields_read(self) -> Set[Field]:
assignments = self.atoms(SympyAssignment) assignments = self.atoms(SympyAssignment)
return set().union(itertools.chain.from_iterable([f.field for f in a.rhs.free_symbols if hasattr(f, 'field')] return set().union(itertools.chain.from_iterable([f.field for f in a.rhs.free_symbols if hasattr(f, 'field')]
for a in assignments)) for a in assignments))
......
...@@ -7,9 +7,11 @@ ...@@ -7,9 +7,11 @@
""" """
""" """
import itertools
from os.path import dirname, join from os.path import dirname, join
import numpy as np import numpy as np
import pytest
import sympy import sympy
import pycuda.autoinit # NOQA import pycuda.autoinit # NOQA
...@@ -215,19 +217,20 @@ def test_rotate_interpolation_size_change(): ...@@ -215,19 +217,20 @@ def test_rotate_interpolation_size_change():
pyconrad.imshow(out, "small out " + address_mode) pyconrad.imshow(out, "small out " + address_mode)
def test_field_interpolated(): @pytest.mark.parametrize('address_mode, target',
itertools.product(['border', 'wrap', 'clamp', 'mirror'], ['cpu', 'gpu']))
def test_field_interpolated(address_mode, target):
x_f, y_f = pystencils.fields('x,y: float64 [2d]') x_f, y_f = pystencils.fields('x,y: float64 [2d]')
for address_mode in ['border', 'wrap', 'clamp', 'mirror']: assignments = pystencils.AssignmentCollection({
assignments = pystencils.AssignmentCollection({ y_f.center(): x_f.interpolated_access([0.5 * x_ + 2.7, 0.25 * y_ + 7.2], address_mode=address_mode)
y_f.center(): x_f.interpolated_access([0.5 * x_ + 2.7, 0.25 * y_ + 7.2], address_mode=address_mode) })
}) print(assignments)
print(assignments) ast = pystencils.create_kernel(assignments)
ast = pystencils.create_kernel(assignments) print(ast)
print(ast) print(pystencils.show_code(ast))
print(pystencils.show_code(ast)) kernel = ast.compile()
kernel = ast.compile()
out = np.zeros_like(lenna) out = np.zeros_like(lenna)
kernel(x=lenna, y=out) kernel(x=lenna, y=out)
pyconrad.imshow(out, "out " + address_mode) pyconrad.imshow(out, "out " + address_mode)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment