diff --git a/src/pystencils_reco/_crazy_decorator.py b/src/pystencils_reco/_crazy_decorator.py index 93af0289726e924cf8ef5961314e8e1cd3eb1523..f68bf385ffee566615e9bdf36a0b1b860cc337ad 100644 --- a/src/pystencils_reco/_crazy_decorator.py +++ b/src/pystencils_reco/_crazy_decorator.py @@ -26,7 +26,6 @@ def crazy(function) -> pystencils_reco.AssignmentCollection: # @disk_cache_no_fallback @functools.wraps(function) def wrapper(*args, **kwargs): - import pycuda.gpuarray # TODO(seitz): remove dependency inspection = inspect.getfullargspec(function) arg_names = inspection.args annotations = inspection.annotations @@ -35,7 +34,7 @@ def crazy(function) -> pystencils_reco.AssignmentCollection: if is_array_like(a) else a for i, a in enumerate(args)} compile_kwargs = {k: create_field_from_array_like(str(k), a, annotations.get(k, None)) - if (hasattr(a, '__array__') or isinstance(a, pycuda.gpuarray.GPUArray)) and + if (hasattr(a, '__array__') or 'GPUArray' in str(a.__class__)) and not isinstance(a, sympy.Matrix) # noqa else a for (k, a) in kwargs.items()}