Skip to content
Snippets Groups Projects
Commit 943505f7 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Fix texture uploading

parent 4d0b1d3d
No related branches found
No related tags found
No related merge requests found
...@@ -43,9 +43,10 @@ class DestructuringBindingsForFieldClass(Node): ...@@ -43,9 +43,10 @@ class DestructuringBindingsForFieldClass(Node):
"""Set of Field instances: fields which are accessed inside this kernel function""" """Set of Field instances: fields which are accessed inside this kernel function"""
from pystencils.interpolation_astnodes import InterpolatorAccess from pystencils.interpolation_astnodes import InterpolatorAccess
return set(o.field for o in self.atoms(ResolvedFieldAccess) | self.atoms(InterpolatorAccess)) \ return (set(o.field for o in self.atoms(ResolvedFieldAccess)
| set(itertools.chain.from_iterable((k.kernel_function.fields_accessed | self.atoms(InterpolatorAccess))
for k in self.atoms(FunctionCall)))) | set(itertools.chain.from_iterable((k.kernel_function.fields_accessed
for k in self.atoms(FunctionCall)))))
def __init__(self, body): def __init__(self, body):
super(DestructuringBindingsForFieldClass, self).__init__() super(DestructuringBindingsForFieldClass, self).__init__()
...@@ -132,15 +133,12 @@ class WrapperFunction(pystencils.astnodes.KernelFunction): ...@@ -132,15 +133,12 @@ class WrapperFunction(pystencils.astnodes.KernelFunction):
def generate_kernel_call(kernel_function): def generate_kernel_call(kernel_function):
try: from pystencils.interpolation_astnodes import InterpolatorAccess
from pystencils.interpolation_astnodes import TextureAccess from pystencils.kernelparameters import FieldPointerSymbol
from pystencils.kernelparameters import FieldPointerSymbol
textures = {a.texture for a in kernel_function.atoms(TextureAccess)} textures = {a.interpolator for a in kernel_function.atoms(InterpolatorAccess) if a.is_texture}
texture_uploads = [NativeTextureBinding(t, FieldPointerSymbol(t.field.name, t.field.dtype, const=True)) texture_uploads = [NativeTextureBinding(t, FieldPointerSymbol(t.field.name, t.field.dtype, const=True))
for t in textures] for t in textures]
except ImportError:
texture_uploads = []
# debug_print = CustomCodeNode( # debug_print = CustomCodeNode(
# 'std::cout << "hallo" << __PRETTY_FUNCTION__ << std::endl;\ngpuErrchk(cudaPeekAtLastError());' \ # 'std::cout << "hallo" << __PRETTY_FUNCTION__ << std::endl;\ngpuErrchk(cudaPeekAtLastError());' \
...@@ -221,8 +219,9 @@ class JinjaCppFile(Node): ...@@ -221,8 +219,9 @@ class JinjaCppFile(Node):
# TODO: possibly costly tree traversal # TODO: possibly costly tree traversal
render_dict.update({"headers": pystencils.backends.cbackend.get_headers(self)}) render_dict.update({"headers": pystencils.backends.cbackend.get_headers(self)})
render_dict.update({"globals": render_dict.update({"globals": sorted({
[self.printer(g) for g in pystencils.backends.cbackend.get_global_declarations(self)]}) self.printer(g) for g in pystencils.backends.cbackend.get_global_declarations(self)
}, key=str)})
return self.TEMPLATE.render(render_dict) return self.TEMPLATE.render(render_dict)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment