Skip to content
Snippets Groups Projects
Commit b20974fe authored by Stephan Seitz's avatar Stephan Seitz
Browse files

interpolation: Stuff is working let's commit quickly

parent fa782b82
No related merge requests found
...@@ -15,6 +15,7 @@ import numpy as np ...@@ -15,6 +15,7 @@ import numpy as np
try: try:
import pycuda.driver as cuda import pycuda.driver as cuda
from pycuda import gpuarray from pycuda import gpuarray
import pycuda
except Exception: except Exception:
pass pass
...@@ -35,6 +36,8 @@ def ndarray_to_tex(tex_ref, ...@@ -35,6 +36,8 @@ def ndarray_to_tex(tex_ref,
use_normalized_coordinates=False, use_normalized_coordinates=False,
read_as_integer=False): read_as_integer=False):
if isinstance(address_mode, str):
address_mode = getattr(pycuda.driver.address_mode, address_mode.upper())
if address_mode is None: if address_mode is None:
address_mode = cuda.address_mode.BORDER address_mode = cuda.address_mode.BORDER
if filter_mode is None: if filter_mode is None:
......
...@@ -116,7 +116,6 @@ class Interpolator(object): ...@@ -116,7 +116,6 @@ class Interpolator(object):
def _hashable_contents(self): def _hashable_contents(self):
return (str(self.address_mode), return (str(self.address_mode),
str(type(self)), str(type(self)),
self.address_mode,
self.hash_str, self.hash_str,
self.use_normalized_coordinates) self.use_normalized_coordinates)
...@@ -416,11 +415,9 @@ class TextureCachedField(Interpolator): ...@@ -416,11 +415,9 @@ class TextureCachedField(Interpolator):
read_as_integer=False read_as_integer=False
): ):
super().__init__(parent_field, interpolation_mode, address_mode, use_normalized_coordinates) super().__init__(parent_field, interpolation_mode, address_mode, use_normalized_coordinates)
if isinstance(address_mode, str):
address_mode = getattr(pycuda.driver.address_mode, address_mode.upper())
if address_mode is None: if address_mode is None:
address_mode = pycuda.driver.address_mode.BORDER address_mode = 'border'
if filter_mode is None: if filter_mode is None:
filter_mode = pycuda.driver.filter_mode.LINEAR filter_mode = pycuda.driver.filter_mode.LINEAR
......
...@@ -1334,19 +1334,19 @@ def implement_interpolations(ast_node: ast.Node, ...@@ -1334,19 +1334,19 @@ def implement_interpolations(ast_node: ast.Node,
if implement_by_texture_accesses: if implement_by_texture_accesses:
for i in interpolation_accesses: for i in interpolation_accesses:
old_i = i from pystencils.interpolation_astnodes import _InterpolationSymbol
try: try:
import pycuda.driver as cuda import pycuda.driver as cuda
texture = TextureCachedField.from_interpolator(i.interpolator) texture = TextureCachedField.from_interpolator(i.interpolator)
i.symbol.interpolator = texture
if can_use_hw_interpolation(i): if can_use_hw_interpolation(i):
i.symbol.interpolator.filter_mode = cuda.filter_mode.LINEAR texture.filter_mode = cuda.filter_mode.LINEAR
else: else:
i.symbol.interpolator.filter_mode = cuda.filter_mode.POINT texture.filter_mode = cuda.filter_mode.POINT
i.symbol.interpolator.read_as_integer = True texture.read_as_integer = True
except Exception as e: except Exception as e:
raise e raise e
ast_node.subs({old_i: i}) i.symbol = _InterpolationSymbol(str(texture), i.symbol.field, texture)
# from pystencils.math_optimizations import ReplaceOptim, optimize_ast # from pystencils.math_optimizations import ReplaceOptim, optimize_ast
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment