Skip to content
Snippets Groups Projects
Commit 943505f7 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Fix texture uploading

parent 4d0b1d3d
Branches
Tags
No related merge requests found
...@@ -43,9 +43,10 @@ class DestructuringBindingsForFieldClass(Node): ...@@ -43,9 +43,10 @@ class DestructuringBindingsForFieldClass(Node):
"""Set of Field instances: fields which are accessed inside this kernel function""" """Set of Field instances: fields which are accessed inside this kernel function"""
from pystencils.interpolation_astnodes import InterpolatorAccess from pystencils.interpolation_astnodes import InterpolatorAccess
return set(o.field for o in self.atoms(ResolvedFieldAccess) | self.atoms(InterpolatorAccess)) \ return (set(o.field for o in self.atoms(ResolvedFieldAccess)
| set(itertools.chain.from_iterable((k.kernel_function.fields_accessed | self.atoms(InterpolatorAccess))
for k in self.atoms(FunctionCall)))) | set(itertools.chain.from_iterable((k.kernel_function.fields_accessed
for k in self.atoms(FunctionCall)))))
def __init__(self, body): def __init__(self, body):
super(DestructuringBindingsForFieldClass, self).__init__() super(DestructuringBindingsForFieldClass, self).__init__()
...@@ -132,15 +133,12 @@ class WrapperFunction(pystencils.astnodes.KernelFunction): ...@@ -132,15 +133,12 @@ class WrapperFunction(pystencils.astnodes.KernelFunction):
def generate_kernel_call(kernel_function): def generate_kernel_call(kernel_function):
try: from pystencils.interpolation_astnodes import InterpolatorAccess
from pystencils.interpolation_astnodes import TextureAccess from pystencils.kernelparameters import FieldPointerSymbol
from pystencils.kernelparameters import FieldPointerSymbol
textures = {a.texture for a in kernel_function.atoms(TextureAccess)} textures = {a.interpolator for a in kernel_function.atoms(InterpolatorAccess) if a.is_texture}
texture_uploads = [NativeTextureBinding(t, FieldPointerSymbol(t.field.name, t.field.dtype, const=True)) texture_uploads = [NativeTextureBinding(t, FieldPointerSymbol(t.field.name, t.field.dtype, const=True))
for t in textures] for t in textures]
except ImportError:
texture_uploads = []
# debug_print = CustomCodeNode( # debug_print = CustomCodeNode(
# 'std::cout << "hallo" << __PRETTY_FUNCTION__ << std::endl;\ngpuErrchk(cudaPeekAtLastError());' \ # 'std::cout << "hallo" << __PRETTY_FUNCTION__ << std::endl;\ngpuErrchk(cudaPeekAtLastError());' \
...@@ -221,8 +219,9 @@ class JinjaCppFile(Node): ...@@ -221,8 +219,9 @@ class JinjaCppFile(Node):
# TODO: possibly costly tree traversal # TODO: possibly costly tree traversal
render_dict.update({"headers": pystencils.backends.cbackend.get_headers(self)}) render_dict.update({"headers": pystencils.backends.cbackend.get_headers(self)})
render_dict.update({"globals": render_dict.update({"globals": sorted({
[self.printer(g) for g in pystencils.backends.cbackend.get_global_declarations(self)]}) self.printer(g) for g in pystencils.backends.cbackend.get_global_declarations(self)
}, key=str)})
return self.TEMPLATE.render(render_dict) return self.TEMPLATE.render(render_dict)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment