diff --git a/src/pystencilssfg/lang/cpp/std_mdspan.py b/src/pystencilssfg/lang/cpp/std_mdspan.py index 1c6d96df93c151872b27e616c32032121869c6a7..a05a2980096669c2e96bb6715114b99a1bdc3b23 100644 --- a/src/pystencilssfg/lang/cpp/std_mdspan.py +++ b/src/pystencilssfg/lang/cpp/std_mdspan.py @@ -29,6 +29,7 @@ class StdMdspan(SrcField): extents: tuple[int | str, ...], extents_type: PsType = PsUnsignedIntegerType(64), ref: bool = False, + const: bool = False, experimental: bool = True, ): T = create_type(T) @@ -39,9 +40,9 @@ class StdMdspan(SrcField): ) if experimental: - dtype = self._template_experimental(T=T, extents=extents_str) + dtype = self._template_experimental(T=T, extents=extents_str, const=const) else: - dtype = self._template(T=T, extents=extents_str) + dtype = self._template(T=T, extents=extents_str, const=const) if ref: dtype = Ref(dtype) @@ -73,7 +74,7 @@ class StdMdspan(SrcField): @staticmethod def from_field( - field: Field, extents_type: PsType = PsUnsignedIntegerType(64), ref: bool = False + field: Field, extents_type: PsType = PsUnsignedIntegerType(64), ref: bool = False, const: bool = False, ): """Creates a `std::mdspan` instance for a given pystencils field.""" from pystencils.field import layout_string_to_tuple @@ -98,6 +99,7 @@ class StdMdspan(SrcField): tuple(extents), extents_type=extents_type, ref=ref, + const=const ).var(field.name) diff --git a/src/pystencilssfg/lang/cpp/std_span.py b/src/pystencilssfg/lang/cpp/std_span.py index f16803762acc12a6737439976a6c644ff89c79d4..861a4c4bb1ea81b0cbaaef4cb683316274ab2edd 100644 --- a/src/pystencilssfg/lang/cpp/std_span.py +++ b/src/pystencilssfg/lang/cpp/std_span.py @@ -42,12 +42,12 @@ class StdSpan(SrcField): return Extraction() @staticmethod - def from_field(field: Field, ref: bool = False): + def from_field(field: Field, ref: bool = False, const: bool = False): if field.spatial_dimensions > 1 or field.index_shape not in ((), (1,)): raise ValueError( "Only one-dimensional fields with trivial index dimensions can be mapped onto `std::span`" ) - return StdSpan(field.dtype, ref, False).var(field.name) + return StdSpan(field.dtype, ref=ref, const=const).var(field.name) def std_span_ref(field: Field): diff --git a/src/pystencilssfg/lang/cpp/std_vector.py b/src/pystencilssfg/lang/cpp/std_vector.py index d3994f82305570e76c80bf49ee3efd2c30568150..5696b32dd55c6092db0ff298fdcbd79fa7df69f5 100644 --- a/src/pystencilssfg/lang/cpp/std_vector.py +++ b/src/pystencilssfg/lang/cpp/std_vector.py @@ -55,13 +55,13 @@ class StdVector(SrcVector, SrcField): return AugExpr.format("{}.at({})", self, coordinate) @staticmethod - def from_field(field: Field, ref: bool = True): + def from_field(field: Field, ref: bool = True, const: bool = False): if field.spatial_dimensions > 1 or field.index_shape not in ((), (1,)): raise ValueError( f"Cannot create std::vector from more-than-one-dimensional field {field}." ) - return StdVector(field.dtype, unsafe=False, ref=ref).var(field.name) + return StdVector(field.dtype, unsafe=False, ref=ref, const=const).var(field.name) def std_vector_ref(field: Field):