Skip to content
Snippets Groups Projects
Commit b0aae533 authored by Martin Bauer's avatar Martin Bauer
Browse files

Boundary Handling fix to work with ParallelDatahandling again

parent f4abda02
Branches
Tags
No related merge requests found
...@@ -8,6 +8,8 @@ from pystencils.boundaries.createindexlist import ( ...@@ -8,6 +8,8 @@ from pystencils.boundaries.createindexlist import (
create_boundary_index_array, numpy_data_type_for_boundary_object) create_boundary_index_array, numpy_data_type_for_boundary_object)
from pystencils.cache import memorycache from pystencils.cache import memorycache
from pystencils.data_types import TypedSymbol, create_type from pystencils.data_types import TypedSymbol, create_type
from pystencils.datahandling import ParallelDataHandling
from pystencils.datahandling.pycuda import PyCudaArrayHandler
from pystencils.field import Field from pystencils.field import Field
from pystencils.kernelparameters import FieldPointerSymbol from pystencils.kernelparameters import FieldPointerSymbol
...@@ -96,11 +98,17 @@ class BoundaryHandling: ...@@ -96,11 +98,17 @@ class BoundaryHandling:
def to_gpu(gpu_version, cpu_version): def to_gpu(gpu_version, cpu_version):
gpu_version = gpu_version.boundary_object_to_index_list gpu_version = gpu_version.boundary_object_to_index_list
cpu_version = cpu_version.boundary_object_to_index_list cpu_version = cpu_version.boundary_object_to_index_list
if isinstance(self.data_handling, ParallelDataHandling):
array_handler = PyCudaArrayHandler()
else:
array_handler = self.data_handling.array_handler
for obj, cpu_arr in cpu_version.items(): for obj, cpu_arr in cpu_version.items():
if obj not in gpu_version or gpu_version[obj].shape != cpu_arr.shape: if obj not in gpu_version or gpu_version[obj].shape != cpu_arr.shape:
gpu_version[obj] = self.data_handling.array_handler.to_gpu(cpu_arr) gpu_version[obj] = array_handler.to_gpu(cpu_arr)
else: else:
self.data_handling.array_handler.upload(gpu_version[obj], cpu_arr) array_handler.upload(gpu_version[obj], cpu_arr)
class_ = self.IndexFieldBlockData class_ = self.IndexFieldBlockData
class_.to_cpu = to_cpu class_.to_cpu = to_cpu
...@@ -332,7 +340,7 @@ class BoundaryHandling: ...@@ -332,7 +340,7 @@ class BoundaryHandling:
self.kernel = kernel self.kernel = kernel
class IndexFieldBlockData: class IndexFieldBlockData:
def __init__(self): def __init__(self, *args, **kwargs):
self.boundary_object_to_index_list = {} self.boundary_object_to_index_list = {}
self.boundary_object_to_data_setter = {} self.boundary_object_to_data_setter = {}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment