diff --git a/src/pystencils_autodiff/framework_integration/astnodes.py b/src/pystencils_autodiff/framework_integration/astnodes.py index 5866bb897f775f205ae984338ed201d7c6f20a3a..24872a82d2f9240990308ae484b0f3b633e54f95 100644 --- a/src/pystencils_autodiff/framework_integration/astnodes.py +++ b/src/pystencils_autodiff/framework_integration/astnodes.py @@ -43,9 +43,10 @@ class DestructuringBindingsForFieldClass(Node): """Set of Field instances: fields which are accessed inside this kernel function""" from pystencils.interpolation_astnodes import InterpolatorAccess - return set(o.field for o in self.atoms(ResolvedFieldAccess) | self.atoms(InterpolatorAccess)) \ - | set(itertools.chain.from_iterable((k.kernel_function.fields_accessed - for k in self.atoms(FunctionCall)))) + return (set(o.field for o in self.atoms(ResolvedFieldAccess) + | self.atoms(InterpolatorAccess)) + | set(itertools.chain.from_iterable((k.kernel_function.fields_accessed + for k in self.atoms(FunctionCall))))) def __init__(self, body): super(DestructuringBindingsForFieldClass, self).__init__() @@ -132,15 +133,12 @@ class WrapperFunction(pystencils.astnodes.KernelFunction): def generate_kernel_call(kernel_function): - try: - from pystencils.interpolation_astnodes import TextureAccess - from pystencils.kernelparameters import FieldPointerSymbol + from pystencils.interpolation_astnodes import InterpolatorAccess + from pystencils.kernelparameters import FieldPointerSymbol - textures = {a.texture for a in kernel_function.atoms(TextureAccess)} - texture_uploads = [NativeTextureBinding(t, FieldPointerSymbol(t.field.name, t.field.dtype, const=True)) - for t in textures] - except ImportError: - texture_uploads = [] + textures = {a.interpolator for a in kernel_function.atoms(InterpolatorAccess) if a.is_texture} + texture_uploads = [NativeTextureBinding(t, FieldPointerSymbol(t.field.name, t.field.dtype, const=True)) + for t in textures] # debug_print = CustomCodeNode( # 'std::cout << "hallo" << __PRETTY_FUNCTION__ << std::endl;\ngpuErrchk(cudaPeekAtLastError());' \ @@ -221,8 +219,9 @@ class JinjaCppFile(Node): # TODO: possibly costly tree traversal render_dict.update({"headers": pystencils.backends.cbackend.get_headers(self)}) - render_dict.update({"globals": - [self.printer(g) for g in pystencils.backends.cbackend.get_global_declarations(self)]}) + render_dict.update({"globals": sorted({ + self.printer(g) for g in pystencils.backends.cbackend.get_global_declarations(self) + }, key=str)}) return self.TEMPLATE.render(render_dict)