Skip to content
Snippets Groups Projects
Commit cd599101 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

first sketch of translation context and iteration domain

parent 4d13952d
No related branches found
No related tags found
No related merge requests found
Pipeline #60341 failed
Showing with 267 additions and 105 deletions
...@@ -31,8 +31,8 @@ all occurences of the shape and stride variables with their constant value: ...@@ -31,8 +31,8 @@ all occurences of the shape and stride variables with their constant value:
``` ```
constraints = ( constraints = (
[PsParamConstraint(s.eq(f)) for s, f in zip(arr.shape, fixed_size)] [PsKernelConstraint(s.eq(f)) for s, f in zip(arr.shape, fixed_size)]
+ [PsParamConstraint(s.eq(f)) for s, f in zip(arr.strides, fixed_strides)] + [PsKernelConstraint(s.eq(f)) for s, f in zip(arr.strides, fixed_strides)]
) )
kernel_function.add_constraints(*constraints) kernel_function.add_constraints(*constraints)
...@@ -43,6 +43,8 @@ kernel_function.add_constraints(*constraints) ...@@ -43,6 +43,8 @@ kernel_function.add_constraints(*constraints)
from __future__ import annotations from __future__ import annotations
from types import EllipsisType
from abc import ABC from abc import ABC
import pymbolic.primitives as pb import pymbolic.primitives as pb
...@@ -56,78 +58,94 @@ from .types import ( ...@@ -56,78 +58,94 @@ from .types import (
constify, constify,
) )
from .typed_expressions import PsTypedVariable, ExprOrConstant from .typed_expressions import PsTypedVariable, ExprOrConstant, PsTypedConstant
class PsLinearizedArray: class PsLinearizedArray:
"""N-dimensional contiguous array""" """Class to model N-dimensional contiguous arrays.
Memory Layout, Shape and Strides
--------------------------------
The memory layout of an array is defined by its shape and strides.
Both shape and stride entries may either be constants or special variables associated with
exactly one array.
Shape and strides may be specified at construction in the following way.
For constant entries, their value must be given as an integer.
For variable shape entries and strides, the Ellipsis `...` must be passed instead.
Internally, the passed `index_dtype` will be used to create typed constants (`PsTypedConstant`)
and variables (`PsArrayShapeVar` and `PsArrayStrideVar`) from the passed values.
"""
def __init__( def __init__(
self, self,
name: str, name: str,
element_type: PsScalarType, element_type: PsAbstractType,
dim: int, shape: tuple[int | EllipsisType, ...],
strides: tuple[int | EllipsisType, ...],
index_dtype: PsIntegerType = PsSignedIntegerType(64), index_dtype: PsIntegerType = PsSignedIntegerType(64),
): ):
self._name = name self._name = name
self._element_type = element_type
self._index_dtype = index_dtype
if len(shape) != len(strides):
raise ValueError("Shape and stride tuples must have the same length")
self._shape = tuple( self._shape: tuple[PsArrayShapeVar | PsTypedConstant, ...] = tuple(
PsArrayShapeVar(self, d, constify(index_dtype)) for d in range(dim) (
PsArrayShapeVar(self, i, index_dtype)
if s == Ellipsis
else PsTypedConstant(s, index_dtype)
) )
self._strides = tuple( for i, s in enumerate(shape)
PsArrayStrideVar(self, d, constify(index_dtype)) for d in range(dim)
) )
self._element_type = element_type self._strides: tuple[PsArrayStrideVar | PsTypedConstant, ...] = tuple(
self._dim = dim (
self._index_dtype = index_dtype PsArrayStrideVar(self, i, index_dtype)
if s == Ellipsis
else PsTypedConstant(s, index_dtype)
)
for i, s in enumerate(strides)
)
@property @property
def name(self): def name(self):
return self._name return self._name
@property @property
def shape(self): def shape(self) -> tuple[PsArrayShapeVar | PsTypedConstant, ...]:
return self._shape return self._shape
@property @property
def strides(self): def strides(self) -> tuple[PsArrayStrideVar | PsTypedConstant, ...]:
return self._strides return self._strides
@property
def dim(self):
return self._dim
@property @property
def element_type(self): def element_type(self):
return self._element_type return self._element_type
def _hashable_contents(self):
"""Contents by which to compare two instances of `PsLinearizedArray`.
Since equality checks on shape and stride variables internally check equality of their associated arrays,
if these variables would occur in here, an infinite recursion would follow.
Hence they are filtered and replaced by the ellipsis.
"""
shape_clean = tuple((s if isinstance(s, PsTypedConstant) else ...) for s in self._shape)
strides_clean = tuple((s if isinstance(s, PsTypedConstant) else ...) for s in self._strides)
return (self._name, self._element_type, self._index_dtype, shape_clean, strides_clean)
def __eq__(self, other: object) -> bool: def __eq__(self, other: object) -> bool:
if not isinstance(other, PsLinearizedArray): if not isinstance(other, PsLinearizedArray):
return False return False
return ( return self._hashable_contents() == other._hashable_contents()
self._name,
self._element_type,
self._dim,
self._index_dtype,
) == (
other._name,
other._element_type,
other._dim,
other._index_dtype,
)
def __hash__(self) -> int: def __hash__(self) -> int:
return hash( return hash(self._hashable_contents())
(
self._name,
self._element_type,
self._dim,
self._index_dtype,
)
)
class PsArrayAssocVar(PsTypedVariable, ABC): class PsArrayAssocVar(PsTypedVariable, ABC):
"""A variable that is associated to an array. """A variable that is associated to an array.
...@@ -166,6 +184,11 @@ class PsArrayBasePointer(PsArrayAssocVar): ...@@ -166,6 +184,11 @@ class PsArrayBasePointer(PsArrayAssocVar):
class PsArrayShapeVar(PsArrayAssocVar): class PsArrayShapeVar(PsArrayAssocVar):
"""Variable that represents an array's shape in one coordinate.
Do not instantiate this class yourself, but only use its instances
as provided by `PsLinearizedArray.shape`.
"""
init_arg_names: tuple[str, ...] = ("array", "coordinate", "dtype") init_arg_names: tuple[str, ...] = ("array", "coordinate", "dtype")
__match_args__ = ("array", "coordinate", "dtype") __match_args__ = ("array", "coordinate", "dtype")
...@@ -183,6 +206,11 @@ class PsArrayShapeVar(PsArrayAssocVar): ...@@ -183,6 +206,11 @@ class PsArrayShapeVar(PsArrayAssocVar):
class PsArrayStrideVar(PsArrayAssocVar): class PsArrayStrideVar(PsArrayAssocVar):
"""Variable that represents an array's stride in one coordinate.
Do not instantiate this class yourself, but only use its instances
as provided by `PsLinearizedArray.strides`.
"""
init_arg_names: tuple[str, ...] = ("array", "coordinate", "dtype") init_arg_names: tuple[str, ...] = ("array", "coordinate", "dtype")
__match_args__ = ("array", "coordinate", "dtype") __match_args__ = ("array", "coordinate", "dtype")
...@@ -217,45 +245,3 @@ class PsArrayAccess(pb.Subscript): ...@@ -217,45 +245,3 @@ class PsArrayAccess(pb.Subscript):
def dtype(self) -> PsAbstractType: def dtype(self) -> PsAbstractType:
"""Data type of this expression, i.e. the element type of the underlying array""" """Data type of this expression, i.e. the element type of the underlying array"""
return self._base_ptr.array.element_type return self._base_ptr.array.element_type
# class PsIterationDomain:
# """A factory for arrays spanning a given iteration domain."""
# def __init__(
# self,
# id: str,
# dim: int | None = None,
# fixed_shape: tuple[int, ...] | None = None,
# index_dtype: PsIntegerType = PsSignedIntegerType(64),
# ):
# if fixed_shape is not None:
# if dim is not None and len(fixed_shape) != dim:
# raise ValueError(
# "If both `dim` and `fixed_shape` are specified, `fixed_shape` must have exactly `dim` entries."
# )
# shape = tuple(PsTypedConstant(s, index_dtype) for s in fixed_shape)
# elif dim is not None:
# shape = tuple(
# PsTypedVariable(f"{id}_shape_{d}", index_dtype) for d in range(dim)
# )
# else:
# raise ValueError("Either `fixed_shape` or `dim` must be specified.")
# self._domain_shape: tuple[VarOrConstant, ...] = shape
# self._index_dtype = index_dtype
# self._archetype_array: PsLinearizedArray | None = None
# self._constraints: list[PsParamConstraint] = []
# @property
# def dim(self) -> int:
# return len(self._domain_shape)
# @property
# def shape(self) -> tuple[VarOrConstant, ...]:
# return self._domain_shape
# def create_array(self, ghost_layers: int = 0):
...@@ -6,7 +6,7 @@ from dataclasses import dataclass ...@@ -6,7 +6,7 @@ from dataclasses import dataclass
from pymbolic.mapper.dependency import DependencyMapper from pymbolic.mapper.dependency import DependencyMapper
from .nodes import PsAstNode, PsBlock, failing_cast from .nodes import PsAstNode, PsBlock, failing_cast
from .constraints import PsParamConstraint from ..constraints import PsKernelConstraint
from ..typed_expressions import PsTypedVariable from ..typed_expressions import PsTypedVariable
from ..arrays import PsLinearizedArray, PsArrayBasePointer, PsArrayAssocVar from ..arrays import PsLinearizedArray, PsArrayBasePointer, PsArrayAssocVar
from ..exceptions import PsInternalCompilerError from ..exceptions import PsInternalCompilerError
...@@ -26,7 +26,7 @@ class PsKernelParametersSpec: ...@@ -26,7 +26,7 @@ class PsKernelParametersSpec:
params: tuple[PsTypedVariable, ...] params: tuple[PsTypedVariable, ...]
arrays: tuple[PsLinearizedArray, ...] arrays: tuple[PsLinearizedArray, ...]
constraints: tuple[PsParamConstraint, ...] constraints: tuple[PsKernelConstraint, ...]
def params_for_array(self, arr: PsLinearizedArray): def params_for_array(self, arr: PsLinearizedArray):
def pred(p: PsTypedVariable): def pred(p: PsTypedVariable):
...@@ -71,7 +71,7 @@ class PsKernelFunction(PsAstNode): ...@@ -71,7 +71,7 @@ class PsKernelFunction(PsAstNode):
self._target = target self._target = target
self._name = name self._name = name
self._constraints: list[PsParamConstraint] = [] self._constraints: list[PsKernelConstraint] = []
@property @property
def target(self) -> Target: def target(self) -> Target:
...@@ -120,7 +120,7 @@ class PsKernelFunction(PsAstNode): ...@@ -120,7 +120,7 @@ class PsKernelFunction(PsAstNode):
raise IndexError(f"Child index out of bounds: {idx}") raise IndexError(f"Child index out of bounds: {idx}")
self._body = failing_cast(PsBlock, c) self._body = failing_cast(PsBlock, c)
def add_constraints(self, *constraints: PsParamConstraint): def add_constraints(self, *constraints: PsKernelConstraint):
self._constraints += constraints self._constraints += constraints
def get_parameters(self) -> PsKernelParametersSpec: def get_parameters(self) -> PsKernelParametersSpec:
......
...@@ -4,11 +4,11 @@ import pymbolic.primitives as pb ...@@ -4,11 +4,11 @@ import pymbolic.primitives as pb
from pymbolic.mapper.c_code import CCodeMapper from pymbolic.mapper.c_code import CCodeMapper
from pymbolic.mapper.dependency import DependencyMapper from pymbolic.mapper.dependency import DependencyMapper
from ..typed_expressions import PsTypedVariable from .typed_expressions import PsTypedVariable
@dataclass @dataclass
class PsParamConstraint: class PsKernelConstraint:
condition: pb.Comparison condition: pb.Comparison
message: str = "" message: str = ""
......
...@@ -11,7 +11,7 @@ import numpy as np ...@@ -11,7 +11,7 @@ import numpy as np
from ..exceptions import PsInternalCompilerError from ..exceptions import PsInternalCompilerError
from ..ast import PsKernelFunction from ..ast import PsKernelFunction
from ..ast.constraints import PsParamConstraint from ..constraints import PsKernelConstraint
from ..typed_expressions import PsTypedVariable from ..typed_expressions import PsTypedVariable
from ..arrays import ( from ..arrays import (
PsLinearizedArray, PsLinearizedArray,
...@@ -285,7 +285,7 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{ ...@@ -285,7 +285,7 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{
case _: case _:
assert False, "Invalid variable encountered." assert False, "Invalid variable encountered."
def check_constraint(self, constraint: PsParamConstraint): def check_constraint(self, constraint: PsKernelConstraint):
variables = constraint.get_variables() variables = constraint.get_variables()
for var in variables: for var in variables:
......
from ...field import Field
from ..arrays import PsLinearizedArray, PsArrayBasePointer
from ..types import PsIntegerType
from ..constraints import PsKernelConstraint
from .iteration_domain import PsIterationDomain
class PsTranslationContext: class PsTranslationContext:
"""The `PsTranslationContext` manages the translation process from the SymPy frontend """The `PsTranslationContext` manages the translation process from the SymPy frontend
...@@ -27,7 +33,6 @@ class PsTranslationContext: ...@@ -27,7 +33,6 @@ class PsTranslationContext:
Domain fields can only be accessed by relative offsets, and therefore must always Domain fields can only be accessed by relative offsets, and therefore must always
be associated with an *iteration domain* that provides a spatial index tuple. be associated with an *iteration domain* that provides a spatial index tuple.
All domain fields associated with the same domain must have the same spatial shape, modulo ghost layers. All domain fields associated with the same domain must have the same spatial shape, modulo ghost layers.
A field and its array may be associated with multiple iteration domains.
- `FieldType.INDEXED` are 1D arrays of index structures. They must be accessed by a single running index. - `FieldType.INDEXED` are 1D arrays of index structures. They must be accessed by a single running index.
If there is at least one indexed field present there must also exist an index source for that field If there is at least one indexed field present there must also exist an index source for that field
(loop or device indexing). (loop or device indexing).
...@@ -36,6 +41,21 @@ class PsTranslationContext: ...@@ -36,6 +41,21 @@ class PsTranslationContext:
Within a domain, a buffer may be either written to or read from, never both. Within a domain, a buffer may be either written to or read from, never both.
In the translator, frontend fields and backend arrays are managed together using the `PsFieldArrayPair` class.
"""
def __init__(self, index_dtype: PsIntegerType):
self._index_dtype = index_dtype
self._constraints: list[PsKernelConstraint] = []
@property
def index_dtype(self) -> PsIntegerType:
return self._index_dtype
def add_constraints(self, *constraints: PsKernelConstraint):
self._constraints += constraints
@property
def constraints(self) -> tuple[PsKernelConstraint, ...]:
return tuple(self._constraints)
"""
from dataclasses import dataclass
from ...field import Field
from ..arrays import PsLinearizedArray, PsArrayBasePointer
from ..types import PsIntegerType
from ..constraints import PsKernelConstraint
from .iteration_domain import PsIterationDomain
@dataclass
class PsFieldArrayPair:
field: Field
array: PsLinearizedArray
base_ptr: PsArrayBasePointer
@dataclass
class PsDomainFieldArrayPair(PsFieldArrayPair):
ghost_layers: int
interior_base_ptr: PsArrayBasePointer
domain: PsIterationDomain
from __future__ import annotations
from typing import TYPE_CHECKING, cast
from types import EllipsisType
from ...field import Field
from ...typing import TypedSymbol, BasicType
from ..arrays import PsLinearizedArray, PsArrayBasePointer
from ..types.quick import make_type
from ..typed_expressions import PsTypedVariable, PsTypedConstant, VarOrConstant
from .field_array_pair import PsDomainFieldArrayPair
if TYPE_CHECKING:
from .context import PsTranslationContext
class PsIterationDomain:
"""Represents the n-dimensonal spatial iteration domain of a pystencils kernel.
Domain Shape
------------
A domain may have either constant or variable, n-dimensional shape, where n = 1, 2, 3.
If the shape is variable, the domain object manages variables for each shape entry.
The domain provides index variables for each dimension which may be used to access fields
associated with the domain.
In the kernel, these index variables must be provided by some index source.
Index sources differ between two major types of domains: full and sparse domains.
In a full domain, it is guaranteed that each interior point is processed by the kernel.
The index source may therefore be a full n-fold loop nest, or a device index calculation.
In a sparse domain, the iteration is controlled by an index vector, which acts as the index
source.
Arrays
------
Any number of domain arrays may be associated with each domain.
Each array is annotated with a number of ghost layers for each spatial coordinate.
### Shape Compatibility
When an array is associated with a domain, it must be ensured that the array's shape
is compatible with the domain.
The first n shape entries are considered the array's spatial shape.
These spatial shapes, after subtracting ghost layers, must all be equal, and are further
constrained by a constant domain shape.
For each spatial coordinate, shape compatibility is ensured as described by the following table.
| | Constant Array Shape | Variable Array Shape |
|---------------------------|-----------------------------|------------------------|
| **Constant Domain Shape** | Compile-Time Equality Check | Kernel Constraints |
| **Variable Domain Shape** | Invalid, Compiler Error | Kernel Constraints |
### Base Pointers and Array Accesses
In the kernel's public interface, each array is represented at least through its base pointer,
which represents the starting address of the array's data in memory.
Since the iteration domain models arrays as being surrounded by ghost layers, it provides for each
array a second, *interior* base pointer, which points to the first interior point after skipping the
ghost layers, e.g. in three dimensions with one index dimension:
```
addr(interior_base_ptr[0, 0, 0, 0]) == addr(base_ptr[gls, gls, gls, 0])
```
To access domain arrays using the domain's index variables, the interior base pointer should be used,
since the domain index variables always count up from zero.
"""
def __init__(self, ctx: PsTranslationContext, shape: tuple[int | EllipsisType, ...]):
self._ctx = ctx
if len(shape) == 0:
raise ValueError("Domain shape must be at least one-dimensional.")
if len(shape) > 3:
raise ValueError("Iteration domain can be at most three-dimensional.")
self._shape: tuple[VarOrConstant, ...] = tuple(
(
PsTypedVariable(f"domain_size_{i}", self._ctx.index_dtype)
if s == Ellipsis
else PsTypedConstant(s, self._ctx.index_dtype)
)
for i, s in enumerate(shape)
)
self._archetype_field: PsDomainFieldArrayPair | None = None
self._fields: dict[str, PsDomainFieldArrayPair] = dict()
@property
def shape(self) -> tuple[VarOrConstant, ...]:
return self._shape
def add_field(self, field: Field, ghost_layers: int) -> PsDomainFieldArrayPair:
arr_shape = tuple(
(Ellipsis if isinstance(s, TypedSymbol) else s) # TODO: Field should also use ellipsis
for s in field.shape
)
arr_strides = tuple(
(Ellipsis if isinstance(s, TypedSymbol) else s) # TODO: Field should also use ellipsis
for s in field.strides
)
# TODO: frontend should use new type system
element_type = make_type(cast(BasicType, field.dtype).numpy_dtype.type)
arr = PsLinearizedArray(field.name, element_type, arr_shape, arr_strides, self._ctx.index_dtype)
fa_pair = PsDomainFieldArrayPair(
field=field,
array=arr,
base_ptr=PsArrayBasePointer("arr_data", arr),
ghost_layers=ghost_layers,
interior_base_ptr=PsArrayBasePointer("arr_interior_data", arr),
domain=self
)
# Check shape compatibility
# TODO
for domain_s, field_s in zip(self.shape, field.shape):
if isinstance(domain_s, PsTypedConstant):
pass
raise NotImplementedError()
...@@ -206,8 +206,7 @@ class PsTypedConstant: ...@@ -206,8 +206,7 @@ class PsTypedConstant:
return PsTypedConstant(rem, self._dtype) return PsTypedConstant(rem, self._dtype)
def __neg__(self): def __neg__(self):
minus_one = PsTypedConstant(-1, self._dtype) return PsTypedConstant(- self._value, self._dtype)
return pb.Product((minus_one, self))
def __bool__(self): def __bool__(self):
return bool(self._value) return bool(self._value)
......
...@@ -10,7 +10,7 @@ from pystencils.nbackend.emission import CPrinter ...@@ -10,7 +10,7 @@ from pystencils.nbackend.emission import CPrinter
def test_basic_kernel(): def test_basic_kernel():
u_arr = PsLinearizedArray("u", Fp(64), 1) u_arr = PsLinearizedArray("u", Fp(64), (..., ), (1, ))
u_size = u_arr.shape[0] u_size = u_arr.shape[0]
u_base = PsArrayBasePointer("u_data", u_arr) u_base = PsArrayBasePointer("u_data", u_arr)
...@@ -40,5 +40,5 @@ def test_basic_kernel(): ...@@ -40,5 +40,5 @@ def test_basic_kernel():
assert code.find("(" + params_str + ")") >= 0 assert code.find("(" + params_str + ")") >= 0
assert code.find("u_data[ctr] = u_data[ctr - 1] + u_data[ctr + 1];") >= 0 assert code.find("u_data[ctr] = u_data[ctr + 1] + u_data[ctr + -1];") >= 0
...@@ -3,7 +3,7 @@ import pytest ...@@ -3,7 +3,7 @@ import pytest
from pystencils import Target from pystencils import Target
from pystencils.nbackend.ast import * from pystencils.nbackend.ast import *
from pystencils.nbackend.ast.constraints import PsParamConstraint from pystencils.nbackend.constraints import PsKernelConstraint
from pystencils.nbackend.typed_expressions import * from pystencils.nbackend.typed_expressions import *
from pystencils.nbackend.arrays import PsLinearizedArray, PsArrayBasePointer, PsArrayAccess from pystencils.nbackend.arrays import PsLinearizedArray, PsArrayBasePointer, PsArrayAccess
from pystencils.nbackend.types.quick import * from pystencils.nbackend.types.quick import *
...@@ -15,8 +15,8 @@ from pystencils.cpu.cpujit import compile_and_load ...@@ -15,8 +15,8 @@ from pystencils.cpu.cpujit import compile_and_load
def test_pairwise_addition(): def test_pairwise_addition():
idx_type = SInt(64) idx_type = SInt(64)
u = PsLinearizedArray("u", Fp(64, const=True), 2, index_dtype=idx_type) u = PsLinearizedArray("u", Fp(64, const=True), (..., ...), (..., ...), index_dtype=idx_type)
v = PsLinearizedArray("v", Fp(64), 2, index_dtype=idx_type) v = PsLinearizedArray("v", Fp(64), (..., ...), (..., ...), index_dtype=idx_type)
u_data = PsArrayBasePointer("u_data", u) u_data = PsArrayBasePointer("u_data", u)
v_data = PsArrayBasePointer("v_data", v) v_data = PsArrayBasePointer("v_data", v)
...@@ -42,7 +42,7 @@ def test_pairwise_addition(): ...@@ -42,7 +42,7 @@ def test_pairwise_addition():
func = PsKernelFunction(PsBlock([loop]), target=Target.CPU) func = PsKernelFunction(PsBlock([loop]), target=Target.CPU)
sizes_constraint = PsParamConstraint( sizes_constraint = PsKernelConstraint(
u.shape[0].eq(2 * v.shape[0]), u.shape[0].eq(2 * v.shape[0]),
"Array `u` must have twice the length of array `v`" "Array `u` must have twice the length of array `v`"
) )
......
...@@ -8,15 +8,18 @@ def test_variable_equality(): ...@@ -8,15 +8,18 @@ def test_variable_equality():
var2 = PsTypedVariable("x", Fp(32)) var2 = PsTypedVariable("x", Fp(32))
assert var1 == var2 assert var1 == var2
arr = PsLinearizedArray("arr", Fp(64), 3) shape = (..., ..., ...)
strides = (..., ..., ...)
arr = PsLinearizedArray("arr", Fp(64), shape, strides)
bp1 = PsArrayBasePointer("arr_data", arr) bp1 = PsArrayBasePointer("arr_data", arr)
bp2 = PsArrayBasePointer("arr_data", arr) bp2 = PsArrayBasePointer("arr_data", arr)
assert bp1 == bp2 assert bp1 == bp2
arr1 = PsLinearizedArray("arr", Fp(64), 3) arr1 = PsLinearizedArray("arr", Fp(64), shape, strides)
bp1 = PsArrayBasePointer("arr_data", arr1) bp1 = PsArrayBasePointer("arr_data", arr1)
arr2 = PsLinearizedArray("arr", Fp(64), 3) arr2 = PsLinearizedArray("arr", Fp(64), shape, strides)
bp2 = PsArrayBasePointer("arr_data", arr2) bp2 = PsArrayBasePointer("arr_data", arr2)
assert bp1 == bp2 assert bp1 == bp2
...@@ -28,6 +31,9 @@ def test_variable_equality(): ...@@ -28,6 +31,9 @@ def test_variable_equality():
def test_variable_inequality(): def test_variable_inequality():
shape = (..., ..., ...)
strides = (..., ..., ...)
var1 = PsTypedVariable("x", Fp(32)) var1 = PsTypedVariable("x", Fp(32))
var2 = PsTypedVariable("x", Fp(64)) var2 = PsTypedVariable("x", Fp(64))
assert var1 != var2 assert var1 != var2
...@@ -37,10 +43,10 @@ def test_variable_inequality(): ...@@ -37,10 +43,10 @@ def test_variable_inequality():
assert var1 != var2 assert var1 != var2
# Arrays # Arrays
arr1 = PsLinearizedArray("arr", Fp(64), 3) arr1 = PsLinearizedArray("arr", Fp(64), shape, strides)
bp1 = PsArrayBasePointer("arr_data", arr1) bp1 = PsArrayBasePointer("arr_data", arr1)
arr2 = PsLinearizedArray("arr", Fp(32), 3) arr2 = PsLinearizedArray("arr", Fp(32), shape, strides)
bp2 = PsArrayBasePointer("arr_data", arr2) bp2 = PsArrayBasePointer("arr_data", arr2)
assert bp1 != bp2 assert bp1 != bp2
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment